import logging
from typing import Optional, Tuple, Union
import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from jax import tree_map, core
from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.types import PRNGKey, IntArray, StaticStandardNestedSamplerState, TerminationCondition, \
NestedSamplerResults
from jaxns.nested_sampler.bases import BaseAbstractNestedSampler
from jaxns.nested_sampler.standard_static import StandardStaticNestedSampler
from jaxns.plotting import plot_cornerplot, plot_diagnostics
from jaxns.samplers.uni_slice_sampler import UniDimSliceSampler
from jaxns.utils import summary, save_results, load_results
tfpd = tfp.distributions
logger = logging.getLogger('jaxns')
__all__ = [
'DefaultNestedSampler',
'ApproximateNestedSampler',
'ExactNestedSampler',
'TerminationCondition'
]
[docs]
class DefaultNestedSampler:
"""
A static nested sampler that uses 1-dimensional slice sampler for the sampling step.
Uses the phantom-powered algorithm. A robust default choice is provided for all parameters.
"""
def __init__(self, model: BaseAbstractModel,
max_samples: Union[int, float],
num_live_points: Optional[int] = None,
s: Optional[int] = None,
k: Optional[int] = None,
c: Optional[int] = None,
num_parallel_workers: int = 1,
difficult_model: bool = False,
parameter_estimation: bool = False,
init_efficiency_threshold: float = 0.1,
verbose: bool = False):
"""
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
Args:
model: a model to perform nested sampling on
max_samples: maximum number of samples to take
num_live_points: approximate number of live points to use. Defaults is c * (k + 1).
s: number of slices to use per dimension. Defaults to 4.
k: number of phantom samples to use. Defaults to 0.
c: number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers: number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model: if True, uses more robust default settings. Defaults to False.
parameter_estimation: if True, uses more robust default settings for parameter estimation. Defaults to False.
init_efficiency_threshold: if > 0 then use uniform sampling first down to this acceptance efficiency.
0 turns it off.
verbose: whether to use JAX printing
"""
if difficult_model:
self._s = 10 if s is None else int(s)
else:
self._s = 5 if s is None else int(s)
if self._s <= 0:
raise ValueError(f"Expected s > 0, got s={self._s}")
if parameter_estimation:
self._k = model.U_ndims if k is None else int(k)
else:
self._k = 0 if k is None else int(k)
if not (0 <= self._k < self._s * model.U_ndims):
raise ValueError(f"Expected 0 <= k < s * U_ndims, got k={self._k}, s={self._s}, U_ndims={model.U_ndims}")
if num_live_points is not None:
self._c = max(1, int(num_live_points / (self._k + 1)))
logger.info(f"Number of parallel Markov-chains set to: {self._c}")
else:
if difficult_model:
self._c = 50 * model.U_ndims if c is None else int(c)
else:
self._c = 30 * model.U_ndims if c is None else int(c)
if self._c <= 0:
raise ValueError(f"Expected c > 0, got c={self._c}")
# Sanity check for max_samples (should be able to at least do one shrinkage)
if max_samples < self._c * (self._k + 1):
logger.warning(f"max_samples={max_samples} is likely too small!")
self._nested_sampler = StandardStaticNestedSampler(
model=model,
num_live_points=self._c,
max_samples=max_samples,
sampler=UniDimSliceSampler(
model=model,
num_slices=model.U_ndims * self._s,
num_phantom_save=self._k,
midpoint_shrink=True,
perfect=True
),
init_efficiency_threshold=init_efficiency_threshold,
num_parallel_workers=num_parallel_workers,
verbose=verbose
)
# Post-analysis utilities
self.summary = summary
self.plot_cornerplot = plot_cornerplot
self.plot_diagnostics = plot_diagnostics
self.save_results = save_results
self.load_results = load_results
[docs]
def __repr__(self):
return f"DefaultNestedSampler(s={self._s}, c={self._c}, k={self._k})"
@property
[docs]
def num_live_points(self) -> int:
return self._nested_sampler.num_live_points
@property
[docs]
def nested_sampler(self) -> BaseAbstractNestedSampler:
return self._nested_sampler
[docs]
def __call__(self, key: PRNGKey, term_cond: Optional[TerminationCondition] = None) -> Tuple[
IntArray, StaticStandardNestedSamplerState]:
"""
Performs nested sampling with the given termination conditions.
Args:
key: PRNGKey
term_cond: termination conditions. If not given, see `TerminationCondition` for defaults.
Returns:
termination reason, state
"""
if term_cond is None:
term_cond = TerminationCondition()
term_cond = term_cond._replace(
max_samples=jnp.minimum(term_cond.max_samples, self._nested_sampler.max_samples)
)
return self._nested_sampler._run(
key=key,
term_cond=term_cond
)
[docs]
def to_results(self, termination_reason: IntArray, state: StaticStandardNestedSamplerState,
trim: bool = True) -> NestedSamplerResults:
"""
Convert the state to results.
Note: Requires static context.
Args:
termination_reason: termination reason
state: state to convert
trim: if True, trims the results to the number of samples taken, requires static context.
Returns:
results
"""
return self._nested_sampler._to_results(
termination_reason=termination_reason,
state=state,
trim=trim
)
@staticmethod
[docs]
def trim_results(results: NestedSamplerResults) -> NestedSamplerResults:
"""
Trims the results to the number of samples taken. Requires static context.
Args:
results: results to trim
Returns:
trimmed results
"""
if isinstance(results.total_num_samples, core.Tracer):
raise RuntimeError("Tracer detected, but expected imperative context.")
def trim(x):
if x.size > 1:
return x[:results.total_num_samples]
return x
results = tree_map(trim, results)
return results
[docs]
class ApproximateNestedSampler(DefaultNestedSampler):
def __init__(self, *args, **kwargs):
logger.warning(f"ApproximateNestedSampler is deprecated. Use DefaultNestedSampler instead.")
super().__init__(*args, **kwargs)
[docs]
class ExactNestedSampler(ApproximateNestedSampler):
def __init__(self, *args, **kwargs):
logger.warning(f"ExactNestedSampler is deprecated. Use DefaultNestedSampler instead.")
super().__init__(*args, **kwargs)