multi_slice_sampler

jaxns.samplers.multi_slice_sampler

Module Contents

class MultiDimSliceSampler(model, num_slices, num_phantom_save, num_restrict_dims=None)[source]

Bases: jaxns.samplers.bases.BaseAbstractMarkovSampler

Multi-dimensional slice sampler, with exponential shrinkage. Produces correlated (non-i.i.d.) samples.

Notes: Not very efficient.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – AbstractModel

  • num_slices (int) – number of slices between acceptance, in units of 1, unlike other software which does it in units of prior dimension.

  • num_phantom_save (int) – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.

  • num_restrict_dims (Optional[int]) – size of subspace to slice along. Setting to 1 would be like UniDimSliceSampler, but far less efficient.

num_phantom()[source]
Return type:

int

pre_process(state)[source]
Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

post_process(sample_collection, sampler_state)[source]
Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

get_seed_point(key, sampler_state, log_L_constraint)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

Return type:

jaxns.samplers.bases.SeedPoint

get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • seed_point (jaxns.samplers.bases.SeedPoint) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]