jaxns
jaxns
Nested sampling with JAX.
Subpackages
Submodules
Package Contents
- plot_diagnostics(results, save_name=None)[source]
Plot diagnostics of the nested sampling run.
- Parameters:
results (jaxns.internals.types.NestedSamplerResults) – NestedSamplerResult
save_name – file to save figure to.
- plot_cornerplot(results, variables=None, with_parametrised=False, save_name=None, kde_overlay=False)[source]
Plots a cornerplot of the posterior samples.
- Parameters:
results (jaxns.internals.types.NestedSamplerResults) – NestedSamplerResult
variables (Optional[List[str]]) – list of variable names to plot. Plots all collected samples by default.
save_name (Optional[str]) – file to save result to.
kde_overlay (bool) – whether to overlay a KDE on the histograms.
with_parametrised (bool) –
- class DefaultNestedSampler(model, max_samples, num_live_points=None, s=None, k=None, c=None, num_parallel_workers=1, difficult_model=False, parameter_estimation=False, init_efficiency_threshold=0.1, verbose=False)[source]
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – a model to perform nested sampling on
max_samples (Union[int, float]) – maximum number of samples to take
num_live_points (Optional[int]) – approximate number of live points to use. Defaults is c * (k + 1).
s (Optional[int]) – number of slices to use per dimension. Defaults to 4.
k (Optional[int]) – number of phantom samples to use. Defaults to 0.
c (Optional[int]) – number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers (int) – number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model (bool) – if True, uses more robust default settings. Defaults to False.
parameter_estimation (bool) – if True, uses more robust default settings for parameter estimation. Defaults to False.
init_efficiency_threshold (float) – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose (bool) – whether to use JAX printing
- property nested_sampler: jaxns.nested_sampler.bases.BaseAbstractNestedSampler
- Return type:
- __call__(key, term_cond=None)[source]
Performs nested sampling with the given termination conditions.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
term_cond (Optional[jaxns.internals.types.TerminationCondition]) – termination conditions. If not given, see TerminationCondition for defaults.
- Returns:
termination reason, state
- Return type:
Tuple[jaxns.internals.types.IntArray, jaxns.internals.types.StaticStandardNestedSamplerState]
- to_results(termination_reason, state, trim=True)[source]
Convert the state to results.
Note: Requires static context.
- Parameters:
termination_reason (jaxns.internals.types.IntArray) – termination reason
state (jaxns.internals.types.StaticStandardNestedSamplerState) – state to convert
trim (bool) – if True, trims the results to the number of samples taken, requires static context.
- Returns:
results
- Return type:
- static trim_results(results)[source]
Trims the results to the number of samples taken. Requires static context.
- Parameters:
results (jaxns.internals.types.NestedSamplerResults) – results to trim
- Returns:
trimmed results
- Return type:
- class ApproximateNestedSampler(*args, **kwargs)[source]
Bases:
DefaultNestedSampler
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model – a model to perform nested sampling on
max_samples – maximum number of samples to take
num_live_points – approximate number of live points to use. Defaults is c * (k + 1).
s – number of slices to use per dimension. Defaults to 4.
k – number of phantom samples to use. Defaults to 0.
c – number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers – number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model – if True, uses more robust default settings. Defaults to False.
parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.
init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose – whether to use JAX printing
- class ExactNestedSampler(*args, **kwargs)[source]
Bases:
ApproximateNestedSampler
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model – a model to perform nested sampling on
max_samples – maximum number of samples to take
num_live_points – approximate number of live points to use. Defaults is c * (k + 1).
s – number of slices to use per dimension. Defaults to 4.
k – number of phantom samples to use. Defaults to 0.
c – number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers – number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model – if True, uses more robust default settings. Defaults to False.
parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.
init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose – whether to use JAX printing
- class TerminationCondition[source]
Bases:
NamedTuple
Contains the termination conditions for the nested sampling run.
- Parameters:
ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.
evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.
live_evidence_frac – Depreceated use dlogZ.
dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)
max_samples – Terminate if the number of samples exceeds this.
max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.
log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.
efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.
- resample(key, samples, log_weights, S=None, replace=False)[source]
Resample the weighted samples into uniformly weighted samples.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
samples (Union[jaxns.internals.types.XType, jaxns.internals.types.UType]) – samples from nested sampled results
log_weights (jax.numpy.ndarray) – log-posterior weight
S (int) – number of samples to generate. Will use Kish’s estimate of ESS if None.
replace (bool) – whether to sample with replacement
- Returns:
equally weighted samples
- Return type:
jaxns.internals.types.XType
- marginalise_static_from_U(key, U_samples, model, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
U_samples (jaxns.internals.types.UType) – array of U samples
model (jaxns.framework.bases.BaseAbstractModel) – model
log_weights (jax.numpy.ndarray) – log weights from nested sampling
ESS (int) – static effective sample size
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
expectation over resampled samples
- Return type:
_V
- marginalise_dynamic_from_U(key, U_samples, model, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS can be dynamic.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
U_samples (jaxns.internals.types.UType) – array of U samples
model (jaxns.framework.bases.BaseAbstractModel) – model
log_weights (jax.numpy.ndarray) – log weights from nested sampling
ESS (jax.numpy.ndarray) – dynamic effective sample size
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
expectation of func over resampled samples.
- Return type:
_V
- marginalise_static(key, samples, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
- Returns:
expectation over resampled samples
- Return type:
_V
- marginalise_dynamic(key, samples, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS can be dynamic.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
samples (dict) – dict of batched array of nested sampling samples
log_weights (jax.numpy.ndarray) – log weights from nested sampling
ESS (jax.numpy.ndarray) – dynamic effective sample size
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
expectation of func over resampled samples.
- Return type:
_V
- maximum_a_posteriori_point(results)[source]
Get the MAP point of a nested sampling result. Does this by choosing the point with largest L(x) p(x).
- Parameters:
results (NestedSamplerResult) – Nested sampler result
- Returns:
dict of samples at MAP-point.
- Return type:
jaxns.internals.types.XType
- evaluate_map_estimate(results, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
results (jaxns.internals.types.NestedSamplerResults) – results from run
fun (
callable(**kwargs)
) – function to marginalise
- Returns:
estimate at MAP sample point
- Return type:
_V
- summary(results, with_parametrised=False, f_obj=None)[source]
Gives a summary of the results of a nested sampling run.
- Parameters:
results (NestedSamplerResults) – Nested sampler result
with_parametrised (bool) – whether to include parametrised samples
f_obj (Optional[Union[str, TextIO]]) – file-like object to write summary to. If None, prints to stdout.
- analytic_posterior_samples(model, S=60)[source]
Compute the evidence with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
log(Z)
- sample_evidence(key, num_live_points_per_sample, log_L_samples, S=100)[source]
Sample the evidence distribution, but stochastically simulating the shrinkage distribution.
Note: this produces approximate samples, since there is also an uncertainty in the placement of the contours during shrinkage. Incorporating this stochasticity into the simulation would require running an entire nested sampling many times.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
num_live_points_per_sample (jaxns.internals.types.IntArray) – the number of live points for each sample
log_L_samples (jaxns.internals.types.FloatArray) – the log-L of samples
S (int) – The number of samples to produce
- Returns:
samples of log(Z)
- Return type:
jaxns.internals.types.FloatArray
- bruteforce_posterior_samples(model, S=60)[source]
Compute the posterior with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
samples, and log-weight
- Return type:
Tuple[jaxns.internals.types.XType, jax.numpy.ndarray]
- bruteforce_evidence(model, S=60)[source]
Compute the evidence with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
log(Z)
- save_pytree(pytree, save_file)[source]
Saves results of nested sampler in a npz file.
- Parameters:
pytree (NamedTuple) – Nested sampler result
save_file (str) – filename
- save_results(results, save_file)[source]
Saves results of nested sampler in a npz file.
- Parameters:
results (NestedSamplerResults) – Nested sampler result
save_file (str) – filename
- load_pytree(save_file)[source]
Loads saved nested sampler results from a npz file.
- Parameters:
save_file (str) – filename
- Returns:
NestedSamplerResults
- load_results(save_file)[source]
Loads saved nested sampler results from a npz file.
- Parameters:
save_file (str) – filename
- Returns:
NestedSamplerResults
- Return type:
- class Model(prior_model, log_likelihood, params=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractModel
Represents a Bayesian model in terms of a generative prior, and likelihood function.
- Parameters:
prior_model (jaxns.framework.bases.PriorModelType) –
log_likelihood (jaxns.internals.types.LikelihoodType) –
params (Optional[haiku.MutableParams]) –
- property params
- set_params(params)[source]
Create a new parametrised model with the given parameters.
- Parameters:
params (haiku.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- __call__(params)[source]
Create a new parametrised model with the given parameters.
This is (and must be) a pure function.
- Parameters:
params (haiku.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- init_params(rng)[source]
Initialise the parameters of the model.
- Parameters:
rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.
- Returns:
The initialised parameters.
- Return type:
haiku.MutableParams
- sample_U(key)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
- Return type:
jaxns.internals.types.FloatArray
- transform(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.XType
- transform_parametrised(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.XType
- forward(U, allow_nan=False)[source]
- Parameters:
U (jaxns.internals.types.UType) –
allow_nan (bool) –
- Return type:
jaxns.internals.types.FloatArray
- log_prob_prior(U)[source]
- Parameters:
U (jaxns.internals.types.UType) –
- Return type:
jaxns.internals.types.FloatArray
- class Prior(dist_or_value, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
Represents a generative prior.
- Parameters:
dist_or_value (Union[tfpd, jaxns.framework.bases.BaseAbstractDistribution, jax.numpy.ndarray]) –
name (Optional[str]) –
- property dist: jaxns.framework.bases.BaseAbstractDistribution
- Return type:
- property value: jax.numpy.ndarray
- Return type:
jax.numpy.ndarray
- parametrised(random_init=False)[source]
Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a hk.Parameter with added _param name suffix.
- Parameters:
random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.
- Returns:
A singular prior.
- Return type:
SingularPrior
- exception InvalidPriorName(name=None)[source]
Bases:
Exception
Raised when a prior name is already taken.
Initialize self. See help(type(self)) for accurate signature.
- Parameters:
name (Optional[str]) –
- class Bernoulli(*, logits=None, probs=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
- Parameters:
name (Optional[str]) –
- class Beta(*, concentration0=None, concentration1=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
- Parameters:
name (Optional[str]) –
- class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
Initialised Categorical special prior.
- Parameters:
parametrisation (Literal[gumbel_max, cdf]) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.
logits – log-prob of each category
probs – prob of each category
name (Optional[str]) – optional name
- class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
Prior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.
- class Poisson(*, rate=None, log_rate=None, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
The base prior class with public methods.
- Parameters:
name (Optional[str]) –
- class UnnormalisedDirichlet(*, concentration, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPrior
Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.
X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)
- Parameters:
name (Optional[str]) –