abc

jaxns.samplers.abc

Module Contents

SamplerState[source]
class EphemeralState[source]

Bases: NamedTuple

key: jaxns.internals.types.PRNGKey[source]
live_points_collection: jaxns.nested_samplers.common.types.LivePointCollection[source]
termination_register: jaxns.nested_samplers.common.types.TerminationRegister[source]
class AbstractSampler[source]

Bases: abc.ABC, Generic[SamplerState]

Helper class that provides a standard way to create an ABC using inheritance.

pre_process(ephemeral_state)[source]

Run this periodically on the state to produce a data pytree that can be used by the sampler, and updated quickly.

Parameters:

ephemeral_state (EphemeralState) – the current state of the sampler

Returns:

any valid pytree

Return type:

SamplerState

post_process(ephemeral_state, sampler_state)[source]

Post process the sampler state, after the sampler has been run. Should be quick.

Parameters:
  • ephemeral_state (EphemeralState) – a sample collection post sample step

  • sampler_state (SamplerState) – data pytree produced by the sampler

Returns:

the updated sampler state

Return type:

SamplerState

get_sample(key, log_L_constraint, sampler_state)[source]

Produce a single i.i.d. sample from the model within the log_L_constraint.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGkey

  • log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within

  • sampler_state (SamplerState) – the data pytree needed and produced by the sampler

Returns:

an i.i.d. sample, and phantom samples

Return type:

Tuple[jaxns.nested_samplers.common.types.Sample, jaxns.nested_samplers.common.types.Sample]

abstract num_phantom()[source]

The number of phantom samples produced by the sampler.

Returns:

number of phantom samples

Return type:

int