jaxns

jaxns

Nested sampling with JAX.

Subpackages

Submodules

Package Contents

logger[source]
plot_diagnostics(results, save_name=None)[source]

Plot diagnostics of the nested sampling run.

Parameters:
plot_cornerplot(results, variables=None, with_parametrised=False, save_name=None, kde_overlay=False)[source]

Plots a cornerplot of the posterior samples.

Parameters:
  • results (jaxns.internals.types.NestedSamplerResults) – NestedSamplerResult

  • variables (Optional[List[str]]) – list of variable names to plot. Plots all collected samples by default.

  • save_name (Optional[str]) – file to save result to.

  • kde_overlay (bool) – whether to overlay a KDE on the histograms.

  • with_parametrised (bool) –

class DefaultNestedSampler(model, max_samples, num_live_points=None, s=None, k=None, c=None, num_parallel_workers=1, difficult_model=False, parameter_estimation=False, init_efficiency_threshold=0.1, verbose=False)[source]

A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.

Initialises the nested sampler.

s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – a model to perform nested sampling on

  • max_samples (Union[int, float]) – maximum number of samples to take

  • num_live_points (Optional[int]) – approximate number of live points to use. Defaults is c * (k + 1).

  • s (Optional[int]) – number of slices to use per dimension. Defaults to 4.

  • k (Optional[int]) – number of phantom samples to use. Defaults to 0.

  • c (Optional[int]) – number of parallel Markov-chains to use. Defaults to 20 * D.

  • num_parallel_workers (int) – number of parallel workers to use. Defaults to 1. Experimental feature.

  • difficult_model (bool) – if True, uses more robust default settings. Defaults to False.

  • parameter_estimation (bool) – if True, uses more robust default settings for parameter estimation. Defaults to False.

  • init_efficiency_threshold (float) – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.

  • verbose (bool) – whether to use JAX printing

property num_live_points: int
Return type:

int

property nested_sampler: jaxns.nested_sampler.bases.BaseAbstractNestedSampler
Return type:

jaxns.nested_sampler.bases.BaseAbstractNestedSampler

__repr__()[source]

Return repr(self).

__call__(key, term_cond=None)[source]

Performs nested sampling with the given termination conditions.

Parameters:
Returns:

termination reason, state

Return type:

Tuple[jaxns.internals.types.IntArray, jaxns.internals.types.StaticStandardNestedSamplerState]

to_results(termination_reason, state, trim=True)[source]

Convert the state to results.

Note: Requires static context.

Parameters:
Returns:

results

Return type:

jaxns.internals.types.NestedSamplerResults

static trim_results(results)[source]

Trims the results to the number of samples taken. Requires static context.

Parameters:

results (jaxns.internals.types.NestedSamplerResults) – results to trim

Returns:

trimmed results

Return type:

jaxns.internals.types.NestedSamplerResults

class ApproximateNestedSampler(*args, **kwargs)[source]

Bases: DefaultNestedSampler

A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.

Initialises the nested sampler.

s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330

Parameters:
  • model – a model to perform nested sampling on

  • max_samples – maximum number of samples to take

  • num_live_points – approximate number of live points to use. Defaults is c * (k + 1).

  • s – number of slices to use per dimension. Defaults to 4.

  • k – number of phantom samples to use. Defaults to 0.

  • c – number of parallel Markov-chains to use. Defaults to 20 * D.

  • num_parallel_workers – number of parallel workers to use. Defaults to 1. Experimental feature.

  • difficult_model – if True, uses more robust default settings. Defaults to False.

  • parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.

  • init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.

  • verbose – whether to use JAX printing

class ExactNestedSampler(*args, **kwargs)[source]

Bases: ApproximateNestedSampler

A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.

Initialises the nested sampler.

s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330

Parameters:
  • model – a model to perform nested sampling on

  • max_samples – maximum number of samples to take

  • num_live_points – approximate number of live points to use. Defaults is c * (k + 1).

  • s – number of slices to use per dimension. Defaults to 4.

  • k – number of phantom samples to use. Defaults to 0.

  • c – number of parallel Markov-chains to use. Defaults to 20 * D.

  • num_parallel_workers – number of parallel workers to use. Defaults to 1. Experimental feature.

  • difficult_model – if True, uses more robust default settings. Defaults to False.

  • parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.

  • init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.

  • verbose – whether to use JAX printing

class TerminationCondition[source]

Bases: NamedTuple

Contains the termination conditions for the nested sampling run.

Parameters:
  • ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.

  • evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.

  • live_evidence_frac – Depreceated use dlogZ.

  • dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)

  • max_samples – Terminate if the number of samples exceeds this.

  • max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.

  • log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.

  • efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.

ess: FloatArray | None
evidence_uncert: FloatArray | None
live_evidence_frac: FloatArray | None
dlogZ: FloatArray | None
max_samples: IntArray | None
max_num_likelihood_evaluations: IntArray | None
log_L_contour: FloatArray | None
efficiency_threshold: FloatArray | None
__and__(other)[source]
__or__(other)[source]
resample(key, samples, log_weights, S=None, replace=False)[source]

Resample the weighted samples into uniformly weighted samples.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • samples (Union[jaxns.internals.types.XType, jaxns.internals.types.UType]) – samples from nested sampled results

  • log_weights (jax.numpy.ndarray) – log-posterior weight

  • S (int) – number of samples to generate. Will use Kish’s estimate of ESS if None.

  • replace (bool) – whether to sample with replacement

Returns:

equally weighted samples

Return type:

jaxns.internals.types.XType

marginalise_static_from_U(key, U_samples, model, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • U_samples (jaxns.internals.types.UType) – array of U samples

  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (int) – static effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation over resampled samples

Return type:

_V

marginalise_dynamic_from_U(key, U_samples, model, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS can be dynamic.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • U_samples (jaxns.internals.types.UType) – array of U samples

  • model (jaxns.framework.bases.BaseAbstractModel) – model

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (jax.numpy.ndarray) – dynamic effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation of func over resampled samples.

Return type:

_V

marginalise_static(key, samples, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • samples (dict) – dict of batched array of nested sampling samples

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (int) – static effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation over resampled samples

Return type:

_V

marginalise_dynamic(key, samples, log_weights, ESS, fun)[source]

Marginalises function over posterior samples, where ESS can be dynamic.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNG key

  • samples (dict) – dict of batched array of nested sampling samples

  • log_weights (jax.numpy.ndarray) – log weights from nested sampling

  • ESS (jax.numpy.ndarray) – dynamic effective sample size

  • fun (callable(**kwargs)) – function to marginalise

Returns:

expectation of func over resampled samples.

Return type:

_V

maximum_a_posteriori_point(results)[source]

Get the MAP point of a nested sampling result. Does this by choosing the point with largest L(x) p(x).

Parameters:

results (NestedSamplerResult) – Nested sampler result

Returns:

dict of samples at MAP-point.

Return type:

jaxns.internals.types.XType

evaluate_map_estimate(results, fun)[source]

Marginalises function over posterior samples, where ESS is static.

Parameters:
Returns:

estimate at MAP sample point

Return type:

_V

summary(results, with_parametrised=False, f_obj=None)[source]

Gives a summary of the results of a nested sampling run.

Parameters:
  • results (NestedSamplerResults) – Nested sampler result

  • with_parametrised (bool) – whether to include parametrised samples

  • f_obj (Optional[Union[str, TextIO]]) – file-like object to write summary to. If None, prints to stdout.

analytic_posterior_samples(model, S=60)[source]

Compute the evidence with brute-force over a regular grid.

Parameters:
Returns:

log(Z)

sample_evidence(key, num_live_points_per_sample, log_L_samples, S=100)[source]

Sample the evidence distribution, but stochastically simulating the shrinkage distribution.

Note: this produces approximate samples, since there is also an uncertainty in the placement of the contours during shrinkage. Incorporating this stochasticity into the simulation would require running an entire nested sampling many times.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • num_live_points_per_sample (jaxns.internals.types.IntArray) – the number of live points for each sample

  • log_L_samples (jaxns.internals.types.FloatArray) – the log-L of samples

  • S (int) – The number of samples to produce

Returns:

samples of log(Z)

Return type:

jaxns.internals.types.FloatArray

bruteforce_posterior_samples(model, S=60)[source]

Compute the posterior with brute-force over a regular grid.

Parameters:
Returns:

samples, and log-weight

Return type:

Tuple[jaxns.internals.types.XType, jax.numpy.ndarray]

bruteforce_evidence(model, S=60)[source]

Compute the evidence with brute-force over a regular grid.

Parameters:
Returns:

log(Z)

save_pytree(pytree, save_file)[source]

Saves results of nested sampler in a npz file.

Parameters:
  • pytree (NamedTuple) – Nested sampler result

  • save_file (str) – filename

save_results(results, save_file)[source]

Saves results of nested sampler in a npz file.

Parameters:
load_pytree(save_file)[source]

Loads saved nested sampler results from a npz file.

Parameters:

save_file (str) – filename

Returns:

NestedSamplerResults

load_results(save_file)[source]

Loads saved nested sampler results from a npz file.

Parameters:

save_file (str) – filename

Returns:

NestedSamplerResults

Return type:

jaxns.internals.types.NestedSamplerResults

PriorModelGen[source]
PriorModelType[source]
class Model(prior_model, log_likelihood, params=None)[source]

Bases: jaxns.framework.bases.BaseAbstractModel

Represents a Bayesian model in terms of a generative prior, and likelihood function.

Parameters:
  • prior_model (jaxns.framework.bases.PriorModelType) –

  • log_likelihood (jaxns.internals.types.LikelihoodType) –

  • params (Optional[haiku.MutableParams]) –

property num_params: int
Return type:

int

property params
set_params(params)[source]

Create a new parametrised model with the given parameters.

Parameters:

params (haiku.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

__call__(params)[source]

Create a new parametrised model with the given parameters.

This is (and must be) a pure function.

Parameters:

params (haiku.MutableParams) – The parameters to use.

Returns:

A model with set parameters.

Return type:

Model

init_params(rng)[source]

Initialise the parameters of the model.

Parameters:

rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.

Returns:

The initialised parameters.

Return type:

haiku.MutableParams

__hash__()[source]
__repr__()[source]
sample_U(key)[source]
Parameters:

key (jaxns.internals.types.PRNGKey) –

Return type:

jaxns.internals.types.FloatArray

transform(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.XType

transform_parametrised(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.XType

forward(U, allow_nan=False)[source]
Parameters:
  • U (jaxns.internals.types.UType) –

  • allow_nan (bool) –

Return type:

jaxns.internals.types.FloatArray

log_prob_prior(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.FloatArray

prepare_input(U)[source]
Parameters:

U (jaxns.internals.types.UType) –

Return type:

jaxns.internals.types.LikelihoodInputType

sanity_check(key, S)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • S (int) –

class Prior(dist_or_value, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Represents a generative prior.

Parameters:
property dist: jaxns.framework.bases.BaseAbstractDistribution
Return type:

jaxns.framework.bases.BaseAbstractDistribution

property value: jax.numpy.ndarray
Return type:

jax.numpy.ndarray

parametrised(random_init=False)[source]

Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a hk.Parameter with added _param name suffix.

Parameters:

random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.

Returns:

A singular prior.

Return type:

SingularPrior

exception InvalidPriorName(name=None)[source]

Bases: Exception

Raised when a prior name is already taken.

Initialize self. See help(type(self)) for accurate signature.

Parameters:

name (Optional[str]) –

class Bernoulli(*, logits=None, probs=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class Beta(*, concentration0=None, concentration1=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Initialised Categorical special prior.

Parameters:
  • parametrisation (Literal[gumbel_max, cdf]) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.

  • logits – log-prob of each category

  • probs – prob of each category

  • name (Optional[str]) – optional name

class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.

Parameters:
  • n (int) – number of samples within [low,high]

  • low – minimum of distribution

  • high – maximum of distribution

  • fix_left (bool) – if True, the leftmost value is fixed to low

  • fix_right (bool) – if True, the rightmost value is fixed to high

  • name (Optional[str]) –

class Poisson(*, rate=None, log_rate=None, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

The base prior class with public methods.

Parameters:

name (Optional[str]) –

class UnnormalisedDirichlet(*, concentration, name=None)[source]

Bases: jaxns.framework.bases.BaseAbstractPrior

Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.

X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)

Parameters:

name (Optional[str]) –