from typing import NamedTuple, Tuple
from jax import random, numpy as jnp, lax, tree_map
from jaxns.internals.shrinkage_statistics import compute_evidence_stats
from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges
from jaxns.internals.types import IntArray, StaticStandardNestedSamplerState, UType, StaticStandardSampleCollection
from jaxns.internals.types import PRNGKey, FloatArray
from jaxns.internals.types import Sample, int_type
from jaxns.samplers.abc import SamplerState
from jaxns.samplers.bases import BaseAbstractRejectionSampler
from jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils import ellipsoid_clustering, MultEllipsoidState
from jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils import sample_multi_ellipsoid
__all__ = [
'MultiEllipsoidalSampler'
]
[docs]
class MultiEllipsoidalSampler(BaseAbstractRejectionSampler):
"""
Uses a multi-ellipsoidal decomposition of the live points to create a bound around regions to sample from.
Inefficient for high dimensional problems, but can be very efficient for low dimensional problems.
"""
def __init__(self, depth: int, expansion_factor: float, *args, **kwargs):
self._depth = depth
self._expansion_factor = expansion_factor
super().__init__(*args, **kwargs)
[docs]
def num_phantom(self) -> int:
return 0
[docs]
def pre_process(self, state: StaticStandardNestedSamplerState) -> SamplerState:
key, sampler_key = random.split(state.key)
sample_tree = SampleTreeGraph(
sender_node_idx=state.sample_collection.sender_node_idx,
log_L=state.sample_collection.log_L
)
live_point_counts = count_crossed_edges(sample_tree=sample_tree, num_samples=state.next_sample_idx)
log_L = sample_tree.log_L[live_point_counts.samples_indices]
num_live_points = live_point_counts.num_live_points
final_evidence_stats, _ = compute_evidence_stats(
log_L=log_L,
num_live_points=num_live_points,
num_samples=state.next_sample_idx
)
points = state.sample_collection.U_samples[state.front_idx]
return ellipsoid_clustering(
key=sampler_key,
points=points,
log_VS=final_evidence_stats.log_X_mean,
max_num_ellipsoids=self.max_num_ellipsoids,
method='em_gmm'
)
[docs]
def post_process(self, sample_collection: StaticStandardSampleCollection, sampler_state: SamplerState) -> SamplerState:
return sampler_state
@property
[docs]
def max_num_ellipsoids(self):
return 2 ** self._depth
[docs]
def get_sample(self, key: PRNGKey, log_L_constraint: FloatArray,
sampler_state: MultEllipsoidState) -> Tuple[Sample, Sample]:
def _sample_multi_ellipsoid(key: PRNGKey) -> UType:
_, U = sample_multi_ellipsoid(
key=key,
mu=sampler_state.params.mu,
radii=sampler_state.params.radii * self._expansion_factor,
rotation=sampler_state.params.rotation,
unit_cube_constraint=True
)
return U
class CarryState(NamedTuple):
key: PRNGKey
U: FloatArray
log_L: FloatArray
num_likelihood_evals: IntArray
def cond(carry: CarryState):
return carry.log_L <= log_L_constraint
def body(carry: CarryState):
key, sample_key = random.split(carry.key, 2)
point_U = _sample_multi_ellipsoid(key=sample_key)
log_L = self.model.forward(U=point_U)
num_likelihood_evals = carry.num_likelihood_evals + jnp.ones_like(carry.num_likelihood_evals)
# backoff by one e-fold per attempt after efficiency threshold reached
return CarryState(
key=key,
U=point_U,
log_L=log_L,
num_likelihood_evals=num_likelihood_evals
)
key, sample_key = random.split(key, 2)
point_U = _sample_multi_ellipsoid(key=sample_key)
init_carry_state = CarryState(
key=key,
U=point_U,
log_L=self.model.forward(point_U),
num_likelihood_evals=jnp.asarray(1, int_type)
)
final_carry = lax.while_loop(
cond_fun=cond,
body_fun=body,
init_val=init_carry_state
)
sample = Sample(
U_sample=final_carry.U,
log_L_constraint=log_L_constraint,
log_L=final_carry.log_L,
num_likelihood_evaluations=final_carry.num_likelihood_evals
)
phantom_samples = tree_map(lambda x: jnp.zeros((0,) + x.shape, x.dtype), sample)
return sample, phantom_samples