from typing import NamedTuple, Optional, Union, Any, Callable, Tuple, Dict, List, TypeVar
import chex
import numpy as np
from jax import numpy as jnp
__all__ = [
'EvidenceCalculation',
'TerminationCondition',
'StaticStandardNestedSamplerState',
'NestedSamplerResults',
'PRNGKey',
'IntArray',
'FloatArray',
'BoolArray',
'float_type',
'int_type',
'complex_type',
'LikelihoodType',
'UType',
'XType',
'LikelihoodInputType',
'RandomVariableType',
'MeasureType'
]
[docs]
float_type = jnp.result_type(float)
[docs]
int_type = jnp.result_type(int)
[docs]
complex_type = jnp.result_type(complex)
[docs]
FloatArray = chex.Array
[docs]
LikelihoodType = Callable[..., FloatArray]
[docs]
RandomVariableType = TypeVar('RandomVariableType')
[docs]
MeasureType = TypeVar('MeasureType')
[docs]
UType = FloatArray # Sample space type
[docs]
XType = Dict[str, RandomVariableType] # Prior variable type
[docs]
class EvidenceCalculation(NamedTuple):
"""
Contains a running estimate of evidence and related quantities.
"""
[docs]
log_X2_mean: FloatArray
[docs]
log_ZX_mean: FloatArray
[docs]
log_Z2_mean: FloatArray
[docs]
log_dZ_mean: FloatArray
[docs]
log_dZ2_mean: FloatArray
[docs]
class TerminationCondition(NamedTuple):
"""
Contains the termination conditions for the nested sampling run.
Args:
ess: The effective sample size, if the ESS (Kish's estimate) is greater than this the run will terminate.
evidence_uncert: The uncertainty in the evidence, if the uncertainty is less than this the run will terminate.
live_evidence_frac: Depreceated use dlogZ.
dlogZ: Terminate if log(Z_current + Z_remaining) - log(Z_current) < dlogZ. Default log(1 + 1e-2)
max_samples: Terminate if the number of samples exceeds this.
max_num_likelihood_evaluations: Terminate if the number of likelihood evaluations exceeds this.
log_L_contour: Terminate if this log(L) contour is reached. A contour is reached if any dead point
has log(L) > log_L_contour. Uncollected live points are not considered.
efficiency_threshold: Terminate if the efficiency (num_samples / num_likelihood_evaluations) is less than this,
for the last shrinkage iteration.
"""
[docs]
ess: Optional[FloatArray] = jnp.asarray(jnp.inf, float_type)
[docs]
evidence_uncert: Optional[FloatArray] = jnp.asarray(0., float_type)
[docs]
live_evidence_frac: Optional[FloatArray] = None # Depreceated use dlogZ
[docs]
dlogZ: Optional[FloatArray] = jnp.asarray(np.log(1. + 1e-3), float_type) #
[docs]
max_samples: Optional[IntArray] = jnp.asarray(jnp.iinfo(int_type).max, int_type)
[docs]
max_num_likelihood_evaluations: Optional[IntArray] = jnp.asarray(jnp.iinfo(int_type).max, int_type)
[docs]
log_L_contour: Optional[FloatArray] = jnp.asarray(jnp.inf, float_type)
[docs]
efficiency_threshold: Optional[FloatArray] = jnp.asarray(0., float_type)
[docs]
def __and__(self, other):
return TerminationConditionConjunction(conds=[self, other])
[docs]
def __or__(self, other):
return TerminationConditionDisjunction(conds=[self, other])
class TerminationConditionConjunction(NamedTuple):
"""
A conjunction of termination conditions, e.g. term_cond1 & term_cond2
"""
conds: List[Union['TerminationConditionDisjunction', 'TerminationConditionConjunction', TerminationCondition]]
class TerminationConditionDisjunction(NamedTuple):
"""
A disjunction of termination conditions, e.g. term_cond1 | term_cond2
"""
conds: List[Union['TerminationConditionDisjunction', TerminationConditionConjunction, TerminationCondition]]
[docs]
class NestedSamplerResults(NamedTuple):
"""
Results of the nested sampling run.
"""
[docs]
log_Z_mean: FloatArray # estimate of E[log(Z)]
[docs]
log_Z_uncert: FloatArray # estimate of StdDev[log(Z)]
[docs]
ESS: FloatArray # estimate of Kish's effective sample size
[docs]
H_mean: FloatArray # estimate of E[int log(L) L dp/Z]
[docs]
samples: XType # Dict of arrays with leading dimension num_samples
[docs]
parametrised_samples: XType # Dict of arrays with leading dimension num_samples
[docs]
U_samples: UType # Dict of arrays with leading dimension num_samples
[docs]
log_L_samples: FloatArray # log(L) of each sample
[docs]
log_dp_mean: FloatArray # log(E[dZ]) of each sample, where dZ is how much it contributes to the total evidence.
[docs]
log_X_mean: FloatArray # log(E[U]) of each sample
[docs]
log_posterior_density: FloatArray # log(P( theta | D )) log posteriori density
[docs]
num_live_points_per_sample: IntArray # how many live points were taken for the samples.
[docs]
num_likelihood_evaluations_per_sample: IntArray # how many likelihood evaluations were made per sample.
[docs]
total_num_samples: IntArray # int, the total number of samples collected.
[docs]
total_phantom_samples: IntArray # int, the total number of phantom samples collected.
[docs]
total_num_likelihood_evaluations: IntArray # how many likelihood evaluations were made in total
[docs]
log_efficiency: FloatArray # log(total_num_samples / total_num_likelihood_evaluations)
[docs]
termination_reason: IntArray # this will be an int reflecting the reason for termination
class SignedLog(NamedTuple):
"""
Represents a signed value in log-space
"""
log_abs_val: jnp.ndarray
sign: Union[jnp.ndarray, Any]
class Sample(NamedTuple):
U_sample: UType # [..., D] sample in U-space
log_L_constraint: FloatArray # [...,] log(L) constraint
log_L: FloatArray # [...,] log(L) of sample
num_likelihood_evaluations: IntArray # [...,] number of likelihood evaluations
class StaticStandardSampleCollection(NamedTuple):
sender_node_idx: IntArray # [N] with values in [0, N]
log_L: MeasureType # [N] log(L) of each sample
U_samples: UType # [N, D] samples in U-space
num_likelihood_evaluations: IntArray # [N] number of likelihood evaluations for each sample
phantom: BoolArray # [N] whether the sample is a phantom sample
class TerminationRegister(NamedTuple):
num_samples_used: IntArray
evidence_calc: EvidenceCalculation
evidence_calc_with_remaining: EvidenceCalculation
num_likelihood_evaluations: IntArray
log_L_contour: FloatArray
efficiency: FloatArray
plateau: BoolArray
[docs]
class StaticStandardNestedSamplerState(NamedTuple):
[docs]
next_sample_idx: IntArray # the next sample insert index <==> the number of samples
[docs]
sample_collection: StaticStandardSampleCollection
[docs]
front_idx: IntArray # the index of the front of the live points within sample collection
def isinstance_namedtuple(obj) -> bool:
"""
Check if object is a namedtuple.
Args:
obj: object
Returns:
bool
"""
return (
isinstance(obj, tuple) and
hasattr(obj, '_asdict') and
hasattr(obj, '_fields')
)