Source code for jaxns.samplers.multi_ellipsoidal_samplers

import dataclasses
import warnings
from typing import NamedTuple, Tuple, Any

import jax
from jax import random, numpy as jnp, lax

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import IntArray, UType
from jaxns.internals.types import PRNGKey, FloatArray
from jaxns.nested_samplers.common.types import Sample
from jaxns.samplers.abc import EphemeralState
from jaxns.samplers.bases import BaseAbstractRejectionSampler
from jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils import ellipsoid_clustering, MultEllipsoidState
from jaxns.samplers.multi_ellipsoid.multi_ellipsoid_utils import sample_multi_ellipsoid

__all__ = [
    'MultiEllipsoidalSampler'
]


@dataclasses.dataclass(eq=False)
[docs] class MultiEllipsoidalSampler(BaseAbstractRejectionSampler[MultEllipsoidState]): """ Uses a multi-ellipsoidal decomposition of the live points to create a bound around regions to sample from. Inefficient for high dimensional problems, but can be very efficient for low dimensional problems. """
[docs] model: BaseAbstractModel
[docs] depth: int
[docs] expansion_factor: float
[docs] def __post_init__(self): warnings.warn( "MultiEllipsoidalSampler does not give consistent results/performance. Consider `UniDimSliceSampler`.") if self.depth < 0: raise ValueError(f"depth {self.depth} must be >= 0") if self.expansion_factor <= 0.: raise ValueError(f"expansion_factor {self.expansion_factor} must be > 0")
[docs] def num_phantom(self) -> int: return 0
def _pre_process(self, ephemeral_state: EphemeralState) -> Any: return ellipsoid_clustering( key=ephemeral_state.key, points=ephemeral_state.live_points_collection.U_sample, log_VS=ephemeral_state.termination_register.evidence_calc_with_remaining.log_X_mean, max_num_ellipsoids=self.max_num_ellipsoids, method='em_gmm' ) def _post_process(self, ephemeral_state: EphemeralState, sampler_state: Any) -> Any: return sampler_state @property
[docs] def max_num_ellipsoids(self): return 2 ** self.depth
def _get_sample(self, key: PRNGKey, log_L_constraint: FloatArray, sampler_state: MultEllipsoidState) -> Tuple[ Sample, Sample]: def _sample_multi_ellipsoid(key: PRNGKey) -> UType: _, U = sample_multi_ellipsoid( key=key, mu=sampler_state.params.mu, radii=sampler_state.params.radii * jnp.asarray(self.expansion_factor, mp_policy.measure_dtype), rotation=sampler_state.params.rotation, unit_cube_constraint=True ) return U class CarryState(NamedTuple): key: PRNGKey U: FloatArray log_L: FloatArray num_likelihood_evals: IntArray def cond(carry: CarryState): return carry.log_L <= log_L_constraint def body(carry: CarryState): key, sample_key = random.split(carry.key, 2) point_U = _sample_multi_ellipsoid(key=sample_key) log_L = self.model.forward(U=point_U) num_likelihood_evals = carry.num_likelihood_evals + jnp.ones_like(carry.num_likelihood_evals) # backoff by one e-fold per attempt after efficiency threshold reached return CarryState( key=key, U=point_U, log_L=log_L, num_likelihood_evals=num_likelihood_evals ) key, sample_key = random.split(key, 2) point_U = _sample_multi_ellipsoid(key=sample_key) init_carry_state = CarryState( key=key, U=point_U, log_L=self.model.forward(point_U), num_likelihood_evals=jnp.asarray(1, mp_policy.count_dtype) ) final_carry = lax.while_loop( cond_fun=cond, body_fun=body, init_val=init_carry_state ) sample = Sample( U_sample=final_carry.U, log_L_constraint=log_L_constraint, log_L=final_carry.log_L, num_likelihood_evaluations=final_carry.num_likelihood_evals ) # TODO: could use rejected samples, not as phantom because they don't satisfy constraint, but for ML apps phantom_samples = jax.tree.map(lambda x: jnp.zeros((0,) + x.shape, x.dtype), sample) # [k, D] return sample, phantom_samples