uniform_samplers
jaxns.samplers.uniform_samplers
Module Contents
- class UniformSampler(model, max_likelihood_evals=100)[source]
Bases:
jaxns.samplers.bases.BaseAbstractRejectionSampler
A sampler that produces uniform samples from the model within the log_L_constraint.
Initialises the sampler.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) – the model to sample from
max_likelihood_evals (int) – the maximum number of likelihood evaluations to perform, before stopping. This is important for not getting stuck on plateaus, or forbidden zones.
- num_phantom()[source]
The number of phantom samples produced by the sampler.
- Returns:
number of phantom samples
- Return type:
- pre_process(state)[source]
Run this periodically on the state to produce a data pytree that can be used by the sampler, and updated quickly.
- Parameters:
state (jaxns.internals.types.StaticStandardNestedSamplerState) – the current state of the sampler
- Returns:
any valid pytree
- Return type:
jaxns.samplers.abc.SamplerState
- post_process(sample_collection, sampler_state)[source]
Post process the sampler state, after the sampler has been run. Should be quick.
- Parameters:
sample_collection (jaxns.internals.types.StaticStandardSampleCollection) – a sample collection post sample step
sampler_state (jaxns.samplers.abc.SamplerState) – data pytree produced by the sampler
- Returns:
the updated sampler state
- Return type:
jaxns.samplers.abc.SamplerState
- get_sample(key, log_L_constraint, sampler_state)[source]
Produce a single i.i.d. sample from the model within the log_L_constraint.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGkey
log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within
sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler
- Returns:
an i.i.d. sample, and batched phantom samples
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]