Source code for jaxns.experimental.solvers.test_cg
from jax import numpy as jnp
from jaxns.experimental.solvers.cg import cg_solve
[docs]
def test_cg():
A = jnp.array([[1.0, 0.0], [0.0, 1.0]])
b = jnp.array([1.0, 1.0])
x0 = jnp.array([0.0, 0.0])
def A_op(x):
return jnp.dot(A, x)
x, diag = cg_solve(A_op, b, x0)
assert jnp.allclose(x, b)
assert diag.iterations == 1