"""Newton-Raphson solver with symbolic Jacobian and numpy linear algebra.""" from __future__ import annotations import math from typing import Callable, List import numpy as np from .expr import Expr from .params import ParamTable def newton_solve( residuals: List[Expr], params: ParamTable, quat_groups: List[tuple[str, str, str, str]] | None = None, max_iter: int = 50, tol: float = 1e-10, post_step: "Callable[[ParamTable], None] | None" = None, weight_vector: "np.ndarray | None" = None, jac_exprs: "List[List[Expr]] | None" = None, compiled_eval: "Callable | None" = None, ) -> bool: """Solve ``residuals == 0`` by Newton-Raphson. Parameters ---------- residuals: List of Expr that should each evaluate to zero. params: Parameter table with current values as initial guess. quat_groups: List of (qw, qx, qy, qz) parameter name tuples. After each Newton step these are re-normalized to unit length. max_iter: Maximum Newton iterations. tol: Convergence threshold on ``||r||``. post_step: Optional callback invoked after each parameter update, before quaternion renormalization. Used for half-space correction. weight_vector: Optional 1-D array of length ``n_free``. When provided, the lstsq step is column-scaled to produce the weighted minimum-norm solution (prefer small movements in high-weight parameters). jac_exprs: Pre-built symbolic Jacobian (list-of-lists of Expr). When provided, skips the ``diff().simplify()`` step. compiled_eval: Pre-compiled evaluation function from :mod:`codegen`. When provided, uses flat compiled code instead of tree-walk eval. Returns True if converged within *max_iter* iterations. """ free = params.free_names() n_free = len(free) n_res = len(residuals) if n_free == 0 or n_res == 0: return True # Build symbolic Jacobian once (or reuse pre-built) if jac_exprs is None: jac_exprs = [] for r in residuals: row = [] for name in free: row.append(r.diff(name).simplify()) jac_exprs.append(row) # Try compilation if not provided if compiled_eval is None: from .codegen import try_compile_system compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free) # Pre-allocate arrays reused across iterations r_vec = np.empty(n_res) J = np.zeros((n_res, n_free)) for _it in range(max_iter): if compiled_eval is not None: J[:] = 0.0 compiled_eval(params.env_ref(), r_vec, J) else: env = params.get_env() for i, r in enumerate(residuals): r_vec[i] = r.eval(env) for i in range(n_res): for j in range(n_free): J[i, j] = jac_exprs[i][j].eval(env) r_norm = np.linalg.norm(r_vec) if r_norm < tol: return True # Solve J @ dx = -r (least-squares handles rank-deficient) if weight_vector is not None: # Column-scale J by W^{-1/2} for weighted minimum-norm w_inv_sqrt = 1.0 / np.sqrt(weight_vector) J_scaled = J * w_inv_sqrt[np.newaxis, :] dx_scaled, _, _, _ = np.linalg.lstsq(J_scaled, -r_vec, rcond=None) dx = dx_scaled * w_inv_sqrt else: dx, _, _, _ = np.linalg.lstsq(J, -r_vec, rcond=None) # Update parameters x = params.get_free_vector() x += dx params.set_free_vector(x) # Half-space correction (before quat renormalization) if post_step: post_step(params) # Re-normalize quaternions if quat_groups: _renormalize_quats(params, quat_groups) # Check final residual if compiled_eval is not None: compiled_eval(params.env_ref(), r_vec, J) else: env = params.get_env() for i, r in enumerate(residuals): r_vec[i] = r.eval(env) return bool(np.linalg.norm(r_vec) < tol) def _renormalize_quats( params: ParamTable, groups: List[tuple[str, str, str, str]], ): """Project quaternion params back onto the unit sphere.""" for qw_name, qx_name, qy_name, qz_name in groups: # Skip if all components are fixed if ( params.is_fixed(qw_name) and params.is_fixed(qx_name) and params.is_fixed(qy_name) and params.is_fixed(qz_name) ): continue w = params.get_value(qw_name) x = params.get_value(qx_name) y = params.get_value(qy_name) z = params.get_value(qz_name) norm = math.sqrt(w * w + x * x + y * y + z * z) if norm < 1e-15: # Degenerate — reset to identity params.set_value(qw_name, 1.0) params.set_value(qx_name, 0.0) params.set_value(qy_name, 0.0) params.set_value(qz_name, 0.0) else: params.set_value(qw_name, w / norm) params.set_value(qx_name, x / norm) params.set_value(qy_name, y / norm) params.set_value(qz_name, z / norm)