Source code for jaxns.nested_sampler.standard_static

import logging
from typing import Tuple, NamedTuple, Any, Union

import jax
from jax import random, pmap, tree_map, numpy as jnp, lax, core, vmap
from jax._src.lax import parallel

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.cumulative_ops import cumulative_op_static
from jaxns.internals.log_semiring import LogSpace, normalise_log_space
from jaxns.internals.shrinkage_statistics import compute_evidence_stats, init_evidence_calc, \
    update_evicence_calculation, EvidenceUpdateVariables, _update_evidence_calc_op
from jaxns.internals.stats import linear_to_log_stats, effective_sample_size
from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges, unbatch_state
from jaxns.internals.types import TerminationCondition, IntArray, PRNGKey, BoolArray, int_type, UType, MeasureType, \
    float_type, \
    TerminationConditionDisjunction, \
    TerminationConditionConjunction, Sample, StaticStandardSampleCollection, \
    StaticStandardNestedSamplerState, NestedSamplerResults, EvidenceCalculation, TerminationRegister
from jaxns.nested_sampler.bases import BaseAbstractNestedSampler
from jaxns.samplers.abc import SamplerState
from jaxns.samplers.bases import BaseAbstractSampler
from jaxns.samplers.uniform_samplers import UniformSampler

__all__ = [
    'TerminationCondition',
    'StandardStaticNestedSampler'
]

logger = logging.getLogger('jaxns')


def _inter_sync_shrinkage_process(
        init_state: StaticStandardNestedSamplerState,
        init_termination_register: TerminationRegister,
        sampler: BaseAbstractSampler,
        num_samples: int) -> Tuple[StaticStandardNestedSamplerState, TerminationRegister]:
    """
    Run nested sampling until `num_samples` samples are collected.

    Args:
        init_state: the state of the nested sampler at the start
        init_termination_register: the termination register at the start
        sampler: sampler to use
        num_samples: number of samples to take, i.e. work to do, must be >= front_size

    Returns:
        sampler state with samples added
    """

    front_size = init_state.front_idx.size
    if num_samples < front_size:
        raise RuntimeError(f"num_samples ({num_samples}) must be >= front_size ({front_size})")

    max_num_samples = init_state.sample_collection.log_L.size

    class CarryType(NamedTuple):
        front_sample_collection: StaticStandardSampleCollection
        sampler_state: SamplerState
        key: PRNGKey
        front_idx: IntArray
        next_sample_idx: IntArray
        evidence_calc: EvidenceCalculation

    class ResultType(NamedTuple):
        replace_idx: IntArray
        sample_collection: StaticStandardSampleCollection

    def body(carry: CarryType, unused_X: IntArray) -> Tuple[CarryType, ResultType]:
        front_loc = jnp.argmin(carry.front_sample_collection.log_L)
        dead_idx = carry.front_idx[front_loc]

        # Node index is based on root of 0, so sample-nodes are 1-indexed
        dead_node_idx = dead_idx + jnp.asarray(1, int_type)

        log_L_contour = carry.front_sample_collection.log_L[front_loc]

        # Update evidence calculation
        next_evidence_calculation = update_evicence_calculation(
            evidence_calculation=carry.evidence_calc,
            update=EvidenceUpdateVariables(
                num_live_points=jnp.asarray(carry.front_idx.size, float_type),
                log_L_next=log_L_contour
            )
        )

        key, sample_key = random.split(carry.key, 2)

        sample, phantom_samples = sampler.get_sample(
            key=sample_key,
            log_L_constraint=log_L_contour,
            sampler_state=carry.sampler_state
        )
        # Replace sample in front_sample_collection
        front_sample_collection = carry.front_sample_collection._replace(
            sender_node_idx=carry.front_sample_collection.sender_node_idx.at[front_loc].set(dead_node_idx),
            log_L=carry.front_sample_collection.log_L.at[front_loc].set(sample.log_L),
            U_samples=carry.front_sample_collection.U_samples.at[front_loc].set(sample.U_sample),
            num_likelihood_evaluations=carry.front_sample_collection.num_likelihood_evaluations.at[front_loc].set(
                sample.num_likelihood_evaluations),
            # Phantom samples are not on the front, so don't need to be updated from default of False
        )
        front_idx = carry.front_idx.at[front_loc].set(carry.next_sample_idx)

        # Set (non-phantom) sample as the next sample

        new_replace_idx = [
            carry.next_sample_idx[None]
        ]
        new_sender_node_idx = [
            dead_node_idx[None]
        ]
        new_log_L = [
            sample.log_L[None]
        ]
        new_U_samples = [
            sample.U_sample[None]
        ]
        new_num_likelihood_evaluations = [
            sample.num_likelihood_evaluations[None]
        ]
        new_phantom = [
            jnp.zeros((1,), jnp.bool_)
        ]

        next_sample_idx = jnp.minimum(carry.next_sample_idx + 1, max_num_samples)

        # Set phantom samples, whose sender nodes are all the dead point. These do not get set on the front.

        num_phantom = phantom_samples.log_L.size
        new_replace_idx.append(
            (next_sample_idx + jnp.arange(num_phantom)).astype(next_sample_idx.dtype)
        )
        new_sender_node_idx.append(
            jnp.full((num_phantom,), dead_node_idx)
        )
        new_log_L.append(
            phantom_samples.log_L
        )
        new_U_samples.append(
            phantom_samples.U_sample
        )
        new_num_likelihood_evaluations.append(
            phantom_samples.num_likelihood_evaluations
        )
        new_phantom.append(
            jnp.ones((num_phantom,), dtype=jnp.bool_)
        )

        next_sample_idx = jnp.minimum(next_sample_idx + num_phantom, max_num_samples)

        new_sample_collection = StaticStandardSampleCollection(
            sender_node_idx=jnp.concatenate(new_sender_node_idx, axis=0),
            log_L=jnp.concatenate(new_log_L, axis=0),
            U_samples=jnp.concatenate(new_U_samples, axis=0),
            num_likelihood_evaluations=jnp.concatenate(new_num_likelihood_evaluations, axis=0),
            phantom=jnp.concatenate(new_phantom, axis=0)
        )
        new_replace_idx = jnp.concatenate(new_replace_idx, axis=0)

        # Fast update of sampler state given a new sample collection that satisfies the front
        sampler_state = sampler.post_process(sample_collection=front_sample_collection,
                                             sampler_state=carry.sampler_state)

        new_carry = CarryType(
            front_sample_collection=front_sample_collection,
            sampler_state=sampler_state,
            key=key,
            front_idx=front_idx,
            next_sample_idx=next_sample_idx,
            evidence_calc=next_evidence_calculation
        )

        new_return = ResultType(
            replace_idx=new_replace_idx,
            sample_collection=new_sample_collection
        )

        return new_carry, new_return

    # Sampler state is created before all this work. Quickly updated during shrinkage.
    init_sampler_state = sampler.pre_process(state=init_state)
    init_front_sample_collection = tree_map(lambda x: x[init_state.front_idx], init_state.sample_collection)
    key, carry_key = random.split(init_state.key)
    init_carry = CarryType(
        sampler_state=init_sampler_state,
        key=carry_key,
        front_idx=init_state.front_idx,
        front_sample_collection=init_front_sample_collection,
        next_sample_idx=init_state.next_sample_idx,
        evidence_calc=init_termination_register.evidence_calc
    )
    out_carry, out_return = lax.scan(body, init_carry, jnp.arange(num_samples), unroll=1)

    # Replace the samples in the sample collection with out_return counterparts.
    sample_collection = tree_map(
        lambda x, y: x.at[out_return.replace_idx].set(y),
        init_state.sample_collection,
        out_return.sample_collection
    )
    # Note, discard front_sample_collection since it's already in out_return, and we've replaced the whole front.

    # Take front_idx and next_sample_idx from carry, which have been kept up-to-date at every iteration.
    state = StaticStandardNestedSamplerState(
        key=key,
        next_sample_idx=out_carry.next_sample_idx,
        sample_collection=sample_collection,
        front_idx=out_carry.front_idx
    )
    # Update termination register
    _n = init_state.front_idx.size
    _num_samples = _n
    evidence_calc_with_remaining, _ = cumulative_op_static(
        op=_update_evidence_calc_op,
        init=out_carry.evidence_calc,
        xs=EvidenceUpdateVariables(
            num_live_points=jnp.arange(_n, 0., -1., float_type),
            log_L_next=jnp.sort(out_carry.front_sample_collection.log_L)
        ),
    )
    num_likelihood_evaluations = init_termination_register.num_likelihood_evaluations + jnp.sum(
        out_return.sample_collection.num_likelihood_evaluations)
    efficiency = out_return.sample_collection.log_L.size / num_likelihood_evaluations
    plateau = jnp.all(jnp.equal(out_carry.front_sample_collection.log_L, out_carry.front_sample_collection.log_L[0]))
    termination_register = TerminationRegister(
        num_samples_used=out_carry.next_sample_idx,
        evidence_calc=out_carry.evidence_calc,
        evidence_calc_with_remaining=evidence_calc_with_remaining,
        num_likelihood_evaluations=num_likelihood_evaluations,
        log_L_contour=out_carry.evidence_calc.log_L,
        efficiency=efficiency,
        plateau=plateau
    )
    return state, termination_register


def _single_thread_ns(init_state: StaticStandardNestedSamplerState,
                      init_termination_register: TerminationRegister,
                      termination_cond: TerminationCondition,
                      sampler: BaseAbstractSampler,
                      num_samples_per_sync: int,
                      verbose: bool = False) -> Tuple[
    StaticStandardNestedSamplerState, TerminationRegister, IntArray]:
    """
    Runs a single thread of static nested sampling until a stopping condition is reached. Runs `num_samples_per_sync`
    between updating samples to limit memory ops.

    Args:
        init_state: the state of the nested sampler at the start
        termination_cond: the termination condition
        sampler: the sampler to use
        num_samples_per_sync: number of samples to take per all-gather
        verbose: whether to log debug messages.

    Returns:
        final sampler state
    """

    # Update the termination condition to stop before going over the maximum number of samples.
    space_needed_per_sync = num_samples_per_sync * (sampler.num_phantom() + 1)
    termination_cond = termination_cond._replace(
        max_samples=jnp.minimum(
            termination_cond.max_samples,
            init_state.sample_collection.log_L.size - space_needed_per_sync
        )
    )

    class CarryType(NamedTuple):
        state: StaticStandardNestedSamplerState
        termination_register: TerminationRegister

    def cond(carry: CarryType) -> BoolArray:
        done, termination_reason = determine_termination(
            term_cond=termination_cond,
            termination_register=carry.termination_register
        )
        return jnp.bitwise_not(done)

    def body(carry: CarryType) -> CarryType:
        # Devices are independent, i.e. expect no communication between them in sampler.
        state, termination_register = _inter_sync_shrinkage_process(
            init_state=carry.state,
            sampler=sampler,
            num_samples=num_samples_per_sync,
            init_termination_register=carry.termination_register
        )
        if verbose:
            log_Z_mean, log_Z_var = linear_to_log_stats(
                log_f_mean=termination_register.evidence_calc_with_remaining.log_Z_mean,
                log_f2_mean=termination_register.evidence_calc_with_remaining.log_Z2_mean)
            log_Z_uncert = jnp.sqrt(log_Z_var)
            jax.debug.print(
                "-------\n"
                "Num samples: {num_samples}\n"
                "Num likelihood evals: {num_likelihood_evals}\n"
                "Efficiency: {efficiency}\n"
                "log(L) contour: {log_L_contour}\n"
                "log(Z) est.: {log_Z_mean} +- {log_Z_uncert}",
                num_samples=termination_register.num_samples_used,
                num_likelihood_evals=termination_register.num_likelihood_evaluations,
                efficiency=termination_register.efficiency,
                log_L_contour=termination_register.log_L_contour,
                log_Z_mean=log_Z_mean,
                log_Z_uncert=log_Z_uncert
            )

        return CarryType(state=state, termination_register=termination_register)

    init_carry_state = CarryType(
        state=init_state,
        termination_register=init_termination_register
    )

    carry_state: CarryType = lax.while_loop(
        cond_fun=cond,
        body_fun=body,
        init_val=init_carry_state
    )

    _, termination_reason = determine_termination(
        term_cond=termination_cond,
        termination_register=carry_state.termination_register
    )

    return carry_state.state, carry_state.termination_register, termination_reason


def create_init_termination_register() -> TerminationRegister:
    """
    Initialise the termination register.

    Returns:
        The initial termination register.
    """
    return TerminationRegister(
        num_samples_used=jnp.asarray(0, int_type),
        evidence_calc=init_evidence_calc(),
        evidence_calc_with_remaining=init_evidence_calc(),
        num_likelihood_evaluations=jnp.asarray(0, int_type),
        log_L_contour=jnp.asarray(-jnp.inf, float_type),
        efficiency=jnp.asarray(0., float_type),
        plateau=jnp.asarray(False, bool)
    )


def determine_termination(
        term_cond: Union[TerminationConditionDisjunction, TerminationConditionConjunction, TerminationCondition],
        termination_register: TerminationRegister) -> Tuple[BoolArray, IntArray]:
    """
    Determine if termination should happen. Termination Flags are bits:
        0-bit -> 1: used maximum allowed number of samples
        1-bit -> 2: evidence uncert below threshold
        2-bit -> 4: live points evidence below threshold
        3-bit -> 8: effective sample size big enough
        4-bit -> 16: used maxmimum allowed number of likelihood evaluations
        5-bit -> 32: maximum log-likelihood contour reached
        6-bit -> 64: sampler efficiency too low
        7-bit -> 128: entire live-points set is a single plateau

    Multiple flags are summed together

    Args:
        term_cond: termination condition
        termination_register: register of termination variables to check against termination condition

    Returns:
        boolean done signal, and termination reason
    """

    termination_reason = jnp.asarray(0, int_type)
    done = jnp.asarray(False, jnp.bool_)

    def _set_done_bit(bit_done, bit_reason, done, termination_reason):
        if bit_done.size > 1:
            raise RuntimeError("bit_done must be a scalar.")
        done = jnp.bitwise_or(bit_done, done)
        termination_reason += jnp.where(bit_done,
                                        jnp.asarray(2 ** bit_reason, int_type),
                                        jnp.asarray(0, int_type))
        return done, termination_reason

    if isinstance(term_cond, TerminationConditionConjunction):
        for c in term_cond.conds:
            _done, _reason = determine_termination(term_cond=c, termination_register=termination_register)
            done = jnp.bitwise_and(_done, done)
            termination_reason = jnp.bitwise_and(_reason, termination_reason)
        return done, termination_reason

    if isinstance(term_cond, TerminationConditionDisjunction):
        for c in term_cond.conds:
            _done, _reason = determine_termination(term_cond=c, termination_register=termination_register)
            done = jnp.bitwise_or(_done, done)
            termination_reason = jnp.bitwise_or(_reason, termination_reason)
        return done, termination_reason

    if term_cond.live_evidence_frac is not None:
        logger.warning("live_evidence_frac is deprecated, use dlogZ instead.")

    if term_cond.max_samples is not None:
        # used all points
        reached_max_samples = termination_register.num_samples_used >= term_cond.max_samples
        done, termination_reason = _set_done_bit(reached_max_samples, 0,
                                                 done=done, termination_reason=termination_reason)

    if term_cond.evidence_uncert is not None:
        _, log_Z_var = linear_to_log_stats(
            log_f_mean=termination_register.evidence_calc_with_remaining.log_Z_mean,
            log_f2_mean=termination_register.evidence_calc_with_remaining.log_Z2_mean)
        evidence_uncert_low_enough = log_Z_var <= jnp.square(term_cond.evidence_uncert)
        done, termination_reason = _set_done_bit(evidence_uncert_low_enough, 1,
                                                 done=done, termination_reason=termination_reason)

    if term_cond.dlogZ is not None:
        # (Z_remaining + Z_current) / Z_remaining < exp(dlogZ)
        log_Z_mean_1, log_Z_var_1 = linear_to_log_stats(
            log_f_mean=termination_register.evidence_calc_with_remaining.log_Z_mean,
            log_f2_mean=termination_register.evidence_calc_with_remaining.log_Z2_mean)

        log_Z_mean_0, log_Z_var_0 = linear_to_log_stats(
            log_f_mean=termination_register.evidence_calc.log_Z_mean,
            log_f2_mean=termination_register.evidence_calc.log_Z2_mean)

        small_remaining_evidence = jnp.less(
            log_Z_mean_1 - log_Z_mean_0, term_cond.dlogZ
        )
        done, termination_reason = _set_done_bit(small_remaining_evidence, 2,
                                                 done=done, termination_reason=termination_reason)

    if term_cond.ess is not None:
        # Kish's ESS = [sum weights]^2 / [sum weights^2]
        ess = effective_sample_size(termination_register.evidence_calc_with_remaining.log_Z_mean,
                                    termination_register.evidence_calc_with_remaining.log_dZ2_mean)
        ess_reached = ess >= term_cond.ess
        done, termination_reason = _set_done_bit(ess_reached, 3,
                                                 done=done, termination_reason=termination_reason)

    if term_cond.max_num_likelihood_evaluations is not None:
        num_likelihood_evaluations = jnp.sum(termination_register.num_likelihood_evaluations)
        too_max_likelihood_evaluations = num_likelihood_evaluations >= term_cond.max_num_likelihood_evaluations
        done, termination_reason = _set_done_bit(too_max_likelihood_evaluations, 4,
                                                 done=done, termination_reason=termination_reason)

    if term_cond.log_L_contour is not None:
        likelihood_contour_reached = termination_register.log_L_contour >= term_cond.log_L_contour
        done, termination_reason = _set_done_bit(likelihood_contour_reached, 5,
                                                 done=done, termination_reason=termination_reason)

    if term_cond.efficiency_threshold is not None:
        efficiency_too_low = termination_register.efficiency < term_cond.efficiency_threshold
        done, termination_reason = _set_done_bit(efficiency_too_low, 6,
                                                 done=done, termination_reason=termination_reason)

    done, termination_reason = _set_done_bit(termination_register.plateau, 7,
                                             done=done, termination_reason=termination_reason)

    return done, termination_reason


def _single_uniform_sample(key: PRNGKey, model: BaseAbstractModel) -> Sample:
    """
    Gets a single sample strictly within -inf bound (the entire prior), accounting for forbidden regions.

    Args:
        key: PRNGKey
        model: the model to use.

    Returns:
        a sample
    """

    log_L_constraint = jnp.asarray(-jnp.inf, float_type)

    class CarryState(NamedTuple):
        key: PRNGKey
        U: UType
        log_L: MeasureType
        num_likelihood_evals: IntArray

    def cond(carry_state: CarryState):
        return carry_state.log_L <= log_L_constraint

    def body(carry_state: CarryState) -> CarryState:
        key, sample_key = random.split(carry_state.key, 2)
        U = model.sample_U(key=sample_key)
        log_L = model.forward(U=U)
        num_likelihood_evals = carry_state.num_likelihood_evals + jnp.ones_like(carry_state.num_likelihood_evals)
        return CarryState(key=key, U=U, log_L=log_L, num_likelihood_evals=num_likelihood_evals)

    key, sample_key = random.split(key, 2)
    init_U = model.sample_U(key=sample_key)
    init_log_L = model.forward(init_U)
    init_carry_state = CarryState(
        key=key,
        U=init_U,
        log_L=init_log_L,
        num_likelihood_evals=jnp.asarray(1, int_type)
    )

    carry_state = lax.while_loop(cond_fun=cond, body_fun=body, init_val=init_carry_state)

    sample = Sample(
        U_sample=carry_state.U,
        log_L_constraint=log_L_constraint,
        log_L=carry_state.log_L,
        num_likelihood_evaluations=carry_state.num_likelihood_evals
    )
    return sample


def draw_uniform_samples(key: PRNGKey, num_live_points: int, model: BaseAbstractModel, method: str = 'vmap') -> Sample:
    """
    Get initial live points from uniformly sampling the entire prior.

    Args:
        key: PRNGKey
        num_live_points: the number of live points
        model: the model
        method: which way to draw the init points. vmap is vectorised, and for performant but uses more memory.

    Returns:
        uniformly drawn samples within -inf bound
    """

    keys = random.split(key, num_live_points)
    if method == 'vmap':
        return jax.vmap(lambda _key: _single_uniform_sample(key=_key, model=model))(keys)
    elif method == 'scan':

        def body(carry_unused: Any, key: PRNGKey) -> Tuple[Any, Sample]:
            return carry_unused, _single_uniform_sample(key=key, model=model)

        _, samples = lax.scan(body, (), keys)

        return samples
    else:
        raise ValueError(f'Invalid method {method}')


def create_init_state(key: PRNGKey, num_live_points: int, max_samples: int,
                      model: BaseAbstractModel) -> StaticStandardNestedSamplerState:
    """
    Return an initial sample collection, that will be incremented by the sampler.

    Args:
        key: PRNGKey
        num_live_points: the number of live points
        max_samples: the maximum number of samples
        model: the model to use.

    Returns:
        sample collection
    """

    def _repeat(a):
        return jnp.repeat(a[None], repeats=max_samples, axis=0, total_repeat_length=max_samples)

    sample_collection = StaticStandardSampleCollection(
        sender_node_idx=jnp.zeros(max_samples, dtype=int_type),
        log_L=jnp.full((max_samples,), jnp.inf, dtype=float_type),
        U_samples=_repeat(model.U_placeholder),
        num_likelihood_evaluations=jnp.full((max_samples,), 0, dtype=int_type),
        phantom=jnp.full((max_samples,), False, dtype=jnp.bool_)
    )

    key, sample_key = random.split(key, 2)
    init_samples = draw_uniform_samples(key=sample_key, num_live_points=num_live_points, model=model)
    # Merge the initial samples into the sample collection
    sample_collection = sample_collection._replace(
        log_L=sample_collection.log_L.at[:num_live_points].set(init_samples.log_L),
        U_samples=sample_collection.U_samples.at[:num_live_points].set(init_samples.U_sample),
        num_likelihood_evaluations=sample_collection.num_likelihood_evaluations.at[:num_live_points].set(
            init_samples.num_likelihood_evaluations)
    )

    return StaticStandardNestedSamplerState(
        key=key,
        next_sample_idx=jnp.asarray(num_live_points, int_type),
        sample_collection=sample_collection,
        front_idx=jnp.arange(num_live_points, dtype=int_type)
    )


[docs] class StandardStaticNestedSampler(BaseAbstractNestedSampler): """ A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the samples until the termination condition is met. """ def __init__(self, init_efficiency_threshold: float, sampler: BaseAbstractSampler, num_live_points: int, model: BaseAbstractModel, max_samples: int, num_parallel_workers: int = 1, verbose: bool = False): """ Initialise the static nested sampler. Args: init_efficiency_threshold: the efficiency threshold to use for the initial uniform sampling. If 0 then turns it off. sampler: the sampler to use after the initial uniform sampling. num_live_points: the number of live points to use. model: the model to use. max_samples: the maximum number of samples to take. num_parallel_workers: number of parallel workers to use. Defaults to 1. Experimental feature. verbose: whether to log as we go. """ self.init_efficiency_threshold = init_efficiency_threshold self.sampler = sampler self.num_live_points = int(num_live_points) self.num_parallel_workers = int(num_parallel_workers) self.verbose = bool(verbose) remainder = max_samples % self.num_live_points extra = (max_samples - remainder) % self.num_live_points if extra > 0: logger.warning( f"Increasing max_samples ({max_samples}) by {extra} to closest multiple of " f"num_live_points {self.num_live_points}." ) max_samples = int(max_samples + extra) if self.num_parallel_workers > 1: logger.info(f"Using {self.num_parallel_workers} parallel workers, each running identical samplers.") super().__init__(model=model, max_samples=max_samples)
[docs] def __repr__(self): return f"StandardStaticNestedSampler(init_efficiency_threshold={self.init_efficiency_threshold}, " \ f"sampler={self.sampler}, num_live_points={self.num_live_points}, model={self.model}, " \ f"max_samples={self.max_samples}, num_parallel_workers={self.num_parallel_workers})"
def _to_results(self, termination_reason: IntArray, state: StaticStandardNestedSamplerState, trim: bool) -> NestedSamplerResults: num_samples = jnp.minimum(state.next_sample_idx, state.sample_collection.log_L.size) sample_collection = state.sample_collection if trim: if isinstance(num_samples, core.Tracer): raise RuntimeError("Tracer detected, but expected imperative context.") sample_collection = tree_map(lambda x: x[:num_samples], sample_collection) sample_tree = SampleTreeGraph( sender_node_idx=sample_collection.sender_node_idx, log_L=sample_collection.log_L ) live_point_counts = count_crossed_edges(sample_tree=sample_tree) num_live_points = live_point_counts.num_live_points log_L = sample_tree.log_L[live_point_counts.samples_indices] U_samples = sample_collection.U_samples[live_point_counts.samples_indices] num_likelihood_evaluations = sample_collection.num_likelihood_evaluations[live_point_counts.samples_indices] final_evidence_stats, per_sample_evidence_stats = compute_evidence_stats( log_L=log_L, num_live_points=num_live_points ) else: sample_tree = SampleTreeGraph( sender_node_idx=sample_collection.sender_node_idx, log_L=sample_collection.log_L ) live_point_counts = count_crossed_edges(sample_tree=sample_tree, num_samples=num_samples) num_live_points = live_point_counts.num_live_points log_L = sample_tree.log_L[live_point_counts.samples_indices] U_samples = sample_collection.U_samples[live_point_counts.samples_indices] num_likelihood_evaluations = sample_collection.num_likelihood_evaluations[live_point_counts.samples_indices] final_evidence_stats, per_sample_evidence_stats = compute_evidence_stats( log_L=log_L, num_live_points=num_live_points, num_samples=num_samples ) log_Z_mean, log_Z_var = linear_to_log_stats( log_f_mean=final_evidence_stats.log_Z_mean, log_f2_mean=final_evidence_stats.log_Z2_mean ) log_Z_uncert = jnp.sqrt(log_Z_var) # Correction by sqrt(k+1) total_phantom_samples = jnp.sum(sample_collection.phantom.astype(int_type)) phantom_fraction = total_phantom_samples / num_samples # k / (k+1) k = phantom_fraction / (1. - phantom_fraction) log_Z_uncert = log_Z_uncert * jnp.sqrt(1. + k) # Kish's ESS = [sum dZ]^2 / [sum dZ^2] ESS = effective_sample_size(final_evidence_stats.log_Z_mean, final_evidence_stats.log_dZ2_mean) ESS = ESS / (1. + k) samples = vmap(self.model.transform)(U_samples) parametrised_samples = vmap(self.model.transform_parametrised)(U_samples) log_L_samples = log_L dp_mean = LogSpace(per_sample_evidence_stats.log_dZ_mean) dp_mean = normalise_log_space(dp_mean) H_mean_instable = -((dp_mean * LogSpace(jnp.log(jnp.abs(log_L_samples)), jnp.sign(log_L_samples))).sum().value - log_Z_mean) # H \approx E[-log(compression)] = E[-log(X)] (More stable than E[log(L) - log(Z)] H_mean_stable = -((dp_mean * LogSpace(jnp.log(-per_sample_evidence_stats.log_X_mean))).sum().value) H_mean = jnp.where(jnp.isfinite(H_mean_instable), H_mean_instable, H_mean_stable) X_mean = LogSpace(per_sample_evidence_stats.log_X_mean) num_likelihood_evaluations_per_sample = num_likelihood_evaluations total_num_likelihood_evaluations = jnp.sum(num_likelihood_evaluations_per_sample) num_live_points_per_sample = num_live_points efficiency = LogSpace(jnp.log(num_samples) - jnp.log(total_num_likelihood_evaluations)) log_posterior_density = log_L + vmap(self.model.log_prob_prior)( U_samples) return NestedSamplerResults( log_Z_mean=log_Z_mean, # estimate of log(E[Z]) log_Z_uncert=log_Z_uncert, # estimate of log(StdDev[Z]) ESS=ESS, # estimate of Kish's effective sample size H_mean=H_mean, # estimate of E[int log(L) L dp/Z] total_num_samples=num_samples, # int, the total number of samples collected. total_phantom_samples=total_phantom_samples, # int, the total number of phantom samples collected. log_L_samples=log_L_samples, # log(L) of each sample log_dp_mean=dp_mean.log_abs_val, log_posterior_density=log_posterior_density, # log(E[dZ]) of each sample, where dZ is how much it contributes to the total evidence. # log(StdDev[dZ]) of each sample, where dZ is how much it contributes to the total evidence. log_X_mean=X_mean.log_abs_val, # log(E[U]) of each sample num_likelihood_evaluations_per_sample=num_likelihood_evaluations_per_sample, # how many likelihood evaluations were made per sample. num_live_points_per_sample=num_live_points_per_sample, # how many live points were taken for the samples. total_num_likelihood_evaluations=total_num_likelihood_evaluations, # how many likelihood evaluations were made in total, # sum of num_likelihood_evaluations_per_sample. log_efficiency=efficiency.log_abs_val, # total_num_samples / total_num_likelihood_evaluations termination_reason=termination_reason, # termination condition as bit mask samples=samples, parametrised_samples=parametrised_samples, U_samples=U_samples ) def _run(self, key: PRNGKey, term_cond: TerminationCondition) -> Tuple[IntArray, StaticStandardNestedSamplerState]: # Create sampler threads. def replica(key: PRNGKey) -> Tuple[StaticStandardNestedSamplerState, IntArray]: state = create_init_state( key=key, num_live_points=self.num_live_points, max_samples=self.max_samples, model=self.model ) termination_register = create_init_termination_register() if self.init_efficiency_threshold > 0.: # Uniform sampling down to a given mean efficiency uniform_sampler = UniformSampler(model=self.model) termination_cond = TerminationCondition( efficiency_threshold=jnp.asarray(self.init_efficiency_threshold), dlogZ=jnp.asarray(0., float_type), max_samples=jnp.asarray(self.max_samples) ) state, termination_register, termination_reason = _single_thread_ns( init_state=state, init_termination_register=termination_register, termination_cond=termination_cond, sampler=uniform_sampler, num_samples_per_sync=self.num_live_points, verbose=self.verbose ) # Continue sampling with provided sampler until user-defined termination condition is met. state, termination_register, termination_reason = _single_thread_ns( init_state=state, init_termination_register=termination_register, termination_cond=term_cond, sampler=self.sampler, num_samples_per_sync=self.num_live_points, verbose=self.verbose ) if self.num_parallel_workers > 1: # We need to do a final sampling run to make all the chains consistent, # to a likelihood contour (i.e. standardise on L(X)). Would mean that some workers are idle. target_log_L_contour = jnp.max( parallel.all_gather(termination_register.log_L_contour, 'i') ) termination_cond = TerminationCondition( dlogZ=jnp.asarray(0., float_type), log_L_contour=target_log_L_contour, max_samples=jnp.asarray(self.max_samples) ) state, termination_register, termination_reason = _single_thread_ns( init_state=state, init_termination_register=termination_register, termination_cond=termination_cond, sampler=self.sampler, num_samples_per_sync=self.num_live_points, verbose=self.verbose ) return state, termination_reason if self.num_parallel_workers > 1: parallel_ns = pmap(replica, axis_name='i') keys = random.split(key, self.num_parallel_workers) batched_state, termination_reason = parallel_ns(keys) state = unbatch_state(batched_state=batched_state) else: state, termination_reason = replica(key) return termination_reason, state