pytree_utils

jaxns.internals.pytree_utils

Module Contents

tree_dot(x, y)[source]
tree_norm(x)[source]
tree_mul(x, y)[source]
tree_sub(x, y)[source]
tree_div(x, y)[source]