ad_utils ================== .. py:module:: jaxns.experimental.solvers.ad_utils .. rubric:: :code:`jaxns.experimental.solvers.ad_utils` .. rubric:: Module Contents .. py:function:: tree_dot(x, y) .. py:function:: tree_norm(x) .. py:function:: tree_mul(x, y) .. py:function:: tree_div(x, y) .. py:function:: tree_vdot(x, y) .. py:function:: tree_vdot_real_part(x, y) .. py:function:: tree_add(x, y) .. py:function:: tree_scalar_mul(alpha, x) .. py:function:: tree_neg(x) .. py:function:: tree_sub(x, y) .. py:function:: hvp_linearized(f, params) .. py:function:: hvp_forward_over_reverse(f, params) .. py:function:: hvp_reverse_over_reverse(f, params) .. py:function:: hvp_reverse_over_forward(f, params) .. py:function:: grad_and_hvp(f, params, v) Compute the gradient and Hessian-vector product of a function. :param f: the function to differentiate, should be scalar output :param params: the parameters to differentiate with respect to :param v: the vector to multiply the Hessian with :returns: the gradient and Hessian-vector product .. py:function:: build_hvp(f, params, linearise = True) Build a function that computes the Hessian-vector product of a function. :param f: scalar function to differentiate :param params: the parameters to differentiate with respect to :param linearise: whether to linearize the gradient function at params, can be better for reapplying the HVP multiple times. :returns: a function that computes the Hessian-vector product