uniform_samplers

jaxns.samplers.uniform_samplers

Module Contents

class UniformSampler(model, max_likelihood_evals=100)[source]

Bases: jaxns.samplers.bases.BaseAbstractRejectionSampler

A sampler that produces uniform samples from the model within the log_L_constraint.

Initialises the sampler.

Parameters:
  • model (jaxns.framework.bases.BaseAbstractModel) – the model to sample from

  • max_likelihood_evals (int) – the maximum number of likelihood evaluations to perform, before stopping. This is important for not getting stuck on plateaus, or forbidden zones.

num_phantom()[source]

The number of phantom samples produced by the sampler.

Returns:

number of phantom samples

Return type:

int

pre_process(state)[source]

Run this periodically on the state to produce a data pytree that can be used by the sampler, and updated quickly.

Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) – the current state of the sampler

Returns:

any valid pytree

Return type:

jaxns.samplers.abc.SamplerState

post_process(sample_collection, sampler_state)[source]

Post process the sampler state, after the sampler has been run. Should be quick.

Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) – a sample collection post sample step

  • sampler_state (jaxns.samplers.abc.SamplerState) – data pytree produced by the sampler

Returns:

the updated sampler state

Return type:

jaxns.samplers.abc.SamplerState

get_sample(key, log_L_constraint, sampler_state)[source]

Produce a single i.i.d. sample from the model within the log_L_constraint.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGkey

  • log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within

  • sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler

Returns:

an i.i.d. sample, and batched phantom samples

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]