bases
jaxns.samplers.bases
Module Contents
- class BaseAbstractSampler(model)[source]
Bases:
jaxns.samplers.abc.AbstractSampler
Helper class that provides a standard way to create an ABC using inheritance.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) –
- class BaseAbstractRejectionSampler(model)[source]
Bases:
BaseAbstractSampler
Samplers that are based on rejection sampling. They usually first-lines of attack, and are stopped once efficiency gets too low.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) –
- class BaseAbstractMarkovSampler(model)[source]
Bases:
BaseAbstractSampler
A sampler that conditions off a known satisfying point, e.g. a seed point.
- Parameters:
model (jaxns.framework.bases.BaseAbstractModel) –
- abstract get_sample_from_seed(key, seed_point, log_L_constraint, sampler_state)[source]
Produce a single i.i.d. sample from the model within the log_L_constraint.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGkey
seed_point (SeedPoint) – function that gets the next sample from a seed point
log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within
sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler
- Returns:
an i.i.d. sample, and batched phantom samples
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]
- abstract get_seed_point(key, sampler_state, log_L_constraint)[source]
Samples a seed point from the live points.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGKey
sampler_state (jaxns.samplers.abc.SamplerState) – the current sampler state
log_L_constraint (jaxns.internals.types.FloatArray) – a log-L constraint to sample within. Must always be at least one sample in front above this to avoid infinite loop.
- Returns:
a seed point
- Return type:
- get_sample(key, log_L_constraint, sampler_state)[source]
Produce a single i.i.d. sample from the model within the log_L_constraint.
- Parameters:
key (jaxns.internals.types.PRNGKey) – PRNGkey
log_L_constraint (jaxns.internals.types.FloatArray) – the constraint to sample within
sampler_state (jaxns.samplers.abc.SamplerState) – the data pytree needed and produced by the sampler
- Returns:
an i.i.d. sample, and batched phantom samples
- Return type:
Tuple[jaxns.internals.types.Sample, jaxns.internals.types.Sample]