Source code for jaxns.internals.types

from typing import NamedTuple, Union, Any, Callable, Tuple, Dict, TypeVar

import jax
import numpy as np

__all__ = [
    'PRNGKey',
    'IntArray',
    'FloatArray',
    'BoolArray',
    'LikelihoodType',
    'UType',
    'XType',
    'LikelihoodInputType',
    'RandomVariableType',
    'MeasureType'
]

[docs] PRNGKey = jax.Array
Array = Union[ jax.Array, # JAX array type np.ndarray, # NumPy array type ]
[docs] FloatArray = Union[ jax.Array, # JAX array type np.ndarray, # NumPy array type float, # valid scalars ]
[docs] IntArray = Union[ jax.Array, # JAX array type np.ndarray, # NumPy array type int, # valid scalars ]
[docs] BoolArray = Union[ jax.Array, # JAX array type np.ndarray, # NumPy array type np.bool_, bool, # valid scalars ]
Array.__doc__ = "Type annotation for JAX array-like objects, with no scalar types." FloatArray.__doc__ = "Type annotation for JAX array-like objects, with float scalar types." IntArray.__doc__ = "Type annotation for JAX array-like objects, with int scalar types." BoolArray.__doc__ = "Type annotation for JAX array-like objects, with bool scalar types."
[docs] LikelihoodType = Callable[..., FloatArray]
[docs] RandomVariableType = TypeVar('RandomVariableType')
[docs] MeasureType = TypeVar('MeasureType')
[docs] LikelihoodInputType = Union[Tuple[Any, ...], Any] # Likelihood conditional variables
[docs] UType = jax.Array # Sample space type
WType = Tuple[jax.Array, ...]
[docs] XType = Dict[str, RandomVariableType] # Prior variable type
class SignedLog(NamedTuple): """ Represents a signed value in log-space """ log_abs_val: jax.Array sign: Union[jax.Array, Any] def isinstance_namedtuple(obj) -> bool: """ Check if object is a namedtuple. Args: obj: object Returns: bool """ return ( isinstance(obj, tuple) and hasattr(obj, '_asdict') and hasattr(obj, '_fields') )