"""L-BFGS-B fallback solver for when Newton-Raphson fails to converge. Minimizes f(x) = 0.5 * sum(r_i(x)^2) using scipy's L-BFGS-B with analytic gradient from the Expr DAG's symbolic differentiation. """ from __future__ import annotations import math from typing import Callable, List import numpy as np from .expr import Expr from .params import ParamTable try: from scipy.optimize import minimize as _scipy_minimize _HAS_SCIPY = True except ImportError: _HAS_SCIPY = False def bfgs_solve( residuals: List[Expr], params: ParamTable, quat_groups: List[tuple[str, str, str, str]] | None = None, max_iter: int = 200, tol: float = 1e-10, weight_vector: "np.ndarray | None" = None, jac_exprs: "List[List[Expr]] | None" = None, compiled_eval: "Callable | None" = None, ) -> bool: """Solve ``residuals == 0`` by minimizing sum of squared residuals. Falls back gracefully to False if scipy is not available. When *weight_vector* is provided, residuals are scaled by ``sqrt(w)`` so that the objective becomes ``0.5 * sum(w_i * r_i^2)`` — equivalent to weighted least-squares. Parameters ---------- jac_exprs: Pre-built symbolic Jacobian (list-of-lists of Expr). compiled_eval: Pre-compiled evaluation function from :mod:`codegen`. Returns True if converged (||r|| < tol). """ if not _HAS_SCIPY: return False free = params.free_names() n_free = len(free) n_res = len(residuals) if n_free == 0 or n_res == 0: return True # Build symbolic gradient expressions once: d(r_i)/d(x_j) if jac_exprs is None: jac_exprs = [] for r in residuals: row = [] for name in free: row.append(r.diff(name).simplify()) jac_exprs.append(row) # Try compilation if not provided if compiled_eval is None: from .codegen import try_compile_system compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free) # Pre-compute scaling for weighted minimum-norm if weight_vector is not None: w_sqrt = np.sqrt(weight_vector) w_inv_sqrt = 1.0 / w_sqrt else: w_sqrt = None w_inv_sqrt = None # Pre-allocate arrays reused across objective calls r_vals = np.empty(n_res) J = np.zeros((n_res, n_free)) def objective_and_grad(y_vec): # Transform back from scaled space if weighted if w_inv_sqrt is not None: x_vec = y_vec * w_inv_sqrt else: x_vec = y_vec # Update params params.set_free_vector(x_vec) if quat_groups: _renormalize_quats(params, quat_groups) if compiled_eval is not None: J[:] = 0.0 compiled_eval(params.env_ref(), r_vals, J) else: env = params.get_env() for i, r in enumerate(residuals): r_vals[i] = r.eval(env) for i in range(n_res): for j in range(n_free): J[i, j] = jac_exprs[i][j].eval(env) f = 0.5 * np.dot(r_vals, r_vals) # Gradient of f w.r.t. x = J^T @ r grad_x = J.T @ r_vals # Chain rule: df/dy = df/dx * dx/dy = grad_x * w_inv_sqrt if w_inv_sqrt is not None: grad = grad_x * w_inv_sqrt else: grad = grad_x return f, grad x0 = params.get_free_vector().copy() # Transform initial guess to scaled space if w_sqrt is not None: y0 = x0 * w_sqrt else: y0 = x0 result = _scipy_minimize( objective_and_grad, y0, method="L-BFGS-B", jac=True, options={"maxiter": max_iter, "ftol": tol * tol, "gtol": tol}, ) # Apply final result (transform back from scaled space) if w_inv_sqrt is not None: params.set_free_vector(result.x * w_inv_sqrt) else: params.set_free_vector(result.x) if quat_groups: _renormalize_quats(params, quat_groups) # Check convergence on actual residual norm if compiled_eval is not None: compiled_eval(params.env_ref(), r_vals, J) else: env = params.get_env() for i, r in enumerate(residuals): r_vals[i] = r.eval(env) return bool(np.linalg.norm(r_vals) < tol) def _renormalize_quats( params: ParamTable, groups: List[tuple[str, str, str, str]], ): """Project quaternion params back onto the unit sphere.""" for qw_name, qx_name, qy_name, qz_name in groups: if ( params.is_fixed(qw_name) and params.is_fixed(qx_name) and params.is_fixed(qy_name) and params.is_fixed(qz_name) ): continue w = params.get_value(qw_name) x = params.get_value(qx_name) y = params.get_value(qy_name) z = params.get_value(qz_name) norm = math.sqrt(w * w + x * x + y * y + z * z) if norm < 1e-15: params.set_value(qw_name, 1.0) params.set_value(qx_name, 0.0) params.set_value(qy_name, 0.0) params.set_value(qz_name, 0.0) else: params.set_value(qw_name, w / norm) params.set_value(qx_name, x / norm) params.set_value(qy_name, y / norm) params.set_value(qz_name, z / norm)