from typing import Tuple, Optional, NamedTuple
import jax.numpy as jnp
from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.log_semiring import LogSpace
from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges
from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray
[docs]
def compute_enclosed_prior_volume(sample_tree: SampleTreeGraph) -> MeasureType:
"""
Compute the enclosed prior volume of the likelihood constraint.
Args:
sample_tree: The sample tree graph.
Returns:
The log enclosed prior volume.
"""
live_point_counts = count_crossed_edges(sample_tree=sample_tree)
def op(log_X, num_live_points):
X_mean = LogSpace(log_X)
# T_mean = LogSpace(jnp.log(num_live_points) - jnp.log(num_live_points + 1.))
# T_mean = LogSpace(jnp.log(1.) - jnp.log(1. + 1./num_live_points))
T_mean = LogSpace(- jnp.logaddexp(0., -jnp.log(num_live_points)))
next_X_mean = X_mean * T_mean
return next_X_mean.log_abs_val
_, log_X = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type),
xs=live_point_counts.num_live_points)
return log_X
[docs]
class EvidenceUpdateVariables(NamedTuple):
[docs]
num_live_points: FloatArray
def _update_evidence_calc_op(carry: EvidenceCalculation, y: EvidenceUpdateVariables) -> EvidenceCalculation:
# num_live_points = num_live_points.astype(float_type)
next_L = LogSpace(y.log_L_next)
L_contour = LogSpace(carry.log_L)
midL = LogSpace(jnp.log(0.5)) * (next_L + L_contour)
X_mean = LogSpace(carry.log_X_mean)
X2_mean = LogSpace(carry.log_X2_mean)
Z_mean = LogSpace(carry.log_Z_mean)
ZX_mean = LogSpace(carry.log_ZX_mean)
Z2_mean = LogSpace(carry.log_Z2_mean)
dZ2_mean = LogSpace(carry.log_dZ2_mean)
# T_mean = LogSpace(jnp.log(num_live_points) - jnp.log(num_live_points + 1.))
# T_mean = LogSpace(jnp.log(1.) - jnp.log(1. + 1./num_live_points))
T_mean = LogSpace(- jnp.logaddexp(0., -jnp.log(y.num_live_points)))
# T_mean = LogSpace(- jnp.logaddexp(0., -jnp.log(num_live_points)))
t_mean = LogSpace(- jnp.log(y.num_live_points + 1.))
# T2_mean = LogSpace(jnp.log(num_live_points) - jnp.log( num_live_points + 2.))
# T2_mean = LogSpace(jnp.log(1.) - jnp.log(1. + 2./num_live_points))
T2_mean = LogSpace(- jnp.logaddexp((0.), jnp.log(2.) - jnp.log(y.num_live_points)))
# T2_mean = LogSpace(- jnp.logaddexp(jnp.log(2.), -jnp.log(num_live_points)))
t2_mean = LogSpace(jnp.log(2.) - jnp.log(y.num_live_points + 1.) - jnp.log(y.num_live_points + 2.))
# tT_mean = LogSpace(jnp.log(num_live_points) - jnp.log(num_live_points + 1.) - jnp.log(num_live_points + 2.))
# tT_mean = LogSpace(jnp.log(1.) - jnp.log(1. + 1./num_live_points) - jnp.log(num_live_points + 2.))
tT_mean = LogSpace(- jnp.logaddexp(0., -jnp.log(y.num_live_points)) - jnp.log(y.num_live_points + 2.))
# tT_mean = LogSpace(- jnp.logaddexp(0., -jnp.log(num_live_points)) - jnp.log(num_live_points + 2.))
dZ_mean = X_mean * t_mean * midL
next_X_mean = X_mean * T_mean
next_X2_mean = X2_mean * T2_mean
next_Z_mean = Z_mean + dZ_mean
next_ZX_mean = ZX_mean * T_mean + X2_mean * tT_mean * midL
next_Z2_mean = Z2_mean + LogSpace(jnp.log(2.)) * ZX_mean * t_mean * midL + (X2_mean * t2_mean * midL ** 2)
next_dZ2_mean = dZ2_mean + (X2_mean * t2_mean * midL ** 2)
next_evidence_calculation = EvidenceCalculation(
log_L=y.log_L_next.astype(float_type),
log_X_mean=next_X_mean.log_abs_val.astype(float_type),
log_X2_mean=next_X2_mean.log_abs_val.astype(float_type),
log_Z_mean=next_Z_mean.log_abs_val.astype(float_type),
log_Z2_mean=next_Z2_mean.log_abs_val.astype(float_type),
log_ZX_mean=next_ZX_mean.log_abs_val.astype(float_type),
log_dZ_mean=dZ_mean.log_abs_val.astype(float_type),
log_dZ2_mean=next_dZ2_mean.log_abs_val.astype(float_type)
)
return next_evidence_calculation
[docs]
def update_evicence_calculation(evidence_calculation: EvidenceCalculation,
update: EvidenceUpdateVariables) -> EvidenceCalculation:
"""
Update the evidence statistics with a new sample.
Args:
evidence_calculation: The current evidence statistics.
y: The update variables.
Returns:
The updated evidence statistics.
"""
return _update_evidence_calc_op(evidence_calculation, update)
[docs]
def init_evidence_calc() -> EvidenceCalculation:
"""
Initialise the evidence statistics.
Returns:
The initial evidence statistics.
"""
return EvidenceCalculation(
log_L=jnp.asarray(-jnp.inf, float_type),
log_X_mean=jnp.asarray(0., float_type),
log_X2_mean=jnp.asarray(0., float_type),
log_Z_mean=jnp.asarray(-jnp.inf, float_type),
log_ZX_mean=jnp.asarray(-jnp.inf, float_type),
log_Z2_mean=jnp.asarray(-jnp.inf, float_type),
log_dZ_mean=jnp.asarray(-jnp.inf, float_type),
log_dZ2_mean=jnp.asarray(-jnp.inf, float_type)
)
[docs]
def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_samples: Optional[IntArray] = None) -> \
Tuple[EvidenceCalculation, EvidenceCalculation]:
"""
Compute the evidence statistics along the shrinkage process.
Args:
log_L: The log likelihoods of the samples.
num_live_points: The number of live points at each sample.
num_samples: The number of samples to use. If None, all samples are used.
Returns:
The final evidence statistics, and the evidence statistics for each sample.
"""
init = init_evidence_calc()
xs = EvidenceUpdateVariables(
num_live_points=num_live_points.astype(float_type),
log_L_next=log_L
)
if num_samples is not None:
stop_idx = num_samples
final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx)
else:
final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs)
final_evidence_calculation = final_accumulate
per_sample_evidence_calculation = result
return final_evidence_calculation, per_sample_evidence_calculation