from typing import NamedTuple, Union, Any, Callable, Tuple, Dict, TypeVar
import jax
import numpy as np
__all__ = [
'PRNGKey',
'IntArray',
'FloatArray',
'BoolArray',
'LikelihoodType',
'UType',
'XType',
'LikelihoodInputType',
'RandomVariableType',
'MeasureType'
]
Array = Union[
jax.Array, # JAX array type
np.ndarray, # NumPy array type
]
[docs]
FloatArray = Union[
jax.Array, # JAX array type
np.ndarray, # NumPy array type
float, # valid scalars
]
[docs]
IntArray = Union[
jax.Array, # JAX array type
np.ndarray, # NumPy array type
int, # valid scalars
]
[docs]
BoolArray = Union[
jax.Array, # JAX array type
np.ndarray, # NumPy array type
np.bool_, bool, # valid scalars
]
Array.__doc__ = "Type annotation for JAX array-like objects, with no scalar types."
FloatArray.__doc__ = "Type annotation for JAX array-like objects, with float scalar types."
IntArray.__doc__ = "Type annotation for JAX array-like objects, with int scalar types."
BoolArray.__doc__ = "Type annotation for JAX array-like objects, with bool scalar types."
[docs]
LikelihoodType = Callable[..., FloatArray]
[docs]
RandomVariableType = TypeVar('RandomVariableType')
[docs]
MeasureType = TypeVar('MeasureType')
[docs]
UType = jax.Array # Sample space type
WType = Tuple[jax.Array, ...]
[docs]
XType = Dict[str, RandomVariableType] # Prior variable type
class SignedLog(NamedTuple):
"""
Represents a signed value in log-space
"""
log_abs_val: jax.Array
sign: Union[jax.Array, Any]
def isinstance_namedtuple(obj) -> bool:
"""
Check if object is a namedtuple.
Args:
obj: object
Returns:
bool
"""
return (
isinstance(obj, tuple) and
hasattr(obj, '_asdict') and
hasattr(obj, '_fields')
)