import logging
from typing import Tuple, Optional, Union
import haiku as hk
import jax.nn
import tensorflow_probability.substrates.jax as tfp
from jax import numpy as jnp
from jaxns.framework.bases import BaseAbstractPrior, BaseAbstractDistribution
from jaxns.framework.distribution import Distribution
from jaxns.internals.types import FloatArray, IntArray, BoolArray, XType, UType, float_type
tfpd = tfp.distributions
__all__ = [
"Prior",
"InvalidPriorName"
]
logger = logging.getLogger('jaxns')
[docs]
class InvalidPriorName(Exception):
"""
Raised when a prior name is already taken.
"""
def __init__(self, name: Optional[str] = None):
super(InvalidPriorName, self).__init__(f'Prior name {name} already taken by another prior.')
class SingularPrior(BaseAbstractPrior):
"""
Represents a singular prior, which has no inverse transformation, but does have a log_prob
(at the singular value).
"""
def __init__(self, value: jnp.ndarray, dist: BaseAbstractDistribution, name: Optional[str] = None):
super().__init__(name=name)
self.value = value
self.dist = dist
def __repr__(self):
return f"{self.value} -> {self.dist}"
def _dtype(self):
return self.dist.dtype
def _base_shape(self) -> Tuple[int, ...]:
return (0,) # Singular prior has no base shape
def _shape(self) -> Tuple[int, ...]:
return self.dist.shape
def _forward(self, U: UType) -> Union[FloatArray, IntArray, BoolArray]:
return self.value
def _inverse(self, X: XType) -> FloatArray:
return jnp.asarray([], float_type)
def _log_prob(self, X: XType) -> FloatArray:
return self.dist.log_prob(X)
[docs]
class Prior(BaseAbstractPrior):
"""
Represents a generative prior.
"""
def __init__(self, dist_or_value: Union[tfpd.Distribution, BaseAbstractDistribution, jnp.ndarray],
name: Optional[str] = None):
super(Prior, self).__init__(name=name)
if isinstance(dist_or_value, tfpd.Distribution):
self._type = 'dist'
self._dist = Distribution(dist_or_value)
elif isinstance(dist_or_value, BaseAbstractDistribution):
self._type = 'dist'
self._dist = dist_or_value
else:
self._type = 'value'
self._value = jnp.asarray(dist_or_value)
self.name = name
@property
[docs]
def dist(self) -> BaseAbstractDistribution:
if self._type != 'dist':
raise ValueError(f"Wrong type, got {self._type}")
return self._dist
@property
[docs]
def value(self) -> jnp.ndarray:
if self._type != 'value':
raise ValueError(f"Wrong type, got {self._type}")
return self._value
def _base_shape(self) -> Tuple[int, ...]:
if self._type == 'value':
return (0,)
elif self._type == 'dist':
return self.dist.base_shape
else:
raise NotImplementedError()
def _shape(self) -> Tuple[int, ...]:
if self._type == 'value':
return self.value.shape
elif self._type == 'dist':
return self.dist.shape
else:
raise NotImplementedError()
def _dtype(self):
if self._type == 'value':
return self.value.dtype
elif self._type == 'dist':
return self.dist.dtype
else:
raise NotImplementedError()
def _forward(self, U: UType) -> Union[FloatArray, IntArray, BoolArray]:
if self._type == 'value':
return self.value
elif self._type == 'dist':
return self.dist.forward(U)
else:
raise NotImplementedError()
def _inverse(self, X: XType) -> FloatArray:
if self._type == 'value':
return jnp.asarray([], float_type)
elif self._type == 'dist':
return self.dist.inverse(X)
else:
raise NotImplementedError()
def _log_prob(self, X: XType) -> FloatArray:
if self._type == 'value':
return jnp.asarray(0., float_type)
elif self._type == 'dist':
return self.dist.log_prob(X=X)
else:
raise NotImplementedError()
[docs]
def parametrised(self, random_init: bool = False) -> SingularPrior:
"""
Convert this prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated
log_prob. The parameter is registered as a `hk.Parameter` with added `_param` name suffix.
Args:
random_init: whether to initialise the parameter randomly or at the median of the distribution.
Returns:
A singular prior.
"""
if self._type == 'value':
raise ValueError("Cannot parametrise a prior without distribution.")
return prior_to_parametrised_singular(self, random_init=random_init)
def prior_to_parametrised_singular(prior: Prior, random_init: bool = False) -> SingularPrior:
"""
Convert a prior into a non-Bayesian parameter, that takes a single value in the model, but still has an associated
log_prob. The parameter is registered as a `hk.Parameter` with added `_param` name suffix.
To constrain the parameter we use a Normal parameter with centre on unit cube, and scale covering the whole cube,
as the base representation. This base representation covers the whole real line and be reliably used with SGD, etc.
Args:
prior: any prior
random_init: whether to initialise the parameter randomly or at the median of the distribution.
Returns:
A parameter representing the prior.
"""
if prior.name is None:
raise ValueError("Prior must have a name to be parametrised.")
name = f"{prior.name}_param"
# Initialises at median of distribution.
if random_init:
init_value = jax.random.normal(hk.next_rng_key(), shape=prior.base_shape, dtype=float_type)
else:
init_value = jnp.zeros(prior.base_shape, dtype=float_type)
if init_value.size == 0:
logger.warning(f"Creating a zero-sized parameter for {prior.name}. Probably unintended.")
norm_U_base_param = hk.get_parameter(
name=name,
shape=prior.base_shape,
dtype=float_type,
init=hk.initializers.Constant(init_value)
)
# transform [-inf, inf] -> [0,1]
# Sigmoid is faster than ndtr to save FLOPs
# U_base_param = ndtr(norm_U_base_param)
U_base_param = jax.nn.sigmoid(norm_U_base_param)
param = prior.forward(U_base_param)
return SingularPrior(value=param, dist=prior.dist, name=prior.name)