Source code for jaxns.framework.special_priors

from functools import partial
from typing import Tuple, Union, Optional, Literal

import jax
import numpy as np
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp, vmap, lax
from jax._src.scipy.special import gammaln

from jaxns.framework.bases import BaseAbstractPrior
from jaxns.internals.log_semiring import cumulative_logsumexp
from jaxns.internals.types import FloatArray, IntArray, BoolArray, float_type, int_type

tfpd = tfp.distributions

__all__ = [
    "Bernoulli",
    "Beta",
    "Categorical",
    "ForcedIdentifiability",
    "Poisson",
    "UnnormalisedDirichlet"
]


[docs] class Bernoulli(BaseAbstractPrior): def __init__(self, *, logits=None, probs=None, name: Optional[str] = None): super(Bernoulli, self).__init__(name=name) self.dist = tfpd.Bernoulli(logits=logits, probs=probs) def _dtype(self): return self.dist.dtype def _base_shape(self) -> Tuple[int, ...]: return self._shape() def _shape(self) -> Tuple[int, ...]: return tuple(self.dist.batch_shape_tensor()) + tuple(self.dist.event_shape_tensor()) def _forward(self, U) -> Union[FloatArray, IntArray, BoolArray]: return self._quantile(U) def _inverse(self, X) -> FloatArray: return self.dist.cdf(X) def _log_prob(self, X) -> FloatArray: return self.dist.log_prob(X) def _quantile(self, U): probs = self.dist._probs_parameter_no_checks() sample = jnp.less(U, probs) return sample.astype(self.dtype)
[docs] class Beta(BaseAbstractPrior): def __init__(self, *, concentration0=None, concentration1=None, name: Optional[str] = None): super(Beta, self).__init__(name=name) # Special cases for Beta that are faster use the Kumaraswamy distribution if isinstance(concentration0, (float, int)) and concentration0 == 1: self.dist = tfpd.Kumaraswamy(concentration0=concentration0, concentration1=concentration1) elif isinstance(concentration1, (float, int)) and concentration1 == 1: self.dist = tfpd.Kumaraswamy(concentration0=concentration0, concentration1=concentration1) elif isinstance(concentration0, np.ndarray) and np.all(concentration0 == 1): self.dist = tfpd.Kumaraswamy(concentration0=concentration0, concentration1=concentration1) elif isinstance(concentration1, np.ndarray) and np.all(concentration1 == 1): self.dist = tfpd.Kumaraswamy(concentration0=concentration0, concentration1=concentration1) else: self.dist = tfpd.Beta(concentration0=concentration0, concentration1=concentration1) def _dtype(self): return self.dist.dtype def _base_shape(self) -> Tuple[int, ...]: return self._shape() def _shape(self) -> Tuple[int, ...]: return tuple(self.dist.batch_shape_tensor()) + tuple(self.dist.event_shape_tensor()) def _forward(self, U) -> Union[FloatArray, IntArray, BoolArray]: return self._quantile(U) def _inverse(self, X) -> FloatArray: return self.dist.cdf(X) def _log_prob(self, X) -> FloatArray: return self.dist.log_prob(X) def _quantile(self, U): X = self.dist.quantile(U) return X.astype(self.dtype)
[docs] class Categorical(BaseAbstractPrior): def __init__(self, parametrisation: Literal['gumbel_max', 'cdf'], *, logits=None, probs=None, name: Optional[str] = None): """ Initialised Categorical special prior. Args: parametrisation: 'cdf' is good for discrete params with correlation between neighbouring categories, otherwise gumbel is better. logits: log-prob of each category probs: prob of each category name: optional name """ super(Categorical, self).__init__(name=name) self.dist = tfpd.Categorical(logits=logits, probs=probs) self._parametrisation = parametrisation def _dtype(self): return self.dist.dtype def _base_shape(self) -> Tuple[int, ...]: if self._parametrisation == 'gumbel_max': return self._shape() + (self.dist._num_categories(),) elif self._parametrisation == 'cdf': return self._shape() def _shape(self) -> Tuple[int, ...]: return tuple(self.dist.batch_shape_tensor()) + tuple(self.dist.event_shape_tensor()) def _forward(self, U) -> Union[FloatArray, IntArray, BoolArray]: if self._parametrisation == 'gumbel_max': return self._quantile_gumbelmax(U) elif self._parametrisation == 'cdf': return self._quantile_cdf(U) def _inverse(self, X) -> FloatArray: if self._parametrisation == 'cdf': return self.dist.cdf(X) raise NotImplementedError() def _log_prob(self, X) -> FloatArray: return self.dist.log_prob(X) def _quantile_gumbelmax(self, U): logits = self.dist._logits_parameter_no_checks() sample_dtype = self.dtype z = -jnp.log(-jnp.log(U)) # gumbel draws = jnp.argmax(logits + z, axis=-1).astype(sample_dtype) return draws def _cdf_gumbelmax(self, X): logits = self.dist._logits_parameter_no_checks() z = jnp.max(logits, axis=-1, keepdims=True) logits -= z return jnp.exp(logits) / jnp.sum(jnp.exp(logits), axis=-1, keepdims=True) def _quantile_cdf(self, U): """ The quantile for CDF parametrisation. Args: U: [...] Returns: [...] """ logits = self.dist._logits_parameter_no_checks() # [..., N] cum_logits = cumulative_logsumexp(logits, axis=-1) # [..., N] cum_logits -= cum_logits[..., -1:] N = cum_logits.shape[-1] cum_logits_flat = jnp.reshape(cum_logits, (-1, N)) log_U_flat = jnp.reshape(jnp.log(U), (-1,)) category_flat = vmap(lambda a, v: jnp.searchsorted(a, v, side='left'))(cum_logits_flat, log_U_flat) category = jnp.reshape(category_flat, U.shape) return category
[docs] class ForcedIdentifiability(BaseAbstractPrior): """ Prior for a sequence of `n` random variables uniformly distributed on U[low, high] such that U[i,...] <= U[i+1,...]. For broadcasting the resulting random variable is sorted on the first dimension elementwise. Args: n: number of samples within [low,high] low: minimum of distribution high: maximum of distribution fix_left: if True, the leftmost value is fixed to `low` fix_right: if True, the rightmost value is fixed to `high` """ def __init__(self, *, n: int, low=None, high=None, fix_left: bool = False, fix_right: bool = False, name: Optional[str] = None): super(ForcedIdentifiability, self).__init__(name=name) n_min = (1 if fix_left else 0) + (1 if fix_right else 0) if n < n_min: raise ValueError(f'`n` too small for fix_left={fix_left} and fix_right={fix_right}') self.n = n low, high = jnp.broadcast_arrays(low, high) self.low = low self.high = high self.fix_left = fix_left self.fix_right = fix_right def _dtype(self): return float_type def _base_shape(self) -> Tuple[int, ...]: num_base = self.n if self.fix_left: num_base -= 1 if self.fix_right: num_base -= 1 return (num_base,) + np.shape(self.low) def _shape(self) -> Tuple[int, ...]: return (self.n,) + np.shape(self.low) def _forward(self, U) -> Union[FloatArray, IntArray, BoolArray]: return self._quantile(U) def _inverse(self, X) -> FloatArray: return self._cdf(X) def _log_prob(self, X) -> FloatArray: n = self.n if self.fix_left: n -= 1 if self.fix_right: n -= 1 log_n_fac = gammaln(n + 1) diff = self.high - self.low log_prob = - log_n_fac - n * jnp.log(diff) # no check that X is inside high and low return log_prob def _cdf(self, X): if self.fix_left: X = X[1:] if self.fix_right: X = X[:-1] n = self.n if self.fix_left: n -= 1 if self.fix_right: n -= 1 log_output = jnp.log((X - self.low) / (self.high - self.low)) # Step 2: Undo cumulative sum operation inner = jnp.diff(log_output[::-1], prepend=0, axis=0)[::-1] # Step 3: Undo reshaping and division k = jnp.arange(n) + 1 log_x = inner * jnp.reshape(k, (n,) + (1,) * (len(self.shape) - 1)) # Step 4: Find U by exponentiating log_x U = jnp.exp(log_x) return U.astype(self.dtype) def _quantile(self, U): n = self.n if self.fix_left: n -= 1 if self.fix_right: n -= 1 log_x = jnp.log(U) # [n, ...] k = jnp.arange(n) + 1 inner = log_x / lax.reshape(k, (n,) + (1,) * (len(self.shape) - 1)) log_output = jnp.cumsum(inner[::-1], axis=0)[::-1] output = self.low + (self.high - self.low) * jnp.exp(log_output) if self.fix_left: output = jnp.concatenate([self.low[None], output], axis=0) if self.fix_right: output = jnp.concatenate([output, self.high[None]], axis=0) return output.astype(self.dtype)
def _poisson_quantile_bisection(U, rate, max_iter=15, unroll: bool = True): # max_iter is set so that error < 1 up to rate of 1e4 rate = jnp.maximum(jnp.asarray(rate), 1e-5) if np.size(rate) > 1: raise ValueError("Rate must be a scalar") if np.size(U) > 1: U_flat = U.ravel() x_final, x_results = vmap(lambda u: _poisson_quantile_bisection(u, rate, max_iter, unroll))(U_flat) return x_final.reshape(U.shape), x_results.reshape(U.shape + (max_iter,)) def smooth_cdf(x, rate): return lax.igammac(x + 1., rate) def fixed_point_update(x, args): (a, b, f_a, f_b) = x c = 0.5 * (a + b) f_c = smooth_cdf(c, rate) left = f_c > U a1 = jnp.where(left, a, c) f_a1 = jnp.where(left, f_a, f_c) b1 = jnp.where(left, c, b) f_b1 = jnp.where(left, f_c, f_b) a2 = a f_a2 = f_a b2 = b * 2. f_b2 = smooth_cdf(b2, rate) bounded = f_b >= U # a already bounds. a = jnp.where(bounded, a1, a2) b = jnp.where(bounded, b1, b2) f_a = jnp.where(bounded, f_a1, f_a2) f_b = jnp.where(bounded, f_b1, f_b2) new_x = (a, b, f_a, f_b) return new_x, 0.5 * (a + b) a = jnp.asarray(0.) b = jnp.asarray(rate) f_a = jnp.asarray(0.) f_b = smooth_cdf(b, rate) init = (a, b, f_a, f_b) # Dummy array to facilitate using scan for a fixed number of iterations (a, b, f_a, f_b), x_results = lax.scan( fixed_point_update, init, jnp.arange(max_iter), unroll=max_iter if unroll else 1 ) c = 0.5 * (a + b) return c, x_results @partial(jax.jit, static_argnames=("unroll",)) def _poisson_quantile(U, rate, unroll: bool = False): x, _ = _poisson_quantile_bisection(U, rate, unroll=unroll) return x.astype(int_type)
[docs] class Poisson(BaseAbstractPrior): def __init__(self, *, rate=None, log_rate=None, name: Optional[str] = None): super(Poisson, self).__init__(name=name) self.dist = tfpd.Poisson(rate=rate, log_rate=log_rate) def _dtype(self): return int_type def _base_shape(self) -> Tuple[int, ...]: return self._shape() def _shape(self) -> Tuple[int, ...]: return tuple(self.dist.batch_shape_tensor()) + tuple(self.dist.event_shape_tensor()) def _forward(self, U) -> Union[FloatArray, IntArray, BoolArray]: return self._quantile(U) def _inverse(self, X) -> FloatArray: return self.dist.cdf(X) def _log_prob(self, X) -> FloatArray: return self.dist.log_prob(X) def _quantile(self, U): """ Algorithmic Poisson generator based upon the inversion by sequential search .. code-block:: init: Let log_x ← -inf, log_p ← −λ, log_s ← log_p. Generate uniform random number u in [0,1]. while log_u > log_s do: log_x ← logaddexp(log_x, 0). log_p ← log_p + log_λ - log_x. log_s ← logaddexp(log_s, log_p). return exp(log_x). """ rate = self.dist.rate_parameter() if np.size(rate) > 1: return jax.vmap(lambda u, r: _poisson_quantile(u, r))(U.ravel(), rate.ravel()).reshape(self.shape) return _poisson_quantile(U, rate)
[docs] class UnnormalisedDirichlet(BaseAbstractPrior): """ Represents an unnormalised dirichlet distribution of K classes. That is, the output is related to the K-simplex via normalisation. X ~ UnnormalisedDirichlet(alpha) Y = X / sum(X) ==> Y ~ Dirichlet(alpha) """ def __init__(self, *, concentration, name: Optional[str] = None): super(UnnormalisedDirichlet, self).__init__(name=name) self._dirichlet_dist = tfpd.Dirichlet(concentration=concentration) self._gamma_dist = tfpd.Gamma(concentration=concentration, log_rate=0.) def _dtype(self): return self._dirichlet_dist.dtype def _base_shape(self) -> Tuple[int, ...]: return self._shape() def _shape(self) -> Tuple[int, ...]: return tuple(self._dirichlet_dist.batch_shape_tensor()) + tuple(self._dirichlet_dist.event_shape_tensor()) def _forward(self, U) -> Union[FloatArray, IntArray, BoolArray]: return self._quantile(U) def _inverse(self, X) -> FloatArray: return self._gamma_dist.cdf(X) def _log_prob(self, X) -> FloatArray: return jnp.sum(self._gamma_dist.log_prob(X), axis=-1) def _quantile(self, U): gamma = self._gamma_dist.quantile(U).astype(self.dtype) return gamma