samplers
jaxns.samplers
Subpackages
Submodules
Package Contents
- class MultiEllipsoidalSampler[source]
Bases:
jaxns.samplers.bases.BaseAbstractRejectionSampler[jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils.MultEllipsoidState]Uses a multi-ellipsoidal decomposition of the live points to create a bound around regions to sample from.
Inefficient for high dimensional problems, but can be very efficient for low dimensional problems.
- model: jaxns.framework.bases.BaseAbstractModel
- property max_num_ellipsoids
- class MultiDimSliceSampler[source]
Bases:
jaxns.samplers.bases.BaseAbstractMarkovSampler[jaxns.nested_samplers.common.types.SampleCollection]Multi-dimensional slice sampler, with exponential shrinkage. Produces correlated (non-i.i.d.) samples.
Notes: Not very efficient.
- Parameters:
model – AbstractModel
num_slices – number of slices between acceptance, in units of 1, unlike other software which does it in units of prior dimension.
num_phantom_save – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.
num_restrict_dims – size of subspace to slice along. Setting to 1 would be like UniDimSliceSampler, but far less efficient.
- model: jaxns.framework.bases.BaseAbstractModel
- get_seed_point(key, sampler_state, log_L_constraint)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey)
sampler_state (jaxns.nested_samplers.common.types.LivePointCollection)
log_L_constraint (jaxns.internals.types.FloatArray)
- Return type:
- get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey)
seed_point (jaxns.samplers.bases.SeedPoint)
log_L_constraint (jaxns.internals.types.FloatArray)
sampler_state (jaxns.nested_samplers.common.types.SampleCollection)
- Return type:
Tuple[jaxns.nested_samplers.common.types.Sample, jaxns.nested_samplers.common.types.Sample]
- class UniDimSliceSampler[source]
Bases:
jaxns.samplers.bases.BaseAbstractMarkovSampler[jaxns.nested_samplers.common.types.SampleCollection]Slice sampler for a single dimension.
- Parameters:
model – AbstractModel
num_slices – number of slices between acceptance. Note: some other software use units of prior dimension.
midpoint_shrink – if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation.
num_phantom_save – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.
perfect – if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance.
gradient_slice – if true then always slice along increasing gradient direction.
adaptive_shrink – if true then shrink interval to random point in interval, rather than midpoint.
gradient_guided – if true then do householder reflections at between proposals with a 50% probability.
- model: jaxns.framework.bases.BaseAbstractModel
- get_seed_point(key, sampler_state, log_L_constraint)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey)
sampler_state (jaxns.nested_samplers.common.types.LivePointCollection)
log_L_constraint (jaxns.internals.types.FloatArray)
- Return type:
- get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey)
seed_point (jaxns.samplers.bases.SeedPoint)
log_L_constraint (jaxns.internals.types.FloatArray)
sampler_state (jaxns.nested_samplers.common.types.SampleCollection)
- Return type:
Tuple[jaxns.nested_samplers.common.types.Sample, jaxns.nested_samplers.common.types.Sample]
- class UniformSampler[source]
Bases:
jaxns.samplers.bases.BaseAbstractRejectionSampler[Tuple]A sampler that produces uniform samples from the model within the log_L_constraint.
- model: jaxns.framework.bases.BaseAbstractModel
- class ShardedStaticNestedSampler[source]
Bases:
jaxns.nested_samplers.abc.AbstractNestedSamplerA static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met.
- Parameters:
init_efficiency_threshold – the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off.
sampler – the sampler to use after the initial uniform sampling.
num_live_points – the number of live points to use.
model – the model to use.
max_samples – the maximum number of samples to take.
devices – the devices to use, default is 1.
verbose – whether to log as we go.
- model: jaxns.framework.bases.BaseAbstractModel
- sampler: jaxns.samplers.abc.AbstractSampler
- class TerminationCondition[source]
Bases:
NamedTupleContains the termination conditions for the nested sampling run.
- Parameters:
ess – The effective sample size, if the ESS (Kish’s estimate) is greater than this the run will terminate.
evidence_uncert – The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.
live_evidence_frac – Depreceated use dlogZ.
dlogZ – Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)
max_samples – Terminate if the number of samples exceeds this.
max_num_likelihood_evaluations – Terminate if the number of likelihood evaluations exceeds this.
log_L_contour – Terminate if this log(L) contour is reached. A contour is reached if any dead point has log(L) > log_L_contour. Uncollected live points are not considered.
efficiency_threshold – Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this, for the last shrinkage iteration.
rtol – finish when the relative value 2*|log_L_max - log_L_min|/|log_L_max + log_L_min| < rol
atol – finish when the absolute |log_L_max - log_L_min| < atol
- class NestedSamplerResults[source]
Bases:
NamedTupleResults of the nested sampling run.
- log_Z_mean: jaxns.internals.types.FloatArray
- log_Z_uncert: jaxns.internals.types.FloatArray
- ESS: jaxns.internals.types.FloatArray
- H_mean: jaxns.internals.types.FloatArray
- samples: jaxns.internals.types.XType
- parametrised_samples: jaxns.internals.types.XType
- U_samples: jaxns.internals.types.UType
- log_L_samples: jaxns.internals.types.FloatArray
- log_dp_mean: jaxns.internals.types.FloatArray
- log_X_mean: jaxns.internals.types.FloatArray
- log_posterior_density: jaxns.internals.types.FloatArray
- num_live_points_per_sample: jaxns.internals.types.IntArray
- num_likelihood_evaluations_per_sample: jaxns.internals.types.IntArray
- total_num_samples: jaxns.internals.types.IntArray
- total_phantom_samples: jaxns.internals.types.IntArray
- total_num_likelihood_evaluations: jaxns.internals.types.IntArray
- log_efficiency: jaxns.internals.types.FloatArray
- termination_reason: jaxns.internals.types.IntArray