Source code for jaxns.samplers.abc

from abc import ABC, abstractmethod
from typing import TypeVar, Tuple

from jaxns.internals.types import PRNGKey, FloatArray, Sample, StaticStandardSampleCollection, \
    StaticStandardNestedSamplerState

[docs] SamplerState = TypeVar('SamplerState')
[docs] class AbstractSampler(ABC): @abstractmethod
[docs] def pre_process(self, state: StaticStandardNestedSamplerState) -> SamplerState: """ Run this periodically on the state to produce a data pytree that can be used by the sampler, and updated quickly. Args: state: the current state of the sampler Returns: any valid pytree """ ...
@abstractmethod
[docs] def post_process(self, sample_collection: StaticStandardSampleCollection, sampler_state: SamplerState) -> SamplerState: """ Post process the sampler state, after the sampler has been run. Should be quick. Args: sample_collection: a sample collection post sample step sampler_state: data pytree produced by the sampler Returns: the updated sampler state """ ...
@abstractmethod
[docs] def get_sample(self, key: PRNGKey, log_L_constraint: FloatArray, sampler_state: SamplerState) -> Tuple[ Sample, Sample]: """ Produce a single i.i.d. sample from the model within the log_L_constraint. Args: key: PRNGkey log_L_constraint: the constraint to sample within sampler_state: the data pytree needed and produced by the sampler Returns: an i.i.d. sample, and batched phantom samples """ ...
@abstractmethod
[docs] def num_phantom(self) -> int: """ The number of phantom samples produced by the sampler. Returns: number of phantom samples """ ...