samplers
jaxns.samplers
Subpackages
Submodules
Package Contents
- class MultiEllipsoidalSampler(depth, expansion_factor, *args, **kwargs)[source]
Bases:
jaxns.samplers.bases.BaseAbstractRejectionSampler
Uses a multi-ellipsoidal decomposition of the live points to create a bound around regions to sample from.
Inefficient for high dimensional problems, but can be very efficient for low dimensional problems.
- property max_num_ellipsoids
- pre_process(state)[source]
- Parameters:
state (jaxns.internals.types.StaticStandardNestedSamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- post_process(sample_collection, sampler_state)[source]
- Parameters:
sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –
sampler_state (jaxns.samplers.abc.SamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- get_sample(key, log_L_constraint, sampler_state)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
log_L_constraint (jaxns.internals.types.FloatArray) –
sampler_state (jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils.MultEllipsoidState) –
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]
- class MultiDimSliceSampler(model, num_slices, num_phantom_save, num_restrict_dims=None)[source]
Bases:
jaxns.samplers.bases.BaseAbstractMarkovSampler
Multi-dimensional slice sampler, with exponential shrinkage. Produces correlated (non-i.i.d.) samples.
Notes: Not very efficient.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – AbstractModel
num_slices (int) – number of slices between acceptance, in units of 1, unlike other software which does it in units of prior dimension.
num_phantom_save (int) – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.
num_restrict_dims (Optional[int]) – size of subspace to slice along. Setting to 1 would be like UniDimSliceSampler, but far less efficient.
- pre_process(state)[source]
- Parameters:
state (jaxns.internals.types.StaticStandardNestedSamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- post_process(sample_collection, sampler_state)[source]
- Parameters:
sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –
sampler_state (jaxns.samplers.abc.SamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- get_seed_point(key, sampler_state, log_L_constraint)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
sampler_state (jaxns.samplers.abc.SamplerState) –
log_L_constraint (jaxns.internals.types.FloatArray) –
- Return type:
- get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
seed_point (jaxns.samplers.bases.SeedPoint) –
log_L_constraint (jaxns.internals.types.FloatArray) –
sampler_state (jaxns.samplers.abc.SamplerState) –
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]
- class UniDimSliceSampler(model, num_slices, num_phantom_save, midpoint_shrink, perfect, gradient_slice=False)[source]
Bases:
jaxns.samplers.bases.BaseAbstractMarkovSampler
Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples.
Unidimensional slice sampler.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – AbstractModel
num_slices (int) – number of slices between acceptance. Note: some other software use units of prior dimension.
midpoint_shrink (bool) – if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation.
num_phantom_save (int) – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.
perfect (bool) – if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance.
gradient_slice (bool) – if true then always slice along gradient direction.
- pre_process(state)[source]
- Parameters:
state (jaxns.internals.types.StaticStandardNestedSamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- post_process(sample_collection, sampler_state)[source]
- Parameters:
sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –
sampler_state (jaxns.samplers.abc.SamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- get_seed_point(key, sampler_state, log_L_constraint)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
sampler_state (jaxns.samplers.abc.SamplerState) –
log_L_constraint (jaxns.internals.types.FloatArray) –
- Return type:
- get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
seed_point (jaxns.samplers.bases.SeedPoint) –
log_L_constraint (jaxns.internals.types.FloatArray) –
sampler_state (jaxns.samplers.abc.SamplerState) –
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]
- class UniformSampler(model, max_likelihood_evals=100)[source]
Bases:
jaxns.samplers.bases.BaseAbstractRejectionSampler
A sampler that produces uniform samples from the model within the log_L_constraint.
Initialises the sampler.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – the model to sample from
max_likelihood_evals (int) – the maximum number of likelihood evaluations to perform, before stopping. This is important for not getting stuck on plateaus, or forbidden zones.
- num_phantom()[source]
The number of phantom samples produced by the sampler.
- Returns:
number of phantom samples
- Return type:
- pre_process(state)[source]
Run this periodically on the state to produce a data pytree that can be used by the sampler, and updated quickly.
- Parameters:
state (jaxns.internals.types.StaticStandardNestedSamplerState) – the current state of the sampler
- Returns:
any valid pytree
- Return type:
jaxns.samplers.abc.SamplerState
- post_process(sample_collection, sampler_state)[source]
Post process the sampler state, after the sampler has been run. Should be quick.
- Parameters:
sample_collection (jaxns.internals.types.StaticStandardSampleCollection) – a sample collection post sample step
sampler_state (jaxns.samplers.abc.SamplerState) – data pytree produced by the sampler
- Returns:
the updated sampler state
- Return type:
jaxns.samplers.abc.SamplerState
- get_sample(key, log_L_constraint, sampler_state)[source]
Produce a single i.i.d. sample from the model within the log_L_constraint.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGkey
log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within
sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler
- Returns:
an i.i.d. sample, and batched phantom samples
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]