ad_utils
jaxns.experimental.solvers.ad_utils
Module Contents
- grad_and_hvp(f, params, v)[source]
Compute the gradient and Hessian-vector product of a function.
- Parameters:
f – the function to differentiate, should be scalar output
params – the parameters to differentiate with respect to
v – the vector to multiply the Hessian with
- Returns:
the gradient and Hessian-vector product
- build_hvp(f, params, linearise=True)[source]
Build a function that computes the Hessian-vector product of a function.
- Parameters:
f – scalar function to differentiate
params – the parameters to differentiate with respect to
linearise (bool) – whether to linearize the gradient function at params, can be better for reapplying the HVP multiple times.
- Returns:
a function that computes the Hessian-vector product