Source code for jaxns.internals.pytree_utils
import jax
import jax.numpy as jnp
[docs]
def tree_dot(x, y):
dots = jax.tree.leaves(jax.tree.map(jnp.vdot, x, y))
return sum(dots[1:], start=dots[0])
[docs]
def tree_norm(x):
norm2 = tree_dot(x, x)
if jnp.issubdtype(norm2.dtype, jnp.complexfloating):
return jnp.sqrt(norm2.real)
return jnp.sqrt(norm2)