Source code for jaxns.internals.shapes

import numpy as np
from jax import numpy as jnp


[docs] def broadcast_dtypes(*dtypes): """ Returns the dtype with highest precision. Args: *dtypes: list of JAX dtypes. Returns: dtype """ levels = [jnp.bool_, jnp.int32, jnp.int64, jnp.float32, jnp.float64, jnp.complex64, jnp.complex128] output = -1 for dtype in dtypes: if dtype not in levels: raise ValueError("dtype {dtype} not in list {levels}.") output = max(output, levels.index(dtype)) return levels[output]
[docs] def convert_to_array(v): """ If necessary convert v to a jnp.ndarray. Passes through Prior. Args: v: array-like or scalar Returns: jnp.ndarray """ if isinstance(v, (list, tuple, np.ndarray, float, int, bool, complex)): return jnp.asarray(v) return v
[docs] def tuple_prod(t): """ Product of shape tuple Args: t: tuple Returns: int """ if len(t) == 0: return 1 res = t[0] for a in t[1:]: res *= a return res
[docs] def broadcast_shapes(shape1, shape2): """ Broadcasts two shapes together. Args: shape1: tuple of int shape2: tuple of int Returns: tuple of int with resulting shape. """ if isinstance(shape1, int): shape1 = (shape1,) if isinstance(shape2, int): shape2 = (shape2,) def left_pad_shape(shape, l): return tuple([1] * l + list(shape)) l = max(len(shape1), len(shape2)) shape1 = left_pad_shape(shape1, l - len(shape1)) shape2 = left_pad_shape(shape2, l - len(shape2)) out_shape = [] for s1, s2 in zip(shape1, shape2): m = max(s1, s2) if ((s1 != m) and (s1 != 1)) or ((s2 != m) and (s2 != 1)): raise ValueError("Trying to broadcast {} with {}".format(shape1, shape2)) out_shape.append(m) return tuple(out_shape)