Source code for jaxns.public

import dataclasses
from typing import Optional, Tuple, Union, List

import jax
import jax.numpy as jnp
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import core
from jaxlib import xla_client

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.logging import logger
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import PRNGKey, IntArray
from jaxns.nested_samplers.abc import AbstractNestedSampler
from jaxns.nested_samplers.common.types import TerminationCondition, NestedSamplerResults, \
    NestedSamplerState
from jaxns.nested_samplers.sharded import ShardedStaticNestedSampler
from jaxns.plotting import plot_cornerplot, plot_diagnostics
from jaxns.samplers.uni_slice_sampler import UniDimSliceSampler
from jaxns.utils import summary, save_results, load_results

tfpd = tfp.distributions

__all__ = [
    'NestedSampler'
]


@dataclasses.dataclass(eq=False)
[docs] class NestedSampler: """ A static nested sampler that uses 1-dimensional slice sampler for the sampling step. Uses the phantom-powered algorithm. A robust default choice is provided for all parameters. s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330 Args: model: a model to perform nested sampling on max_samples: maximum number of samples to take num_live_points: approximate number of live points to use. Defaults is c * (k + 1). s: number of slices to use per dimension. Defaults to 4. k: number of phantom samples to use. Defaults to 0. c: number of parallel Markov-chains to use. Defaults to 20 * D. devices: devices to use. Defaults to all available devices. difficult_model: if True, uses more robust default settings. Defaults to False. parameter_estimation: if True, uses more robust default settings for parameter estimation. Defaults to False. shell_fraction: fraction of the shell to use for the slice sampler. Defaults to 0.5. gradient_guided: if True, uses gradient guided sampling. Defaults to False. init_efficiency_threshold: if > 0 then use uniform sampling first down to this acceptance efficiency. 0 turns it off. verbose: whether to log progress. """
[docs] model: BaseAbstractModel
[docs] max_samples: Optional[Union[int, float]] = None
[docs] num_live_points: Optional[int] = None
[docs] num_slices: Optional[int] = None
[docs] s: Optional[Union[int, float]] = None
[docs] k: Optional[int] = None
[docs] c: Optional[int] = None
[docs] devices: Optional[List[xla_client.Device]] = None
[docs] difficult_model: bool = False
[docs] parameter_estimation: bool = False
[docs] shell_fraction: float = 0.5
[docs] gradient_guided: bool = False
[docs] init_efficiency_threshold: float = 0.1
[docs] verbose: bool = False
[docs] def __post_init__(self): # Determine number of slices per acceptance if self.num_slices is None: if self.difficult_model: self.s = 10 if self.s is None else float(self.s) else: self.s = 5 if self.s is None else float(self.s) if self.s <= 0: raise ValueError(f"Expected s > 0, got s={self.s}") self.num_slices = self.model.U_ndims * self.s self.num_slices = int(self.num_slices) # Determine number of phantom samples if self.parameter_estimation: max_k = self.s * self.model.U_ndims - 1 self.k = min(self.model.U_ndims, max_k) if self.k is None else int(self.k) else: self.k = 0 if self.k is None else int(self.k) if not (0 <= self.k < self.num_slices): raise ValueError( f"Expected 0 <= k < num_slices, got k={self.k}, num_slices={self.num_slices}, U_ndims={self.model.U_ndims}") # Determine number of parallel Markov-chains if self.num_live_points is not None: self.c = max(1, int(self.num_live_points / (self.k + 1))) logger.info(f"Number of Markov-chains set to: {self.c}") else: if self.difficult_model: self.c = 100 * self.model.U_ndims if self.c is None else int(self.c) else: self.c = 30 * self.model.U_ndims if self.c is None else int(self.c) if self.c <= 0: raise ValueError(f"Expected c > 0, got c={self.c}") # Sanity check for max_samples (should be able to at least do one shrinkage) if self.max_samples is None: # Default to 100 shrinkages. self.max_samples = self.c * (self.k + 1) * 100 self.max_samples = int(self.max_samples) self._nested_sampler = ShardedStaticNestedSampler( model=self.model, num_live_points=self.c, max_samples=self.max_samples, sampler=UniDimSliceSampler( model=self.model, num_slices=self.num_slices, num_phantom_save=self.k, midpoint_shrink=not self.difficult_model, gradient_guided=self.gradient_guided, perfect=True ), init_efficiency_threshold=self.init_efficiency_threshold, shell_fraction=self.shell_fraction, devices=self.devices, verbose=self.verbose, ) # Back propagate any updates here self.num_live_points = self._nested_sampler.num_live_points # Post-analysis utilities self.summary = summary self.plot_cornerplot = plot_cornerplot self.plot_diagnostics = plot_diagnostics self.save_results = save_results self.load_results = load_results
@property
[docs] def nested_sampler(self) -> AbstractNestedSampler: return self._nested_sampler
[docs] def __call__(self, key: PRNGKey, term_cond: Optional[TerminationCondition] = None) -> Tuple[ IntArray, NestedSamplerState]: """ Performs nested sampling with the given termination conditions. Args: key: PRNGKey term_cond: termination conditions. If not given, see `TerminationCondition` for defaults. Returns: termination reason, state """ if term_cond is None: if self.parameter_estimation: term_cond = TerminationCondition( peak_XL_frac=jnp.asarray(0.1, mp_policy.measure_dtype), max_samples=jnp.asarray(jnp.iinfo(mp_policy.count_dtype).max, mp_policy.count_dtype) ) else: term_cond = TerminationCondition( dlogZ=jnp.asarray(np.log(1. + 1e-3), mp_policy.measure_dtype), max_samples=jnp.asarray(jnp.iinfo(mp_policy.count_dtype).max, mp_policy.count_dtype) ) term_cond = term_cond._replace( max_samples=( jnp.minimum(term_cond.max_samples, jnp.asarray(self._nested_sampler.max_samples, mp_policy.count_dtype)) if term_cond.max_samples is not None else jnp.asarray(self._nested_sampler.max_samples, mp_policy.count_dtype) ) ) termination_reason, termination_register, state = self._nested_sampler._run( key=key, term_cond=term_cond ) return termination_reason, state
[docs] def to_results(self, termination_reason: IntArray, state: NestedSamplerState, trim: bool = True) -> NestedSamplerResults: """ Convert the state to results. Note: Requires static context. Args: termination_reason: termination reason state: state to convert trim: if True, trims the results to the number of samples taken, requires static context. Returns: results """ return self._nested_sampler._to_results( termination_reason=termination_reason, state=state, trim=trim )
@staticmethod
[docs] def trim_results(results: NestedSamplerResults) -> NestedSamplerResults: """ Trims the results to the number of samples taken. Requires static context. Args: results: results to trim Returns: trimmed results """ if isinstance(results.total_num_samples, core.Tracer): raise RuntimeError("Tracer detected, but expected imperative context.") def trim(x): if x.size > 1: return x[:results.total_num_samples] return x results = jax.tree.map(trim, results) return results
DefaultNestedSampler = NestedSampler