mixed_precision
jaxns.internals.mixed_precision
Module Contents
- class Policy[source]
Encapsulates casting for inputs, outputs and parameters.
- measure_dtype: jax.numpy.dtype[source]
- index_dtype: jax.numpy.dtype[source]
- count_dtype: jax.numpy.dtype[source]
- cast_to_index(x, quiet=False)[source]
Converts index values to the index dtype.
- Parameters:
x (X)
quiet (bool)
- Return type:
X