uni_slice_sampler

jaxns.samplers.uni_slice_sampler

Module Contents

class UniDimSliceSampler(model, num_slices, num_phantom_save, midpoint_shrink, perfect, gradient_slice=False)[source]

Bases: jaxns.samplers.bases.BaseAbstractMarkovSampler

Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples.

Unidimensional slice sampler.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – AbstractModel

  • num_slices (int) – number of slices between acceptance. Note: some other software use units of prior dimension.

  • midpoint_shrink (bool) – if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation.

  • num_phantom_save (int) – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.

  • perfect (bool) – if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance.

  • gradient_slice (bool) – if true then always slice along gradient direction.

num_phantom()[source]
Return type:

int

pre_process(state)[source]
Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

post_process(sample_collection, sampler_state)[source]
Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

jaxns.samplers.abc.SamplerState

get_seed_point(key, sampler_state, log_L_constraint)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

Return type:

jaxns.samplers.bases.SeedPoint

get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Parameters:
  • key (jaxns.internals.types.PRNGKey) –

  • seed_point (jaxns.samplers.bases.SeedPoint) –

  • log_L_constraint (jaxns.internals.types.FloatArray) –

  • sampler_state (jaxns.samplers.abc.SamplerState) –

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]