bases

jaxns.samplers.bases

Module Contents

class BaseAbstractSampler(model)[source]

Bases: jaxns.samplers.abc.AbstractSampler

Helper class that provides a standard way to create an ABC using inheritance.

Parameters:

model (jaxns.framework.bases.BaseAbstractModel) –

__repr__()[source]

Return repr(self).

class BaseAbstractRejectionSampler(model)[source]

Bases: BaseAbstractSampler

Samplers that are based on rejection sampling. They usually first-lines of attack, and are stopped once efficiency gets too low.

Parameters:

model (jaxns.framework.bases.BaseAbstractModel) –

class SeedPoint[source]

Bases: NamedTuple

U0: jaxns.internals.types.FloatArray[source]
log_L0: jaxns.internals.types.FloatArray[source]
class BaseAbstractMarkovSampler(model)[source]

Bases: BaseAbstractSampler

A sampler that conditions off a known satisfying point, e.g. a seed point.

Parameters:

model (jaxns.framework.bases.BaseAbstractModel) –

abstract get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]

Produce a single i.i.d. sample from the model within the log_L_constraint.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGkey

  • seed_point (SeedPoint) – function that gets the next sample from a seed point

  • log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within

  • sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler

Returns:

an i.i.d. sample, and batched phantom samples

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]

abstract get_seed_point(key, sampler_state, log_L_constraint)[source]

Samples a seed point from the live points.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGKey

  • sampler_state (jaxns.samplers.abc.SamplerState) – the current sampler state

  • log_L_constraint (jaxns.internals.types.FloatArray) – a log-L constraint to sample within. Must always be at least one sample in front above this to avoid infinite loop.

Returns:

a seed point

Return type:

SeedPoint

get_sample(key, log_L_constraint, sampler_state)[source]

Produce a single i.i.d. sample from the model within the log_L_constraint.

Parameters:
  • key (jaxns.internals.types.PRNGKey) – PRNGkey

  • log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within

  • sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler

Returns:

an i.i.d. sample, and batched phantom samples

Return type:

Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]