Source code for jaxns.framework.bases

from abc import abstractmethod
from typing import Tuple, Optional, Generator, Callable

import jax.numpy as jnp
import numpy as np
from jax import lax

from jaxns.framework.abc import AbstractModel, AbstractPrior, AbstractDistribution
from jaxns.internals.shapes import tuple_prod
from jaxns.internals.types import LikelihoodInputType
from jaxns.internals.types import LikelihoodType, UType, XType, RandomVariableType, MeasureType

__all__ = [
    "PriorModelGen",
    "PriorModelType"
]


class BaseAbstractPrior(AbstractPrior):
    """
    The base prior class with public methods.
    """

    def __init__(self, name: Optional[str] = None):
        self.name = name

    def __repr__(self):
        return f"{self.name if self.name is not None else '*'}\t{self.base_shape} -> {self.shape} {self.dtype}"

    @property
    def dtype(self):
        """
        The dtype of the prior random variable in X-space.
        """
        return self._dtype()

    @property
    def base_shape(self) -> Tuple[int, ...]:
        """
        The base shape of the prior random variable in U-space.
        """
        return self._base_shape()

    @property
    def base_ndims(self):
        """
        The number of dimensions of the prior random variable in U-space.
        """
        return tuple_prod(self.base_shape)

    @property
    def shape(self) -> Tuple[int, ...]:
        """
        The shape of the prior random variable in X-space.
        """
        return self._shape()

    def forward(self, U: UType) -> RandomVariableType:
        """
        The forward transformation from U-space to X-space.

        Args:
            U: U-space representation

        Returns:
            X-space representation
        """
        return self._forward(U)

    def inverse(self, X: RandomVariableType) -> UType:
        """
        The inverse transformation from X-space to U-space.

        Args:
            X: X-space representation

        Returns:
            U-space representation
        """
        return self._inverse(X)

    def log_prob(self, X: RandomVariableType) -> MeasureType:
        """
        The log probability of the prior.

        Args:
            X: X-space representation

        Returns:
            log probability of the prior
        """
        log_prob = self._log_prob(X)
        if np.size(log_prob) > 1:
            log_prob = jnp.sum(log_prob)
        if log_prob.shape != ():
            log_prob = lax.reshape(log_prob, ())
        return log_prob


[docs] PriorModelGen = Generator[BaseAbstractPrior, RandomVariableType, LikelihoodInputType]
[docs] PriorModelType = Callable[[], PriorModelGen]
class BaseAbstractModel(AbstractModel): """ The base model class with public methods. """ def __init__(self, prior_model: PriorModelType, log_likelihood: LikelihoodType): self._prior_model = prior_model self._log_likelihood = log_likelihood @property def prior_model(self) -> PriorModelType: """ The prior model. """ return self._prior_model @property def log_likelihood(self) -> LikelihoodType: """ The log likelihood function. Returns: log likelihood function """ return self._log_likelihood @abstractmethod def _U_placeholder(self) -> UType: ... @abstractmethod def _X_placeholder(self) -> XType: ... @property def U_placeholder(self) -> UType: """ A placeholder for U-space sample. """ return self._U_placeholder() @property def X_placeholder(self) -> XType: """ A placeholder for X-space sample. """ return self._X_placeholder() @property def U_ndims(self) -> int: """ The prior dimensionality. """ return self.U_placeholder.size # TODO(Joshuaalbert): distribution is too similar to prior, where we only need to be able to extract the log_prob and # potential tranformations from the underlying. Perhaps we should just define as ops to create priors. Try making # priors just use distribution functionality. Special priors will need treatment. class BaseAbstractDistribution(AbstractDistribution): """ The base distribution class with public methods. """ @property def dtype(self): """ The dtype of the distribution, in X-space. """ return self._dtype() @property def base_shape(self) -> Tuple[int, ...]: """ The base shape of the distribution, in U-space. """ return self._base_shape() @property def shape(self) -> Tuple[int, ...]: """ The shape of the distribution, in X-space. """ return self._shape() def forward(self, U: UType) -> RandomVariableType: """ The forward transformation from U-space to X-space. Args: U: U-space representation Returns: X-space representation """ return self._forward(U) def inverse(self, X: RandomVariableType) -> UType: """ The inverse transformation from X-space to U-space. Args: X: X-space representation Returns: U-space representation """ return self._inverse(X) def log_prob(self, X: RandomVariableType) -> MeasureType: """ The log probability of the distribution. Args: X: X-space representation Returns: log probability of the distribution """ return self._log_prob(X)