Add a code generation pipeline that compiles Expr DAGs into flat Python functions, eliminating recursive tree-walk dispatch in the Newton-Raphson inner loop. Key changes: - Add to_code() method to all 11 Expr node types (expr.py) - New codegen.py module with CSE (common subexpression elimination), sparsity detection, and compile()/exec() compilation pipeline - Add ParamTable.env_ref() to avoid dict copies per iteration (params.py) - Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval to avoid redundant diff/simplify and enable compiled evaluation - count_dof() and diagnostics accept pre-built jac_exprs - solver.py builds symbolic Jacobian once, compiles once, passes to all consumers (_monolithic_solve, count_dof, diagnostics) - Automatic fallback: if codegen fails, tree-walk eval is used Expected performance impact: - ~10-20x faster Jacobian evaluation (no recursive dispatch) - ~2-5x additional from CSE on quaternion-heavy systems - ~3x fewer entries evaluated via sparsity detection - Eliminates redundant diff().simplify() in DOF/diagnostics
162 lines
5.2 KiB
Python
162 lines
5.2 KiB
Python
"""Newton-Raphson solver with symbolic Jacobian and numpy linear algebra."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import math
|
|
from typing import Callable, List
|
|
|
|
import numpy as np
|
|
|
|
from .expr import Expr
|
|
from .params import ParamTable
|
|
|
|
|
|
def newton_solve(
|
|
residuals: List[Expr],
|
|
params: ParamTable,
|
|
quat_groups: List[tuple[str, str, str, str]] | None = None,
|
|
max_iter: int = 50,
|
|
tol: float = 1e-10,
|
|
post_step: "Callable[[ParamTable], None] | None" = None,
|
|
weight_vector: "np.ndarray | None" = None,
|
|
jac_exprs: "List[List[Expr]] | None" = None,
|
|
compiled_eval: "Callable | None" = None,
|
|
) -> bool:
|
|
"""Solve ``residuals == 0`` by Newton-Raphson.
|
|
|
|
Parameters
|
|
----------
|
|
residuals:
|
|
List of Expr that should each evaluate to zero.
|
|
params:
|
|
Parameter table with current values as initial guess.
|
|
quat_groups:
|
|
List of (qw, qx, qy, qz) parameter name tuples.
|
|
After each Newton step these are re-normalized to unit length.
|
|
max_iter:
|
|
Maximum Newton iterations.
|
|
tol:
|
|
Convergence threshold on ``||r||``.
|
|
post_step:
|
|
Optional callback invoked after each parameter update, before
|
|
quaternion renormalization. Used for half-space correction.
|
|
weight_vector:
|
|
Optional 1-D array of length ``n_free``. When provided, the
|
|
lstsq step is column-scaled to produce the weighted
|
|
minimum-norm solution (prefer small movements in
|
|
high-weight parameters).
|
|
jac_exprs:
|
|
Pre-built symbolic Jacobian (list-of-lists of Expr). When
|
|
provided, skips the ``diff().simplify()`` step.
|
|
compiled_eval:
|
|
Pre-compiled evaluation function from :mod:`codegen`. When
|
|
provided, uses flat compiled code instead of tree-walk eval.
|
|
|
|
Returns True if converged within *max_iter* iterations.
|
|
"""
|
|
free = params.free_names()
|
|
n_free = len(free)
|
|
n_res = len(residuals)
|
|
|
|
if n_free == 0 or n_res == 0:
|
|
return True
|
|
|
|
# Build symbolic Jacobian once (or reuse pre-built)
|
|
if jac_exprs is None:
|
|
jac_exprs = []
|
|
for r in residuals:
|
|
row = []
|
|
for name in free:
|
|
row.append(r.diff(name).simplify())
|
|
jac_exprs.append(row)
|
|
|
|
# Try compilation if not provided
|
|
if compiled_eval is None:
|
|
from .codegen import try_compile_system
|
|
|
|
compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free)
|
|
|
|
# Pre-allocate arrays reused across iterations
|
|
r_vec = np.empty(n_res)
|
|
J = np.zeros((n_res, n_free))
|
|
|
|
for _it in range(max_iter):
|
|
if compiled_eval is not None:
|
|
J[:] = 0.0
|
|
compiled_eval(params.env_ref(), r_vec, J)
|
|
else:
|
|
env = params.get_env()
|
|
for i, r in enumerate(residuals):
|
|
r_vec[i] = r.eval(env)
|
|
for i in range(n_res):
|
|
for j in range(n_free):
|
|
J[i, j] = jac_exprs[i][j].eval(env)
|
|
|
|
r_norm = np.linalg.norm(r_vec)
|
|
if r_norm < tol:
|
|
return True
|
|
|
|
# Solve J @ dx = -r (least-squares handles rank-deficient)
|
|
if weight_vector is not None:
|
|
# Column-scale J by W^{-1/2} for weighted minimum-norm
|
|
w_inv_sqrt = 1.0 / np.sqrt(weight_vector)
|
|
J_scaled = J * w_inv_sqrt[np.newaxis, :]
|
|
dx_scaled, _, _, _ = np.linalg.lstsq(J_scaled, -r_vec, rcond=None)
|
|
dx = dx_scaled * w_inv_sqrt
|
|
else:
|
|
dx, _, _, _ = np.linalg.lstsq(J, -r_vec, rcond=None)
|
|
|
|
# Update parameters
|
|
x = params.get_free_vector()
|
|
x += dx
|
|
params.set_free_vector(x)
|
|
|
|
# Half-space correction (before quat renormalization)
|
|
if post_step:
|
|
post_step(params)
|
|
|
|
# Re-normalize quaternions
|
|
if quat_groups:
|
|
_renormalize_quats(params, quat_groups)
|
|
|
|
# Check final residual
|
|
if compiled_eval is not None:
|
|
compiled_eval(params.env_ref(), r_vec, J)
|
|
else:
|
|
env = params.get_env()
|
|
for i, r in enumerate(residuals):
|
|
r_vec[i] = r.eval(env)
|
|
return bool(np.linalg.norm(r_vec) < tol)
|
|
|
|
|
|
def _renormalize_quats(
|
|
params: ParamTable,
|
|
groups: List[tuple[str, str, str, str]],
|
|
):
|
|
"""Project quaternion params back onto the unit sphere."""
|
|
for qw_name, qx_name, qy_name, qz_name in groups:
|
|
# Skip if all components are fixed
|
|
if (
|
|
params.is_fixed(qw_name)
|
|
and params.is_fixed(qx_name)
|
|
and params.is_fixed(qy_name)
|
|
and params.is_fixed(qz_name)
|
|
):
|
|
continue
|
|
w = params.get_value(qw_name)
|
|
x = params.get_value(qx_name)
|
|
y = params.get_value(qy_name)
|
|
z = params.get_value(qz_name)
|
|
norm = math.sqrt(w * w + x * x + y * y + z * z)
|
|
if norm < 1e-15:
|
|
# Degenerate — reset to identity
|
|
params.set_value(qw_name, 1.0)
|
|
params.set_value(qx_name, 0.0)
|
|
params.set_value(qy_name, 0.0)
|
|
params.set_value(qz_name, 0.0)
|
|
else:
|
|
params.set_value(qw_name, w / norm)
|
|
params.set_value(qx_name, x / norm)
|
|
params.set_value(qy_name, y / norm)
|
|
params.set_value(qz_name, z / norm)
|