mixed_precision

jaxns.internals.mixed_precision

Module Contents

T[source]
X[source]
class Policy[source]

Encapsulates casting for inputs, outputs and parameters.

measure_dtype: jax.numpy.dtype[source]
index_dtype: jax.numpy.dtype[source]
count_dtype: jax.numpy.dtype[source]
cast_to_index(x, quiet=False)[source]

Converts index values to the index dtype.

Parameters:
  • x (X)

  • quiet (bool)

Return type:

X

cast_to_measure(x, quiet=False)[source]

Converts measure values to the measure dtype.

Parameters:
  • x (X)

  • quiet (bool)

Return type:

X

cast_to_count(x, quiet=False)[source]

Converts count values to the count dtype.

Parameters:
  • x (X)

  • quiet (bool)

Return type:

X

mp_policy[source]
float_type[source]
int_type[source]
complex_type[source]