gauss_newton_cg

jaxns.experimental.solvers.gauss_newton_cg

Module Contents

DomainType[source]
ObjectiveRet[source]
CT[source]
convert_to_real(x)[source]

Return a real-valued twin of x and a merge-back function.

Parameters:

x (CT)

Return type:

Tuple[_CT, Callable[[_CT], CT]]

class NewtonDiagnostic[source]

Bases: NamedTuple

iteration: jaxns.internals.types.IntArray[source]
g_norm: jaxns.internals.types.FloatArray[source]
mu: jaxns.internals.types.FloatArray[source]
damping: jaxns.internals.types.FloatArray[source]
cg_iters: jaxns.internals.types.IntArray[source]
f: jaxns.internals.types.FloatArray[source]
f_prop: jaxns.internals.types.FloatArray[source]
f_quad: jaxns.internals.types.FloatArray[source]
delta_f_pred: jaxns.internals.types.FloatArray[source]
delta_f_actual: jaxns.internals.types.FloatArray[source]
gain_ratio: jaxns.internals.types.FloatArray[source]
accepted: jaxns.internals.types.BoolArray[source]
in_trust_region: jaxns.internals.types.BoolArray[source]
delta_x_norm: jaxns.internals.types.FloatArray[source]
ddelta_x_norm: jaxns.internals.types.FloatArray[source]
newton_cg_solver(obj_fn, x0, args=(), maxiter=100, maxiter_cg=100, gtol=3e-05, p_accept=0.01, p_lower=0.25, p_upper=1.1, mu_init=1.0, mu_min=1e-06, mu_in_factor=5, mu_out_factor=0.1, approx_hvp=False, verbose=False)[source]

Trust-region Newton-CG minimiser.

Identical call signature and adaptive-μ logic as lm_solver, but uses a scalar objective instead of residuals, and solves

(H + damping·I) δx = -∇f

by CG with Hessian–vector products.

Returns:

  • x_final (pytree matching x0 (merged back to complex if needed))

  • diagnostics (NewtonDiagnostic[…] array with length = maxiter)

Parameters:
Return type:

Tuple[DomainType, NewtonDiagnostic]