from typing import Tuple
import jax
from jax import numpy as jnp
from jaxns.framework.bases import PriorModelType, BaseAbstractPrior
from jaxns.framework.prior import InvalidPriorName, SingularPrior
from jaxns.internals.types import UType, XType, float_type, LikelihoodInputType, FloatArray, LikelihoodType, PRNGKey, \
isinstance_namedtuple
__all__ = [
'simulate_prior_model'
]
def compute_U_ndims(prior_model: PriorModelType) -> int:
"""
Computes placeholders of model.
Args:
prior_model: a callable that produces a prior model generator
Returns:
number of U dims
"""
U_ndims = 0
gen = prior_model()
prior_response = None
names = set()
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
U_ndims += d
u = jnp.zeros(prior.base_shape, float_type)
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
except StopIteration:
break
return U_ndims
[docs]
def simulate_prior_model(key: PRNGKey, prior_model: PriorModelType) -> Tuple[LikelihoodInputType, XType]:
"""
Simulate a prior model.
Args:
key: PRNGKey
prior_model: A prior model
Returns:
a tuple of the likelihood input variables, and dict of non-hidden (named) prior variables.
"""
U_ndims = compute_U_ndims(prior_model=prior_model)
U = jax.random.uniform(key, shape=(U_ndims,), dtype=float_type)
return prepare_input(U=U, prior_model=prior_model), transform(U=U, prior_model=prior_model)
def parse_prior(prior_model: PriorModelType) -> Tuple[UType, XType]:
"""
Computes placeholders of model.
Args:
prior_model: a callable that produces a prior model generator
Returns:
U placeholder, X placeholder
"""
U_ndims = 0
gen = prior_model()
prior_response = None
names = set()
X_placeholder: XType = dict()
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
U_ndims += d
u = jnp.zeros(prior.base_shape, float_type)
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if not isinstance(prior, SingularPrior):
X_placeholder[prior.name] = prior_response
except StopIteration:
break
U_placeholder = jnp.zeros((U_ndims,), float_type)
return U_placeholder, X_placeholder
def parse_joint(prior_model: PriorModelType, log_likelihood: LikelihoodType) -> Tuple[
UType, XType, LikelihoodInputType, FloatArray]:
"""
Computes placeholders of model.
Args:
prior_model: a callable that produces a prior model generator
Returns:
U placeholder, X placeholder
"""
U_ndims = 0
gen = prior_model()
prior_response = None
names = set()
X_placeholder: XType = dict()
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
U_ndims += d
u = jnp.zeros(prior.base_shape, float_type)
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if not isinstance(prior, SingularPrior):
X_placeholder[prior.name] = prior_response
except StopIteration as e:
output = e.value
if (not isinstance(output, tuple)) or isinstance_namedtuple(output):
output = (output,)
break
likelihood_input_placeholder = output
log_L_placeholder = jnp.asarray(log_likelihood(*output), float_type)
U_placeholder = jnp.zeros((U_ndims,), float_type)
return U_placeholder, X_placeholder, likelihood_input_placeholder, log_L_placeholder
def transform(U: UType, prior_model: PriorModelType) -> XType:
"""
Transforms a flat array of `U_ndims` i.i.d. samples of U[0,1] into the target prior.
Args:
U: [U_ndims] a flat array of i.i.d. samples of U[0,1]
prior_model: a callable that produces a prior model generator
Returns:
the prior variables
"""
gen = prior_model()
prior_response = None
names = set()
X_collection = dict()
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
u = jnp.reshape(U[idx:idx + d], prior.base_shape)
idx += d
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if not isinstance(prior, SingularPrior):
X_collection[prior.name] = prior_response
except StopIteration:
break
return X_collection
def transform_parametrised(U: UType, prior_model: PriorModelType) -> XType:
"""
Transforms a flat array of `U_ndims` i.i.d. samples of U[0,1] into the the parametrised prior variables.
Args:
U: [U_ndims] a flat array of i.i.d. samples of U[0,1]
prior_model: a callable that produces a prior model generator
Returns:
the parametrised prior variables
"""
gen = prior_model()
prior_response = None
names = set()
Y_collection = dict()
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
u = jnp.reshape(U[idx:idx + d], prior.base_shape)
idx += d
prior_response = prior.forward(u)
if prior.name is not None:
if prior.name in names:
raise InvalidPriorName(name=prior.name)
names.add(prior.name)
if isinstance(prior, SingularPrior):
Y_collection[prior.name] = prior_response
except StopIteration:
break
return Y_collection
def prepare_input(U: UType, prior_model: PriorModelType) -> LikelihoodInputType:
"""
Transforms a flat array of `U_ndims` i.i.d. samples of U[0,1] into the likelihood conditional variables.
Args:
U: [U_ndims] a flat array of i.i.d. samples of U[0,1]
prior_model: a callable that produces a prior model generator
Returns:
the conditional variables of likelihood model
"""
gen = prior_model()
prior_response = None
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
u = jnp.reshape(U[idx:idx + d], prior.base_shape)
idx += d
prior_response = prior.forward(u)
except StopIteration as e:
output = e.value
if (not isinstance(output, tuple)) or isinstance_namedtuple(output):
output = (output,)
break
return output
def compute_log_prob_prior(U: UType, prior_model: PriorModelType) -> FloatArray:
"""
Computes the prior log-density from a U-space sample.
Args:
U: [U_ndims] a flat array of i.i.d. samples of U[0,1]
prior_model: a callable that produces a prior model generator
Returns:
prior log-density
"""
gen = prior_model()
prior_response = None
log_prob = []
idx = 0
while True:
try:
prior: BaseAbstractPrior = gen.send(prior_response)
d = prior.base_ndims
u = jnp.reshape(U[idx:idx + d], prior.base_shape)
idx += d
prior_response = prior.forward(u)
log_prob.append(prior.log_prob(prior_response))
except StopIteration:
break
return sum(log_prob, jnp.asarray(0., float_type))
def compute_log_likelihood(U: UType, prior_model: PriorModelType, log_likelihood: LikelihoodType,
allow_nan: bool = False) -> FloatArray:
"""
Computes the log likelihood from U-space sample.
Args:
U: [U_ndims] a flat array of i.i.d. samples of U[0,1]
prior_model: a callable that produces a prior model generator
log_likelihood: callable that takes arrays returned by the prior model and returns a scalar float
allow_nan: whether to allow nans in likelihood
Returns:
log-likelihood
"""
V = prepare_input(U=U, prior_model=prior_model)
log_L = jnp.asarray(log_likelihood(*V), float_type)
if not allow_nan:
log_L = jnp.where(jnp.isnan(log_L), -jnp.inf, log_L)
if log_L.size != 1:
raise ValueError(f"Log likelihood should be scalar, but got {log_L.shape}.")
if log_L.shape != ():
log_L = jnp.reshape(log_L, ())
return log_L