multi_slice_sampler

jaxns.samplers.multi_slice_sampler

Module Contents

class MultiDimSliceSampler[source]

Bases: jaxns.samplers.bases.BaseAbstractMarkovSampler[jaxns.nested_samplers.common.types.SampleCollection]

Multi-dimensional slice sampler, with exponential shrinkage. Produces correlated (non-i.i.d.) samples.

Notes: Not very efficient.

Parameters:
  • model – AbstractModel

  • num_slices – number of slices between acceptance, in units of 1, unlike other software which does it in units of prior dimension.

  • num_phantom_save – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.

  • num_restrict_dims – size of subspace to slice along. Setting to 1 would be like UniDimSliceSampler, but far less efficient.

model: jaxns.framework.bases.BaseAbstractModel[source]
num_slices: int[source]
num_phantom_save: int[source]
num_restrict_dims: int | None = None[source]
__post_init__()[source]
num_phantom()[source]
Return type:

int

get_seed_point(key, sampler_state, log_L_constraint)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • sampler_state (jaxns.nested_samplers.common.types.LivePointCollection)

  • log_L_constraint (jaxns.internals.types.FloatArray)

Return type:

jaxns.samplers.bases.SeedPoint

get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey)

  • seed_point (jaxns.samplers.bases.SeedPoint)

  • log_L_constraint (jaxns.internals.types.FloatArray)

  • sampler_state (jaxns.nested_samplers.common.types.SampleCollection)

Return type:

Tuple[jaxns.nested_samplers.common.types.Sample, jaxns.nested_samplers.common.types.Sample]