uni_slice_sampler
jaxns.samplers.uni_slice_sampler
Module Contents
- class UniDimSliceSampler(model, num_slices, num_phantom_save, midpoint_shrink, perfect, gradient_slice=False)[source]
Bases:
jaxns.samplers.bases.BaseAbstractMarkovSampler
Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples.
Unidimensional slice sampler.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – AbstractModel
num_slices (int) – number of slices between acceptance. Note: some other software use units of prior dimension.
midpoint_shrink (bool) – if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation.
num_phantom_save (int) – number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty.
perfect (bool) – if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance.
gradient_slice (bool) – if true then always slice along gradient direction.
- pre_process(state)[source]
- Parameters:
state (jaxns.internals.types.StaticStandardNestedSamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- post_process(sample_collection, sampler_state)[source]
- Parameters:
sample_collection (jaxns.internals.types.StaticStandardSampleCollection) –
sampler_state (jaxns.samplers.abc.SamplerState) –
- Return type:
jaxns.samplers.abc.SamplerState
- get_seed_point(key, sampler_state, log_L_constraint)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
sampler_state (jaxns.samplers.abc.SamplerState) –
log_L_constraint (jaxns.internals.types.FloatArray) –
- Return type:
- get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
- Parameters:
key (jaxns.internals.types.PRNGKey) –
seed_point (jaxns.samplers.bases.SeedPoint) –
log_L_constraint (jaxns.internals.types.FloatArray) –
sampler_state (jaxns.samplers.abc.SamplerState) –
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]