shapes

jaxns.internals.shapes

Module Contents

broadcast_dtypes(*dtypes)[source]

Returns the dtype with highest precision.

Parameters:

*dtypes – list of JAX dtypes.

Returns: dtype

convert_to_array(v)[source]

If necessary convert v to a jnp.ndarray. Passes through Prior.

Parameters:

v – array-like or scalar

Returns: jnp.ndarray

tuple_prod(t)[source]

Product of shape tuple

Parameters:

t – tuple

Returns:

int

broadcast_shapes(shape1, shape2)[source]

Broadcasts two shapes together.

Parameters:
  • shape1 – tuple of int

  • shape2 – tuple of int

Returns: tuple of int with resulting shape.