import dataclasses
import warnings
from functools import partial
from typing import List, Optional, Tuple, NamedTuple, Any
import jax
import jax.numpy as jnp
import numpy as np
from jax import core
from jax._src.mesh import Mesh
from jax._src.partition_spec import PartitionSpec
from jax.experimental.shard_map import shard_map
from jaxlib import xla_client
from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.cumulative_ops import cumulative_op_static
from jaxns.internals.log_semiring import LogSpace, normalise_log_space
from jaxns.internals.maps import create_mesh, tree_device_put, replace_index
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.random import sample_uniformly_masked, resample_indicies
from jaxns.internals.shrinkage_statistics import EvidenceUpdateVariables, _update_evidence_calc_op, \
compute_evidence_stats
from jaxns.internals.stats import linear_to_log_stats, effective_sample_size_kish
from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges
from jaxns.internals.types import PRNGKey, IntArray, BoolArray
from jaxns.nested_samplers.abc import AbstractNestedSampler
from jaxns.nested_samplers.common.initialisation import create_init_state, create_init_termination_register
from jaxns.nested_samplers.common.termination import determine_termination
from jaxns.nested_samplers.common.types import TerminationCondition, NestedSamplerState, TerminationRegister, \
SampleCollection, LivePointCollection, NestedSamplerResults
from jaxns.samplers.abc import AbstractSampler
from jaxns.samplers.abc import EphemeralState
from jaxns.samplers.uniform_samplers import UniformSampler
__all__ = [
'ShardedStaticNestedSampler'
]
def _add_samples_to_state(sample_collection: LivePointCollection,
state: NestedSamplerState,
is_phantom: bool) -> NestedSamplerState:
"""
Adds samples to state.
Args:
sample_collection: batched [N] samples
state: state
is_phantom: whether samples are phantom
Returns:
updated state
"""
replace_idx = state.next_sample_idx
num_samples = np.shape(sample_collection.log_L)[0]
sample_collection = SampleCollection(
sender_node_idx=replace_index(
state.sample_collection.sender_node_idx,
sample_collection.sender_node_idx,
replace_idx
),
log_L=replace_index(state.sample_collection.log_L, sample_collection.log_L, replace_idx),
U_samples=replace_index(state.sample_collection.U_samples, sample_collection.U_sample, replace_idx),
num_likelihood_evaluations=replace_index(
state.sample_collection.num_likelihood_evaluations,
sample_collection.num_likelihood_evaluations,
replace_idx
),
phantom=replace_index(
state.sample_collection.phantom,
jnp.full((num_samples,), is_phantom, jnp.bool_),
replace_idx
)
)
num_added = jnp.asarray(num_samples, mp_policy.index_dtype)
next_sample_idx = state.next_sample_idx + num_added
# Wrap at the number of samples (this is a trick for global optimisation that doesn't care about the entire progress)
next_sample_idx = next_sample_idx % jnp.asarray(np.shape(state.sample_collection.log_L)[0], mp_policy.index_dtype)
state = NestedSamplerState(
key=state.key,
next_sample_idx=next_sample_idx,
sample_collection=sample_collection,
num_samples=state.num_samples + num_added
)
return state
def get_samples(key, mesh: Mesh, sampler: AbstractSampler, sampler_state: Any, log_L_contour, num_samples: int):
"""
Get samples from the sampler.
Args:
key: the PRNG key
mesh: the mesh to use
sampler: the sampler to use
sampler_state: the sampler state
log_L_contour: the log likelihood contour
num_samples: the number of samples to get
Returns:
samples, phantom_samples
"""
@partial(
shard_map,
mesh=mesh,
in_specs=(PartitionSpec('shard', ), PartitionSpec(), PartitionSpec()),
out_specs=(PartitionSpec('shard', ), PartitionSpec('shard', )),
check_rep=False
)
def get_samples(sample_keys, log_L_contour, sampler_state: Any):
def body(carry, sample_key):
sample, phantom_samples = sampler.get_sample(
key=sample_key,
log_L_constraint=log_L_contour,
sampler_state=sampler_state
)
return carry, (sample, phantom_samples)
_, (sample, phantom_samples) = jax.lax.scan(body, None, sample_keys)
# phantom samples are [k, ...] and samples are [...] so
phantom_samples = jax.tree.map(
lambda x: jax.lax.reshape(x, (int(np.prod(np.shape(x)[:2])),) + np.shape(x)[2:]),
phantom_samples
)
return sample, phantom_samples
sharded_sample_keys = tree_device_put(jax.random.split(key, num_samples), mesh, ('shard',))
return get_samples(sharded_sample_keys, log_L_contour, sampler_state)
def dynamic_new_live_point_collection(
key,
mesh: Mesh,
state: NestedSamplerState,
sender_sample_idx: IntArray,
sampler: AbstractSampler,
sampler_state: Any,
num_live_points: int
):
"""
Get live points from the sampler for a dynamic slice.
Args:
mesh: the mesh to use
state: the state
sampler: the sampler
sampler_state: the sampler state
num_live_points: the number of live points to get
Returns:
live_point_collection, state
"""
log_L_contour = state.sample_collection.log_L[sender_sample_idx]
new_samples, phantom_samples = get_samples(
key=key,
mesh=mesh,
sampler=sampler,
sampler_state=sampler_state,
log_L_contour=log_L_contour,
num_samples=num_live_points
)
# Node is sample_idx - 1
sender_node_idx = sender_sample_idx - jnp.asarray(1, mp_policy.index_dtype)
live_point_collection = LivePointCollection(
sender_node_idx=jnp.full((num_live_points,), sender_node_idx, mp_policy.index_dtype),
U_sample=new_samples.U_sample,
log_L=new_samples.log_L,
log_L_constraint=new_samples.log_L_constraint,
num_likelihood_evaluations=new_samples.num_likelihood_evaluations
)
sort_indices = jnp.argsort(live_point_collection.log_L)
live_point_collection = jax.tree.map(lambda x: x[sort_indices], live_point_collection)
state = add_phantom_samples_to_state(phantom_samples, sender_node_idx, state)
return live_point_collection, state
def add_phantom_samples_to_state(phantom_samples, sender_node_idx, state):
"""
Add phantom samples to the state
Args:
phantom_samples: the phantom samples
sender_node_idx: the sender node index
state: the state
Returns:
live_point_collection, state
"""
# Add phantom samples (this is an option of user, controlled by `k`)
num_phantom = np.shape(phantom_samples.log_L)[0]
phantom_collection = LivePointCollection(
sender_node_idx=jnp.full((num_phantom,), sender_node_idx, mp_policy.index_dtype),
U_sample=phantom_samples.U_sample,
log_L=phantom_samples.log_L,
log_L_constraint=phantom_samples.log_L_constraint,
num_likelihood_evaluations=phantom_samples.num_likelihood_evaluations
)
state = _add_samples_to_state(
sample_collection=phantom_collection,
state=state,
is_phantom=True
)
return state
def _collect_shell(
mesh: Mesh,
live_point_collection: LivePointCollection,
state: NestedSamplerState,
termination_register: TerminationRegister,
sampler: AbstractSampler,
sampler_state: Any,
shell_size: int
) -> Tuple[LivePointCollection, NestedSamplerState, TerminationRegister]:
"""
Run nested sampling until `num_samples` samples are collected.
Args:
mesh: the device mesh to use
live_point_collection: the live point collection
state: the state of the nested sampler at the start
termination_register: the termination register at the start
sampler: sampler to use
sampler_state: the sampler state to use for
shell_size: the size of the shell to collect
Returns:
live_point_collection: the live point collection
state: the state of the nested sampler at the end
termination_register: the termination register at the end
"""
# Find and discard shell
front_size = np.shape(live_point_collection.log_L)[0]
# always leave live points sorted so that we don't need to do it here.
discarded_sample_collection: LivePointCollection = jax.tree.map(lambda x: x[:shell_size], live_point_collection)
state = _add_samples_to_state(
sample_collection=discarded_sample_collection,
state=state,
is_phantom=False
)
# Replace the discarded samples
key, sample_key = jax.random.split(state.key, 2)
state = state._replace(key=key)
supremum_index = shell_size - 1 # Biggest of discarded
log_L_contour = live_point_collection.log_L[supremum_index]
new_samples, phantom_samples = get_samples(
key=sample_key,
mesh=mesh,
sampler=sampler,
sampler_state=sampler_state,
log_L_contour=log_L_contour,
num_samples=shell_size
)
# Sender is the maximum log_L sample from discarded, i.e. that last added discarded
sender_node_idx = state.next_sample_idx - jnp.asarray(1, mp_policy.index_dtype)
new_sample_collection = LivePointCollection(
sender_node_idx=jnp.full((shell_size,), sender_node_idx, mp_policy.index_dtype),
U_sample=new_samples.U_sample,
log_L=new_samples.log_L,
log_L_constraint=new_samples.log_L_constraint,
num_likelihood_evaluations=new_samples.num_likelihood_evaluations
)
live_point_collection: LivePointCollection = jax.tree.map(
lambda x, update: x.at[:shell_size, ...].set(update),
live_point_collection,
new_sample_collection
)
sort_indices = jnp.argsort(live_point_collection.log_L)
live_point_collection = jax.tree.map(lambda x: x[sort_indices], live_point_collection)
# TODO: compute insert index KS-statistic
_, insert_indices = jax.lax.top_k(-sort_indices, k=shell_size)
state = add_phantom_samples_to_state(phantom_samples, sender_node_idx, state)
# Update termination register
# Technically we must compute the num live points in case there is a plateau.
# but it gets done properly in to_results.
evidence_calc, _ = cumulative_op_static(
op=_update_evidence_calc_op,
init=termination_register.evidence_calc,
xs=EvidenceUpdateVariables(
num_live_points=jnp.full((shell_size,), front_size, mp_policy.measure_dtype),
log_L_next=discarded_sample_collection.log_L
),
)
evidence_calc_with_remaining, _ = cumulative_op_static(
op=_update_evidence_calc_op,
init=evidence_calc,
xs=EvidenceUpdateVariables(
num_live_points=jnp.arange(front_size, 0., -1., mp_policy.measure_dtype),
log_L_next=live_point_collection.log_L
),
)
# Note we consider phantom samples requiring 0 num_likelihood_evaluations
num_likelihood_evaluations = termination_register.num_likelihood_evaluations + jnp.sum(
new_samples.num_likelihood_evaluations)
# We determine efficiency
efficiency = jnp.asarray(front_size / jnp.sum(live_point_collection.num_likelihood_evaluations),
mp_policy.measure_dtype)
plateau = jnp.all(jnp.equal(live_point_collection.log_L, live_point_collection.log_L[0]))
absolute_spread = jnp.abs(live_point_collection.log_L[-1] - live_point_collection.log_L[0])
relative_spread = 2. * absolute_spread / jnp.abs(live_point_collection.log_L[0] + live_point_collection.log_L[-1])
no_seed_points = live_point_collection.log_L[supremum_index] >= live_point_collection.log_L[-1]
peak_log_XL = jnp.maximum(termination_register.peak_log_XL, evidence_calc.log_X_mean + evidence_calc.log_L)
termination_register = TerminationRegister(
num_samples_used=state.num_samples,
evidence_calc=evidence_calc,
evidence_calc_with_remaining=evidence_calc_with_remaining,
num_likelihood_evaluations=num_likelihood_evaluations,
log_L_contour=log_L_contour,
efficiency=efficiency,
plateau=plateau,
no_seed_points=no_seed_points,
relative_spread=relative_spread,
absolute_spread=absolute_spread,
peak_log_XL=peak_log_XL
)
return live_point_collection, state, termination_register
def _dynamic_posterior_refinement_iteration(
mesh: Mesh,
state: NestedSamplerState,
sampler: AbstractSampler,
termination_register: TerminationRegister,
refine_threshold: float,
num_live_points: int
) -> NestedSamplerState:
"""
Perform a dynamic posterior refinement iteration.
Args:
mesh: the mesh to use
state: the state of the nested sampler
sampler: the sampler to use
termination_register: the termination register
refine_threshold: the threshold for refinement, will attach new points from somewhere within a masked region,
where XL > max(XL) * refine_threshold
num_live_points: the number of live points to get
Returns:
the updated state
"""
key, sample_key, ephemeral_key, reseed_key, select_attach_key = jax.random.split(state.key, 5)
state = state._replace(key=key)
sample_collection = state.sample_collection
# Get the part of prior region that contains most of the posterior mass.
# This is where XL is largest.
sample_tree = SampleTreeGraph(
sender_node_idx=sample_collection.sender_node_idx,
log_L=sample_collection.log_L
)
num_samples = jnp.minimum(
state.num_samples,
jnp.asarray(state.sample_collection.log_L.size, mp_policy.count_dtype)
)
live_point_counts = count_crossed_edges(sample_tree=sample_tree, num_samples=num_samples)
num_live_points_per_sample = live_point_counts.num_live_points
log_L = sample_tree.log_L[live_point_counts.samples_indices]
final_evidence_stats, per_sample_evidence_stats = compute_evidence_stats(
log_L=log_L,
num_live_points=num_live_points_per_sample,
num_samples=num_samples
)
log_XL = per_sample_evidence_stats.log_X_mean + per_sample_evidence_stats.log_L
peak_log_XL = jnp.max(log_XL)
# Randomly choose attach point
# TODO: ensure attachment certain to be within sample set.
select_mask = jnp.logical_and(
log_XL > peak_log_XL + np.log(refine_threshold),
jnp.logical_not(sample_collection.phantom)
)
log_weights = jnp.where(select_mask, 0., -jnp.inf)
sample_root_idx = resample_indicies(select_attach_key, log_weights, 1)[0]
# Create a fake live point collection to create sampler state. We don't care about the samples being i.i.d.
log_L_constraint = sample_collection.log_L[sample_root_idx]
select_mask = sample_collection.log_L > log_L_constraint
ephemeral_sample_collection: SampleCollection = sample_uniformly_masked(
key=reseed_key,
v=sample_collection,
select_mask=select_mask,
num_samples=num_live_points
)
ephemeral_live_point_collection = LivePointCollection(
sender_node_idx=ephemeral_sample_collection.sender_node_idx,
U_sample=ephemeral_sample_collection.U_samples,
log_L=ephemeral_sample_collection.log_L,
log_L_constraint=sample_collection.log_L[ephemeral_sample_collection.sender_node_idx + 1],
num_likelihood_evaluations=ephemeral_sample_collection.num_likelihood_evaluations
)
ephemeral_state = EphemeralState(
key=ephemeral_key,
live_points_collection=ephemeral_live_point_collection,
termination_register=termination_register
)
sampler_state = sampler.pre_process(ephemeral_state)
# Get a proper i.i.d. set of live points
live_point_collection, state = dynamic_new_live_point_collection(
key=sample_key,
mesh=mesh,
state=state,
sender_sample_idx=sample_root_idx,
sampler=sampler,
sampler_state=sampler_state,
num_live_points=num_live_points
)
# Add all live points to the state
state = _add_samples_to_state(
sample_collection=live_point_collection,
state=state,
is_phantom=False
)
return state
def _main_ns_thread(
mesh: Mesh,
live_point_collection: LivePointCollection,
state: NestedSamplerState,
termination_register: TerminationRegister,
termination_cond: TerminationCondition,
sampler: AbstractSampler,
num_discards_per_iteration: int,
shell_fraction: float,
verbose: bool
) -> Tuple[LivePointCollection, NestedSamplerState, TerminationRegister, IntArray]:
"""
Runs a single thread of static nested sampling until a stopping condition is reached. Discards 1/2 of the
live points at once, replacing them from the supremum contour, creating a sample tree.
Args:
mesh: the device mesh to use
state: the state of the nested sampler at the start
termination_register: the termination register at the start
termination_cond: the termination condition
sampler: the sampler to use
num_discards_per_iteration: number of discard shells per iteration, between processing sampler state.
verbose: whether to log debug messages.
Returns:
live_point_collection: the final set of live points
state: the final state
termination_register: the termination register
termination_condition: the reason for termination
"""
if num_discards_per_iteration <= 0:
raise ValueError("num_discards_per_iteration must be > 0 got {num_discards_per_iteration}.")
# Update the termination condition to stop before going over the maximum number of samples.
shell_size = int(np.shape(live_point_collection.log_L)[0] * shell_fraction)
space_needed_per_iteration = num_discards_per_iteration * shell_size * (1 + sampler.num_phantom())
if termination_cond.max_samples is not None:
termination_cond = termination_cond._replace(
max_samples=jnp.minimum(
termination_cond.max_samples,
np.shape(state.sample_collection.log_L)[0] - space_needed_per_iteration
)
)
# Catch case of no seed points left
no_seed_points = live_point_collection.log_L[shell_size - 1] >= live_point_collection.log_L[-1]
termination_register = termination_register._replace(no_seed_points=no_seed_points)
class CarryType(NamedTuple):
live_point_collection: LivePointCollection
state: NestedSamplerState
termination_register: TerminationRegister
def cond(carry: CarryType) -> BoolArray:
done, termination_reason = determine_termination(
term_cond=termination_cond,
termination_register=carry.termination_register
)
return jnp.bitwise_not(done)
def body(carry: CarryType) -> CarryType:
# Discard half the live points and replace them with new samples
live_point_collection, state, termination_register = carry
key, ephemeral_key = jax.random.split(state.key, 2)
state = state._replace(key=key)
ephemeral_state = EphemeralState(
key=ephemeral_key,
live_points_collection=live_point_collection,
termination_register=termination_register
)
sampler_state = sampler.pre_process(ephemeral_state)
for _ in range(num_discards_per_iteration):
live_point_collection, state, termination_register = _collect_shell(
mesh=mesh,
live_point_collection=live_point_collection,
state=state,
sampler=sampler,
termination_register=termination_register,
sampler_state=sampler_state,
shell_size=shell_size
)
key, ephemeral_key = jax.random.split(state.key, 2)
state = state._replace(key=key)
ephemeral_state = EphemeralState(
key=ephemeral_key,
live_points_collection=live_point_collection,
termination_register=termination_register
)
sampler_state = sampler.post_process(ephemeral_state=ephemeral_state, sampler_state=sampler_state)
if verbose:
log_Z_mean, log_Z_var = linear_to_log_stats(
log_f_mean=termination_register.evidence_calc_with_remaining.log_Z_mean,
log_f2_mean=termination_register.evidence_calc_with_remaining.log_Z2_mean)
log_Z_uncert = jnp.sqrt(log_Z_var)
log_Z_mean0, log_Z_var0 = linear_to_log_stats(
log_f_mean=termination_register.evidence_calc.log_Z_mean,
log_f2_mean=termination_register.evidence_calc.log_Z2_mean)
log_Z_remaining = log_Z_mean - log_Z_mean0
log_Z_remaining_error = jnp.sqrt(log_Z_var + log_Z_var0)
ess = effective_sample_size_kish(termination_register.evidence_calc_with_remaining.log_Z_mean,
termination_register.evidence_calc_with_remaining.log_dZ2_mean)
jax.debug.print(
"-------\n"
"Num samples: {num_samples}\n"
"Num likelihood evals: {num_likelihood_evals}\n"
"Efficiency: {efficiency}\n"
"log(L) contour: {log_L_contour}\n"
"log(Z) est.: {log_Z_mean} +- {log_Z_uncert}\n"
"log(Z | remaining) est.: {log_Z_remaining} +- {log_Z_remaining_error}\n"
"ESS: {ess}\n",
num_samples=termination_register.num_samples_used,
num_likelihood_evals=termination_register.num_likelihood_evaluations,
efficiency=termination_register.efficiency,
log_L_contour=termination_register.log_L_contour,
log_Z_mean=log_Z_mean,
log_Z_uncert=log_Z_uncert,
log_Z_remaining=log_Z_remaining,
log_Z_remaining_error=log_Z_remaining_error,
ess=ess
)
return CarryType(
state=state,
termination_register=termination_register,
live_point_collection=live_point_collection
)
carry = CarryType(
state=state,
termination_register=termination_register,
live_point_collection=live_point_collection
)
carry = jax.lax.while_loop(
cond_fun=cond,
body_fun=body,
init_val=carry
)
_, termination_reason = determine_termination(
term_cond=termination_cond,
termination_register=carry.termination_register
)
return carry.live_point_collection, carry.state, carry.termination_register, termination_reason
def round_up_num_live_points(init_num_live_points, shell_frac, num_devices):
num_live_points = int(init_num_live_points)
while True:
shell_size = int(num_live_points * shell_frac)
if shell_size % num_devices == 0:
break
num_live_points += 1
return num_live_points
def round_up_max_samples(init_max_samples, num_discard, num_phantom_points):
max_samples = int(init_max_samples)
block_size = num_discard * (1 + num_phantom_points)
while True:
if max_samples % block_size == 0:
break
max_samples += 1
return max_samples
@dataclasses.dataclass(eq=False)
[docs]
class ShardedStaticNestedSampler(AbstractNestedSampler):
"""
A static nested sampler that uses a fixed number of live points. This uses a uniform sampler to generate the
initial set of samples down to an efficiency threshold, then uses a provided sampler to generate the rest of the
samples until the termination condition is met.
Args:
init_efficiency_threshold: the efficiency threshold to use for the initial uniform sampling. If 0 then
turns it off.
sampler: the sampler to use after the initial uniform sampling.
num_live_points: the number of live points to use.
model: the model to use.
max_samples: the maximum number of samples to take.
devices: the devices to use, default is 1.
verbose: whether to log as we go.
"""
[docs]
model: BaseAbstractModel
[docs]
init_efficiency_threshold: float
[docs]
sampler: AbstractSampler
[docs]
shell_fraction: Optional[float] = None
[docs]
num_dynamic_refinement_iterations: int = 0
[docs]
refine_threshold: float = 0.01
[docs]
devices: Optional[List[xla_client.Device]] = None
[docs]
def __post_init__(self):
if self.shell_fraction is None:
self.shell_fraction = 0.5
self.shell_fraction = max(self.shell_fraction, 1. / self.num_live_points)
if (self.shell_fraction <= 0.) or (self.shell_fraction > 1.):
raise ValueError(
f"Expected 0 < shell_fraction <= 1, got {self.shell_fraction}. Best to keep it around 0.5.")
if self.devices is None:
self.devices = jax.devices()
if len(self.devices) > 1:
print(f"Running over {len(self.devices)} devices.")
# Make sure num_live_points // 2 is a multiple of the number of devices
self.num_live_points = round_up_num_live_points(
init_num_live_points=self.num_live_points,
shell_frac=self.shell_fraction,
num_devices=len(self.devices)
)
self.max_samples = round_up_max_samples(
init_max_samples=self.max_samples,
# TODO: if we do more than 1 discard per iteration need to update here too.
num_discard=int(self.shell_fraction * self.num_live_points),
num_phantom_points=self.sampler.num_phantom()
)
if self.num_dynamic_refinement_iterations > 0:
warnings.warn("Dynamic refinement is experimental and may not work as expected.")
def _to_results(self, termination_reason: IntArray, state: NestedSamplerState,
trim: bool) -> NestedSamplerResults:
num_samples = jnp.minimum(
state.num_samples,
jnp.asarray(state.sample_collection.log_L.size, mp_policy.count_dtype)
)
sample_collection = state.sample_collection
if trim:
trim_size = jnp.minimum(
state.num_samples,
jnp.asarray(state.sample_collection.log_L.size, mp_policy.count_dtype)
)
if isinstance(num_samples, core.Tracer):
raise RuntimeError("Tracer detected, but expected imperative context.")
sample_collection = jax.tree.map(lambda x: x[:trim_size], sample_collection)
sample_tree = SampleTreeGraph(
sender_node_idx=sample_collection.sender_node_idx,
log_L=sample_collection.log_L
)
live_point_counts = count_crossed_edges(sample_tree=sample_tree)
num_live_points = live_point_counts.num_live_points
log_L = sample_tree.log_L[live_point_counts.samples_indices]
U_samples = sample_collection.U_samples[live_point_counts.samples_indices]
num_likelihood_evaluations = sample_collection.num_likelihood_evaluations[
live_point_counts.samples_indices]
final_evidence_stats, per_sample_evidence_stats = compute_evidence_stats(
log_L=log_L,
num_live_points=num_live_points
)
else:
sample_tree = SampleTreeGraph(
sender_node_idx=sample_collection.sender_node_idx,
log_L=sample_collection.log_L
)
live_point_counts = count_crossed_edges(sample_tree=sample_tree, num_samples=num_samples)
num_live_points = live_point_counts.num_live_points
log_L = sample_tree.log_L[live_point_counts.samples_indices]
U_samples = sample_collection.U_samples[live_point_counts.samples_indices]
num_likelihood_evaluations = sample_collection.num_likelihood_evaluations[
live_point_counts.samples_indices]
final_evidence_stats, per_sample_evidence_stats = compute_evidence_stats(
log_L=log_L,
num_live_points=num_live_points,
num_samples=num_samples
)
log_Z_mean, log_Z_var = linear_to_log_stats(
log_f_mean=final_evidence_stats.log_Z_mean,
log_f2_mean=final_evidence_stats.log_Z2_mean
)
log_Z_uncert = jnp.sqrt(log_Z_var)
# Correction by sqrt(k+1)
total_phantom_samples = jnp.sum(mp_policy.cast_to_count(sample_collection.phantom, quiet=True))
phantom_fraction = total_phantom_samples / num_samples # k / (k+1)
k = phantom_fraction / (1. - phantom_fraction)
log_Z_uncert = log_Z_uncert * jnp.sqrt(1. + k)
# Kish's ESS = [sum dZ]^2 / [sum dZ^2]
ESS = effective_sample_size_kish(final_evidence_stats.log_Z_mean, final_evidence_stats.log_dZ2_mean)
ESS = ESS / (1. + k)
samples = jax.vmap(self.model.transform)(U_samples)
parametrised_samples = jax.vmap(self.model.transform_parametrised)(U_samples)
log_L_samples = log_L
dp_mean = LogSpace(per_sample_evidence_stats.log_dZ_mean)
dp_mean = normalise_log_space(dp_mean)
H_mean_instable = -(
(
dp_mean * LogSpace.from_signed_value(
jnp.where(jnp.isneginf(dp_mean.log_abs_val), 0., log_L_samples)
)
).sum().value - log_Z_mean
)
# H \approx E[-log(compression)] = E[-log(X)] (More stable than E[log(L) - log(Z)] but biased)
H_mean_stable = -((dp_mean * LogSpace(jnp.log(-per_sample_evidence_stats.log_X_mean))).sum().value)
H_mean = jnp.where(jnp.isfinite(H_mean_instable), H_mean_instable, H_mean_stable)
X_mean = LogSpace(per_sample_evidence_stats.log_X_mean)
num_likelihood_evaluations_per_sample = num_likelihood_evaluations
total_num_likelihood_evaluations = jnp.sum(num_likelihood_evaluations_per_sample)
num_live_points_per_sample = num_live_points
efficiency = LogSpace(jnp.log(num_samples) - jnp.log(total_num_likelihood_evaluations))
log_posterior_density = log_L + jax.vmap(self.model.log_prob_prior)(
U_samples)
return NestedSamplerResults(
log_Z_mean=log_Z_mean, # estimate of log(E[Z])
log_Z_uncert=log_Z_uncert, # estimate of log(StdDev[Z])
ESS=ESS, # estimate of Kish's effective sample size
H_mean=H_mean, # estimate of E[int log(L) L dp/Z]
total_num_samples=num_samples, # int, the total number of samples collected.
total_phantom_samples=total_phantom_samples, # int, the total number of phantom samples collected.
log_L_samples=log_L_samples, # log(L) of each sample
log_dp_mean=dp_mean.log_abs_val,
log_posterior_density=log_posterior_density,
# log(E[dZ]) of each sample, where dZ is how much it contributes to the total evidence.
# log(StdDev[dZ]) of each sample, where dZ is how much it contributes to the total evidence.
log_X_mean=X_mean.log_abs_val, # log(E[U]) of each sample
num_likelihood_evaluations_per_sample=num_likelihood_evaluations_per_sample,
# how many likelihood evaluations were made per sample.
num_live_points_per_sample=num_live_points_per_sample,
# how many live points were taken for the samples.
total_num_likelihood_evaluations=total_num_likelihood_evaluations,
# how many likelihood evaluations were made in total,
# sum of num_likelihood_evaluations_per_sample.
log_efficiency=efficiency.log_abs_val,
# total_num_samples / total_num_likelihood_evaluations
termination_reason=termination_reason, # termination condition as bit mask
samples=samples,
parametrised_samples=parametrised_samples,
U_samples=U_samples
)
def _run(self, key: PRNGKey, term_cond: TerminationCondition) -> Tuple[
IntArray, TerminationRegister, NestedSamplerState]:
# Create sampler threads.
mesh = create_mesh((len(self.devices),), ('shard',), devices=self.devices)
if self.verbose:
jax.debug.print(f"Creating initial state with {self.num_live_points} live points.")
live_point_collection, state = create_init_state(
key=key,
num_live_points=self.num_live_points,
max_samples=self.max_samples,
model=self.model,
mesh=mesh
)
termination_register = create_init_termination_register()
if self.init_efficiency_threshold > 0.:
if self.verbose:
jax.debug.print(
f"Running uniform sampling down to efficiency threshold of {self.init_efficiency_threshold}."
)
# Uniform sampling down to a given mean efficiency
uniform_sampler = UniformSampler(model=self.model)
uniform_term_cond = TerminationCondition(
efficiency_threshold=jnp.asarray(self.init_efficiency_threshold, mp_policy.measure_dtype),
dlogZ=jnp.asarray(0., mp_policy.measure_dtype),
max_samples=jnp.asarray(self.max_samples, mp_policy.count_dtype)
)
live_point_collection, state, termination_register, termination_reason = _main_ns_thread(
mesh=mesh,
live_point_collection=live_point_collection,
state=state,
termination_register=termination_register,
termination_cond=uniform_term_cond,
sampler=uniform_sampler,
num_discards_per_iteration=1,
shell_fraction=self.shell_fraction,
verbose=self.verbose
)
if self.verbose:
jax.debug.print("Running until termination condition: {term_cond}",
term_cond=term_cond)
# Continue sampling with provided sampler until user-defined termination condition is met.
live_point_collection, state, termination_register, termination_reason = _main_ns_thread(
mesh=mesh,
live_point_collection=live_point_collection,
state=state,
termination_register=termination_register,
termination_cond=term_cond,
sampler=self.sampler,
num_discards_per_iteration=1,
shell_fraction=self.shell_fraction,
verbose=self.verbose
)
# Consumer live_point_collection
state = _add_samples_to_state(
sample_collection=live_point_collection,
state=state,
is_phantom=False
)
def body(i, state: NestedSamplerState) -> NestedSamplerState:
state = _dynamic_posterior_refinement_iteration(
mesh=mesh, state=state, sampler=self.sampler, termination_register=termination_register,
refine_threshold=self.refine_threshold,
num_live_points=self.num_live_points
)
return state
if self.num_dynamic_refinement_iterations > 0:
state = jax.lax.fori_loop(0, self.num_dynamic_refinement_iterations, body, state)
return termination_reason, termination_register, state