jaxns

jaxns

Nested sampling with JAX.

Subpackages

Submodules

Package Contents

PriorModelGen[source]
PriorModelType[source]
jaxify_likelihood(log_likelihood, vectorised=False)[source]

Wraps a non-JAX log likelihood function.

Parameters:
  • log_likelihood (Callable[Ellipsis, numpy.ndarray]) – a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood.

  • vectorised (bool) – if True then the log_likelihood must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle.

Returns:

A JAX-compatible log-likelihood function.

Return type:

jaxns.internals.types.LikelihoodType

class Model(prior_model, log_likelihood, params=None)[source]

Bases: jaxns.framework.bases.BaseAbstractModel

Represents a Bayesian model in terms of a generative prior, and likelihood function.

Parameters:
  • prior_model (jaxns.framework.bases.PriorModelType)

  • log_likelihood (jaxns.internals.types.LikelihoodType)

  • params (Optional[jaxns.framework.context.MutableParams])

__repr__()[source]
property num_params: int
Return type:

int

property params
set_params(params)[source]

Create a new parametrised model with the given parameters.

Parameters:

params (jaxns.framework.context.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

__call__(params)[source]

Create a new parametrised model with the given parameters.

This is (and must be) a pure function.

Parameters:

params (jaxns.framework.context.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

init_params(rng)[source]

Initialise the parameters of the model.

Parameters:

rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.

Returns:

The initialised parameters.

Return type:

jaxns.framework.context.MutableParams

__hash__()[source]
sample_U(key)[source]

Sample from the prior model.

Parameters:

key (jaxns.internals.types.PRNGKey) – PRNGKey to use.

Returns:

The sampled U.

Return type:

jaxns.internals.types.UType

sample_W(key)[source]

Sample from the prior model.

Parameters:

key (jaxns.internals.types.PRNGKey) – PRNGKey to use.

Returns:

The sampled W.

Return type:

jaxns.internals.types.WType

transform(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.XType

transform_parametrised(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.XType

forward(U, allow_nan=False)[source]
Parameters:
  • U (jaxns.internals.types.UType)

  • allow_nan (bool)

Return type:

jaxns.internals.types.FloatArray

log_prob_prior(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.FloatArray

prepare_input(U)[source]
Parameters:

U (jaxns.internals.types.UType)

Return type:

jaxns.internals.types.LikelihoodInputType

sanity_check(key, S)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • S (int)

class Prior(dist_or_value, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Represents a generative prior.

Parameters:
  • dist_or_value (Union[tfpd, jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray, jaxns.internals.types.BoolArray])

  • name (Optional[str])

name = None
property dist: jaxns.framework.bases.BaseAbstractDistribution
Return type:

jaxns.framework.bases.BaseAbstractDistribution

property value: jax.Array
Return type:

jax.Array

parametrised(random_init=False)[source]

Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a get_parameter with added _param name suffix. Prior must have a name.

Parameters:

random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.

Returns:

A singular prior.

Raises:

ValueError – if the prior has no name.

Return type:

SingularPrior

exception InvalidPriorName(name=None)[source]

Bases: Exception

Raised when a prior name is already taken.

Initialize self. See help(type(self)) for accurate signature.

Parameters:

name (Optional[str])

class Bernoulli(*, logits=None, probs=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

dist
class Beta(*, concentration0=None, concentration1=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Initialised Categorical special prior.

Parameters:
  • parametrisation (Literal['gumbel_max', 'cdf']) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.

  • logits – log-prob of each category

  • probs – prob of each category

  • name (Optional[str]) – optional name

dist
class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]

Bases: SpecialPrior

Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.

Parameters:
  • n (int) – number of samples within [low,high]

  • low – minimum of distribution

  • high – maximum of distribution

  • fix_left (bool) – if True, the leftmost value is fixed to low

  • fix_right (bool) – if True, the rightmost value is fixed to high

  • name (Optional[str])

n
low
high
fix_left = False
fix_right = False
class Poisson(*, rate=None, log_rate=None, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:

name (Optional[str])

dist
class UnnormalisedDirichlet(*, concentration, name=None)[source]

Bases: SpecialPrior

Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.

X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)

Parameters:

name (Optional[str])

class Empirical(*, samples, support_min=None, support_max=None, resolution=100, name=None)[source]

Bases: SpecialPrior

Represents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.

Parameters:
  • samples (jax.Array)

  • support_min (Optional[jaxns.internals.types.FloatArray])

  • support_max (Optional[jaxns.internals.types.FloatArray])

  • resolution (int)

  • name (Optional[str])

class TruncationWrapper(prior, low, high, name=None)[source]

Bases: SpecialPrior

Wraps another prior to make it truncated.

For truncated distribution the quantile transforms to:

Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low))

And the CDF transforms to:

F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low))

Parameters:
prior
low
high
cdf_low
cdf_diff
class ExplicitDensityPrior(*, axes, density, regular_grid=False, name=None)[source]

Bases: SpecialPrior

The base prior class with public methods.

Parameters:
plot_diagnostics(results, save_name=None)[source]

Plot diagnostics of the nested sampling run.

Parameters:
plot_cornerplot(results, variables=None, with_parametrised=False, save_name=None, kde_overlay=False)[source]

Plots a cornerplot of the posterior samples.

Parameters:
  • results (jaxns.nested_samplers.common.types.NestedSamplerResults) – NestedSamplerResult

  • variables (Optional[List[str]]) – list of variable names to plot. Plots all collected samples by default.

  • save_name (Optional[str]) – file to save result to.

  • kde_overlay (bool) – whether to overlay a KDE on the histograms.

  • with_parametrised (bool)

class NestedSampler[source]

A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.

s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330

Parameters:
  • model – a model to perform nested sampling on

  • max_samples – maximum number of samples to take

  • num_live_points – approximate number of live points to use. Defaults is c * (k + 1).

  • s – number of slices to use per dimension. Defaults to 4.

  • k – number of phantom samples to use. Defaults to 0.

  • c – number of parallel Markov-chains to use. Defaults to 20 * D.

  • devices – devices to use. Defaults to all available devices.

  • difficult_model – if True, uses more robust default settings. Defaults to False.

  • parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.

  • shell_fraction – fraction of the shell to use for the slice sampler. Defaults to 0.5.

  • gradient_guided – if True, uses gradient guided sampling. Defaults to False.

  • init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.

  • verbose – whether to log progress.

model: jaxns.framework.bases.BaseAbstractModel
max_samples: int | float | None = None
num_live_points: int | None = None
num_slices: int | None = None
s: int | float | None = None
k: int | None = None
c: int | None = None
devices: List[jaxlib.xla_client.Device] | None = None
difficult_model: bool = False
parameter_estimation: bool = False
shell_fraction: float = 0.5
gradient_guided: bool = False
init_efficiency_threshold: float = 0.1
verbose: bool = False
__post_init__()[source]
property nested_sampler: jaxns.nested_samplers.abc.AbstractNestedSampler
Return type:

jaxns.nested_samplers.abc.AbstractNestedSampler

__call__(key, term_cond=None)[source]

Performs nested sampling with the given termination conditions.

Parameters:
Returns:

termination reason, state

Return type:

Tuple[jaxns.internals.types.IntArray, jaxns.nested_samplers.common.types.NestedSamplerState]

to_results(termination_reason, state, trim=True)[source]

Convert the state to results.

Note: Requires static context.

Parameters:
Returns:

results

Return type:

jaxns.nested_samplers.common.types.NestedSamplerResults

static trim_results(results)[source]

Trims the results to the number of samples taken. Requires static context.

Parameters:

results (jaxns.nested_samplers.common.types.NestedSamplerResults) – results to trim

Returns:

trimmed results

Return type:

jaxns.nested_samplers.common.types.NestedSamplerResults

resample(key, samples, log_weights, S=None, replace=True)[source]

Resample the weighted samples into uniformly weighted samples.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • samples (Union[jaxns.internals.types.XType, jaxns.internals.types.UType]) – samples from nested sampled results

  • log_weights (jax.Array) – log-posterior weight

  • S (int) – number of samples to generate. Will use Kish’s estimate of ESS if None.

  • replace (bool) – whether to sample with replacement

Returns:

equally weighted samples

Return type:

jaxns.internals.types.XType

marginalise_static_from_U(key, U_samples, model, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • U_samples (jaxns.internals.types.UType) – array of U samples

  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • log_weights (jax.Array) – log weights from nested sampling

  • ESS (int) – static effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation over resampled samples

Return type:

_V

marginalise_dynamic_from_U(key, U_samples, model, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS can be dynamic.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • U_samples (jaxns.internals.types.UType) – array of U samples

  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • log_weights (jax.Array) – log weights from nested sampling

  • ESS (jax.Array) – dynamic effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation of func over resampled samples.

Return type:

_V

marginalise_static(key, samples, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • samples (dict) – dict of batched array of nested sampling samples

  • log_weights (jax.Array) – log weights from nested sampling

  • ESS (int) – static effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation over resampled samples

Return type:

_V

marginalise_dynamic(key, samples, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS can be dynamic.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • samples (dict) – dict of batched array of nested sampling samples

  • log_weights (jax.Array) – log weights from nested sampling

  • ESS (jax.Array) – dynamic effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation of func over resampled samples.

Return type:

_V

maximum_a_posteriori_point(results)[source]

Get the MAP point of a nested sampling result. Does this by choosing the point with largest L(x) p(x).

Parameters:

results (NestedSamplerResult) – Nested sampler result

Returns:

dict of samples at MAP-point.

Return type:

jaxns.internals.types.XType

evaluate_map_estimate(results, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
Returns:

estimate at MAP sample point

Return type:

_V

summary(results, with_parametrised=False, f_obj=None)[source]

Gives a summary of the results of a nested sampling run.

Parameters:
  • results (NestedSamplerResults) – Nested sampler result

  • with_parametrised (bool) – whether to include parametrised samples

  • f_obj (Optional[Union[str, TextIO]]) – file-like object to write summary to. If None, prints to stdout.

analytic_posterior_samples(model, S=60)[source]

Compute the evidence with brute-force over a regular grid.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • S (int) – resolution of grid

Returns:

log(Z)

sample_evidence(key, num_live_points_per_sample, log_L_samples, S=100)[source]

Sample the evidence distribution, but stochastically simulating the shrinkage distribution.

Note: this produces approximate samples, since there is also an uncertainty in the placement of the contours during shrinkage. Incorporating this stochasticity into the simulation would require running an entire nested sampling many times.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • num_live_points_per_sample (jaxns.internals.types.IntArray) – the number of live points for each sample

  • log_L_samples (jaxns.internals.types.FloatArray) – the log-L of samples

  • S (int) – The number of samples to produce

Returns:

samples of log(Z)

Return type:

jaxns.internals.types.FloatArray

bruteforce_posterior_samples(model, S=60)[source]

Compute the posterior with brute-force over a regular grid.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • S (int) – resolution of grid

Returns:

samples, and log-weight

Return type:

Tuple[jaxns.internals.types.XType, jax.Array]

bruteforce_evidence(model, S=60)[source]

Compute the evidence with brute-force over a regular grid.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • S (int) – resolution of grid

Returns:

log(Z)

save_pytree(pytree, save_file)[source]

Saves results of nested sampler in a json file.

Parameters:
  • pytree (NamedTuple) – Nested sampler result

  • save_file (str) – filename

save_results(results, save_file)[source]

Saves results of nested sampler in a json file.

Parameters:
load_pytree(save_file)[source]

Loads saved nested sampler results from a json file.

Parameters:

save_file (str) – filename

Returns:

NestedSamplerResults

load_results(save_file)[source]

Loads saved nested sampler results from a json file.

Parameters:

save_file (str) – filename

Returns:

NestedSamplerResults

Return type:

jaxns.nested_samplers.common.types.NestedSamplerResults

class ShardedStaticNestedSampler[source]

Bases: jaxns.nested_samplers.abc.AbstractNestedSampler

A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met.

Parameters:
  • init_efficiency_threshold – the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off.

  • sampler – the sampler to use after the initial uniform sampling.

  • num_live_points – the number of live points to use.

  • model – the model to use.

  • max_samples – the maximum number of samples to take.

  • devices – the devices to use, default is 1.

  • verbose – whether to log as we go.

model: jaxns.framework.bases.BaseAbstractModel
max_samples: int
init_efficiency_threshold: float
sampler: jaxns.samplers.abc.AbstractSampler
num_live_points: int
shell_fraction: float | None = None
num_dynamic_refinement_iterations: int = 0
refine_threshold: float = 0.01
devices: List[jaxlib.xla_client.Device] | None = None
verbose: bool = False
__post_init__()[source]
class TerminationCondition[source]

Bases: NamedTuple

Contains the termination conditions for the nested sampling run.

Parameters:
  • ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.

  • evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.

  • live_evidence_frac – Depreceated use dlogZ.

  • dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)

  • max_samples – Terminate if the number of samples exceeds this.

  • max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.

  • log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.

  • efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.

  • rtol – finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol

  • atol – finish when the absolute |log_L_max - log_L_min| < atol

ess: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
evidence_uncert: jaxns.internals.types.FloatArray | None = None
live_evidence_frac: jaxns.internals.types.FloatArray | None = None
dlogZ: jaxns.internals.types.FloatArray | None = None
max_samples: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
max_num_likelihood_evaluations: jaxns.internals.types.FloatArray | jaxns.internals.types.IntArray | None = None
log_L_contour: jaxns.internals.types.FloatArray | None = None
efficiency_threshold: jaxns.internals.types.FloatArray | None = None
rtol: jaxns.internals.types.FloatArray | None = None
atol: jaxns.internals.types.FloatArray | None = None
peak_XL_frac: jaxns.internals.types.FloatArray | None = None
__and__(other)[source]
__or__(other)[source]
class NestedSamplerResults[source]

Bases: NamedTuple

Results of the nested sampling run.

log_Z_mean: jaxns.internals.types.FloatArray
log_Z_uncert: jaxns.internals.types.FloatArray
ESS: jaxns.internals.types.FloatArray
H_mean: jaxns.internals.types.FloatArray
samples: jaxns.internals.types.XType
parametrised_samples: jaxns.internals.types.XType
U_samples: jaxns.internals.types.UType
log_L_samples: jaxns.internals.types.FloatArray
log_dp_mean: jaxns.internals.types.FloatArray
log_X_mean: jaxns.internals.types.FloatArray
log_posterior_density: jaxns.internals.types.FloatArray
num_live_points_per_sample: jaxns.internals.types.IntArray
num_likelihood_evaluations_per_sample: jaxns.internals.types.IntArray
total_num_samples: jaxns.internals.types.IntArray
total_phantom_samples: jaxns.internals.types.IntArray
total_num_likelihood_evaluations: jaxns.internals.types.IntArray
log_efficiency: jaxns.internals.types.FloatArray
termination_reason: jaxns.internals.types.IntArray
class NestedSamplerState[source]

Bases: NamedTuple

key: jaxns.internals.types.PRNGKey
next_sample_idx: jaxns.internals.types.IntArray
num_samples: jaxns.internals.types.IntArray
sample_collection: StaticStandardSampleCollection