import inspect
import warnings
from typing import Tuple, Callable, Generator
import jax
import numpy as np
from jax import numpy as jnp, lax
from jaxns.framework.bases import PriorModelType, BaseAbstractPrior, PriorModelGen
from jaxns.framework.prior import InvalidPriorName, SingularPrior, Prior
from jaxns.internals.maps import pytree_unravel
from jaxns.internals.mixed_precision import mp_policy
from jaxns.internals.types import UType, XType, LikelihoodInputType, FloatArray, LikelihoodType, PRNGKey, \
isinstance_namedtuple, WType, RandomVariableType
__all__ = [
'simulate_prior_model'
]
def _get_prior_model_gen(prior_model: PriorModelType) -> PriorModelGen:
gen = prior_model()
# Check if gen is a generator
if not inspect.isgenerator(gen):
warnings.warn(
"The provided prior_model is not a generator, this may mean you forget `yield` statements. "
"This means there are no Bayesian variables."
)
def dummy_prior_model(output):
_ = yield Prior(0.)
return output
# Make an empty generator that returns the output.
gen = dummy_prior_model(gen)
return gen
def compute_U_ndims(prior_model: PriorModelType) -> int:
"""
Computes placeholders of model.
Args:
prior_model: a callable that produces a prior model generator
Returns:
number of U dims
"""
U_ndims = 0
gen = _get_prior_model_gen(prior_model=prior_model)
prior_response = None
names = set()
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
U_ndims += d
u = jnp.full(prior.base_shape, 0.5, mp_policy.measure_dtype)
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
except StopIteration:
break
return U_ndims
[docs]
def simulate_prior_model(key: PRNGKey, prior_model: PriorModelType) -> Tuple[LikelihoodInputType, XType]:
"""
Simulate a prior model.
Args:
key: PRNGKey
prior_model: A prior model
Returns:
a tuple of the likelihood input variables, and dict of non-hidden (named) prior variables.
"""
# parse prior
U_placeholder, _, W_placeholder = parse_prior(prior_model=prior_model)
_, unravel_fn = pytree_unravel(W_placeholder)
U = jax.random.uniform(key, shape=U_placeholder.shape, dtype=mp_policy.measure_dtype)
W = unravel_fn(U)
return prepare_input(W=W, prior_model=prior_model), transform(W=W, prior_model=prior_model)
def parse_prior(prior_model: PriorModelType) -> Tuple[UType, XType, WType]:
"""
Computes placeholders of model.
Args:
prior_model: a callable that produces a prior model generator
Returns:
U placeholder, X placeholder, W placeholder
"""
U_ndims = 0
gen = _get_prior_model_gen(prior_model=prior_model)
prior_response = None
names = set()
X_placeholder: XType = dict()
W_placeholder: WType = ()
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
U_ndims += d
u = jnp.full(prior.base_shape, 0.5, mp_policy.measure_dtype)
W_placeholder += (u,)
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if not isinstance(prior, SingularPrior):
X_placeholder[prior.name] = prior_response
except StopIteration:
break
U_placeholder = jnp.full((U_ndims,), 0.5, mp_policy.measure_dtype)
return U_placeholder, X_placeholder, W_placeholder
def parse_joint(prior_model: PriorModelType, log_likelihood: LikelihoodType) -> Tuple[
UType, XType, WType, LikelihoodInputType, FloatArray
]:
"""
Computes placeholders of model.
Args:
prior_model: a callable that produces a prior model generator
Returns:
U placeholder, X placeholder, W placeholder, likelihood input placeholder, log likelihood placeholder
"""
U_ndims = 0
gen = _get_prior_model_gen(prior_model=prior_model)
prior_response = None
names = set()
X_placeholder: XType = dict()
W_placeholder: WType = ()
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
U_ndims += d
u = jnp.full(prior.base_shape, 0.5, mp_policy.measure_dtype)
W_placeholder += (u,)
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if not isinstance(prior, SingularPrior):
X_placeholder[prior.name] = prior_response
except StopIteration as e:
output = e.value
if (not isinstance(output, tuple)) or isinstance_namedtuple(output):
output = (output,)
break
likelihood_input_placeholder = output
log_L_placeholder = mp_policy.cast_to_measure(log_likelihood(*output))
U_placeholder = jnp.full((U_ndims,), 0.5, mp_policy.measure_dtype)
return U_placeholder, X_placeholder, W_placeholder, likelihood_input_placeholder, log_L_placeholder
def transform(W: WType, prior_model: PriorModelType) -> XType:
"""
Transforms a W sample into the prior variables.
Args:
W: tuple of W-space samples
prior_model: a callable that produces a prior model generator
Returns:
the prior variables
"""
gen = _get_prior_model_gen(prior_model=prior_model)
prior_response = None
names = set()
X_collection = dict()
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
u = W[idx]
idx += 1
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if not isinstance(prior, SingularPrior):
X_collection[prior.name] = prior_response
except StopIteration:
break
return X_collection
def transform_parametrised(W: WType, prior_model: PriorModelType) -> XType:
"""
Transforms a W sample into the parametrised prior variables.
Args:
W: tuple of W-space samples
prior_model: a callable that produces a prior model generator
Returns:
the parametrised prior variables
"""
gen = _get_prior_model_gen(prior_model=prior_model)
prior_response = None
names = set()
X_collection = dict()
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
u = W[idx]
idx += 1
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if isinstance(prior, SingularPrior):
X_collection[prior.name] = prior_response
except StopIteration:
break
return X_collection
def prepare_input(W: WType, prior_model: PriorModelType) -> LikelihoodInputType:
"""
Transforms a W sample into the likelihood conditional variables.
Args:
W: tuple of W-space samples
prior_model: a callable that produces a prior model generator
Returns:
the conditional variables of likelihood model
"""
gen = _get_prior_model_gen(prior_model=prior_model)
prior_response = None
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
u = W[idx]
idx += 1
prior_response = prior.forward(u)
except StopIteration as e:
output = e.value
if (not isinstance(output, tuple)) or isinstance_namedtuple(output):
output = (output,)
break
return output
def compute_log_prob_prior(W: WType, prior_model: PriorModelType) -> FloatArray:
"""
Computes the prior log-density from a W-space sample.
Args:
W: tuple of W-space samples
prior_model: a callable that produces a prior model generator
Returns:
prior log-density
"""
gen = _get_prior_model_gen(prior_model=prior_model)
prior_response = None
log_prob = []
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
u = W[idx]
idx += 1
prior_response = prior.forward(u)
log_prob.append(prior.log_prob(prior_response))
except StopIteration:
break
if len(log_prob) == 0:
return jnp.asarray(0., mp_policy.measure_dtype)
else:
return sum(log_prob[1:], log_prob[0])
def compute_log_likelihood(W: WType, prior_model: PriorModelType, log_likelihood: LikelihoodType,
allow_nan: bool = False) -> FloatArray:
"""
Computes the log likelihood from W-space sample.
Args:
W: tuple of W-space samples
prior_model: a callable that produces a prior model generator
log_likelihood: callable that takes arrays returned by the prior model and returns a scalar float
allow_nan: whether to allow nans in likelihood
Returns:
log-likelihood
"""
V = prepare_input(W=W, prior_model=prior_model)
log_L = mp_policy.cast_to_measure(log_likelihood(*V))
if np.size(log_L) != 1:
raise ValueError(f"Log likelihood should be scalar, but got {log_L.shape}.")
if log_L.shape != ():
log_L = lax.reshape(log_L, ())
if not allow_nan:
is_nan = lax.ne(log_L, log_L)
log_L = lax.select(is_nan, jnp.asarray(-jnp.inf, mp_policy.measure_dtype), log_L)
return log_L
def memoize_prior_model(prior_model: PriorModelType, *args, **kwargs) -> Generator[
BaseAbstractPrior, RandomVariableType, Callable
]:
"""
Memoize the prior model into a pure function. This can be used, e.g. to compute jacobians, or gradients inside a
prior model.
Args:
prior_model: a prior model optionally with *args or **kwargs.
*args: inputs to prior model
**kwargs: inputs to prior model
Returns:
a generator that eventually returns a pure function of prior_model inputs.
"""
gen = prior_model(*args, **kwargs)
prior_response = None
stack = []
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
prior_response = yield prior
stack.append(prior_response)
except StopIteration as e:
# output = e.value
break
def _pure_fn(*args, **kwargs):
gen = prior_model(*args, **kwargs)
stack_iter = iter(stack)
prior_response = None
while True:
try:
_ = gen.send(prior_response)
# Pass the response that was given at the time of memoization.
prior_response = next(stack_iter)
except StopIteration as e:
output = e.value
break
return output
return _pure_fn