from abc import abstractmethod
from typing import Tuple, Optional, Generator, Callable
import jax.numpy as jnp
from jaxns.framework.abc import AbstractModel, AbstractPrior, AbstractDistribution
from jaxns.internals.shapes import tuple_prod
from jaxns.internals.types import LikelihoodInputType
from jaxns.internals.types import LikelihoodType, UType, XType, RandomVariableType, MeasureType
__all__ = [
'BaseAbstractModel',
"BaseAbstractPrior",
"BaseAbstractDistribution",
"PriorModelGen",
"PriorModelType"
]
[docs]
class BaseAbstractPrior(AbstractPrior):
"""
The base prior class with public methods.
"""
def __init__(self, name: Optional[str] = None):
self.name = name
[docs]
def __repr__(self):
return f"{self.name if self.name is not None else '*'}\t{self.base_shape} -> {self.shape} {self.dtype}"
@property
[docs]
def dtype(self):
"""
The dtype of the prior random variable in X-space.
"""
return self._dtype()
@property
[docs]
def base_shape(self) -> Tuple[int, ...]:
"""
The base shape of the prior random variable in U-space.
"""
return self._base_shape()
@property
[docs]
def base_ndims(self):
"""
The number of dimensions of the prior random variable in U-space.
"""
return tuple_prod(self.base_shape)
@property
[docs]
def shape(self) -> Tuple[int, ...]:
"""
The shape of the prior random variable in X-space.
"""
return self._shape()
[docs]
def forward(self, U: UType) -> RandomVariableType:
"""
The forward transformation from U-space to X-space.
Args:
U: U-space representation
Returns:
X-space representation
"""
return self._forward(U)
[docs]
def inverse(self, X: RandomVariableType) -> UType:
"""
The inverse transformation from X-space to U-space.
Args:
X: X-space representation
Returns:
U-space representation
"""
return self._inverse(X)
[docs]
def log_prob(self, X: RandomVariableType) -> MeasureType:
"""
The log probability of the prior.
Args:
X: X-space representation
Returns:
log probability of the prior
"""
log_prob = self._log_prob(X)
if log_prob.size > 1:
log_prob = jnp.sum(log_prob)
if log_prob.shape != ():
log_prob = log_prob.reshape(())
return log_prob
[docs]
PriorModelGen = Generator[BaseAbstractPrior, RandomVariableType, LikelihoodInputType]
[docs]
PriorModelType = Callable[[], PriorModelGen]
[docs]
class BaseAbstractModel(AbstractModel):
"""
The base model class with public methods.
"""
def __init__(self, prior_model: PriorModelType, log_likelihood: LikelihoodType):
self._prior_model = prior_model
self._log_likelihood = log_likelihood
@property
[docs]
def prior_model(self) -> PriorModelType:
"""
The prior model.
"""
return self._prior_model
@property
[docs]
def log_likelihood(self) -> LikelihoodType:
"""
The log likelihood function.
Returns:
log likelihood function
"""
return self._log_likelihood
@abstractmethod
def _U_placeholder(self) -> UType:
...
@abstractmethod
def _X_placeholder(self) -> XType:
...
@property
[docs]
def U_placeholder(self) -> UType:
"""
A placeholder for U-space sample.
"""
return self._U_placeholder()
@property
[docs]
def X_placeholder(self) -> XType:
"""
A placeholder for X-space sample.
"""
return self._X_placeholder()
@property
[docs]
def U_ndims(self) -> int:
"""
The prior dimensionality.
"""
return self.U_placeholder.size
# TODO(Joshuaalbert): distribution is too similar to prior, where we only need to be able to extract the log_prob and
# potential tranformations from the underlying. Perhaps we should just define as ops to create priors. Try making
# priors just use distribution functionality. Special priors will need treatment.
[docs]
class BaseAbstractDistribution(AbstractDistribution):
"""
The base distribution class with public methods.
"""
@property
[docs]
def dtype(self):
"""
The dtype of the distribution, in X-space.
"""
return self._dtype()
@property
[docs]
def base_shape(self) -> Tuple[int, ...]:
"""
The base shape of the distribution, in U-space.
"""
return self._base_shape()
@property
[docs]
def shape(self) -> Tuple[int, ...]:
"""
The shape of the distribution, in X-space.
"""
return self._shape()
[docs]
def forward(self, U: UType) -> RandomVariableType:
"""
The forward transformation from U-space to X-space.
Args:
U: U-space representation
Returns:
X-space representation
"""
return self._forward(U)
[docs]
def inverse(self, X: RandomVariableType) -> UType:
"""
The inverse transformation from X-space to U-space.
Args:
X: X-space representation
Returns:
U-space representation
"""
return self._inverse(X)
[docs]
def log_prob(self, X: RandomVariableType) -> MeasureType:
"""
The log probability of the distribution.
Args:
X: X-space representation
Returns:
log probability of the distribution
"""
return self._log_prob(X)