from abc import ABC, abstractmethod
from typing import TypeVar, Tuple
from jaxns.internals.types import PRNGKey, FloatArray, Sample, StaticStandardSampleCollection, \
StaticStandardNestedSamplerState
[docs]
SamplerState = TypeVar('SamplerState')
[docs]
class AbstractSampler(ABC):
@abstractmethod
[docs]
def pre_process(self, state: StaticStandardNestedSamplerState) -> SamplerState:
"""
Run this periodically on the state to produce a data pytree that can be used by the sampler, and
updated quickly.
Args:
state: the current state of the sampler
Returns:
any valid pytree
"""
...
@abstractmethod
[docs]
def post_process(self, sample_collection: StaticStandardSampleCollection,
sampler_state: SamplerState) -> SamplerState:
"""
Post process the sampler state, after the sampler has been run. Should be quick.
Args:
sample_collection: a sample collection post sample step
sampler_state: data pytree produced by the sampler
Returns:
the updated sampler state
"""
...
@abstractmethod
[docs]
def get_sample(self, key: PRNGKey, log_L_constraint: FloatArray, sampler_state: SamplerState) -> Tuple[
Sample, Sample]:
"""
Produce a single i.i.d. sample from the model within the log_L_constraint.
Args:
key: PRNGkey
log_L_constraint: the constraint to sample within
sampler_state: the data pytree needed and produced by the sampler
Returns:
an i.i.d. sample, and batched phantom samples
"""
...
@abstractmethod
[docs]
def num_phantom(self) -> int:
"""
The number of phantom samples produced by the sampler.
Returns:
number of phantom samples
"""
...