import logging
from typing import TypeVar, NamedTuple, Tuple

import jax
from jax import numpy as jnp, random, lax, tree_map

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.cumulative_ops import cumulative_op_static
from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, float_type, int_type, \
    StaticStandardNestedSamplerState, \
    IntArray, UType, StaticStandardSampleCollection
from import SamplerState
from jaxns.samplers.bases import SeedPoint, BaseAbstractMarkovSampler

logger = logging.getLogger('jaxns')

T = TypeVar('T')

def _sample_direction(n_key: PRNGKey, ndim: int) -> FloatArray:
    Choose a direction randomly from S^(D-1).

        n_key: PRNG key
        ndim: int, number of dimentions

        direction: [D] direction from S^(D-1)
    if ndim == 1:
        return jnp.ones(())
    direction = random.normal(n_key, shape=(ndim,), dtype=float_type)
    direction /= jnp.linalg.norm(direction)
    return direction

def _slice_bounds(point_U0: FloatArray, direction: FloatArray) -> Tuple[FloatArray, FloatArray]:
    Compute the slice bounds, t, where point_U0 + direction * t intersects uit cube boundary.

        point_U0: [D]
        direction: [D]

        left_bound: left most point (<= 0).
        right_bound: right most point (>= 0).
    t1 = (1. - point_U0) / direction
    t1_right = jnp.min(jnp.where(t1 >= 0., t1, jnp.inf))
    t1_left = jnp.max(jnp.where(t1 <= 0., t1, -jnp.inf))
    t0 = -point_U0 / direction
    t0_right = jnp.min(jnp.where(t0 >= 0., t0, jnp.inf))
    t0_left = jnp.max(jnp.where(t0 <= 0., t0, -jnp.inf))
    right_bound = jnp.minimum(t0_right, t1_right)
    left_bound = jnp.maximum(t0_left, t1_left)
    return left_bound, right_bound

def _pick_point_in_interval(key: PRNGKey, point_U0: FloatArray, direction: FloatArray, left: FloatArray,
                            right: FloatArray) -> Tuple[FloatArray, FloatArray]:
    Select a point along slice in [point_U0 + direction * left, point_U0 + direction * right]

        key: PRNG key
        point_U0: [D]
        direction: [D]
        left: left most point (<= 0).
        right: right most point (>= 0).

        point_U: [D]
        t: selection point between [left, right]
    t = random.uniform(key, minval=left, maxval=right, dtype=float_type)
    point_U = point_U0 + t * direction
    # close_to_zero = (left >= -10*jnp.finfo(left.dtype).eps) & (right <= 10*jnp.finfo(right.dtype).eps)
    # point_U = jnp.where(close_to_zero, point_U0, point_U)
    # t = jnp.where(close_to_zero, jnp.zeros_like(t), t)
    return point_U, t

def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, log_L_proposal: FloatArray,
                     log_L_constraint: FloatArray, log_L0: FloatArray,
                     midpoint_shrink: bool) -> Tuple[FloatArray, FloatArray]:
    Not successful proposal, so shrink, optionally apply exponential shrinkage.
    # witout exponential shrinkage, we shrink to failed proposal point, which is 100% correct.
    left = jnp.where(t < 0., t, left)
    right = jnp.where(t > 0., t, right)
    key, t_key, midpoint_key = random.split(key, 3)

    if midpoint_shrink:
        # we take two points along lines from origin:
        #  - alpha: one from a satisfying point (t=0) to non-satisfying proposal,
        #  - beta: and, one from non-satisfying proposal to the constraint.
        # We shrink to alpha point if beta point is above alpha point.
        # Intuitively, the tangent from constraint is more accurate than proposal and should be a supremum of heights
        # in reasonable cases.
        # An extension is to make alpha shrink to constraint line, which would shrink very fast, but introduce
        # auto-correlation which must be later refined away.
        # Line logic:
        # logL(t) = m * t + b
        # logL(0) = b
        # (logL(t_R) - logL(0))/t_R
        # logL(t_R*alpha) = (logL(t_R) - logL(0))*alpha + logL(0)
        alpha_key, beta_key = random.split(midpoint_key, 2)
        alpha = random.uniform(alpha_key)
        beta = random.uniform(beta_key)
        logL_alpha = log_L0 + alpha * (log_L_proposal - log_L0)
        logL_beta = log_L_proposal + beta * (log_L_constraint - log_L_proposal)
        do_mid_point_shrink = logL_alpha < logL_beta
        left = jnp.where((t < 0.) & do_mid_point_shrink, alpha * left, left)
        right = jnp.where((t > 0.) & do_mid_point_shrink, alpha * right, right)
    return left, right

def _new_proposal(key: PRNGKey, seed_point: SeedPoint, midpoint_shrink: bool, perfect: bool,
                  gradient_slice: bool,
                  log_L_constraint: FloatArray,
                  model: BaseAbstractModel) -> Tuple[FloatArray, FloatArray, IntArray]:
    Sample from a slice about a seed point.

        key: PRNG key
        seed_point: the seed point to sample from
        midpoint_shrink: if true then contract to the midpoint of interval on rejection. Otherwise, normal contract
        perfect: if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure.
        gradient_slice: if true the slice along gradient direction
        log_L_constraint: the constraint to sample within
        model: the model to sample from

        point_U: the new sample
        log_L: the log-likelihood of the new sample
        num_likelihood_evaluations: the number of likelihood evaluations performed

    class Carry(NamedTuple):
        key: PRNGKey
        direction: FloatArray
        left: FloatArray
        right: FloatArray
        t: FloatArray
        point_U: UType
        log_L: FloatArray
        num_likelihood_evaluations: IntArray

    def cond(carry: Carry) -> BoolArray:
        close_to_zero_interval = (carry.right - carry.left) <= 2 * jnp.finfo(carry.right.dtype).eps
        satisfaction = carry.log_L > log_L_constraint
        # Allow if on plateau to fly around the plateau for a while
        lesser_satisfaction = jnp.bitwise_and(seed_point.log_L0 == log_L_constraint, carry.log_L == log_L_constraint)
        done = jnp.bitwise_or(jnp.bitwise_or(close_to_zero_interval, satisfaction), lesser_satisfaction)
        return jnp.bitwise_not(done)

    def body(carry: Carry) -> Carry:
        key, t_key, shrink_key = random.split(carry.key, 3)
        left, right = _shrink_interval(
        point_U, t = _pick_point_in_interval(
        log_L = model.forward(point_U)
        num_likelihood_evaluations = carry.num_likelihood_evaluations + jnp.ones_like(carry.num_likelihood_evaluations)
        return Carry(

    key, n_key, t_key = random.split(key, 3)
    if gradient_slice:
        direction = jax.grad(model.forward)(seed_point.U0)
        norm = jnp.linalg.norm(direction)
        direction /= norm
        direction = jnp.where(
            jnp.bitwise_or(jnp.equal(norm, jnp.zeros_like(norm)), ~jnp.isfinite(norm)),
            _sample_direction(n_key, seed_point.U0.size),
        num_likelihood_evaluations = jnp.full((), 1, int_type)
        (left, right) = _slice_bounds(
        left = jnp.zeros_like(left)
        direction = _sample_direction(n_key, seed_point.U0.size)
        num_likelihood_evaluations = jnp.full((), 0, int_type)
        if perfect:
            (left, right) = _slice_bounds(
            # TODO: implement doubling step out
            raise NotImplementedError("TODO: implement doubling step out")
    point_U, t = _pick_point_in_interval(
    log_L = model.forward(point_U)
    init_carry = Carry(

    carry = lax.while_loop(
    return carry.point_U, carry.log_L, carry.num_likelihood_evaluations

[docs] class UniDimSliceSampler(BaseAbstractMarkovSampler): """ Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples. """ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: int, midpoint_shrink: bool, perfect: bool, gradient_slice: bool = False): """ Unidimensional slice sampler. Args: model: AbstractModel num_slices: number of slices between acceptance. Note: some other software use units of prior dimension. midpoint_shrink: if true then contract to the midpoint of interval on rejection. Otherwise, contract to rejection point. Speeds up convergence, but introduces minor auto-correlation. num_phantom_save: number of phantom samples to save. Phantom samples are samples that meeting the constraint but are not accepted. They can be used for numerous things, e.g. to estimate the evidence uncertainty. perfect: if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance. gradient_slice: if true then always slice along gradient direction. """ super().__init__(model=model) if num_slices < 1: raise ValueError(f"num_slices should be >= 1, got {num_slices}.") if num_phantom_save < 0: raise ValueError(f"num_phantom_save should be >= 0, got {num_phantom_save}.") if num_phantom_save >= num_slices: raise ValueError(f"num_phantom_save should be < num_slices, got {num_phantom_save} >= {num_slices}.") self.num_slices = int(num_slices) self.num_phantom_save = int(num_phantom_save) self.midpoint_shrink = bool(midpoint_shrink) self.perfect = bool(perfect) self.gradient_slice = bool(gradient_slice) if not self.perfect: raise ValueError("Only perfect slice sampler is implemented.")
[docs] def num_phantom(self) -> int: return self.num_phantom_save
[docs] def pre_process(self, state: StaticStandardNestedSamplerState) -> SamplerState: sample_collection = tree_map(lambda x: x[state.front_idx], state.sample_collection) if self.perfect: # nothing needed return (sample_collection,) else: # TODO: step out with doubling, using ellipsoidal clustering return (sample_collection,) # multi_ellipsoidal_params()
[docs] def post_process(self, sample_collection: StaticStandardSampleCollection, sampler_state: SamplerState) -> SamplerState: if self.perfect: # nothing needed return (sample_collection,) else: # TODO: step out with doubling, using ellipsoidal clustering, could shrink ellipsoids return (sample_collection,)
[docs] def get_seed_point(self, key: PRNGKey, sampler_state: SamplerState, log_L_constraint: FloatArray) -> SeedPoint: sample_collection: StaticStandardSampleCollection (sample_collection,) = sampler_state select_mask = sample_collection.log_L > log_L_constraint # If non satisfied samples, then choose randomly from them. any_satisfied = jnp.any(select_mask) yes_ = jnp.asarray(0., float_type) no_ = jnp.asarray(-jnp.inf, float_type) unnorm_select_log_prob = jnp.where( any_satisfied, jnp.where(select_mask, yes_, no_), yes_ ) # Choose randomly where mask is True g = random.gumbel(key, shape=unnorm_select_log_prob.shape) sample_idx = jnp.argmax(g + unnorm_select_log_prob) return SeedPoint( U0=sample_collection.U_samples[sample_idx], log_L0=sample_collection.log_L[sample_idx] )
[docs] def get_sample_from_seed(self, key: PRNGKey, seed_point: SeedPoint, log_L_constraint: FloatArray, sampler_state: SamplerState) -> Tuple[Sample, Sample]: def propose_op(sample: Sample, key: PRNGKey) -> Sample: U_sample, log_L, num_likelihood_evaluations = _new_proposal( key=key, seed_point=SeedPoint( U0=sample.U_sample, log_L0=sample.log_L ), midpoint_shrink=self.midpoint_shrink, perfect=self.perfect, gradient_slice=self.gradient_slice, log_L_constraint=log_L_constraint, model=self.model ) return Sample( U_sample=U_sample, log_L_constraint=log_L_constraint, log_L=log_L, num_likelihood_evaluations=num_likelihood_evaluations + sample.num_likelihood_evaluations ) init_sample = Sample( U_sample=seed_point.U0, log_L_constraint=log_L_constraint, log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) final_sample, cumulative_samples = cumulative_op_static( op=propose_op, init=init_sample, xs=random.split(key, self.num_slices), unroll=2 ) # Last sample is the final sample, the rest are potential phantom samples # Take only the last num_phantom_save phantom samples phantom_samples: Sample = tree_map(lambda x: x[-(self.num_phantom_save + 1):-1], cumulative_samples) # Due to the cumulative nature of the sampler, the final number of likelihood evaluations should be divided # equally among the accepted sample and retained phantom samples. num_likelihood_evaluations_per_phantom_sample = ( final_sample.num_likelihood_evaluations / (self.num_phantom_save + 1) ).astype(int_type) num_likelihood_evaluations_per_accepted_sample = ( final_sample.num_likelihood_evaluations - num_likelihood_evaluations_per_phantom_sample * self.num_phantom_save ) final_sample = final_sample._replace( num_likelihood_evaluations=num_likelihood_evaluations_per_accepted_sample ) phantom_samples = phantom_samples._replace( num_likelihood_evaluations=jnp.full( phantom_samples.num_likelihood_evaluations.shape, num_likelihood_evaluations_per_phantom_sample, phantom_samples.num_likelihood_evaluations.dtype ) ) return final_sample, phantom_samples