jaxns
jaxns
Nested sampling with JAX.
Subpackages
Submodules
Package Contents
- jaxify_likelihood(log_likelihood, vectorised=False)[source]
Wraps a non-JAX log likelihood function.
- Parameters:
log_likelihood (Callable[Ellipsis, numpy.ndarray]) – a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar log-likelihood.
vectorised (bool) – if True then the log_likelihood must handle batched inputs, i.e. each input will receive a common set of batched dimensions which the function must handle.
- Returns:
A JAX-compatible log-likelihood function.
- Return type:
jaxns.internals.types.LikelihoodType
- class Model(prior_model, log_likelihood, params=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractModelRepresents a Bayesian model in terms of a generative prior, and likelihood function.
- Parameters:
prior_model (jaxns.framework.bases.PriorModelType)
log_likelihood (jaxns.internals.types.LikelihoodType)
params (Optional[jaxns.framework.context.MutableParams])
- property params
- set_params(params)[source]
Create a new parametrised model with the given parameters.
- Parameters:
params (jaxns.framework.context.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- __call__(params)[source]
Create a new parametrised model with the given parameters.
This is (and must be) a pure function.
- Parameters:
params (jaxns.framework.context.MutableParams) – The parameters to use.
- Returns:
A model with set parameters.
- Return type:
- init_params(rng)[source]
Initialise the parameters of the model.
- Parameters:
rng (jaxns.internals.types.PRNGKey) – PRNGkey to initialise the parameters.
- Returns:
The initialised parameters.
- Return type:
jaxns.framework.context.MutableParams
- sample_U(key)[source]
Sample from the prior model.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey to use.
- Returns:
The sampled U.
- Return type:
jaxns.internals.types.UType
- sample_W(key)[source]
Sample from the prior model.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey to use.
- Returns:
The sampled W.
- Return type:
jaxns.internals.types.WType
- transform(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.XType
- transform_parametrised(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.XType
- forward(U, allow_nan=False)[source]
- Parameters:
U (jaxns.internals.types.UType)
allow_nan (bool)
- Return type:
jaxns.internals.types.FloatArray
- log_prob_prior(U)[source]
- Parameters:
U (jaxns.internals.types.UType)
- Return type:
jaxns.internals.types.FloatArray
- class Prior(dist_or_value, name=None)[source]
Bases:
jaxns.framework.bases.BaseAbstractPriorRepresents a generative prior.
- Parameters:
dist_or_value (Union[tfpd, jaxns.internals.types.FloatArray, jaxns.internals.types.IntArray, jaxns.internals.types.BoolArray])
name (Optional[str])
- name = None
- property dist: jaxns.framework.bases.BaseAbstractDistribution
- Return type:
jaxns.framework.bases.BaseAbstractDistribution
- parametrised(random_init=False)[source]
Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated log_prob. The parameter is registered as a get_parameter with added _param name suffix. Prior must have a name.
- Parameters:
random_init (bool) – whether to initialise the parameter randomly or at the median of the distribution.
- Returns:
A singular prior.
- Raises:
ValueError – if the prior has no name.
- Return type:
SingularPrior
- exception InvalidPriorName(name=None)[source]
Bases:
ExceptionRaised when a prior name is already taken.
Initialize self. See help(type(self)) for accurate signature.
- Parameters:
name (Optional[str])
- class Bernoulli(*, logits=None, probs=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- dist
- class Beta(*, concentration0=None, concentration1=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- class Categorical(parametrisation, *, logits=None, probs=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
Initialised Categorical special prior.
- Parameters:
parametrisation (Literal['gumbel_max', 'cdf']) – ‘cdf’ is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better.
logits – log-prob of each category
probs – prob of each category
name (Optional[str]) – optional name
- dist
- class ForcedIdentifiability(*, n, low=None, high=None, fix_left=False, fix_right=False, name=None)[source]
Bases:
SpecialPriorPrior for a sequence of n random variables uniformly distributed on U[low, high] such that U[i,…] <= U[i+1,…]. For broadcasting the resulting random variable is sorted on the first dimension elementwise.
- Parameters:
- n
- low
- high
- fix_left = False
- fix_right = False
- class Poisson(*, rate=None, log_rate=None, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- Parameters:
name (Optional[str])
- dist
- class UnnormalisedDirichlet(*, concentration, name=None)[source]
Bases:
SpecialPriorRepresents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation.
X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha)
- Parameters:
name (Optional[str])
- class Empirical(*, samples, support_min=None, support_max=None, resolution=100, name=None)[source]
Bases:
SpecialPriorRepresents the empirical distribution of a set of 1D samples, with arbitrary batch dimension.
- class TruncationWrapper(prior, low, high, name=None)[source]
Bases:
SpecialPriorWraps another prior to make it truncated.
For truncated distribution the quantile transforms to:
Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low))
And the CDF transforms to:
F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low))
- Parameters:
- prior
- low
- high
- cdf_low
- cdf_diff
- class ExplicitDensityPrior(*, axes, density, regular_grid=False, name=None)[source]
Bases:
SpecialPriorThe base prior class with public methods.
- plot_diagnostics(results, save_name=None)[source]
Plot diagnostics of the nested sampling run.
- Parameters:
results (jaxns.nested_samplers.common.types.NestedSamplerResults) – NestedSamplerResult
save_name – file to save figure to.
- plot_cornerplot(results, variables=None, with_parametrised=False, save_name=None, kde_overlay=False)[source]
Plots a cornerplot of the posterior samples.
- Parameters:
results (jaxns.nested_samplers.common.types.NestedSamplerResults) – NestedSamplerResult
variables (Optional[List[str]]) – list of variable names to plot. Plots all collected samples by default.
save_name (Optional[str]) – file to save result to.
kde_overlay (bool) – whether to overlay a KDE on the histograms.
with_parametrised (bool)
- class NestedSampler[source]
A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
- Parameters:
model – a model to perform nested sampling on
max_samples – maximum number of samples to take
num_live_points – approximate number of live points to use. Defaults is c * (k + 1).
s – number of slices to use per dimension. Defaults to 4.
k – number of phantom samples to use. Defaults to 0.
c – number of parallel Markov-chains to use. Defaults to 20 * D.
devices – devices to use. Defaults to all available devices.
difficult_model – if True, uses more robust default settings. Defaults to False.
parameter_estimation – if True, uses more robust default settings for parameter estimation. Defaults to False.
shell_fraction – fraction of the shell to use for the slice sampler. Defaults to 0.5.
gradient_guided – if True, uses gradient guided sampling. Defaults to False.
init_efficiency_threshold – if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off.
verbose – whether to log progress.
- model: jaxns.framework.bases.BaseAbstractModel
- property nested_sampler: jaxns.nested_samplers.abc.AbstractNestedSampler
- Return type:
- __call__(key, term_cond=None)[source]
Performs nested sampling with the given termination conditions.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
term_cond (Optional[jaxns.nested_samplers.common.types.TerminationCondition]) – termination conditions. If not given, see TerminationCondition for defaults.
- Returns:
termination reason, state
- Return type:
Tuple[jaxns.internals.types.IntArray, jaxns.nested_samplers.common.types.NestedSamplerState]
- to_results(termination_reason, state, trim=True)[source]
Convert the state to results.
Note: Requires static context.
- Parameters:
termination_reason (jaxns.internals.types.IntArray) – termination reason
state (jaxns.nested_samplers.common.types.NestedSamplerState) – state to convert
trim (bool) – if True, trims the results to the number of samples taken, requires static context.
- Returns:
results
- Return type:
- static trim_results(results)[source]
Trims the results to the number of samples taken. Requires static context.
- Parameters:
results (jaxns.nested_samplers.common.types.NestedSamplerResults) – results to trim
- Returns:
trimmed results
- Return type:
- resample(key, samples, log_weights, S=None, replace=True)[source]
Resample the weighted samples into uniformly weighted samples.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
samples (Union[jaxns.internals.types.XType, jaxns.internals.types.UType]) – samples from nested sampled results
log_weights (jax.Array) – log-posterior weight
S (int) – number of samples to generate. Will use Kish’s estimate of ESS if None.
replace (bool) – whether to sample with replacement
- Returns:
equally weighted samples
- Return type:
jaxns.internals.types.XType
- marginalise_static_from_U(key, U_samples, model, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
U_samples (jaxns.internals.types.UType) – array of U samples
model (jaxns.framework.bases.BaseAbstractModel) – model
log_weights (jax.Array) – log weights from nested sampling
ESS (int) – static effective sample size
fun (
callable(**kwargs)) – function to marginalise
- Returns:
expectation over resampled samples
- Return type:
_V
- marginalise_dynamic_from_U(key, U_samples, model, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS can be dynamic.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNG key
U_samples (jaxns.internals.types.UType) – array of U samples
model (jaxns.framework.bases.BaseAbstractModel) – model
log_weights (jax.Array) – log weights from nested sampling
ESS (jax.Array) – dynamic effective sample size
fun (
callable(**kwargs)) – function to marginalise
- Returns:
expectation of func over resampled samples.
- Return type:
_V
- marginalise_static(key, samples, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
- Returns:
expectation over resampled samples
- Return type:
_V
- marginalise_dynamic(key, samples, log_weights, ESS, fun)[source]
Marginalises function over posterior samples, where ESS can be dynamic.
- Parameters:
- Returns:
expectation of func over resampled samples.
- Return type:
_V
- maximum_a_posteriori_point(results)[source]
Get the MAP point of a nested sampling result. Does this by choosing the point with largest L(x) p(x).
- Parameters:
results (NestedSamplerResult) – Nested sampler result
- Returns:
dict of samples at MAP-point.
- Return type:
jaxns.internals.types.XType
- evaluate_map_estimate(results, fun)[source]
Marginalises function over posterior samples, where ESS is static.
- Parameters:
results (jaxns.nested_samplers.common.types.NestedSamplerResults) – results from run
fun (
callable(**kwargs)) – function to marginalise
- Returns:
estimate at MAP sample point
- Return type:
_V
- summary(results, with_parametrised=False, f_obj=None)[source]
Gives a summary of the results of a nested sampling run.
- Parameters:
results (NestedSamplerResults) – Nested sampler result
with_parametrised (bool) – whether to include parametrised samples
f_obj (Optional[Union[str, TextIO]]) – file-like object to write summary to. If None, prints to stdout.
- analytic_posterior_samples(model, S=60)[source]
Compute the evidence with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
log(Z)
- sample_evidence(key, num_live_points_per_sample, log_L_samples, S=100)[source]
Sample the evidence distribution, but stochastically simulating the shrinkage distribution.
Note: this produces approximate samples, since there is also an uncertainty in the placement of the contours during shrinkage. Incorporating this stochasticity into the simulation would require running an entire nested sampling many times.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
num_live_points_per_sample (jaxns.internals.types.IntArray) – the number of live points for each sample
log_L_samples (jaxns.internals.types.FloatArray) – the log-L of samples
S (int) – The number of samples to produce
- Returns:
samples of log(Z)
- Return type:
jaxns.internals.types.FloatArray
- bruteforce_posterior_samples(model, S=60)[source]
Compute the posterior with brute-force over a regular grid.
- bruteforce_evidence(model, S=60)[source]
Compute the evidence with brute-force over a regular grid.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – model
S (int) – resolution of grid
- Returns:
log(Z)
- save_pytree(pytree, save_file)[source]
Saves results of nested sampler in a json file.
- Parameters:
pytree (NamedTuple) – Nested sampler result
save_file (str) – filename
- save_results(results, save_file)[source]
Saves results of nested sampler in a json file.
- Parameters:
results (NestedSamplerResults) – Nested sampler result
save_file (str) – filename
- load_pytree(save_file)[source]
Loads saved nested sampler results from a json file.
- Parameters:
save_file (str) – filename
- Returns:
NestedSamplerResults
- load_results(save_file)[source]
Loads saved nested sampler results from a json file.
- Parameters:
save_file (str) – filename
- Returns:
NestedSamplerResults
- Return type:
- class ShardedStaticNestedSampler[source]
Bases:
jaxns.nested_samplers.abc.AbstractNestedSamplerA static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met.
- Parameters:
init_efficiency_threshold – the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off.
sampler – the sampler to use after the initial uniform sampling.
num_live_points – the number of live points to use.
model – the model to use.
max_samples – the maximum number of samples to take.
devices – the devices to use, default is 1.
verbose – whether to log as we go.
- model: jaxns.framework.bases.BaseAbstractModel
- sampler: jaxns.samplers.abc.AbstractSampler
- class TerminationCondition[source]
Bases:
NamedTupleContains the termination conditions for the nested sampling run.
- Parameters:
ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.
evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.
live_evidence_frac – Depreceated use dlogZ.
dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)
max_samples – Terminate if the number of samples exceeds this.
max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.
log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.
efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.
rtol – finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol
atol – finish when the absolute |log_L_max - log_L_min| < atol
- class NestedSamplerResults[source]
Bases:
NamedTupleResults of the nested sampling run.
- log_Z_mean: jaxns.internals.types.FloatArray
- log_Z_uncert: jaxns.internals.types.FloatArray
- ESS: jaxns.internals.types.FloatArray
- H_mean: jaxns.internals.types.FloatArray
- samples: jaxns.internals.types.XType
- parametrised_samples: jaxns.internals.types.XType
- U_samples: jaxns.internals.types.UType
- log_L_samples: jaxns.internals.types.FloatArray
- log_dp_mean: jaxns.internals.types.FloatArray
- log_X_mean: jaxns.internals.types.FloatArray
- log_posterior_density: jaxns.internals.types.FloatArray
- num_live_points_per_sample: jaxns.internals.types.IntArray
- num_likelihood_evaluations_per_sample: jaxns.internals.types.IntArray
- total_num_samples: jaxns.internals.types.IntArray
- total_phantom_samples: jaxns.internals.types.IntArray
- total_num_likelihood_evaluations: jaxns.internals.types.IntArray
- log_efficiency: jaxns.internals.types.FloatArray
- termination_reason: jaxns.internals.types.IntArray