from functools import partial
import jax
import jax.numpy as jnp
_dot = partial(jnp.dot, precision=jax.lax.Precision.HIGHEST)
_vdot = partial(jnp.vdot, precision=jax.lax.Precision.HIGHEST)
_einsum = partial(jnp.einsum, precision=jax.lax.Precision.HIGHEST)
[docs]
def tree_dot(x, y):
dots = jax.tree.leaves(jax.tree.map(jnp.vdot, x, y))
return sum(dots[1:], start=dots[0])
[docs]
def tree_norm(x):
return jnp.sqrt(tree_dot(x, x).real)
[docs]
def tree_mul(x, y):
return jax.tree.map(jax.lax.mul, x, y)
[docs]
def tree_div(x, y):
return jax.tree.map(jax.lax.div, x, y)
# aliases for working with pytrees
def _vdot_real_part(x, y):
"""Vector dot-product guaranteed to have a real valued result despite
possibly complex input. Thus neglects the real-imaginary cross-terms.
The result is a real float.
"""
# all our uses of vdot() in CG are for computing an operator of the form
# z^H M z
# where M is positive definite and Hermitian, so the result is
# real valued:
# https://en.wikipedia.org/wiki/Definiteness_of_a_matrix#Definitions_for_complex_matrices
result = _vdot(x.real, y.real)
if jnp.iscomplexobj(x) or jnp.iscomplexobj(y):
result += _vdot(x.imag, y.imag)
return result
[docs]
def tree_vdot(x, y):
z = jax.tree.leaves(jax.tree.map(_vdot, x, y))
return sum(z[1:], z[0])
[docs]
def tree_vdot_real_part(x, y):
z = jax.tree.leaves(jax.tree.map(_vdot_real_part, x, y))
return sum(z[1:], z[0])
[docs]
def tree_add(x, y):
return jax.tree.map(jax.lax.add, x, y)
[docs]
def tree_scalar_mul(alpha, x):
return jax.tree.map(lambda x: alpha * x, x)
[docs]
def tree_neg(x):
return jax.tree.map(jax.lax.neg, x)
[docs]
def tree_sub(x, y):
return jax.tree.map(jax.lax.sub, x, y)
[docs]
def hvp_linearized(f, params):
# Compute the gradient function and linearize it at params
grad_f = jax.grad(f)
_, jvp_lin = jax.linearize(grad_f, params)
# lin_fun is a function that computes the JVP of grad_f at params
return jvp_lin # This function computes HVPs for different v
[docs]
def hvp_forward_over_reverse(f, params):
def hvp(v):
return jax.jvp(jax.grad(f), (params,), (v,))[1]
return hvp
[docs]
def hvp_reverse_over_reverse(f, params):
def hvp(v):
return jax.grad(lambda y: jnp.vdot(jax.grad(f)(y), v))(params)
return hvp
[docs]
def hvp_reverse_over_forward(f, params):
def hvp(v):
jvp_fun = lambda params: jax.jvp(f, (params,), (v,))[1]
return jax.grad(jvp_fun)(params)
return hvp
[docs]
def grad_and_hvp(f, params, v):
"""
Compute the gradient and Hessian-vector product of a function.
Args:
f: the function to differentiate, should be scalar output
params: the parameters to differentiate with respect to
v: the vector to multiply the Hessian with
Returns:
the gradient and Hessian-vector product
"""
return jax.jvp(jax.grad(f), (params,), (v,))
[docs]
def build_hvp(f, params, linearise: bool = True):
"""
Build a function that computes the Hessian-vector product of a function.
Args:
f: scalar function to differentiate
params: the parameters to differentiate with respect to
linearise: whether to linearize the gradient function at params, can be better for reapplying the HVP multiple
times.
Returns:
a function that computes the Hessian-vector product
"""
if linearise:
# Compute the gradient function and linearize it at params
grad_f = jax.grad(f)
# lin_fun is a function that computes the JVP of grad_f at params
_, grad_jvp_lin = jax.linearize(grad_f, params)
def matvec(v):
return grad_jvp_lin(v)
else:
def matvec(v):
return grad_and_hvp(f, params, v)[1]
return matvec