cg

jaxns.experimental.solvers.cg

Module Contents

class CGDiagnostics[source]

Bases: NamedTuple

iterations: jaxns.internals.types.IntArray[source]
final_res_norm: jaxns.internals.types.FloatArray[source]
DomainType[source]
cg_solve(A, b, x0, M=_identity, maxiter=100, tol=1e-05, atol=0.0)[source]

Solve a linear system Ax = b using the conjugate gradient method.

Parameters:
  • A (Callable[[DomainType], DomainType]) – a square PSD linear operator

  • b (DomainType) – the right-hand side

  • x0 (DomainType) – an initial guess for the solution

  • M (Callable[[DomainType], DomainType]) – a preconditioner for A

  • maxiter (int | None) – the maximum number of iterations, if None then size of b

  • tol (float) – the relative tolerance for the residual norm

  • atol (float) – the absolute tolerance for the residual norm

Returns:

the solution x and diagnostics

Return type:

Tuple[DomainType, CGDiagnostics]