feat(solver): compile symbolic Jacobian to flat Python for fast evaluation

Add a code generation pipeline that compiles Expr DAGs into flat Python
functions, eliminating recursive tree-walk dispatch in the Newton-Raphson
inner loop.

Key changes:
- Add to_code() method to all 11 Expr node types (expr.py)
- New codegen.py module with CSE (common subexpression elimination),
  sparsity detection, and compile()/exec() compilation pipeline
- Add ParamTable.env_ref() to avoid dict copies per iteration (params.py)
- Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval
  to avoid redundant diff/simplify and enable compiled evaluation
- count_dof() and diagnostics accept pre-built jac_exprs
- solver.py builds symbolic Jacobian once, compiles once, passes to all
  consumers (_monolithic_solve, count_dof, diagnostics)
- Automatic fallback: if codegen fails, tree-walk eval is used

Expected performance impact:
- ~10-20x faster Jacobian evaluation (no recursive dispatch)
- ~2-5x additional from CSE on quaternion-heavy systems
- ~3x fewer entries evaluated via sparsity detection
- Eliminates redundant diff().simplify() in DOF/diagnostics
This commit is contained in:
forbes-0023
2026-02-21 11:22:36 -06:00
parent d20b38e760
commit 64b1e24467
9 changed files with 864 additions and 53 deletions

View File

@@ -7,7 +7,7 @@ analytic gradient from the Expr DAG's symbolic differentiation.
from __future__ import annotations
import math
from typing import List
from typing import Callable, List
import numpy as np
@@ -29,6 +29,8 @@ def bfgs_solve(
max_iter: int = 200,
tol: float = 1e-10,
weight_vector: "np.ndarray | None" = None,
jac_exprs: "List[List[Expr]] | None" = None,
compiled_eval: "Callable | None" = None,
) -> bool:
"""Solve ``residuals == 0`` by minimizing sum of squared residuals.
@@ -38,6 +40,13 @@ def bfgs_solve(
``sqrt(w)`` so that the objective becomes
``0.5 * sum(w_i * r_i^2)`` — equivalent to weighted least-squares.
Parameters
----------
jac_exprs:
Pre-built symbolic Jacobian (list-of-lists of Expr).
compiled_eval:
Pre-compiled evaluation function from :mod:`codegen`.
Returns True if converged (||r|| < tol).
"""
if not _HAS_SCIPY:
@@ -51,12 +60,19 @@ def bfgs_solve(
return True
# Build symbolic gradient expressions once: d(r_i)/d(x_j)
jac_exprs: List[List[Expr]] = []
for r in residuals:
row = []
for name in free:
row.append(r.diff(name).simplify())
jac_exprs.append(row)
if jac_exprs is None:
jac_exprs = []
for r in residuals:
row = []
for name in free:
row.append(r.diff(name).simplify())
jac_exprs.append(row)
# Try compilation if not provided
if compiled_eval is None:
from .codegen import try_compile_system
compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free)
# Pre-compute scaling for weighted minimum-norm
if weight_vector is not None:
@@ -66,6 +82,10 @@ def bfgs_solve(
w_sqrt = None
w_inv_sqrt = None
# Pre-allocate arrays reused across objective calls
r_vals = np.empty(n_res)
J = np.zeros((n_res, n_free))
def objective_and_grad(y_vec):
# Transform back from scaled space if weighted
if w_inv_sqrt is not None:
@@ -78,18 +98,19 @@ def bfgs_solve(
if quat_groups:
_renormalize_quats(params, quat_groups)
env = params.get_env()
if compiled_eval is not None:
J[:] = 0.0
compiled_eval(params.env_ref(), r_vals, J)
else:
env = params.get_env()
for i, r in enumerate(residuals):
r_vals[i] = r.eval(env)
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
# Evaluate residuals
r_vals = np.array([r.eval(env) for r in residuals])
f = 0.5 * np.dot(r_vals, r_vals)
# Evaluate Jacobian
J = np.empty((n_res, n_free))
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
# Gradient of f w.r.t. x = J^T @ r
grad_x = J.T @ r_vals
@@ -126,8 +147,12 @@ def bfgs_solve(
_renormalize_quats(params, quat_groups)
# Check convergence on actual residual norm
env = params.get_env()
r_vals = np.array([r.eval(env) for r in residuals])
if compiled_eval is not None:
compiled_eval(params.env_ref(), r_vals, J)
else:
env = params.get_env()
for i, r in enumerate(residuals):
r_vals[i] = r.eval(env)
return bool(np.linalg.norm(r_vals) < tol)

308
kindred_solver/codegen.py Normal file
View File

@@ -0,0 +1,308 @@
"""Compile Expr DAGs into flat Python functions for fast evaluation.
The compilation pipeline:
1. Collect all Expr nodes to be evaluated (residuals + Jacobian entries).
2. Identify common subexpressions (CSE) by ``id()`` — the Expr DAG
already shares node objects via ParamTable's Var instances.
3. Generate a single Python function body that computes CSE temps,
then fills ``r_vec`` and ``J`` arrays in-place.
4. Compile with ``compile()`` + ``exec()`` and return the callable.
The generated function signature is::
fn(env: dict[str, float], r_vec: ndarray, J: ndarray) -> None
"""
from __future__ import annotations
import logging
import math
from collections import Counter
from typing import Callable, List
import numpy as np
from .expr import Const, Expr, Var
log = logging.getLogger(__name__)
# Namespace injected into compiled functions.
_CODEGEN_NS = {
"_sin": math.sin,
"_cos": math.cos,
"_sqrt": math.sqrt,
}
# ---------------------------------------------------------------------------
# CSE (Common Subexpression Elimination)
# ---------------------------------------------------------------------------
def _collect_nodes(expr: Expr, counts: Counter, visited: set[int]) -> None:
"""Walk *expr* and count how many times each node ``id()`` appears."""
eid = id(expr)
counts[eid] += 1
if eid in visited:
return
visited.add(eid)
# Recurse into children
if isinstance(expr, (Const, Var)):
return
if hasattr(expr, "child"):
_collect_nodes(expr.child, counts, visited)
elif hasattr(expr, "a"):
_collect_nodes(expr.a, counts, visited)
_collect_nodes(expr.b, counts, visited)
elif hasattr(expr, "base"):
_collect_nodes(expr.base, counts, visited)
_collect_nodes(expr.exp, counts, visited)
def _build_cse(
exprs: list[Expr],
) -> tuple[dict[int, str], list[tuple[str, Expr]]]:
"""Identify shared sub-trees and assign temporary variable names.
Returns:
id_to_temp: mapping from ``id(node)`` to temp variable name
temps_ordered: ``(temp_name, expr)`` pairs in dependency order
"""
counts: Counter = Counter()
visited: set[int] = set()
id_to_expr: dict[int, Expr] = {}
for expr in exprs:
_collect_nodes(expr, counts, visited)
# Map id -> Expr for nodes we visited
for expr in exprs:
_map_ids(expr, id_to_expr)
# Nodes referenced more than once and not trivial (Const/Var) become temps
shared_ids = set()
for eid, cnt in counts.items():
if cnt > 1:
node = id_to_expr.get(eid)
if node is not None and not isinstance(node, (Const, Var)):
shared_ids.add(eid)
if not shared_ids:
return {}, []
# Topological order: a temp must be computed before any temp that uses it.
# Walk each shared node, collect in post-order.
ordered_ids: list[int] = []
order_visited: set[int] = set()
def _topo(expr: Expr) -> None:
eid = id(expr)
if eid in order_visited:
return
order_visited.add(eid)
if isinstance(expr, (Const, Var)):
return
if hasattr(expr, "child"):
_topo(expr.child)
elif hasattr(expr, "a"):
_topo(expr.a)
_topo(expr.b)
elif hasattr(expr, "base"):
_topo(expr.base)
_topo(expr.exp)
if eid in shared_ids:
ordered_ids.append(eid)
for expr in exprs:
_topo(expr)
id_to_temp: dict[int, str] = {}
temps_ordered: list[tuple[str, Expr]] = []
for i, eid in enumerate(ordered_ids):
name = f"_c{i}"
id_to_temp[eid] = name
temps_ordered.append((name, id_to_expr[eid]))
return id_to_temp, temps_ordered
def _map_ids(expr: Expr, mapping: dict[int, Expr]) -> None:
"""Populate id -> Expr mapping for all nodes in *expr*."""
eid = id(expr)
if eid in mapping:
return
mapping[eid] = expr
if isinstance(expr, (Const, Var)):
return
if hasattr(expr, "child"):
_map_ids(expr.child, mapping)
elif hasattr(expr, "a"):
_map_ids(expr.a, mapping)
_map_ids(expr.b, mapping)
elif hasattr(expr, "base"):
_map_ids(expr.base, mapping)
_map_ids(expr.exp, mapping)
def _expr_to_code(expr: Expr, id_to_temp: dict[int, str]) -> str:
"""Emit code for *expr*, substituting temp names for shared nodes."""
eid = id(expr)
temp = id_to_temp.get(eid)
if temp is not None:
return temp
return expr.to_code()
def _expr_to_code_recursive(expr: Expr, id_to_temp: dict[int, str]) -> str:
"""Emit code for *expr*, recursing into children but respecting temps."""
eid = id(expr)
temp = id_to_temp.get(eid)
if temp is not None:
return temp
# For leaf nodes, just use to_code() directly
if isinstance(expr, (Const, Var)):
return expr.to_code()
# For non-leaf nodes, recurse into children with temp substitution
from .expr import Add, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub
if isinstance(expr, Neg):
return f"(-{_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Sin):
return f"_sin({_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Cos):
return f"_cos({_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Sqrt):
return f"_sqrt({_expr_to_code_recursive(expr.child, id_to_temp)})"
if isinstance(expr, Add):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} + {b})"
if isinstance(expr, Sub):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} - {b})"
if isinstance(expr, Mul):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} * {b})"
if isinstance(expr, Div):
a = _expr_to_code_recursive(expr.a, id_to_temp)
b = _expr_to_code_recursive(expr.b, id_to_temp)
return f"({a} / {b})"
if isinstance(expr, Pow):
base = _expr_to_code_recursive(expr.base, id_to_temp)
exp = _expr_to_code_recursive(expr.exp, id_to_temp)
return f"({base} ** {exp})"
# Fallback — should not happen for known node types
return expr.to_code()
# ---------------------------------------------------------------------------
# Sparsity detection
# ---------------------------------------------------------------------------
def _find_nonzero_entries(
jac_exprs: list[list[Expr]],
) -> list[tuple[int, int]]:
"""Return ``(row, col)`` pairs for non-zero Jacobian entries."""
nz = []
for i, row in enumerate(jac_exprs):
for j, expr in enumerate(row):
if isinstance(expr, Const) and expr.value == 0.0:
continue
nz.append((i, j))
return nz
# ---------------------------------------------------------------------------
# Code generation
# ---------------------------------------------------------------------------
def compile_system(
residuals: list[Expr],
jac_exprs: list[list[Expr]],
n_res: int,
n_free: int,
) -> Callable[[dict, np.ndarray, np.ndarray], None]:
"""Compile residuals + Jacobian into a single evaluation function.
Returns a callable ``fn(env, r_vec, J)`` that fills *r_vec* and *J*
in-place. *J* must be pre-zeroed by the caller (only non-zero
entries are written).
"""
# Detect non-zero Jacobian entries
nz_entries = _find_nonzero_entries(jac_exprs)
# Collect all expressions for CSE analysis
all_exprs: list[Expr] = list(residuals)
nz_jac_exprs: list[Expr] = [jac_exprs[i][j] for i, j in nz_entries]
all_exprs.extend(nz_jac_exprs)
# CSE
id_to_temp, temps_ordered = _build_cse(all_exprs)
# Generate function body
lines: list[str] = ["def _eval(env, r_vec, J):"]
# Temporaries — temporarily remove each temp's own id so its RHS
# is expanded rather than self-referencing.
for temp_name, temp_expr in temps_ordered:
eid = id(temp_expr)
saved = id_to_temp.pop(eid)
code = _expr_to_code_recursive(temp_expr, id_to_temp)
id_to_temp[eid] = saved
lines.append(f" {temp_name} = {code}")
# Residuals
for i, r in enumerate(residuals):
code = _expr_to_code_recursive(r, id_to_temp)
lines.append(f" r_vec[{i}] = {code}")
# Jacobian (sparse)
for idx, (i, j) in enumerate(nz_entries):
code = _expr_to_code_recursive(nz_jac_exprs[idx], id_to_temp)
lines.append(f" J[{i}, {j}] = {code}")
source = "\n".join(lines)
# Compile
code_obj = compile(source, "<kindred_codegen>", "exec")
ns = dict(_CODEGEN_NS)
exec(code_obj, ns)
fn = ns["_eval"]
n_temps = len(temps_ordered)
n_nz = len(nz_entries)
n_total = n_res * n_free
log.debug(
"codegen: compiled %d residuals + %d/%d Jacobian entries, %d CSE temps",
n_res,
n_nz,
n_total,
n_temps,
)
return fn
def try_compile_system(
residuals: list[Expr],
jac_exprs: list[list[Expr]],
n_res: int,
n_free: int,
) -> Callable[[dict, np.ndarray, np.ndarray], None] | None:
"""Compile with automatic fallback. Returns ``None`` on failure."""
try:
return compile_system(residuals, jac_exprs, n_res, n_free)
except Exception:
log.debug(
"codegen: compilation failed, falling back to tree-walk eval", exc_info=True
)
return None

View File

@@ -32,6 +32,7 @@ def per_entity_dof(
params: ParamTable,
bodies: dict[str, RigidBody],
rank_tol: float = 1e-8,
jac_exprs: "list[list[Expr]] | None" = None,
) -> list[EntityDOF]:
"""Compute remaining DOF for each non-grounded body.
@@ -71,9 +72,14 @@ def per_entity_dof(
# Build full Jacobian (for efficiency, compute once)
n_free = len(free)
J_full = np.empty((n_res, n_free))
for i, r in enumerate(residuals):
for j, name in enumerate(free):
J_full[i, j] = r.diff(name).simplify().eval(env)
if jac_exprs is not None:
for i in range(n_res):
for j in range(n_free):
J_full[i, j] = jac_exprs[i][j].eval(env)
else:
for i, r in enumerate(residuals):
for j, name in enumerate(free):
J_full[i, j] = r.diff(name).simplify().eval(env)
result = []
for pid, body in bodies.items():
@@ -217,6 +223,7 @@ def find_overconstrained(
params: ParamTable,
residual_ranges: list[tuple[int, int, int]],
rank_tol: float = 1e-8,
jac_exprs: "list[list[Expr]] | None" = None,
) -> list[ConstraintDiag]:
"""Identify redundant and conflicting constraints.
@@ -243,8 +250,14 @@ def find_overconstrained(
r_vec = np.empty(n_res)
for i, r in enumerate(residuals):
r_vec[i] = r.eval(env)
for j, name in enumerate(free):
J[i, j] = r.diff(name).simplify().eval(env)
if jac_exprs is not None:
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
else:
for i, r in enumerate(residuals):
for j, name in enumerate(free):
J[i, j] = r.diff(name).simplify().eval(env)
# Full rank
sv_full = np.linalg.svd(J, compute_uv=False)

View File

@@ -14,11 +14,15 @@ def count_dof(
residuals: List[Expr],
params: ParamTable,
rank_tol: float = 1e-8,
jac_exprs: "List[List[Expr]] | None" = None,
) -> int:
"""Compute DOF = n_free_params - rank(Jacobian).
Evaluates the Jacobian numerically at the current parameter values
and computes its rank via SVD.
When *jac_exprs* is provided, reuses the pre-built symbolic
Jacobian instead of re-differentiating every residual.
"""
free = params.free_names()
n_free = len(free)
@@ -32,9 +36,14 @@ def count_dof(
env = params.get_env()
J = np.empty((n_res, n_free))
for i, r in enumerate(residuals):
for j, name in enumerate(free):
J[i, j] = r.diff(name).simplify().eval(env)
if jac_exprs is not None:
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
else:
for i, r in enumerate(residuals):
for j, name in enumerate(free):
J[i, j] = r.diff(name).simplify().eval(env)
if J.size == 0:
return n_free

View File

@@ -24,6 +24,16 @@ class Expr:
"""Return the set of variable names in this expression."""
raise NotImplementedError
def to_code(self) -> str:
"""Emit a Python arithmetic expression string.
The returned string, when evaluated with a dict ``env`` mapping
parameter names to floats (and ``_sin``, ``_cos``, ``_sqrt``
bound to their ``math`` equivalents), produces the same result
as ``self.eval(env)``.
"""
raise NotImplementedError
# -- operator overloads --------------------------------------------------
def __add__(self, other):
@@ -90,6 +100,9 @@ class Const(Expr):
def vars(self):
return set()
def to_code(self):
return repr(self.value)
def __repr__(self):
return f"Const({self.value})"
@@ -118,6 +131,9 @@ class Var(Expr):
def vars(self):
return {self.name}
def to_code(self):
return f"env[{self.name!r}]"
def __repr__(self):
return f"Var({self.name!r})"
@@ -154,6 +170,9 @@ class Neg(Expr):
def vars(self):
return self.child.vars()
def to_code(self):
return f"(-{self.child.to_code()})"
def __repr__(self):
return f"Neg({self.child!r})"
@@ -180,6 +199,9 @@ class Sin(Expr):
def vars(self):
return self.child.vars()
def to_code(self):
return f"_sin({self.child.to_code()})"
def __repr__(self):
return f"Sin({self.child!r})"
@@ -206,6 +228,9 @@ class Cos(Expr):
def vars(self):
return self.child.vars()
def to_code(self):
return f"_cos({self.child.to_code()})"
def __repr__(self):
return f"Cos({self.child!r})"
@@ -232,6 +257,9 @@ class Sqrt(Expr):
def vars(self):
return self.child.vars()
def to_code(self):
return f"_sqrt({self.child.to_code()})"
def __repr__(self):
return f"Sqrt({self.child!r})"
@@ -266,6 +294,9 @@ class Add(Expr):
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} + {self.b.to_code()})"
def __repr__(self):
return f"Add({self.a!r}, {self.b!r})"
@@ -297,6 +328,9 @@ class Sub(Expr):
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} - {self.b.to_code()})"
def __repr__(self):
return f"Sub({self.a!r}, {self.b!r})"
@@ -337,6 +371,9 @@ class Mul(Expr):
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} * {self.b.to_code()})"
def __repr__(self):
return f"Mul({self.a!r}, {self.b!r})"
@@ -372,6 +409,9 @@ class Div(Expr):
def vars(self):
return self.a.vars() | self.b.vars()
def to_code(self):
return f"({self.a.to_code()} / {self.b.to_code()})"
def __repr__(self):
return f"Div({self.a!r}, {self.b!r})"
@@ -414,6 +454,9 @@ class Pow(Expr):
def vars(self):
return self.base.vars() | self.exp.vars()
def to_code(self):
return f"({self.base.to_code()} ** {self.exp.to_code()})"
def __repr__(self):
return f"Pow({self.base!r}, {self.exp!r})"

View File

@@ -3,7 +3,7 @@
from __future__ import annotations
import math
from typing import List
from typing import Callable, List
import numpy as np
@@ -19,6 +19,8 @@ def newton_solve(
tol: float = 1e-10,
post_step: "Callable[[ParamTable], None] | None" = None,
weight_vector: "np.ndarray | None" = None,
jac_exprs: "List[List[Expr]] | None" = None,
compiled_eval: "Callable | None" = None,
) -> bool:
"""Solve ``residuals == 0`` by Newton-Raphson.
@@ -43,6 +45,12 @@ def newton_solve(
lstsq step is column-scaled to produce the weighted
minimum-norm solution (prefer small movements in
high-weight parameters).
jac_exprs:
Pre-built symbolic Jacobian (list-of-lists of Expr). When
provided, skips the ``diff().simplify()`` step.
compiled_eval:
Pre-compiled evaluation function from :mod:`codegen`. When
provided, uses flat compiled code instead of tree-walk eval.
Returns True if converged within *max_iter* iterations.
"""
@@ -53,29 +61,41 @@ def newton_solve(
if n_free == 0 or n_res == 0:
return True
# Build symbolic Jacobian once (list-of-lists of simplified Expr)
jac_exprs: List[List[Expr]] = []
for r in residuals:
row = []
for name in free:
row.append(r.diff(name).simplify())
jac_exprs.append(row)
# Build symbolic Jacobian once (or reuse pre-built)
if jac_exprs is None:
jac_exprs = []
for r in residuals:
row = []
for name in free:
row.append(r.diff(name).simplify())
jac_exprs.append(row)
# Try compilation if not provided
if compiled_eval is None:
from .codegen import try_compile_system
compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free)
# Pre-allocate arrays reused across iterations
r_vec = np.empty(n_res)
J = np.zeros((n_res, n_free))
for _it in range(max_iter):
env = params.get_env()
if compiled_eval is not None:
J[:] = 0.0
compiled_eval(params.env_ref(), r_vec, J)
else:
env = params.get_env()
for i, r in enumerate(residuals):
r_vec[i] = r.eval(env)
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
# Evaluate residual vector
r_vec = np.array([r.eval(env) for r in residuals])
r_norm = np.linalg.norm(r_vec)
if r_norm < tol:
return True
# Evaluate Jacobian matrix
J = np.empty((n_res, n_free))
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
# Solve J @ dx = -r (least-squares handles rank-deficient)
if weight_vector is not None:
# Column-scale J by W^{-1/2} for weighted minimum-norm
@@ -100,9 +120,13 @@ def newton_solve(
_renormalize_quats(params, quat_groups)
# Check final residual
env = params.get_env()
r_vec = np.array([r.eval(env) for r in residuals])
return np.linalg.norm(r_vec) < tol
if compiled_eval is not None:
compiled_eval(params.env_ref(), r_vec, J)
else:
env = params.get_env()
for i, r in enumerate(residuals):
r_vec[i] = r.eval(env)
return bool(np.linalg.norm(r_vec) < tol)
def _renormalize_quats(

View File

@@ -60,6 +60,14 @@ class ParamTable:
"""Return a snapshot of all current values (for Expr.eval)."""
return dict(self._values)
def env_ref(self) -> Dict[str, float]:
"""Return a direct reference to the internal values dict.
Faster than :meth:`get_env` (no copy). Safe when the caller
only reads during evaluation and mutates via :meth:`set_free_vector`.
"""
return self._values
def free_names(self) -> List[str]:
"""Return ordered list of free (non-fixed) parameter names."""
return list(self._free_order)

View File

@@ -154,6 +154,7 @@ class KindredSolver(kcsolve.IKCSolver):
residuals = single_equation_pass(residuals, system.params)
# Solve (decomposed for large assemblies, monolithic for small)
jac_exprs = None # may be populated by _monolithic_solve
if n_free_bodies >= _DECOMPOSE_THRESHOLD:
grounded_ids = {pid for pid, b in system.bodies.items() if b.grounded}
clusters = decompose(ctx.constraints, grounded_ids)
@@ -172,7 +173,7 @@ class KindredSolver(kcsolve.IKCSolver):
system.params,
)
else:
converged = _monolithic_solve(
converged, jac_exprs = _monolithic_solve(
residuals,
system.params,
system.quat_groups,
@@ -185,7 +186,7 @@ class KindredSolver(kcsolve.IKCSolver):
n_free_bodies,
_DECOMPOSE_THRESHOLD,
)
converged = _monolithic_solve(
converged, jac_exprs = _monolithic_solve(
residuals,
system.params,
system.quat_groups,
@@ -194,7 +195,7 @@ class KindredSolver(kcsolve.IKCSolver):
)
# DOF
dof = count_dof(residuals, system.params)
dof = count_dof(residuals, system.params, jac_exprs=jac_exprs)
# Build result
result = kcsolve.SolveResult()
@@ -210,6 +211,7 @@ class KindredSolver(kcsolve.IKCSolver):
system.params,
system.residual_ranges,
ctx,
jac_exprs=jac_exprs,
)
result.placements = _extract_placements(system.params, system.bodies)
@@ -419,13 +421,15 @@ def _build_system(ctx):
return system
def _run_diagnostics(residuals, params, residual_ranges, ctx):
def _run_diagnostics(residuals, params, residual_ranges, ctx, jac_exprs=None):
"""Run overconstrained detection and return kcsolve diagnostics."""
diagnostics = []
if not hasattr(kcsolve, "ConstraintDiagnostic"):
return diagnostics
diags = find_overconstrained(residuals, params, residual_ranges)
diags = find_overconstrained(
residuals, params, residual_ranges, jac_exprs=jac_exprs
)
for d in diags:
cd = kcsolve.ConstraintDiagnostic()
cd.constraint_id = ctx.constraints[d.constraint_index].id
@@ -458,7 +462,23 @@ def _extract_placements(params, bodies):
def _monolithic_solve(
all_residuals, params, quat_groups, post_step=None, weight_vector=None
):
"""Newton-Raphson solve with BFGS fallback on the full system."""
"""Newton-Raphson solve with BFGS fallback on the full system.
Returns ``(converged, jac_exprs)`` so the caller can reuse the
symbolic Jacobian for DOF counting / diagnostics.
"""
from .codegen import try_compile_system
free = params.free_names()
n_res = len(all_residuals)
n_free = len(free)
# Build symbolic Jacobian once
jac_exprs = [[r.diff(name).simplify() for name in free] for r in all_residuals]
# Compile once
compiled_eval = try_compile_system(all_residuals, jac_exprs, n_res, n_free)
t0 = time.perf_counter()
converged = newton_solve(
all_residuals,
@@ -468,6 +488,8 @@ def _monolithic_solve(
tol=1e-10,
post_step=post_step,
weight_vector=weight_vector,
jac_exprs=jac_exprs,
compiled_eval=compiled_eval,
)
nr_ms = (time.perf_counter() - t0) * 1000
if not converged:
@@ -482,6 +504,8 @@ def _monolithic_solve(
max_iter=200,
tol=1e-10,
weight_vector=weight_vector,
jac_exprs=jac_exprs,
compiled_eval=compiled_eval,
)
bfgs_ms = (time.perf_counter() - t1) * 1000
if converged:
@@ -490,7 +514,7 @@ def _monolithic_solve(
log.warning("_monolithic_solve: BFGS also failed (%.1f ms)", bfgs_ms)
else:
log.debug("_monolithic_solve: Newton-Raphson converged (%.1f ms)", nr_ms)
return converged
return converged, jac_exprs
def _build_constraint(

357
tests/test_codegen.py Normal file
View File

@@ -0,0 +1,357 @@
"""Tests for the codegen module — CSE, compilation, and compiled evaluation."""
import math
import numpy as np
import pytest
from kindred_solver.codegen import (
_build_cse,
_find_nonzero_entries,
compile_system,
try_compile_system,
)
from kindred_solver.expr import (
ZERO,
Add,
Const,
Cos,
Div,
Mul,
Neg,
Pow,
Sin,
Sqrt,
Sub,
Var,
)
from kindred_solver.newton import newton_solve
from kindred_solver.params import ParamTable
# ---------------------------------------------------------------------------
# to_code() — round-trip correctness for each Expr type
# ---------------------------------------------------------------------------
class TestToCode:
"""Verify that eval(expr.to_code()) == expr.eval(env) for each node."""
NS = {"_sin": math.sin, "_cos": math.cos, "_sqrt": math.sqrt}
def _check(self, expr, env):
code = expr.to_code()
ns = dict(self.NS)
ns["env"] = env
compiled = eval(code, ns)
expected = expr.eval(env)
assert abs(compiled - expected) < 1e-15, (
f"{code} = {compiled}, expected {expected}"
)
def test_const(self):
self._check(Const(3.14), {})
def test_const_negative(self):
self._check(Const(-2.5), {})
def test_const_zero(self):
self._check(Const(0.0), {})
def test_var(self):
self._check(Var("x"), {"x": 7.0})
def test_neg(self):
self._check(Neg(Var("x")), {"x": 3.0})
def test_add(self):
self._check(Add(Var("x"), Const(2.0)), {"x": 5.0})
def test_sub(self):
self._check(Sub(Var("x"), Var("y")), {"x": 5.0, "y": 3.0})
def test_mul(self):
self._check(Mul(Var("x"), Const(3.0)), {"x": 4.0})
def test_div(self):
self._check(Div(Var("x"), Const(2.0)), {"x": 6.0})
def test_pow(self):
self._check(Pow(Var("x"), Const(3.0)), {"x": 2.0})
def test_sin(self):
self._check(Sin(Var("x")), {"x": 1.0})
def test_cos(self):
self._check(Cos(Var("x")), {"x": 1.0})
def test_sqrt(self):
self._check(Sqrt(Var("x")), {"x": 9.0})
def test_nested(self):
"""Complex nested expression."""
x, y = Var("x"), Var("y")
expr = Add(Mul(Sin(x), Cos(y)), Sqrt(Sub(x, Neg(y))))
self._check(expr, {"x": 2.0, "y": 1.0})
# ---------------------------------------------------------------------------
# CSE
# ---------------------------------------------------------------------------
class TestCSE:
def test_no_sharing(self):
"""Distinct expressions produce no CSE temps."""
a = Var("x") + Const(1.0)
b = Var("y") + Const(2.0)
id_to_temp, temps = _build_cse([a, b])
assert len(temps) == 0
def test_shared_subtree(self):
"""Same node object used in two places is extracted."""
x = Var("x")
shared = x * Const(2.0) # single Mul node
a = shared + Const(1.0)
b = shared + Const(3.0)
id_to_temp, temps = _build_cse([a, b])
assert len(temps) >= 1
# The shared Mul node should be a temp
assert id(shared) in id_to_temp
def test_leaf_nodes_not_extracted(self):
"""Const and Var nodes are never extracted as temps."""
x = Var("x")
c = Const(5.0)
a = x + c
b = x + c
id_to_temp, temps = _build_cse([a, b])
for _, expr in temps:
assert not isinstance(expr, (Const, Var))
def test_dependency_order(self):
"""Temps are in dependency order (dependencies first)."""
x = Var("x")
inner = x * Const(2.0)
outer = inner + inner # uses inner twice
wrapper_a = outer * Const(3.0)
wrapper_b = outer * Const(4.0)
id_to_temp, temps = _build_cse([wrapper_a, wrapper_b])
# If both inner and outer are temps, inner must come first
temp_names = [name for name, _ in temps]
temp_ids = [id(expr) for _, expr in temps]
if id(inner) in set(id_to_temp) and id(outer) in set(id_to_temp):
inner_idx = temp_ids.index(id(inner))
outer_idx = temp_ids.index(id(outer))
assert inner_idx < outer_idx
# ---------------------------------------------------------------------------
# Sparsity detection
# ---------------------------------------------------------------------------
class TestSparsity:
def test_zero_entries_skipped(self):
nz = _find_nonzero_entries(
[
[Const(0.0), Var("x"), Const(0.0)],
[Const(1.0), Const(0.0), Var("y")],
]
)
assert nz == [(0, 1), (1, 0), (1, 2)]
def test_all_nonzero(self):
nz = _find_nonzero_entries(
[
[Var("x"), Const(1.0)],
]
)
assert nz == [(0, 0), (0, 1)]
def test_all_zero(self):
nz = _find_nonzero_entries(
[
[Const(0.0), Const(0.0)],
]
)
assert nz == []
# ---------------------------------------------------------------------------
# Full compilation pipeline
# ---------------------------------------------------------------------------
class TestCompileSystem:
def test_simple_linear(self):
"""Compile and evaluate a trivial system: r = x - 3, J = [[1]]."""
x = Var("x")
residuals = [x - Const(3.0)]
jac_exprs = [[Const(1.0)]] # d(x-3)/dx = 1
fn = compile_system(residuals, jac_exprs, 1, 1)
env = {"x": 5.0}
r_vec = np.empty(1)
J = np.zeros((1, 1))
fn(env, r_vec, J)
assert abs(r_vec[0] - 2.0) < 1e-15 # 5 - 3 = 2
assert abs(J[0, 0] - 1.0) < 1e-15
def test_two_variable_system(self):
"""Compile: r0 = x + y - 5, r1 = x - y - 1."""
x, y = Var("x"), Var("y")
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
jac_exprs = [
[Const(1.0), Const(1.0)], # d(r0)/dx, d(r0)/dy
[Const(1.0), Const(-1.0)], # d(r1)/dx, d(r1)/dy
]
fn = compile_system(residuals, jac_exprs, 2, 2)
env = {"x": 3.0, "y": 2.0}
r_vec = np.empty(2)
J = np.zeros((2, 2))
fn(env, r_vec, J)
assert abs(r_vec[0] - 0.0) < 1e-15
assert abs(r_vec[1] - 0.0) < 1e-15
assert abs(J[0, 0] - 1.0) < 1e-15
assert abs(J[0, 1] - 1.0) < 1e-15
assert abs(J[1, 0] - 1.0) < 1e-15
assert abs(J[1, 1] - (-1.0)) < 1e-15
def test_sparse_jacobian(self):
"""Zero Jacobian entries remain zero after compiled evaluation."""
x = Var("x")
y = Var("y")
# r0 depends on x only, r1 depends on y only
residuals = [x - Const(1.0), y - Const(2.0)]
jac_exprs = [
[Const(1.0), Const(0.0)],
[Const(0.0), Const(1.0)],
]
fn = compile_system(residuals, jac_exprs, 2, 2)
env = {"x": 3.0, "y": 4.0}
r_vec = np.empty(2)
J = np.zeros((2, 2))
fn(env, r_vec, J)
assert abs(J[0, 1]) < 1e-15 # should remain zero
assert abs(J[1, 0]) < 1e-15 # should remain zero
assert abs(J[0, 0] - 1.0) < 1e-15
assert abs(J[1, 1] - 1.0) < 1e-15
def test_trig_functions(self):
"""Compiled evaluation handles Sin/Cos/Sqrt."""
x = Var("x")
residuals = [Sin(x), Cos(x), Sqrt(x)]
jac_exprs = [
[Cos(x)],
[Neg(Sin(x))],
[Div(Const(1.0), Mul(Const(2.0), Sqrt(x)))],
]
fn = compile_system(residuals, jac_exprs, 3, 1)
env = {"x": 1.0}
r_vec = np.empty(3)
J = np.zeros((3, 1))
fn(env, r_vec, J)
assert abs(r_vec[0] - math.sin(1.0)) < 1e-15
assert abs(r_vec[1] - math.cos(1.0)) < 1e-15
assert abs(r_vec[2] - math.sqrt(1.0)) < 1e-15
assert abs(J[0, 0] - math.cos(1.0)) < 1e-15
assert abs(J[1, 0] - (-math.sin(1.0))) < 1e-15
assert abs(J[2, 0] - (1.0 / (2.0 * math.sqrt(1.0)))) < 1e-15
def test_matches_tree_walk(self):
"""Compiled eval produces identical results to tree-walk eval."""
pt = ParamTable()
x = pt.add("x", 2.0)
y = pt.add("y", 3.0)
residuals = [x * y - Const(6.0), x * x + y - Const(7.0)]
free = pt.free_names()
jac_exprs = [[r.diff(name).simplify() for name in free] for r in residuals]
fn = compile_system(residuals, jac_exprs, 2, 2)
# Tree-walk eval
env = pt.get_env()
r_tree = np.array([r.eval(env) for r in residuals])
J_tree = np.empty((2, 2))
for i in range(2):
for j in range(2):
J_tree[i, j] = jac_exprs[i][j].eval(env)
# Compiled eval
r_comp = np.empty(2)
J_comp = np.zeros((2, 2))
fn(pt.env_ref(), r_comp, J_comp)
np.testing.assert_allclose(r_comp, r_tree, atol=1e-15)
np.testing.assert_allclose(J_comp, J_tree, atol=1e-15)
class TestTryCompile:
def test_returns_callable(self):
x = Var("x")
fn = try_compile_system([x], [[Const(1.0)]], 1, 1)
assert fn is not None
def test_empty_system(self):
"""Empty system returns None (nothing to compile)."""
fn = try_compile_system([], [], 0, 0)
# Empty system is handled by the solver before codegen is reached,
# so returning None is acceptable.
assert fn is None or callable(fn)
# ---------------------------------------------------------------------------
# Integration: Newton with compiled eval matches tree-walk
# ---------------------------------------------------------------------------
class TestCompiledNewton:
def test_single_linear(self):
"""Solve x - 3 = 0 with compiled eval."""
pt = ParamTable()
x = pt.add("x", 0.0)
residuals = [x - Const(3.0)]
assert newton_solve(residuals, pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-10
def test_two_variables(self):
"""Solve x + y = 5, x - y = 1 with compiled eval."""
pt = ParamTable()
x = pt.add("x", 0.0)
y = pt.add("y", 0.0)
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
assert newton_solve(residuals, pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-10
assert abs(pt.get_value("y") - 2.0) < 1e-10
def test_quadratic(self):
"""Solve x^2 - 4 = 0 starting from x=1."""
pt = ParamTable()
x = pt.add("x", 1.0)
residuals = [x * x - Const(4.0)]
assert newton_solve(residuals, pt) is True
assert abs(pt.get_value("x") - 2.0) < 1e-10
def test_nonlinear_system(self):
"""Compiled eval converges for a nonlinear system: xy=6, x+y=5."""
pt = ParamTable()
x = pt.add("x", 2.0)
y = pt.add("y", 3.5)
residuals = [x * y - Const(6.0), x + y - Const(5.0)]
assert newton_solve(residuals, pt, max_iter=100) is True
# Solutions are (2, 3) or (3, 2) — check they satisfy both equations
xv, yv = pt.get_value("x"), pt.get_value("y")
assert abs(xv * yv - 6.0) < 1e-10
assert abs(xv + yv - 5.0) < 1e-10