Source code for jaxns.experimental.solvers.gauss_newton_cg

from typing import NamedTuple, TypeVar, Tuple, Callable, Any

import jax
import jax.numpy as jnp

from jaxns.experimental.solvers.ad_utils import tree_neg, tree_vdot_real_part, tree_scalar_mul, tree_add, tree_sub, \
    build_hvp
from jaxns.experimental.solvers.cg import cg_solve
from jaxns.internals.types import IntArray, FloatArray, BoolArray

# ----------------------------------------------------------------
# Type helpers
# ----------------------------------------------------------------
[docs] DomainType = TypeVar("DomainType") # parameter pytree
[docs] ObjectiveRet = TypeVar("ObjectiveRet") # scalar objective (FloatArray or 0-D ndarray)
[docs] CT = TypeVar("CT")
_CT = TypeVar("_CT") # ---------------------------------------------------------------- # Utility: split complex pytrees into real pairs so Wirtinger calculus # works out of the box with JAX’s real autodiff. # ----------------------------------------------------------------
[docs] def convert_to_real(x: CT) -> Tuple[_CT, Callable[[_CT], CT]]: """Return a real-valued twin of `x` and a merge-back function.""" def _maybe_split(a: jax.Array | Any): if isinstance(a, jax.Array) and jnp.iscomplexobj(a): return (a.real, a.imag) return a leaves, treedef = jax.tree.flatten(x) split_leaves = jax.tree.map(_maybe_split, leaves) def merge(split_x: _CT) -> CT: def _maybe_merge(a): if isinstance(a, tuple): return jax.lax.complex(a[0], a[1]) return a merged = list(map(_maybe_merge, split_x)) return jax.tree.unflatten(treedef, merged) return split_leaves, merge
# ---------------------------------------------------------------- # Diagnostics – patterned after LMDiagnostic # ----------------------------------------------------------------
[docs] class NewtonDiagnostic(NamedTuple):
[docs] iteration: IntArray
[docs] g_norm: FloatArray # |∇f|
[docs] mu: FloatArray # damping parameter
[docs] damping: FloatArray # g_norm / mu
[docs] cg_iters: IntArray
[docs] f: FloatArray # f(x_k)
[docs] f_prop: FloatArray # f(x_k + δx_k)
[docs] f_quad: FloatArray # quadratic model at proposal
[docs] delta_f_pred: FloatArray # predicted decrease
[docs] delta_f_actual: FloatArray # actual decrease
[docs] gain_ratio: FloatArray # delta_f_actual / delta_f_pred
[docs] accepted: BoolArray
[docs] in_trust_region: BoolArray
[docs] delta_x_norm: FloatArray # ‖δx‖
[docs] ddelta_x_norm: FloatArray # ‖δx – δx⁰‖
# ---------------------------------------------------------------- # Main solver # ----------------------------------------------------------------
[docs] def newton_cg_solver( obj_fn: Callable[..., ObjectiveRet], x0: DomainType, args: tuple = (), maxiter: int = 100, maxiter_cg: int = 100, gtol: float = 3e-5, p_accept: float = 0.01, p_lower: float = 0.25, p_upper: float = 1.10, mu_init: float = 1.0, mu_min: float = 1e-6, mu_in_factor: float = 5, mu_out_factor: float = 0.1, approx_hvp: bool = False, # reuse H·v between rejections verbose: bool = False, ) -> Tuple[DomainType, NewtonDiagnostic]: """ Trust-region Newton-CG minimiser. Identical call signature and adaptive-μ logic as `lm_solver`, but uses a scalar objective instead of residuals, and solves (H + damping·I) δx = -∇f by CG with Hessian–vector products. Returns ------- x_final : pytree matching `x0` (merged back to complex if needed) diagnostics : NewtonDiagnostic[...] array with length = `maxiter` """ # ---- 1. Handle complex inputs ----------------------------------------- x0_real, merge_back = convert_to_real(x0) # Wrap obj_fn so that it consumes / produces purely real pytrees def _obj_fn(x): val = obj_fn(merge_back(x), *args) if not isinstance(val, jax.Array): raise RuntimeError(f"Objective function must return a JAX scalar array, got {type(val)}.") if jnp.ndim(val) != 0: raise RuntimeError("Objective function must return a scalar.") return val # ---- 2. State container ----------------------------------------------- class NState(NamedTuple): x: DomainType # current parameters (real) f: FloatArray # f(x) g: DomainType # -∇f (descent direction) g_norm: FloatArray mu: FloatArray delta_x_prev: DomainType # δx⁻¹ delta_x_prev2: DomainType # δx⁻² iter: IntArray # ---- 3. Helpers -------------------------------------------------------- def _gradient(x): return jax.grad(_obj_fn)(x) def _initial_state(x): f0 = _obj_fn(x) grad_f = _gradient(x) g0 = tree_neg(grad_f) # -∇f g_norm0 = jnp.sqrt(tree_vdot_real_part(g0, g0)) g_unit = tree_scalar_mul(1.0 / (g_norm0 + 1e-12), g0) # Simple backtracking line search along -∇f to pick starting μ def ls_cond(mu): step = tree_scalar_mul(mu, g_unit) f_new = _obj_fn(tree_add(x, step)) return (f_new >= f0) & (mu > mu_min) def ls_body(mu): return mu * mu_out_factor mu0 = jax.lax.while_loop(ls_cond, ls_body, mu_init) z = jax.tree.map(jnp.zeros_like, x) return NState( x=x, f=f0, g=g0, g_norm=g_norm0, mu=mu0, delta_x_prev=z, delta_x_prev2=z, iter=0 ) # ---- 4. Iteration ------------------------------------------------------ def cond_fn(carry): state, _ = carry return (state.g_norm > gtol) & (state.iter < maxiter) def step_fn(state: NState): # 4.1 Warm-start: δx⁰ = 2δx⁻¹ − δx⁻² delta_x0 = tree_sub(tree_add(state.delta_x_prev, state.delta_x_prev), state.delta_x_prev2) # 4.2 Build Hessian-vector product operator (maybe reused) hvp = build_hvp(_obj_fn, state.x, linearise=True) damping = state.g_norm / state.mu def A_op(v): return tree_add(hvp(v), tree_scalar_mul(damping, v)) # 4.3 Solve Newton system with CG delta_x, cg_diag = cg_solve( A=A_op, b=state.g, x0=delta_x0, maxiter=maxiter_cg, tol=1e-5, atol=0.0 ) # 4.4 Book-keeping norms delta_x_norm = jnp.sqrt(tree_vdot_real_part(delta_x, delta_x)) ddelta_x = tree_sub(delta_x, delta_x0) ddelta_x_norm = jnp.sqrt(tree_vdot_real_part(ddelta_x, ddelta_x)) # 4.5 Evaluate objective at proposal x_prop = tree_add(state.x, delta_x) f_prop = _obj_fn(x_prop) # Quadratic model prediction hvp_dx = hvp(delta_x) quad_term = 0.5 * tree_vdot_real_part(delta_x, hvp_dx) lin_term = tree_vdot_real_part(state.g, delta_x) # g·δx f_quad = state.f - (lin_term - quad_term) # f(x)+... delta_f_pred = state.f - f_quad # should be >0 delta_f_actual = state.f - f_prop gain_ratio = delta_f_actual / delta_f_pred gain_ratio = jnp.where(jnp.isnan(gain_ratio), 0.0, gain_ratio) # handle NaN # 4.6 Trust-region logic in_trust = (delta_f_pred > 0) & (delta_f_actual > p_lower * delta_f_pred) & ( delta_f_actual < p_upper * delta_f_pred) new_mu = jax.lax.select(in_trust, mu_in_factor * state.mu, state.mu * mu_out_factor) new_mu = jnp.maximum(new_mu, mu_min) accepted = (delta_f_pred > 0) & (delta_f_actual > p_accept * delta_f_pred) # 4.7 Diagnostics ----------------------------------------------------- diag = NewtonDiagnostic( iteration=state.iter, g_norm=state.g_norm, mu=state.mu, damping=damping, cg_iters=cg_diag.iterations, f=state.f, f_prop=f_prop, f_quad=f_quad, delta_f_pred=delta_f_pred, delta_f_actual=delta_f_actual, gain_ratio=gain_ratio, accepted=accepted, in_trust_region=in_trust, delta_x_norm=delta_x_norm, ddelta_x_norm=ddelta_x_norm ) if verbose: jax.debug.print( "iter {iteration:3d} f={f:9.3e} |g|={g_norm:9.3e} " "μ={mu:9.3e} r={gain_ratio:6.2f} " "Δf_pred={delta_f_pred:9.3e} Δf_act={delta_f_actual:9.3e} " "CG={cg_iters}", iteration=state.iter, f=state.f, g_norm=state.g_norm, mu=state.mu, gain_ratio=diag.gain_ratio, delta_f_pred=diag.delta_f_pred, delta_f_actual=diag.delta_f_actual, cg_iters=diag.cg_iters ) # 4.8 Accept / reject step ( x_new, f_new, delta_x_prev_new, delta_x_prev2_new ) = jax.tree.map( lambda a, b: jax.lax.select(accepted, a, b), (x_prop, f_prop, delta_x, state.delta_x_prev), (state.x, state.f, state.delta_x_prev, state.delta_x_prev2) ) # 4.9 Gradient at (possibly) new point if approx_hvp: # keep previous hvp & gradient g_new = state.g g_norm_new = state.g_norm else: grad_f_new = _gradient(x_new) g_new = tree_neg(grad_f_new) g_norm_new = jnp.sqrt(tree_vdot_real_part(g_new, g_new)) # 4.10 Next state return NState( x=x_new, f=f_new, g=g_new, g_norm=g_norm_new, mu=new_mu, delta_x_prev=delta_x_prev_new, delta_x_prev2=delta_x_prev2_new, iter=state.iter + 1 ), diag # -------------------------------------------------------------------- # 5. Main while-loop with diagnostic accumulation # -------------------------------------------------------------------- init_state = _initial_state(x0_real) # prototype diag for shape inference diag_aval = jax.eval_shape(lambda s: step_fn(s)[1], init_state) empty_diag = jax.tree.map( lambda arr: jnp.zeros((maxiter,) + arr.shape, arr.dtype), diag_aval ) def body_fn(carry): st, diag_arr = carry new_st, new_d = step_fn(st) diag_arr = jax.tree.map( lambda arr, d: arr.at[st.iter].set(d), diag_arr, new_d ) return new_st, diag_arr final_state, final_diag = jax.lax.while_loop( cond_fn, body_fn, (init_state, empty_diag) ) # Merge complex components back to the user space return merge_back(final_state.x), final_diag