From 64b1e24467765a8666d8c2a00bcb625a5ae32057 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Sat, 21 Feb 2026 11:22:36 -0600 Subject: [PATCH] feat(solver): compile symbolic Jacobian to flat Python for fast evaluation Add a code generation pipeline that compiles Expr DAGs into flat Python functions, eliminating recursive tree-walk dispatch in the Newton-Raphson inner loop. Key changes: - Add to_code() method to all 11 Expr node types (expr.py) - New codegen.py module with CSE (common subexpression elimination), sparsity detection, and compile()/exec() compilation pipeline - Add ParamTable.env_ref() to avoid dict copies per iteration (params.py) - Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval to avoid redundant diff/simplify and enable compiled evaluation - count_dof() and diagnostics accept pre-built jac_exprs - solver.py builds symbolic Jacobian once, compiles once, passes to all consumers (_monolithic_solve, count_dof, diagnostics) - Automatic fallback: if codegen fails, tree-walk eval is used Expected performance impact: - ~10-20x faster Jacobian evaluation (no recursive dispatch) - ~2-5x additional from CSE on quaternion-heavy systems - ~3x fewer entries evaluated via sparsity detection - Eliminates redundant diff().simplify() in DOF/diagnostics --- kindred_solver/bfgs.py | 61 ++++-- kindred_solver/codegen.py | 308 +++++++++++++++++++++++++++++ kindred_solver/diagnostics.py | 23 ++- kindred_solver/dof.py | 15 +- kindred_solver/expr.py | 43 ++++ kindred_solver/newton.py | 64 ++++-- kindred_solver/params.py | 8 + kindred_solver/solver.py | 38 +++- tests/test_codegen.py | 357 ++++++++++++++++++++++++++++++++++ 9 files changed, 864 insertions(+), 53 deletions(-) create mode 100644 kindred_solver/codegen.py create mode 100644 tests/test_codegen.py diff --git a/kindred_solver/bfgs.py b/kindred_solver/bfgs.py index c4f0f78..01ee515 100644 --- a/kindred_solver/bfgs.py +++ b/kindred_solver/bfgs.py @@ -7,7 +7,7 @@ analytic gradient from the Expr DAG's symbolic differentiation. from __future__ import annotations import math -from typing import List +from typing import Callable, List import numpy as np @@ -29,6 +29,8 @@ def bfgs_solve( max_iter: int = 200, tol: float = 1e-10, weight_vector: "np.ndarray | None" = None, + jac_exprs: "List[List[Expr]] | None" = None, + compiled_eval: "Callable | None" = None, ) -> bool: """Solve ``residuals == 0`` by minimizing sum of squared residuals. @@ -38,6 +40,13 @@ def bfgs_solve( ``sqrt(w)`` so that the objective becomes ``0.5 * sum(w_i * r_i^2)`` — equivalent to weighted least-squares. + Parameters + ---------- + jac_exprs: + Pre-built symbolic Jacobian (list-of-lists of Expr). + compiled_eval: + Pre-compiled evaluation function from :mod:`codegen`. + Returns True if converged (||r|| < tol). """ if not _HAS_SCIPY: @@ -51,12 +60,19 @@ def bfgs_solve( return True # Build symbolic gradient expressions once: d(r_i)/d(x_j) - jac_exprs: List[List[Expr]] = [] - for r in residuals: - row = [] - for name in free: - row.append(r.diff(name).simplify()) - jac_exprs.append(row) + if jac_exprs is None: + jac_exprs = [] + for r in residuals: + row = [] + for name in free: + row.append(r.diff(name).simplify()) + jac_exprs.append(row) + + # Try compilation if not provided + if compiled_eval is None: + from .codegen import try_compile_system + + compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free) # Pre-compute scaling for weighted minimum-norm if weight_vector is not None: @@ -66,6 +82,10 @@ def bfgs_solve( w_sqrt = None w_inv_sqrt = None + # Pre-allocate arrays reused across objective calls + r_vals = np.empty(n_res) + J = np.zeros((n_res, n_free)) + def objective_and_grad(y_vec): # Transform back from scaled space if weighted if w_inv_sqrt is not None: @@ -78,18 +98,19 @@ def bfgs_solve( if quat_groups: _renormalize_quats(params, quat_groups) - env = params.get_env() + if compiled_eval is not None: + J[:] = 0.0 + compiled_eval(params.env_ref(), r_vals, J) + else: + env = params.get_env() + for i, r in enumerate(residuals): + r_vals[i] = r.eval(env) + for i in range(n_res): + for j in range(n_free): + J[i, j] = jac_exprs[i][j].eval(env) - # Evaluate residuals - r_vals = np.array([r.eval(env) for r in residuals]) f = 0.5 * np.dot(r_vals, r_vals) - # Evaluate Jacobian - J = np.empty((n_res, n_free)) - for i in range(n_res): - for j in range(n_free): - J[i, j] = jac_exprs[i][j].eval(env) - # Gradient of f w.r.t. x = J^T @ r grad_x = J.T @ r_vals @@ -126,8 +147,12 @@ def bfgs_solve( _renormalize_quats(params, quat_groups) # Check convergence on actual residual norm - env = params.get_env() - r_vals = np.array([r.eval(env) for r in residuals]) + if compiled_eval is not None: + compiled_eval(params.env_ref(), r_vals, J) + else: + env = params.get_env() + for i, r in enumerate(residuals): + r_vals[i] = r.eval(env) return bool(np.linalg.norm(r_vals) < tol) diff --git a/kindred_solver/codegen.py b/kindred_solver/codegen.py new file mode 100644 index 0000000..a41c672 --- /dev/null +++ b/kindred_solver/codegen.py @@ -0,0 +1,308 @@ +"""Compile Expr DAGs into flat Python functions for fast evaluation. + +The compilation pipeline: +1. Collect all Expr nodes to be evaluated (residuals + Jacobian entries). +2. Identify common subexpressions (CSE) by ``id()`` — the Expr DAG + already shares node objects via ParamTable's Var instances. +3. Generate a single Python function body that computes CSE temps, + then fills ``r_vec`` and ``J`` arrays in-place. +4. Compile with ``compile()`` + ``exec()`` and return the callable. + +The generated function signature is:: + + fn(env: dict[str, float], r_vec: ndarray, J: ndarray) -> None +""" + +from __future__ import annotations + +import logging +import math +from collections import Counter +from typing import Callable, List + +import numpy as np + +from .expr import Const, Expr, Var + +log = logging.getLogger(__name__) + +# Namespace injected into compiled functions. +_CODEGEN_NS = { + "_sin": math.sin, + "_cos": math.cos, + "_sqrt": math.sqrt, +} + + +# --------------------------------------------------------------------------- +# CSE (Common Subexpression Elimination) +# --------------------------------------------------------------------------- + + +def _collect_nodes(expr: Expr, counts: Counter, visited: set[int]) -> None: + """Walk *expr* and count how many times each node ``id()`` appears.""" + eid = id(expr) + counts[eid] += 1 + if eid in visited: + return + visited.add(eid) + + # Recurse into children + if isinstance(expr, (Const, Var)): + return + if hasattr(expr, "child"): + _collect_nodes(expr.child, counts, visited) + elif hasattr(expr, "a"): + _collect_nodes(expr.a, counts, visited) + _collect_nodes(expr.b, counts, visited) + elif hasattr(expr, "base"): + _collect_nodes(expr.base, counts, visited) + _collect_nodes(expr.exp, counts, visited) + + +def _build_cse( + exprs: list[Expr], +) -> tuple[dict[int, str], list[tuple[str, Expr]]]: + """Identify shared sub-trees and assign temporary variable names. + + Returns: + id_to_temp: mapping from ``id(node)`` to temp variable name + temps_ordered: ``(temp_name, expr)`` pairs in dependency order + """ + counts: Counter = Counter() + visited: set[int] = set() + id_to_expr: dict[int, Expr] = {} + + for expr in exprs: + _collect_nodes(expr, counts, visited) + + # Map id -> Expr for nodes we visited + for expr in exprs: + _map_ids(expr, id_to_expr) + + # Nodes referenced more than once and not trivial (Const/Var) become temps + shared_ids = set() + for eid, cnt in counts.items(): + if cnt > 1: + node = id_to_expr.get(eid) + if node is not None and not isinstance(node, (Const, Var)): + shared_ids.add(eid) + + if not shared_ids: + return {}, [] + + # Topological order: a temp must be computed before any temp that uses it. + # Walk each shared node, collect in post-order. + ordered_ids: list[int] = [] + order_visited: set[int] = set() + + def _topo(expr: Expr) -> None: + eid = id(expr) + if eid in order_visited: + return + order_visited.add(eid) + if isinstance(expr, (Const, Var)): + return + if hasattr(expr, "child"): + _topo(expr.child) + elif hasattr(expr, "a"): + _topo(expr.a) + _topo(expr.b) + elif hasattr(expr, "base"): + _topo(expr.base) + _topo(expr.exp) + if eid in shared_ids: + ordered_ids.append(eid) + + for expr in exprs: + _topo(expr) + + id_to_temp: dict[int, str] = {} + temps_ordered: list[tuple[str, Expr]] = [] + for i, eid in enumerate(ordered_ids): + name = f"_c{i}" + id_to_temp[eid] = name + temps_ordered.append((name, id_to_expr[eid])) + + return id_to_temp, temps_ordered + + +def _map_ids(expr: Expr, mapping: dict[int, Expr]) -> None: + """Populate id -> Expr mapping for all nodes in *expr*.""" + eid = id(expr) + if eid in mapping: + return + mapping[eid] = expr + if isinstance(expr, (Const, Var)): + return + if hasattr(expr, "child"): + _map_ids(expr.child, mapping) + elif hasattr(expr, "a"): + _map_ids(expr.a, mapping) + _map_ids(expr.b, mapping) + elif hasattr(expr, "base"): + _map_ids(expr.base, mapping) + _map_ids(expr.exp, mapping) + + +def _expr_to_code(expr: Expr, id_to_temp: dict[int, str]) -> str: + """Emit code for *expr*, substituting temp names for shared nodes.""" + eid = id(expr) + temp = id_to_temp.get(eid) + if temp is not None: + return temp + return expr.to_code() + + +def _expr_to_code_recursive(expr: Expr, id_to_temp: dict[int, str]) -> str: + """Emit code for *expr*, recursing into children but respecting temps.""" + eid = id(expr) + temp = id_to_temp.get(eid) + if temp is not None: + return temp + + # For leaf nodes, just use to_code() directly + if isinstance(expr, (Const, Var)): + return expr.to_code() + + # For non-leaf nodes, recurse into children with temp substitution + from .expr import Add, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub + + if isinstance(expr, Neg): + return f"(-{_expr_to_code_recursive(expr.child, id_to_temp)})" + if isinstance(expr, Sin): + return f"_sin({_expr_to_code_recursive(expr.child, id_to_temp)})" + if isinstance(expr, Cos): + return f"_cos({_expr_to_code_recursive(expr.child, id_to_temp)})" + if isinstance(expr, Sqrt): + return f"_sqrt({_expr_to_code_recursive(expr.child, id_to_temp)})" + if isinstance(expr, Add): + a = _expr_to_code_recursive(expr.a, id_to_temp) + b = _expr_to_code_recursive(expr.b, id_to_temp) + return f"({a} + {b})" + if isinstance(expr, Sub): + a = _expr_to_code_recursive(expr.a, id_to_temp) + b = _expr_to_code_recursive(expr.b, id_to_temp) + return f"({a} - {b})" + if isinstance(expr, Mul): + a = _expr_to_code_recursive(expr.a, id_to_temp) + b = _expr_to_code_recursive(expr.b, id_to_temp) + return f"({a} * {b})" + if isinstance(expr, Div): + a = _expr_to_code_recursive(expr.a, id_to_temp) + b = _expr_to_code_recursive(expr.b, id_to_temp) + return f"({a} / {b})" + if isinstance(expr, Pow): + base = _expr_to_code_recursive(expr.base, id_to_temp) + exp = _expr_to_code_recursive(expr.exp, id_to_temp) + return f"({base} ** {exp})" + + # Fallback — should not happen for known node types + return expr.to_code() + + +# --------------------------------------------------------------------------- +# Sparsity detection +# --------------------------------------------------------------------------- + + +def _find_nonzero_entries( + jac_exprs: list[list[Expr]], +) -> list[tuple[int, int]]: + """Return ``(row, col)`` pairs for non-zero Jacobian entries.""" + nz = [] + for i, row in enumerate(jac_exprs): + for j, expr in enumerate(row): + if isinstance(expr, Const) and expr.value == 0.0: + continue + nz.append((i, j)) + return nz + + +# --------------------------------------------------------------------------- +# Code generation +# --------------------------------------------------------------------------- + + +def compile_system( + residuals: list[Expr], + jac_exprs: list[list[Expr]], + n_res: int, + n_free: int, +) -> Callable[[dict, np.ndarray, np.ndarray], None]: + """Compile residuals + Jacobian into a single evaluation function. + + Returns a callable ``fn(env, r_vec, J)`` that fills *r_vec* and *J* + in-place. *J* must be pre-zeroed by the caller (only non-zero + entries are written). + """ + # Detect non-zero Jacobian entries + nz_entries = _find_nonzero_entries(jac_exprs) + + # Collect all expressions for CSE analysis + all_exprs: list[Expr] = list(residuals) + nz_jac_exprs: list[Expr] = [jac_exprs[i][j] for i, j in nz_entries] + all_exprs.extend(nz_jac_exprs) + + # CSE + id_to_temp, temps_ordered = _build_cse(all_exprs) + + # Generate function body + lines: list[str] = ["def _eval(env, r_vec, J):"] + + # Temporaries — temporarily remove each temp's own id so its RHS + # is expanded rather than self-referencing. + for temp_name, temp_expr in temps_ordered: + eid = id(temp_expr) + saved = id_to_temp.pop(eid) + code = _expr_to_code_recursive(temp_expr, id_to_temp) + id_to_temp[eid] = saved + lines.append(f" {temp_name} = {code}") + + # Residuals + for i, r in enumerate(residuals): + code = _expr_to_code_recursive(r, id_to_temp) + lines.append(f" r_vec[{i}] = {code}") + + # Jacobian (sparse) + for idx, (i, j) in enumerate(nz_entries): + code = _expr_to_code_recursive(nz_jac_exprs[idx], id_to_temp) + lines.append(f" J[{i}, {j}] = {code}") + + source = "\n".join(lines) + + # Compile + code_obj = compile(source, "", "exec") + ns = dict(_CODEGEN_NS) + exec(code_obj, ns) + + fn = ns["_eval"] + + n_temps = len(temps_ordered) + n_nz = len(nz_entries) + n_total = n_res * n_free + log.debug( + "codegen: compiled %d residuals + %d/%d Jacobian entries, %d CSE temps", + n_res, + n_nz, + n_total, + n_temps, + ) + + return fn + + +def try_compile_system( + residuals: list[Expr], + jac_exprs: list[list[Expr]], + n_res: int, + n_free: int, +) -> Callable[[dict, np.ndarray, np.ndarray], None] | None: + """Compile with automatic fallback. Returns ``None`` on failure.""" + try: + return compile_system(residuals, jac_exprs, n_res, n_free) + except Exception: + log.debug( + "codegen: compilation failed, falling back to tree-walk eval", exc_info=True + ) + return None diff --git a/kindred_solver/diagnostics.py b/kindred_solver/diagnostics.py index 6747b8b..2c567c6 100644 --- a/kindred_solver/diagnostics.py +++ b/kindred_solver/diagnostics.py @@ -32,6 +32,7 @@ def per_entity_dof( params: ParamTable, bodies: dict[str, RigidBody], rank_tol: float = 1e-8, + jac_exprs: "list[list[Expr]] | None" = None, ) -> list[EntityDOF]: """Compute remaining DOF for each non-grounded body. @@ -71,9 +72,14 @@ def per_entity_dof( # Build full Jacobian (for efficiency, compute once) n_free = len(free) J_full = np.empty((n_res, n_free)) - for i, r in enumerate(residuals): - for j, name in enumerate(free): - J_full[i, j] = r.diff(name).simplify().eval(env) + if jac_exprs is not None: + for i in range(n_res): + for j in range(n_free): + J_full[i, j] = jac_exprs[i][j].eval(env) + else: + for i, r in enumerate(residuals): + for j, name in enumerate(free): + J_full[i, j] = r.diff(name).simplify().eval(env) result = [] for pid, body in bodies.items(): @@ -217,6 +223,7 @@ def find_overconstrained( params: ParamTable, residual_ranges: list[tuple[int, int, int]], rank_tol: float = 1e-8, + jac_exprs: "list[list[Expr]] | None" = None, ) -> list[ConstraintDiag]: """Identify redundant and conflicting constraints. @@ -243,8 +250,14 @@ def find_overconstrained( r_vec = np.empty(n_res) for i, r in enumerate(residuals): r_vec[i] = r.eval(env) - for j, name in enumerate(free): - J[i, j] = r.diff(name).simplify().eval(env) + if jac_exprs is not None: + for i in range(n_res): + for j in range(n_free): + J[i, j] = jac_exprs[i][j].eval(env) + else: + for i, r in enumerate(residuals): + for j, name in enumerate(free): + J[i, j] = r.diff(name).simplify().eval(env) # Full rank sv_full = np.linalg.svd(J, compute_uv=False) diff --git a/kindred_solver/dof.py b/kindred_solver/dof.py index 6ec5f8d..745ab0e 100644 --- a/kindred_solver/dof.py +++ b/kindred_solver/dof.py @@ -14,11 +14,15 @@ def count_dof( residuals: List[Expr], params: ParamTable, rank_tol: float = 1e-8, + jac_exprs: "List[List[Expr]] | None" = None, ) -> int: """Compute DOF = n_free_params - rank(Jacobian). Evaluates the Jacobian numerically at the current parameter values and computes its rank via SVD. + + When *jac_exprs* is provided, reuses the pre-built symbolic + Jacobian instead of re-differentiating every residual. """ free = params.free_names() n_free = len(free) @@ -32,9 +36,14 @@ def count_dof( env = params.get_env() J = np.empty((n_res, n_free)) - for i, r in enumerate(residuals): - for j, name in enumerate(free): - J[i, j] = r.diff(name).simplify().eval(env) + if jac_exprs is not None: + for i in range(n_res): + for j in range(n_free): + J[i, j] = jac_exprs[i][j].eval(env) + else: + for i, r in enumerate(residuals): + for j, name in enumerate(free): + J[i, j] = r.diff(name).simplify().eval(env) if J.size == 0: return n_free diff --git a/kindred_solver/expr.py b/kindred_solver/expr.py index 7ed19d1..7296410 100644 --- a/kindred_solver/expr.py +++ b/kindred_solver/expr.py @@ -24,6 +24,16 @@ class Expr: """Return the set of variable names in this expression.""" raise NotImplementedError + def to_code(self) -> str: + """Emit a Python arithmetic expression string. + + The returned string, when evaluated with a dict ``env`` mapping + parameter names to floats (and ``_sin``, ``_cos``, ``_sqrt`` + bound to their ``math`` equivalents), produces the same result + as ``self.eval(env)``. + """ + raise NotImplementedError + # -- operator overloads -------------------------------------------------- def __add__(self, other): @@ -90,6 +100,9 @@ class Const(Expr): def vars(self): return set() + def to_code(self): + return repr(self.value) + def __repr__(self): return f"Const({self.value})" @@ -118,6 +131,9 @@ class Var(Expr): def vars(self): return {self.name} + def to_code(self): + return f"env[{self.name!r}]" + def __repr__(self): return f"Var({self.name!r})" @@ -154,6 +170,9 @@ class Neg(Expr): def vars(self): return self.child.vars() + def to_code(self): + return f"(-{self.child.to_code()})" + def __repr__(self): return f"Neg({self.child!r})" @@ -180,6 +199,9 @@ class Sin(Expr): def vars(self): return self.child.vars() + def to_code(self): + return f"_sin({self.child.to_code()})" + def __repr__(self): return f"Sin({self.child!r})" @@ -206,6 +228,9 @@ class Cos(Expr): def vars(self): return self.child.vars() + def to_code(self): + return f"_cos({self.child.to_code()})" + def __repr__(self): return f"Cos({self.child!r})" @@ -232,6 +257,9 @@ class Sqrt(Expr): def vars(self): return self.child.vars() + def to_code(self): + return f"_sqrt({self.child.to_code()})" + def __repr__(self): return f"Sqrt({self.child!r})" @@ -266,6 +294,9 @@ class Add(Expr): def vars(self): return self.a.vars() | self.b.vars() + def to_code(self): + return f"({self.a.to_code()} + {self.b.to_code()})" + def __repr__(self): return f"Add({self.a!r}, {self.b!r})" @@ -297,6 +328,9 @@ class Sub(Expr): def vars(self): return self.a.vars() | self.b.vars() + def to_code(self): + return f"({self.a.to_code()} - {self.b.to_code()})" + def __repr__(self): return f"Sub({self.a!r}, {self.b!r})" @@ -337,6 +371,9 @@ class Mul(Expr): def vars(self): return self.a.vars() | self.b.vars() + def to_code(self): + return f"({self.a.to_code()} * {self.b.to_code()})" + def __repr__(self): return f"Mul({self.a!r}, {self.b!r})" @@ -372,6 +409,9 @@ class Div(Expr): def vars(self): return self.a.vars() | self.b.vars() + def to_code(self): + return f"({self.a.to_code()} / {self.b.to_code()})" + def __repr__(self): return f"Div({self.a!r}, {self.b!r})" @@ -414,6 +454,9 @@ class Pow(Expr): def vars(self): return self.base.vars() | self.exp.vars() + def to_code(self): + return f"({self.base.to_code()} ** {self.exp.to_code()})" + def __repr__(self): return f"Pow({self.base!r}, {self.exp!r})" diff --git a/kindred_solver/newton.py b/kindred_solver/newton.py index 697d86a..0a77800 100644 --- a/kindred_solver/newton.py +++ b/kindred_solver/newton.py @@ -3,7 +3,7 @@ from __future__ import annotations import math -from typing import List +from typing import Callable, List import numpy as np @@ -19,6 +19,8 @@ def newton_solve( tol: float = 1e-10, post_step: "Callable[[ParamTable], None] | None" = None, weight_vector: "np.ndarray | None" = None, + jac_exprs: "List[List[Expr]] | None" = None, + compiled_eval: "Callable | None" = None, ) -> bool: """Solve ``residuals == 0`` by Newton-Raphson. @@ -43,6 +45,12 @@ def newton_solve( lstsq step is column-scaled to produce the weighted minimum-norm solution (prefer small movements in high-weight parameters). + jac_exprs: + Pre-built symbolic Jacobian (list-of-lists of Expr). When + provided, skips the ``diff().simplify()`` step. + compiled_eval: + Pre-compiled evaluation function from :mod:`codegen`. When + provided, uses flat compiled code instead of tree-walk eval. Returns True if converged within *max_iter* iterations. """ @@ -53,29 +61,41 @@ def newton_solve( if n_free == 0 or n_res == 0: return True - # Build symbolic Jacobian once (list-of-lists of simplified Expr) - jac_exprs: List[List[Expr]] = [] - for r in residuals: - row = [] - for name in free: - row.append(r.diff(name).simplify()) - jac_exprs.append(row) + # Build symbolic Jacobian once (or reuse pre-built) + if jac_exprs is None: + jac_exprs = [] + for r in residuals: + row = [] + for name in free: + row.append(r.diff(name).simplify()) + jac_exprs.append(row) + + # Try compilation if not provided + if compiled_eval is None: + from .codegen import try_compile_system + + compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free) + + # Pre-allocate arrays reused across iterations + r_vec = np.empty(n_res) + J = np.zeros((n_res, n_free)) for _it in range(max_iter): - env = params.get_env() + if compiled_eval is not None: + J[:] = 0.0 + compiled_eval(params.env_ref(), r_vec, J) + else: + env = params.get_env() + for i, r in enumerate(residuals): + r_vec[i] = r.eval(env) + for i in range(n_res): + for j in range(n_free): + J[i, j] = jac_exprs[i][j].eval(env) - # Evaluate residual vector - r_vec = np.array([r.eval(env) for r in residuals]) r_norm = np.linalg.norm(r_vec) if r_norm < tol: return True - # Evaluate Jacobian matrix - J = np.empty((n_res, n_free)) - for i in range(n_res): - for j in range(n_free): - J[i, j] = jac_exprs[i][j].eval(env) - # Solve J @ dx = -r (least-squares handles rank-deficient) if weight_vector is not None: # Column-scale J by W^{-1/2} for weighted minimum-norm @@ -100,9 +120,13 @@ def newton_solve( _renormalize_quats(params, quat_groups) # Check final residual - env = params.get_env() - r_vec = np.array([r.eval(env) for r in residuals]) - return np.linalg.norm(r_vec) < tol + if compiled_eval is not None: + compiled_eval(params.env_ref(), r_vec, J) + else: + env = params.get_env() + for i, r in enumerate(residuals): + r_vec[i] = r.eval(env) + return bool(np.linalg.norm(r_vec) < tol) def _renormalize_quats( diff --git a/kindred_solver/params.py b/kindred_solver/params.py index cbee3aa..c5b2297 100644 --- a/kindred_solver/params.py +++ b/kindred_solver/params.py @@ -60,6 +60,14 @@ class ParamTable: """Return a snapshot of all current values (for Expr.eval).""" return dict(self._values) + def env_ref(self) -> Dict[str, float]: + """Return a direct reference to the internal values dict. + + Faster than :meth:`get_env` (no copy). Safe when the caller + only reads during evaluation and mutates via :meth:`set_free_vector`. + """ + return self._values + def free_names(self) -> List[str]: """Return ordered list of free (non-fixed) parameter names.""" return list(self._free_order) diff --git a/kindred_solver/solver.py b/kindred_solver/solver.py index a1d3297..8de9f1d 100644 --- a/kindred_solver/solver.py +++ b/kindred_solver/solver.py @@ -154,6 +154,7 @@ class KindredSolver(kcsolve.IKCSolver): residuals = single_equation_pass(residuals, system.params) # Solve (decomposed for large assemblies, monolithic for small) + jac_exprs = None # may be populated by _monolithic_solve if n_free_bodies >= _DECOMPOSE_THRESHOLD: grounded_ids = {pid for pid, b in system.bodies.items() if b.grounded} clusters = decompose(ctx.constraints, grounded_ids) @@ -172,7 +173,7 @@ class KindredSolver(kcsolve.IKCSolver): system.params, ) else: - converged = _monolithic_solve( + converged, jac_exprs = _monolithic_solve( residuals, system.params, system.quat_groups, @@ -185,7 +186,7 @@ class KindredSolver(kcsolve.IKCSolver): n_free_bodies, _DECOMPOSE_THRESHOLD, ) - converged = _monolithic_solve( + converged, jac_exprs = _monolithic_solve( residuals, system.params, system.quat_groups, @@ -194,7 +195,7 @@ class KindredSolver(kcsolve.IKCSolver): ) # DOF - dof = count_dof(residuals, system.params) + dof = count_dof(residuals, system.params, jac_exprs=jac_exprs) # Build result result = kcsolve.SolveResult() @@ -210,6 +211,7 @@ class KindredSolver(kcsolve.IKCSolver): system.params, system.residual_ranges, ctx, + jac_exprs=jac_exprs, ) result.placements = _extract_placements(system.params, system.bodies) @@ -419,13 +421,15 @@ def _build_system(ctx): return system -def _run_diagnostics(residuals, params, residual_ranges, ctx): +def _run_diagnostics(residuals, params, residual_ranges, ctx, jac_exprs=None): """Run overconstrained detection and return kcsolve diagnostics.""" diagnostics = [] if not hasattr(kcsolve, "ConstraintDiagnostic"): return diagnostics - diags = find_overconstrained(residuals, params, residual_ranges) + diags = find_overconstrained( + residuals, params, residual_ranges, jac_exprs=jac_exprs + ) for d in diags: cd = kcsolve.ConstraintDiagnostic() cd.constraint_id = ctx.constraints[d.constraint_index].id @@ -458,7 +462,23 @@ def _extract_placements(params, bodies): def _monolithic_solve( all_residuals, params, quat_groups, post_step=None, weight_vector=None ): - """Newton-Raphson solve with BFGS fallback on the full system.""" + """Newton-Raphson solve with BFGS fallback on the full system. + + Returns ``(converged, jac_exprs)`` so the caller can reuse the + symbolic Jacobian for DOF counting / diagnostics. + """ + from .codegen import try_compile_system + + free = params.free_names() + n_res = len(all_residuals) + n_free = len(free) + + # Build symbolic Jacobian once + jac_exprs = [[r.diff(name).simplify() for name in free] for r in all_residuals] + + # Compile once + compiled_eval = try_compile_system(all_residuals, jac_exprs, n_res, n_free) + t0 = time.perf_counter() converged = newton_solve( all_residuals, @@ -468,6 +488,8 @@ def _monolithic_solve( tol=1e-10, post_step=post_step, weight_vector=weight_vector, + jac_exprs=jac_exprs, + compiled_eval=compiled_eval, ) nr_ms = (time.perf_counter() - t0) * 1000 if not converged: @@ -482,6 +504,8 @@ def _monolithic_solve( max_iter=200, tol=1e-10, weight_vector=weight_vector, + jac_exprs=jac_exprs, + compiled_eval=compiled_eval, ) bfgs_ms = (time.perf_counter() - t1) * 1000 if converged: @@ -490,7 +514,7 @@ def _monolithic_solve( log.warning("_monolithic_solve: BFGS also failed (%.1f ms)", bfgs_ms) else: log.debug("_monolithic_solve: Newton-Raphson converged (%.1f ms)", nr_ms) - return converged + return converged, jac_exprs def _build_constraint( diff --git a/tests/test_codegen.py b/tests/test_codegen.py new file mode 100644 index 0000000..c935925 --- /dev/null +++ b/tests/test_codegen.py @@ -0,0 +1,357 @@ +"""Tests for the codegen module — CSE, compilation, and compiled evaluation.""" + +import math + +import numpy as np +import pytest +from kindred_solver.codegen import ( + _build_cse, + _find_nonzero_entries, + compile_system, + try_compile_system, +) +from kindred_solver.expr import ( + ZERO, + Add, + Const, + Cos, + Div, + Mul, + Neg, + Pow, + Sin, + Sqrt, + Sub, + Var, +) +from kindred_solver.newton import newton_solve +from kindred_solver.params import ParamTable + +# --------------------------------------------------------------------------- +# to_code() — round-trip correctness for each Expr type +# --------------------------------------------------------------------------- + + +class TestToCode: + """Verify that eval(expr.to_code()) == expr.eval(env) for each node.""" + + NS = {"_sin": math.sin, "_cos": math.cos, "_sqrt": math.sqrt} + + def _check(self, expr, env): + code = expr.to_code() + ns = dict(self.NS) + ns["env"] = env + compiled = eval(code, ns) + expected = expr.eval(env) + assert abs(compiled - expected) < 1e-15, ( + f"{code} = {compiled}, expected {expected}" + ) + + def test_const(self): + self._check(Const(3.14), {}) + + def test_const_negative(self): + self._check(Const(-2.5), {}) + + def test_const_zero(self): + self._check(Const(0.0), {}) + + def test_var(self): + self._check(Var("x"), {"x": 7.0}) + + def test_neg(self): + self._check(Neg(Var("x")), {"x": 3.0}) + + def test_add(self): + self._check(Add(Var("x"), Const(2.0)), {"x": 5.0}) + + def test_sub(self): + self._check(Sub(Var("x"), Var("y")), {"x": 5.0, "y": 3.0}) + + def test_mul(self): + self._check(Mul(Var("x"), Const(3.0)), {"x": 4.0}) + + def test_div(self): + self._check(Div(Var("x"), Const(2.0)), {"x": 6.0}) + + def test_pow(self): + self._check(Pow(Var("x"), Const(3.0)), {"x": 2.0}) + + def test_sin(self): + self._check(Sin(Var("x")), {"x": 1.0}) + + def test_cos(self): + self._check(Cos(Var("x")), {"x": 1.0}) + + def test_sqrt(self): + self._check(Sqrt(Var("x")), {"x": 9.0}) + + def test_nested(self): + """Complex nested expression.""" + x, y = Var("x"), Var("y") + expr = Add(Mul(Sin(x), Cos(y)), Sqrt(Sub(x, Neg(y)))) + self._check(expr, {"x": 2.0, "y": 1.0}) + + +# --------------------------------------------------------------------------- +# CSE +# --------------------------------------------------------------------------- + + +class TestCSE: + def test_no_sharing(self): + """Distinct expressions produce no CSE temps.""" + a = Var("x") + Const(1.0) + b = Var("y") + Const(2.0) + id_to_temp, temps = _build_cse([a, b]) + assert len(temps) == 0 + + def test_shared_subtree(self): + """Same node object used in two places is extracted.""" + x = Var("x") + shared = x * Const(2.0) # single Mul node + a = shared + Const(1.0) + b = shared + Const(3.0) + id_to_temp, temps = _build_cse([a, b]) + assert len(temps) >= 1 + # The shared Mul node should be a temp + assert id(shared) in id_to_temp + + def test_leaf_nodes_not_extracted(self): + """Const and Var nodes are never extracted as temps.""" + x = Var("x") + c = Const(5.0) + a = x + c + b = x + c + id_to_temp, temps = _build_cse([a, b]) + for _, expr in temps: + assert not isinstance(expr, (Const, Var)) + + def test_dependency_order(self): + """Temps are in dependency order (dependencies first).""" + x = Var("x") + inner = x * Const(2.0) + outer = inner + inner # uses inner twice + wrapper_a = outer * Const(3.0) + wrapper_b = outer * Const(4.0) + id_to_temp, temps = _build_cse([wrapper_a, wrapper_b]) + # If both inner and outer are temps, inner must come first + temp_names = [name for name, _ in temps] + temp_ids = [id(expr) for _, expr in temps] + if id(inner) in set(id_to_temp) and id(outer) in set(id_to_temp): + inner_idx = temp_ids.index(id(inner)) + outer_idx = temp_ids.index(id(outer)) + assert inner_idx < outer_idx + + +# --------------------------------------------------------------------------- +# Sparsity detection +# --------------------------------------------------------------------------- + + +class TestSparsity: + def test_zero_entries_skipped(self): + nz = _find_nonzero_entries( + [ + [Const(0.0), Var("x"), Const(0.0)], + [Const(1.0), Const(0.0), Var("y")], + ] + ) + assert nz == [(0, 1), (1, 0), (1, 2)] + + def test_all_nonzero(self): + nz = _find_nonzero_entries( + [ + [Var("x"), Const(1.0)], + ] + ) + assert nz == [(0, 0), (0, 1)] + + def test_all_zero(self): + nz = _find_nonzero_entries( + [ + [Const(0.0), Const(0.0)], + ] + ) + assert nz == [] + + +# --------------------------------------------------------------------------- +# Full compilation pipeline +# --------------------------------------------------------------------------- + + +class TestCompileSystem: + def test_simple_linear(self): + """Compile and evaluate a trivial system: r = x - 3, J = [[1]].""" + x = Var("x") + residuals = [x - Const(3.0)] + jac_exprs = [[Const(1.0)]] # d(x-3)/dx = 1 + + fn = compile_system(residuals, jac_exprs, 1, 1) + + env = {"x": 5.0} + r_vec = np.empty(1) + J = np.zeros((1, 1)) + fn(env, r_vec, J) + + assert abs(r_vec[0] - 2.0) < 1e-15 # 5 - 3 = 2 + assert abs(J[0, 0] - 1.0) < 1e-15 + + def test_two_variable_system(self): + """Compile: r0 = x + y - 5, r1 = x - y - 1.""" + x, y = Var("x"), Var("y") + residuals = [x + y - Const(5.0), x - y - Const(1.0)] + jac_exprs = [ + [Const(1.0), Const(1.0)], # d(r0)/dx, d(r0)/dy + [Const(1.0), Const(-1.0)], # d(r1)/dx, d(r1)/dy + ] + + fn = compile_system(residuals, jac_exprs, 2, 2) + + env = {"x": 3.0, "y": 2.0} + r_vec = np.empty(2) + J = np.zeros((2, 2)) + fn(env, r_vec, J) + + assert abs(r_vec[0] - 0.0) < 1e-15 + assert abs(r_vec[1] - 0.0) < 1e-15 + assert abs(J[0, 0] - 1.0) < 1e-15 + assert abs(J[0, 1] - 1.0) < 1e-15 + assert abs(J[1, 0] - 1.0) < 1e-15 + assert abs(J[1, 1] - (-1.0)) < 1e-15 + + def test_sparse_jacobian(self): + """Zero Jacobian entries remain zero after compiled evaluation.""" + x = Var("x") + y = Var("y") + # r0 depends on x only, r1 depends on y only + residuals = [x - Const(1.0), y - Const(2.0)] + jac_exprs = [ + [Const(1.0), Const(0.0)], + [Const(0.0), Const(1.0)], + ] + + fn = compile_system(residuals, jac_exprs, 2, 2) + + env = {"x": 3.0, "y": 4.0} + r_vec = np.empty(2) + J = np.zeros((2, 2)) + fn(env, r_vec, J) + + assert abs(J[0, 1]) < 1e-15 # should remain zero + assert abs(J[1, 0]) < 1e-15 # should remain zero + assert abs(J[0, 0] - 1.0) < 1e-15 + assert abs(J[1, 1] - 1.0) < 1e-15 + + def test_trig_functions(self): + """Compiled evaluation handles Sin/Cos/Sqrt.""" + x = Var("x") + residuals = [Sin(x), Cos(x), Sqrt(x)] + jac_exprs = [ + [Cos(x)], + [Neg(Sin(x))], + [Div(Const(1.0), Mul(Const(2.0), Sqrt(x)))], + ] + + fn = compile_system(residuals, jac_exprs, 3, 1) + + env = {"x": 1.0} + r_vec = np.empty(3) + J = np.zeros((3, 1)) + fn(env, r_vec, J) + + assert abs(r_vec[0] - math.sin(1.0)) < 1e-15 + assert abs(r_vec[1] - math.cos(1.0)) < 1e-15 + assert abs(r_vec[2] - math.sqrt(1.0)) < 1e-15 + assert abs(J[0, 0] - math.cos(1.0)) < 1e-15 + assert abs(J[1, 0] - (-math.sin(1.0))) < 1e-15 + assert abs(J[2, 0] - (1.0 / (2.0 * math.sqrt(1.0)))) < 1e-15 + + def test_matches_tree_walk(self): + """Compiled eval produces identical results to tree-walk eval.""" + pt = ParamTable() + x = pt.add("x", 2.0) + y = pt.add("y", 3.0) + + residuals = [x * y - Const(6.0), x * x + y - Const(7.0)] + free = pt.free_names() + + jac_exprs = [[r.diff(name).simplify() for name in free] for r in residuals] + + fn = compile_system(residuals, jac_exprs, 2, 2) + + # Tree-walk eval + env = pt.get_env() + r_tree = np.array([r.eval(env) for r in residuals]) + J_tree = np.empty((2, 2)) + for i in range(2): + for j in range(2): + J_tree[i, j] = jac_exprs[i][j].eval(env) + + # Compiled eval + r_comp = np.empty(2) + J_comp = np.zeros((2, 2)) + fn(pt.env_ref(), r_comp, J_comp) + + np.testing.assert_allclose(r_comp, r_tree, atol=1e-15) + np.testing.assert_allclose(J_comp, J_tree, atol=1e-15) + + +class TestTryCompile: + def test_returns_callable(self): + x = Var("x") + fn = try_compile_system([x], [[Const(1.0)]], 1, 1) + assert fn is not None + + def test_empty_system(self): + """Empty system returns None (nothing to compile).""" + fn = try_compile_system([], [], 0, 0) + # Empty system is handled by the solver before codegen is reached, + # so returning None is acceptable. + assert fn is None or callable(fn) + + +# --------------------------------------------------------------------------- +# Integration: Newton with compiled eval matches tree-walk +# --------------------------------------------------------------------------- + + +class TestCompiledNewton: + def test_single_linear(self): + """Solve x - 3 = 0 with compiled eval.""" + pt = ParamTable() + x = pt.add("x", 0.0) + residuals = [x - Const(3.0)] + assert newton_solve(residuals, pt) is True + assert abs(pt.get_value("x") - 3.0) < 1e-10 + + def test_two_variables(self): + """Solve x + y = 5, x - y = 1 with compiled eval.""" + pt = ParamTable() + x = pt.add("x", 0.0) + y = pt.add("y", 0.0) + residuals = [x + y - Const(5.0), x - y - Const(1.0)] + assert newton_solve(residuals, pt) is True + assert abs(pt.get_value("x") - 3.0) < 1e-10 + assert abs(pt.get_value("y") - 2.0) < 1e-10 + + def test_quadratic(self): + """Solve x^2 - 4 = 0 starting from x=1.""" + pt = ParamTable() + x = pt.add("x", 1.0) + residuals = [x * x - Const(4.0)] + assert newton_solve(residuals, pt) is True + assert abs(pt.get_value("x") - 2.0) < 1e-10 + + def test_nonlinear_system(self): + """Compiled eval converges for a nonlinear system: xy=6, x+y=5.""" + pt = ParamTable() + x = pt.add("x", 2.0) + y = pt.add("y", 3.5) + residuals = [x * y - Const(6.0), x + y - Const(5.0)] + assert newton_solve(residuals, pt, max_iter=100) is True + # Solutions are (2, 3) or (3, 2) — check they satisfy both equations + xv, yv = pt.get_value("x"), pt.get_value("y") + assert abs(xv * yv - 6.0) < 1e-10 + assert abs(xv + yv - 5.0) < 1e-10