abc

jaxns.samplers.abc

Module Contents

SamplerState[source]
class AbstractSampler[source]

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

abstract pre_process(state)[source]

Run this periodically on the state to produce a data pytree that can be used by the sampler, and updated quickly.

Parameters:

state (jaxns.internals.types.StaticStandardNestedSamplerState) – the current state of the sampler

Returns:

any valid pytree

Return type:

SamplerState

abstract post_process(sample_collection, sampler_state)[source]

Post process the sampler state, after the sampler has been run. Should be quick.

Parameters:
  • sample_collection (jaxns.internals.types.StaticStandardSampleCollection) – a sample collection post sample step

  • sampler_state (SamplerState) – data pytree produced by the sampler

Returns:

the updated sampler state

Return type:

SamplerState

abstract get_sample(key, log_L_constraint, sampler_state)[source]

Produce a single i.i.d. sample from the model within the log_L_constraint.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGkey

  • log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within

  • sampler_state (SamplerState) – the data pytree needed and produced by the sampler

Returns:

an i.i.d. sample, and batched phantom samples

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]

abstract num_phantom()[source]

The number of phantom samples produced by the sampler.

Returns:

number of phantom samples

Return type:

int