Source code for jaxns.nested_samplers.common.termination

import warnings
from typing import Union, Tuple

from jax import numpy as jnp

from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.stats import linear_to_log_stats, effective_sample_size_kish
from jaxns.internals.types import BoolArray, IntArray
from jaxns.nested_samplers.common.types import TerminationConditionDisjunction, TerminationConditionConjunction, \
    TerminationRegister, TerminationCondition


[docs] def determine_termination( term_cond: Union[TerminationConditionDisjunction, TerminationConditionConjunction, TerminationCondition], termination_register: TerminationRegister) -> Tuple[BoolArray, IntArray]: """ Determine if termination should happen. Termination Flags are bits: 0-bit -> 1: used maximum allowed number of samples 1-bit -> 2: evidence uncert below threshold 2-bit -> 4: live points evidence below threshold 3-bit -> 8: effective sample size big enough 4-bit -> 16: used maxmimum allowed number of likelihood evaluations 5-bit -> 32: maximum log-likelihood contour reached 6-bit -> 64: sampler efficiency too low 7-bit -> 128: entire live-points set is a single plateau 8-bit -> 256: relative spread of live points < rtol 9-bit -> 512: absolute spread of live points < atol 10-bit -> 1024: no seed points left 11-bit -> 2048: XL < max(XL) * peak_XL_frac Multiple flags are summed together Args: term_cond: termination condition termination_register: register of termination variables to check against termination condition Returns: boolean done signal, and termination reason """ termination_reason = jnp.asarray(0, mp_policy.count_dtype) done = jnp.asarray(False, jnp.bool_) def _set_done_bit(bit_done, bit_reason, done, termination_reason): if bit_done.size > 1: raise RuntimeError("bit_done must be a scalar.") done = jnp.bitwise_or(bit_done, done) termination_reason += jnp.where(bit_done, jnp.asarray(2 ** bit_reason, mp_policy.count_dtype), jnp.asarray(0, mp_policy.count_dtype)) return done, termination_reason if isinstance(term_cond, TerminationConditionConjunction): for c in term_cond.conds: _done, _reason = determine_termination(term_cond=c, termination_register=termination_register) done = jnp.bitwise_and(_done, done) termination_reason = jnp.bitwise_and(_reason, termination_reason) return done, termination_reason if isinstance(term_cond, TerminationConditionDisjunction): for c in term_cond.conds: _done, _reason = determine_termination(term_cond=c, termination_register=termination_register) done = jnp.bitwise_or(_done, done) termination_reason = jnp.bitwise_or(_reason, termination_reason) return done, termination_reason if term_cond.live_evidence_frac is not None: warnings.warn("live_evidence_frac is deprecated, use dlogZ instead.") if term_cond.max_samples is not None: # used all points reached_max_samples = termination_register.num_samples_used >= term_cond.max_samples done, termination_reason = _set_done_bit(reached_max_samples, 0, done=done, termination_reason=termination_reason) if term_cond.evidence_uncert is not None: _, log_Z_var = linear_to_log_stats( log_f_mean=termination_register.evidence_calc_with_remaining.log_Z_mean, log_f2_mean=termination_register.evidence_calc_with_remaining.log_Z2_mean) evidence_uncert_low_enough = log_Z_var <= jnp.square(term_cond.evidence_uncert) done, termination_reason = _set_done_bit(evidence_uncert_low_enough, 1, done=done, termination_reason=termination_reason) if term_cond.dlogZ is not None: # (Z_remaining + Z_current) / Z_remaining < exp(dlogZ) log_Z_mean_1, log_Z_var_1 = linear_to_log_stats( log_f_mean=termination_register.evidence_calc_with_remaining.log_Z_mean, log_f2_mean=termination_register.evidence_calc_with_remaining.log_Z2_mean) log_Z_mean_0, log_Z_var_0 = linear_to_log_stats( log_f_mean=termination_register.evidence_calc.log_Z_mean, log_f2_mean=termination_register.evidence_calc.log_Z2_mean) small_remaining_evidence = jnp.less( log_Z_mean_1 - log_Z_mean_0, term_cond.dlogZ ) done, termination_reason = _set_done_bit(small_remaining_evidence, 2, done=done, termination_reason=termination_reason) if term_cond.ess is not None: # Kish's ESS = [sum weights]^2 / [sum weights^2] ess = effective_sample_size_kish(termination_register.evidence_calc_with_remaining.log_Z_mean, termination_register.evidence_calc_with_remaining.log_dZ2_mean) ess_reached = ess >= term_cond.ess done, termination_reason = _set_done_bit(ess_reached, 3, done=done, termination_reason=termination_reason) if term_cond.max_num_likelihood_evaluations is not None: num_likelihood_evaluations = jnp.sum(termination_register.num_likelihood_evaluations) too_max_likelihood_evaluations = num_likelihood_evaluations >= term_cond.max_num_likelihood_evaluations done, termination_reason = _set_done_bit(too_max_likelihood_evaluations, 4, done=done, termination_reason=termination_reason) if term_cond.log_L_contour is not None: likelihood_contour_reached = termination_register.log_L_contour >= term_cond.log_L_contour done, termination_reason = _set_done_bit(likelihood_contour_reached, 5, done=done, termination_reason=termination_reason) if term_cond.efficiency_threshold is not None: efficiency_too_low = termination_register.efficiency < term_cond.efficiency_threshold done, termination_reason = _set_done_bit(efficiency_too_low, 6, done=done, termination_reason=termination_reason) done, termination_reason = _set_done_bit(termination_register.plateau, 7, done=done, termination_reason=termination_reason) if term_cond.rtol is not None: relative_spread_low = termination_register.relative_spread < term_cond.rtol done, termination_reason = _set_done_bit(relative_spread_low, 8, done=done, termination_reason=termination_reason) if term_cond.atol is not None: absolute_spread_low = termination_register.absolute_spread < term_cond.atol done, termination_reason = _set_done_bit(absolute_spread_low, 9, done=done, termination_reason=termination_reason) done, termination_reason = _set_done_bit(termination_register.no_seed_points, 10, done=done, termination_reason=termination_reason) if term_cond.peak_XL_frac is not None: log_XL = termination_register.evidence_calc.log_X_mean + termination_register.evidence_calc.log_L peak_log_XL = termination_register.peak_log_XL XL_reduction_reached = log_XL < peak_log_XL + jnp.log(term_cond.peak_XL_frac) done, termination_reason = _set_done_bit(XL_reduction_reached, 11, done=done, termination_reason=termination_reason) return done, termination_reason