samplers

jaxns.samplers

Subpackages

Submodules

Package Contents

class MultiEllipsoidalSampler(depth, expansion_factor, *args, **kwargs)[source]

Bases: jaxns.samplers.bases.BaseAbstractRejectionSampler

Uses a multi-ellipsoidal decomposition of the live points to create a bound around regions to sample from.

Inefficient for high dimensional problems, but can be very efficient for low dimensional problems.

Parameters:
  • depth (int) –

  • expansion_factor (float) –

property max_num_ellipsoids
num_phantom()[source]
Return type:

int

pre_process(state)[source]
Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

post_process(sample_collection, sampler_state)[source]
Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

get_sample(key, log_L_constraint, sampler_state)[source]
Parameters:
Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]

class MultiDimSliceSampler(model, num_slices, num_phantom_save, num_restrict_dims=None)[source]

Bases: jaxns.samplers.bases.BaseAbstractMarkovSampler

Multi-dimensional slice sampler, with exponential shrinkage. Produces correlated (non-i.i.d.) samples.

Notes: Not very efficient.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – AbstractModel

  • num_slices (int) – number of slices between acceptance, in units of 1, unlike other software which does it in units of prior dimension.

  • num_phantom_save (int) – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.

  • num_restrict_dims (Optional[int]) – size of subspace to slice along. Setting to 1 would be like UniDimSliceSampler, but far less efficient.

num_phantom()[source]
Return type:

int

pre_process(state)[source]
Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

post_process(sample_collection, sampler_state)[source]
Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

get_seed_point(key, sampler_state, log_L_constraint)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

Return type:

jaxns.samplers.bases.SeedPoint

get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • seed_point (jaxns.samplers.bases.SeedPoint) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]

class UniDimSliceSampler(model, num_slices, num_phantom_save, midpoint_shrink, perfect, gradient_slice=False)[source]

Bases: jaxns.samplers.bases.BaseAbstractMarkovSampler

Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples.

Unidimensional slice sampler.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – AbstractModel

  • num_slices (int) – number of slices between acceptance. Note: some other software use units of prior dimension.

  • midpoint_shrink (bool) – if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation.

  • num_phantom_save (int) – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.

  • perfect (bool) – if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance.

  • gradient_slice (bool) – if true then always slice along gradient direction.

num_phantom()[source]
Return type:

int

pre_process(state)[source]
Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

post_process(sample_collection, sampler_state)[source]
Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

get_seed_point(key, sampler_state, log_L_constraint)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

Return type:

jaxns.samplers.bases.SeedPoint

get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • seed_point (jaxns.samplers.bases.SeedPoint) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]

class UniformSampler(model, max_likelihood_evals=100)[source]

Bases: jaxns.samplers.bases.BaseAbstractRejectionSampler

A sampler that produces uniform samples from the model within the log_L_constraint.

Initialises the sampler.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – the model to sample from

  • max_likelihood_evals (int) – the maximum number of likelihood evaluations to perform, before stopping. This is important for not getting stuck on plateaus, or forbidden zones.

num_phantom()[source]

The number of phantom samples produced by the sampler.

Returns:

number of phantom samples

Return type:

int

pre_process(state)[source]

Run this periodically on the state to produce a data pytree that can be used by the sampler, and updated quickly.

Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) – the current state of the sampler

Returns:

any valid pytree

Return type:

jaxns.samplers.abc.SamplerState

post_process(sample_collection, sampler_state)[source]

Post process the sampler state, after the sampler has been run. Should be quick.

Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) – a sample collection post sample step

  • sampler_state (jaxns.samplers.abc.SamplerState) – data pytree produced by the sampler

Returns:

the updated sampler state

Return type:

jaxns.samplers.abc.SamplerState

get_sample(key, log_L_constraint, sampler_state)[source]

Produce a single i.i.d. sample from the model within the log_L_constraint.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGkey

  • log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within

  • sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler

Returns:

an i.i.d. sample, and batched phantom samples

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]