ad_utils

jaxns.experimental.solvers.ad_utils

Module Contents

tree_dot(x, y)[source]
tree_norm(x)[source]
tree_mul(x, y)[source]
tree_div(x, y)[source]
tree_vdot(x, y)[source]
tree_vdot_real_part(x, y)[source]
tree_add(x, y)[source]
tree_scalar_mul(alpha, x)[source]
tree_neg(x)[source]
tree_sub(x, y)[source]
hvp_linearized(f, params)[source]
hvp_forward_over_reverse(f, params)[source]
hvp_reverse_over_reverse(f, params)[source]
hvp_reverse_over_forward(f, params)[source]
grad_and_hvp(f, params, v)[source]

Compute the gradient and Hessian-vector product of a function.

Parameters:
  • f – the function to differentiate, should be scalar output

  • params – the parameters to differentiate with respect to

  • v – the vector to multiply the Hessian with

Returns:

the gradient and Hessian-vector product

build_hvp(f, params, linearise=True)[source]

Build a function that computes the Hessian-vector product of a function.

Parameters:
  • f – scalar function to differentiate

  • params – the parameters to differentiate with respect to

  • linearise (bool) – whether to linearize the gradient function at params, can be better for reapplying the HVP multiple times.

Returns:

a function that computes the Hessian-vector product