feat(solver): compile symbolic Jacobian to flat Python for fast evaluation
Add a code generation pipeline that compiles Expr DAGs into flat Python functions, eliminating recursive tree-walk dispatch in the Newton-Raphson inner loop. Key changes: - Add to_code() method to all 11 Expr node types (expr.py) - New codegen.py module with CSE (common subexpression elimination), sparsity detection, and compile()/exec() compilation pipeline - Add ParamTable.env_ref() to avoid dict copies per iteration (params.py) - Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval to avoid redundant diff/simplify and enable compiled evaluation - count_dof() and diagnostics accept pre-built jac_exprs - solver.py builds symbolic Jacobian once, compiles once, passes to all consumers (_monolithic_solve, count_dof, diagnostics) - Automatic fallback: if codegen fails, tree-walk eval is used Expected performance impact: - ~10-20x faster Jacobian evaluation (no recursive dispatch) - ~2-5x additional from CSE on quaternion-heavy systems - ~3x fewer entries evaluated via sparsity detection - Eliminates redundant diff().simplify() in DOF/diagnostics
This commit is contained in:
@@ -7,7 +7,7 @@ analytic gradient from the Expr DAG's symbolic differentiation.
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -29,6 +29,8 @@ def bfgs_solve(
|
||||
max_iter: int = 200,
|
||||
tol: float = 1e-10,
|
||||
weight_vector: "np.ndarray | None" = None,
|
||||
jac_exprs: "List[List[Expr]] | None" = None,
|
||||
compiled_eval: "Callable | None" = None,
|
||||
) -> bool:
|
||||
"""Solve ``residuals == 0`` by minimizing sum of squared residuals.
|
||||
|
||||
@@ -38,6 +40,13 @@ def bfgs_solve(
|
||||
``sqrt(w)`` so that the objective becomes
|
||||
``0.5 * sum(w_i * r_i^2)`` — equivalent to weighted least-squares.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
jac_exprs:
|
||||
Pre-built symbolic Jacobian (list-of-lists of Expr).
|
||||
compiled_eval:
|
||||
Pre-compiled evaluation function from :mod:`codegen`.
|
||||
|
||||
Returns True if converged (||r|| < tol).
|
||||
"""
|
||||
if not _HAS_SCIPY:
|
||||
@@ -51,13 +60,20 @@ def bfgs_solve(
|
||||
return True
|
||||
|
||||
# Build symbolic gradient expressions once: d(r_i)/d(x_j)
|
||||
jac_exprs: List[List[Expr]] = []
|
||||
if jac_exprs is None:
|
||||
jac_exprs = []
|
||||
for r in residuals:
|
||||
row = []
|
||||
for name in free:
|
||||
row.append(r.diff(name).simplify())
|
||||
jac_exprs.append(row)
|
||||
|
||||
# Try compilation if not provided
|
||||
if compiled_eval is None:
|
||||
from .codegen import try_compile_system
|
||||
|
||||
compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free)
|
||||
|
||||
# Pre-compute scaling for weighted minimum-norm
|
||||
if weight_vector is not None:
|
||||
w_sqrt = np.sqrt(weight_vector)
|
||||
@@ -66,6 +82,10 @@ def bfgs_solve(
|
||||
w_sqrt = None
|
||||
w_inv_sqrt = None
|
||||
|
||||
# Pre-allocate arrays reused across objective calls
|
||||
r_vals = np.empty(n_res)
|
||||
J = np.zeros((n_res, n_free))
|
||||
|
||||
def objective_and_grad(y_vec):
|
||||
# Transform back from scaled space if weighted
|
||||
if w_inv_sqrt is not None:
|
||||
@@ -78,18 +98,19 @@ def bfgs_solve(
|
||||
if quat_groups:
|
||||
_renormalize_quats(params, quat_groups)
|
||||
|
||||
if compiled_eval is not None:
|
||||
J[:] = 0.0
|
||||
compiled_eval(params.env_ref(), r_vals, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
|
||||
# Evaluate residuals
|
||||
r_vals = np.array([r.eval(env) for r in residuals])
|
||||
f = 0.5 * np.dot(r_vals, r_vals)
|
||||
|
||||
# Evaluate Jacobian
|
||||
J = np.empty((n_res, n_free))
|
||||
for i, r in enumerate(residuals):
|
||||
r_vals[i] = r.eval(env)
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
|
||||
f = 0.5 * np.dot(r_vals, r_vals)
|
||||
|
||||
# Gradient of f w.r.t. x = J^T @ r
|
||||
grad_x = J.T @ r_vals
|
||||
|
||||
@@ -126,8 +147,12 @@ def bfgs_solve(
|
||||
_renormalize_quats(params, quat_groups)
|
||||
|
||||
# Check convergence on actual residual norm
|
||||
if compiled_eval is not None:
|
||||
compiled_eval(params.env_ref(), r_vals, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
r_vals = np.array([r.eval(env) for r in residuals])
|
||||
for i, r in enumerate(residuals):
|
||||
r_vals[i] = r.eval(env)
|
||||
return bool(np.linalg.norm(r_vals) < tol)
|
||||
|
||||
|
||||
|
||||
308
kindred_solver/codegen.py
Normal file
308
kindred_solver/codegen.py
Normal file
@@ -0,0 +1,308 @@
|
||||
"""Compile Expr DAGs into flat Python functions for fast evaluation.
|
||||
|
||||
The compilation pipeline:
|
||||
1. Collect all Expr nodes to be evaluated (residuals + Jacobian entries).
|
||||
2. Identify common subexpressions (CSE) by ``id()`` — the Expr DAG
|
||||
already shares node objects via ParamTable's Var instances.
|
||||
3. Generate a single Python function body that computes CSE temps,
|
||||
then fills ``r_vec`` and ``J`` arrays in-place.
|
||||
4. Compile with ``compile()`` + ``exec()`` and return the callable.
|
||||
|
||||
The generated function signature is::
|
||||
|
||||
fn(env: dict[str, float], r_vec: ndarray, J: ndarray) -> None
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections import Counter
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .expr import Const, Expr, Var
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Namespace injected into compiled functions.
|
||||
_CODEGEN_NS = {
|
||||
"_sin": math.sin,
|
||||
"_cos": math.cos,
|
||||
"_sqrt": math.sqrt,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSE (Common Subexpression Elimination)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _collect_nodes(expr: Expr, counts: Counter, visited: set[int]) -> None:
|
||||
"""Walk *expr* and count how many times each node ``id()`` appears."""
|
||||
eid = id(expr)
|
||||
counts[eid] += 1
|
||||
if eid in visited:
|
||||
return
|
||||
visited.add(eid)
|
||||
|
||||
# Recurse into children
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return
|
||||
if hasattr(expr, "child"):
|
||||
_collect_nodes(expr.child, counts, visited)
|
||||
elif hasattr(expr, "a"):
|
||||
_collect_nodes(expr.a, counts, visited)
|
||||
_collect_nodes(expr.b, counts, visited)
|
||||
elif hasattr(expr, "base"):
|
||||
_collect_nodes(expr.base, counts, visited)
|
||||
_collect_nodes(expr.exp, counts, visited)
|
||||
|
||||
|
||||
def _build_cse(
|
||||
exprs: list[Expr],
|
||||
) -> tuple[dict[int, str], list[tuple[str, Expr]]]:
|
||||
"""Identify shared sub-trees and assign temporary variable names.
|
||||
|
||||
Returns:
|
||||
id_to_temp: mapping from ``id(node)`` to temp variable name
|
||||
temps_ordered: ``(temp_name, expr)`` pairs in dependency order
|
||||
"""
|
||||
counts: Counter = Counter()
|
||||
visited: set[int] = set()
|
||||
id_to_expr: dict[int, Expr] = {}
|
||||
|
||||
for expr in exprs:
|
||||
_collect_nodes(expr, counts, visited)
|
||||
|
||||
# Map id -> Expr for nodes we visited
|
||||
for expr in exprs:
|
||||
_map_ids(expr, id_to_expr)
|
||||
|
||||
# Nodes referenced more than once and not trivial (Const/Var) become temps
|
||||
shared_ids = set()
|
||||
for eid, cnt in counts.items():
|
||||
if cnt > 1:
|
||||
node = id_to_expr.get(eid)
|
||||
if node is not None and not isinstance(node, (Const, Var)):
|
||||
shared_ids.add(eid)
|
||||
|
||||
if not shared_ids:
|
||||
return {}, []
|
||||
|
||||
# Topological order: a temp must be computed before any temp that uses it.
|
||||
# Walk each shared node, collect in post-order.
|
||||
ordered_ids: list[int] = []
|
||||
order_visited: set[int] = set()
|
||||
|
||||
def _topo(expr: Expr) -> None:
|
||||
eid = id(expr)
|
||||
if eid in order_visited:
|
||||
return
|
||||
order_visited.add(eid)
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return
|
||||
if hasattr(expr, "child"):
|
||||
_topo(expr.child)
|
||||
elif hasattr(expr, "a"):
|
||||
_topo(expr.a)
|
||||
_topo(expr.b)
|
||||
elif hasattr(expr, "base"):
|
||||
_topo(expr.base)
|
||||
_topo(expr.exp)
|
||||
if eid in shared_ids:
|
||||
ordered_ids.append(eid)
|
||||
|
||||
for expr in exprs:
|
||||
_topo(expr)
|
||||
|
||||
id_to_temp: dict[int, str] = {}
|
||||
temps_ordered: list[tuple[str, Expr]] = []
|
||||
for i, eid in enumerate(ordered_ids):
|
||||
name = f"_c{i}"
|
||||
id_to_temp[eid] = name
|
||||
temps_ordered.append((name, id_to_expr[eid]))
|
||||
|
||||
return id_to_temp, temps_ordered
|
||||
|
||||
|
||||
def _map_ids(expr: Expr, mapping: dict[int, Expr]) -> None:
|
||||
"""Populate id -> Expr mapping for all nodes in *expr*."""
|
||||
eid = id(expr)
|
||||
if eid in mapping:
|
||||
return
|
||||
mapping[eid] = expr
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return
|
||||
if hasattr(expr, "child"):
|
||||
_map_ids(expr.child, mapping)
|
||||
elif hasattr(expr, "a"):
|
||||
_map_ids(expr.a, mapping)
|
||||
_map_ids(expr.b, mapping)
|
||||
elif hasattr(expr, "base"):
|
||||
_map_ids(expr.base, mapping)
|
||||
_map_ids(expr.exp, mapping)
|
||||
|
||||
|
||||
def _expr_to_code(expr: Expr, id_to_temp: dict[int, str]) -> str:
|
||||
"""Emit code for *expr*, substituting temp names for shared nodes."""
|
||||
eid = id(expr)
|
||||
temp = id_to_temp.get(eid)
|
||||
if temp is not None:
|
||||
return temp
|
||||
return expr.to_code()
|
||||
|
||||
|
||||
def _expr_to_code_recursive(expr: Expr, id_to_temp: dict[int, str]) -> str:
|
||||
"""Emit code for *expr*, recursing into children but respecting temps."""
|
||||
eid = id(expr)
|
||||
temp = id_to_temp.get(eid)
|
||||
if temp is not None:
|
||||
return temp
|
||||
|
||||
# For leaf nodes, just use to_code() directly
|
||||
if isinstance(expr, (Const, Var)):
|
||||
return expr.to_code()
|
||||
|
||||
# For non-leaf nodes, recurse into children with temp substitution
|
||||
from .expr import Add, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub
|
||||
|
||||
if isinstance(expr, Neg):
|
||||
return f"(-{_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Sin):
|
||||
return f"_sin({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Cos):
|
||||
return f"_cos({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Sqrt):
|
||||
return f"_sqrt({_expr_to_code_recursive(expr.child, id_to_temp)})"
|
||||
if isinstance(expr, Add):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} + {b})"
|
||||
if isinstance(expr, Sub):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} - {b})"
|
||||
if isinstance(expr, Mul):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} * {b})"
|
||||
if isinstance(expr, Div):
|
||||
a = _expr_to_code_recursive(expr.a, id_to_temp)
|
||||
b = _expr_to_code_recursive(expr.b, id_to_temp)
|
||||
return f"({a} / {b})"
|
||||
if isinstance(expr, Pow):
|
||||
base = _expr_to_code_recursive(expr.base, id_to_temp)
|
||||
exp = _expr_to_code_recursive(expr.exp, id_to_temp)
|
||||
return f"({base} ** {exp})"
|
||||
|
||||
# Fallback — should not happen for known node types
|
||||
return expr.to_code()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sparsity detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_nonzero_entries(
|
||||
jac_exprs: list[list[Expr]],
|
||||
) -> list[tuple[int, int]]:
|
||||
"""Return ``(row, col)`` pairs for non-zero Jacobian entries."""
|
||||
nz = []
|
||||
for i, row in enumerate(jac_exprs):
|
||||
for j, expr in enumerate(row):
|
||||
if isinstance(expr, Const) and expr.value == 0.0:
|
||||
continue
|
||||
nz.append((i, j))
|
||||
return nz
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Code generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compile_system(
|
||||
residuals: list[Expr],
|
||||
jac_exprs: list[list[Expr]],
|
||||
n_res: int,
|
||||
n_free: int,
|
||||
) -> Callable[[dict, np.ndarray, np.ndarray], None]:
|
||||
"""Compile residuals + Jacobian into a single evaluation function.
|
||||
|
||||
Returns a callable ``fn(env, r_vec, J)`` that fills *r_vec* and *J*
|
||||
in-place. *J* must be pre-zeroed by the caller (only non-zero
|
||||
entries are written).
|
||||
"""
|
||||
# Detect non-zero Jacobian entries
|
||||
nz_entries = _find_nonzero_entries(jac_exprs)
|
||||
|
||||
# Collect all expressions for CSE analysis
|
||||
all_exprs: list[Expr] = list(residuals)
|
||||
nz_jac_exprs: list[Expr] = [jac_exprs[i][j] for i, j in nz_entries]
|
||||
all_exprs.extend(nz_jac_exprs)
|
||||
|
||||
# CSE
|
||||
id_to_temp, temps_ordered = _build_cse(all_exprs)
|
||||
|
||||
# Generate function body
|
||||
lines: list[str] = ["def _eval(env, r_vec, J):"]
|
||||
|
||||
# Temporaries — temporarily remove each temp's own id so its RHS
|
||||
# is expanded rather than self-referencing.
|
||||
for temp_name, temp_expr in temps_ordered:
|
||||
eid = id(temp_expr)
|
||||
saved = id_to_temp.pop(eid)
|
||||
code = _expr_to_code_recursive(temp_expr, id_to_temp)
|
||||
id_to_temp[eid] = saved
|
||||
lines.append(f" {temp_name} = {code}")
|
||||
|
||||
# Residuals
|
||||
for i, r in enumerate(residuals):
|
||||
code = _expr_to_code_recursive(r, id_to_temp)
|
||||
lines.append(f" r_vec[{i}] = {code}")
|
||||
|
||||
# Jacobian (sparse)
|
||||
for idx, (i, j) in enumerate(nz_entries):
|
||||
code = _expr_to_code_recursive(nz_jac_exprs[idx], id_to_temp)
|
||||
lines.append(f" J[{i}, {j}] = {code}")
|
||||
|
||||
source = "\n".join(lines)
|
||||
|
||||
# Compile
|
||||
code_obj = compile(source, "<kindred_codegen>", "exec")
|
||||
ns = dict(_CODEGEN_NS)
|
||||
exec(code_obj, ns)
|
||||
|
||||
fn = ns["_eval"]
|
||||
|
||||
n_temps = len(temps_ordered)
|
||||
n_nz = len(nz_entries)
|
||||
n_total = n_res * n_free
|
||||
log.debug(
|
||||
"codegen: compiled %d residuals + %d/%d Jacobian entries, %d CSE temps",
|
||||
n_res,
|
||||
n_nz,
|
||||
n_total,
|
||||
n_temps,
|
||||
)
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def try_compile_system(
|
||||
residuals: list[Expr],
|
||||
jac_exprs: list[list[Expr]],
|
||||
n_res: int,
|
||||
n_free: int,
|
||||
) -> Callable[[dict, np.ndarray, np.ndarray], None] | None:
|
||||
"""Compile with automatic fallback. Returns ``None`` on failure."""
|
||||
try:
|
||||
return compile_system(residuals, jac_exprs, n_res, n_free)
|
||||
except Exception:
|
||||
log.debug(
|
||||
"codegen: compilation failed, falling back to tree-walk eval", exc_info=True
|
||||
)
|
||||
return None
|
||||
@@ -32,6 +32,7 @@ def per_entity_dof(
|
||||
params: ParamTable,
|
||||
bodies: dict[str, RigidBody],
|
||||
rank_tol: float = 1e-8,
|
||||
jac_exprs: "list[list[Expr]] | None" = None,
|
||||
) -> list[EntityDOF]:
|
||||
"""Compute remaining DOF for each non-grounded body.
|
||||
|
||||
@@ -71,6 +72,11 @@ def per_entity_dof(
|
||||
# Build full Jacobian (for efficiency, compute once)
|
||||
n_free = len(free)
|
||||
J_full = np.empty((n_res, n_free))
|
||||
if jac_exprs is not None:
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J_full[i, j] = jac_exprs[i][j].eval(env)
|
||||
else:
|
||||
for i, r in enumerate(residuals):
|
||||
for j, name in enumerate(free):
|
||||
J_full[i, j] = r.diff(name).simplify().eval(env)
|
||||
@@ -217,6 +223,7 @@ def find_overconstrained(
|
||||
params: ParamTable,
|
||||
residual_ranges: list[tuple[int, int, int]],
|
||||
rank_tol: float = 1e-8,
|
||||
jac_exprs: "list[list[Expr]] | None" = None,
|
||||
) -> list[ConstraintDiag]:
|
||||
"""Identify redundant and conflicting constraints.
|
||||
|
||||
@@ -243,6 +250,12 @@ def find_overconstrained(
|
||||
r_vec = np.empty(n_res)
|
||||
for i, r in enumerate(residuals):
|
||||
r_vec[i] = r.eval(env)
|
||||
if jac_exprs is not None:
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
else:
|
||||
for i, r in enumerate(residuals):
|
||||
for j, name in enumerate(free):
|
||||
J[i, j] = r.diff(name).simplify().eval(env)
|
||||
|
||||
|
||||
@@ -14,11 +14,15 @@ def count_dof(
|
||||
residuals: List[Expr],
|
||||
params: ParamTable,
|
||||
rank_tol: float = 1e-8,
|
||||
jac_exprs: "List[List[Expr]] | None" = None,
|
||||
) -> int:
|
||||
"""Compute DOF = n_free_params - rank(Jacobian).
|
||||
|
||||
Evaluates the Jacobian numerically at the current parameter values
|
||||
and computes its rank via SVD.
|
||||
|
||||
When *jac_exprs* is provided, reuses the pre-built symbolic
|
||||
Jacobian instead of re-differentiating every residual.
|
||||
"""
|
||||
free = params.free_names()
|
||||
n_free = len(free)
|
||||
@@ -32,6 +36,11 @@ def count_dof(
|
||||
env = params.get_env()
|
||||
|
||||
J = np.empty((n_res, n_free))
|
||||
if jac_exprs is not None:
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
else:
|
||||
for i, r in enumerate(residuals):
|
||||
for j, name in enumerate(free):
|
||||
J[i, j] = r.diff(name).simplify().eval(env)
|
||||
|
||||
@@ -24,6 +24,16 @@ class Expr:
|
||||
"""Return the set of variable names in this expression."""
|
||||
raise NotImplementedError
|
||||
|
||||
def to_code(self) -> str:
|
||||
"""Emit a Python arithmetic expression string.
|
||||
|
||||
The returned string, when evaluated with a dict ``env`` mapping
|
||||
parameter names to floats (and ``_sin``, ``_cos``, ``_sqrt``
|
||||
bound to their ``math`` equivalents), produces the same result
|
||||
as ``self.eval(env)``.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
# -- operator overloads --------------------------------------------------
|
||||
|
||||
def __add__(self, other):
|
||||
@@ -90,6 +100,9 @@ class Const(Expr):
|
||||
def vars(self):
|
||||
return set()
|
||||
|
||||
def to_code(self):
|
||||
return repr(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return f"Const({self.value})"
|
||||
|
||||
@@ -118,6 +131,9 @@ class Var(Expr):
|
||||
def vars(self):
|
||||
return {self.name}
|
||||
|
||||
def to_code(self):
|
||||
return f"env[{self.name!r}]"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Var({self.name!r})"
|
||||
|
||||
@@ -154,6 +170,9 @@ class Neg(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"(-{self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Neg({self.child!r})"
|
||||
|
||||
@@ -180,6 +199,9 @@ class Sin(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"_sin({self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sin({self.child!r})"
|
||||
|
||||
@@ -206,6 +228,9 @@ class Cos(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"_cos({self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Cos({self.child!r})"
|
||||
|
||||
@@ -232,6 +257,9 @@ class Sqrt(Expr):
|
||||
def vars(self):
|
||||
return self.child.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"_sqrt({self.child.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sqrt({self.child!r})"
|
||||
|
||||
@@ -266,6 +294,9 @@ class Add(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} + {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Add({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -297,6 +328,9 @@ class Sub(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} - {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Sub({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -337,6 +371,9 @@ class Mul(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} * {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Mul({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -372,6 +409,9 @@ class Div(Expr):
|
||||
def vars(self):
|
||||
return self.a.vars() | self.b.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.a.to_code()} / {self.b.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Div({self.a!r}, {self.b!r})"
|
||||
|
||||
@@ -414,6 +454,9 @@ class Pow(Expr):
|
||||
def vars(self):
|
||||
return self.base.vars() | self.exp.vars()
|
||||
|
||||
def to_code(self):
|
||||
return f"({self.base.to_code()} ** {self.exp.to_code()})"
|
||||
|
||||
def __repr__(self):
|
||||
return f"Pow({self.base!r}, {self.exp!r})"
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import List
|
||||
from typing import Callable, List
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -19,6 +19,8 @@ def newton_solve(
|
||||
tol: float = 1e-10,
|
||||
post_step: "Callable[[ParamTable], None] | None" = None,
|
||||
weight_vector: "np.ndarray | None" = None,
|
||||
jac_exprs: "List[List[Expr]] | None" = None,
|
||||
compiled_eval: "Callable | None" = None,
|
||||
) -> bool:
|
||||
"""Solve ``residuals == 0`` by Newton-Raphson.
|
||||
|
||||
@@ -43,6 +45,12 @@ def newton_solve(
|
||||
lstsq step is column-scaled to produce the weighted
|
||||
minimum-norm solution (prefer small movements in
|
||||
high-weight parameters).
|
||||
jac_exprs:
|
||||
Pre-built symbolic Jacobian (list-of-lists of Expr). When
|
||||
provided, skips the ``diff().simplify()`` step.
|
||||
compiled_eval:
|
||||
Pre-compiled evaluation function from :mod:`codegen`. When
|
||||
provided, uses flat compiled code instead of tree-walk eval.
|
||||
|
||||
Returns True if converged within *max_iter* iterations.
|
||||
"""
|
||||
@@ -53,29 +61,41 @@ def newton_solve(
|
||||
if n_free == 0 or n_res == 0:
|
||||
return True
|
||||
|
||||
# Build symbolic Jacobian once (list-of-lists of simplified Expr)
|
||||
jac_exprs: List[List[Expr]] = []
|
||||
# Build symbolic Jacobian once (or reuse pre-built)
|
||||
if jac_exprs is None:
|
||||
jac_exprs = []
|
||||
for r in residuals:
|
||||
row = []
|
||||
for name in free:
|
||||
row.append(r.diff(name).simplify())
|
||||
jac_exprs.append(row)
|
||||
|
||||
# Try compilation if not provided
|
||||
if compiled_eval is None:
|
||||
from .codegen import try_compile_system
|
||||
|
||||
compiled_eval = try_compile_system(residuals, jac_exprs, n_res, n_free)
|
||||
|
||||
# Pre-allocate arrays reused across iterations
|
||||
r_vec = np.empty(n_res)
|
||||
J = np.zeros((n_res, n_free))
|
||||
|
||||
for _it in range(max_iter):
|
||||
if compiled_eval is not None:
|
||||
J[:] = 0.0
|
||||
compiled_eval(params.env_ref(), r_vec, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
|
||||
# Evaluate residual vector
|
||||
r_vec = np.array([r.eval(env) for r in residuals])
|
||||
r_norm = np.linalg.norm(r_vec)
|
||||
if r_norm < tol:
|
||||
return True
|
||||
|
||||
# Evaluate Jacobian matrix
|
||||
J = np.empty((n_res, n_free))
|
||||
for i, r in enumerate(residuals):
|
||||
r_vec[i] = r.eval(env)
|
||||
for i in range(n_res):
|
||||
for j in range(n_free):
|
||||
J[i, j] = jac_exprs[i][j].eval(env)
|
||||
|
||||
r_norm = np.linalg.norm(r_vec)
|
||||
if r_norm < tol:
|
||||
return True
|
||||
|
||||
# Solve J @ dx = -r (least-squares handles rank-deficient)
|
||||
if weight_vector is not None:
|
||||
# Column-scale J by W^{-1/2} for weighted minimum-norm
|
||||
@@ -100,9 +120,13 @@ def newton_solve(
|
||||
_renormalize_quats(params, quat_groups)
|
||||
|
||||
# Check final residual
|
||||
if compiled_eval is not None:
|
||||
compiled_eval(params.env_ref(), r_vec, J)
|
||||
else:
|
||||
env = params.get_env()
|
||||
r_vec = np.array([r.eval(env) for r in residuals])
|
||||
return np.linalg.norm(r_vec) < tol
|
||||
for i, r in enumerate(residuals):
|
||||
r_vec[i] = r.eval(env)
|
||||
return bool(np.linalg.norm(r_vec) < tol)
|
||||
|
||||
|
||||
def _renormalize_quats(
|
||||
|
||||
@@ -60,6 +60,14 @@ class ParamTable:
|
||||
"""Return a snapshot of all current values (for Expr.eval)."""
|
||||
return dict(self._values)
|
||||
|
||||
def env_ref(self) -> Dict[str, float]:
|
||||
"""Return a direct reference to the internal values dict.
|
||||
|
||||
Faster than :meth:`get_env` (no copy). Safe when the caller
|
||||
only reads during evaluation and mutates via :meth:`set_free_vector`.
|
||||
"""
|
||||
return self._values
|
||||
|
||||
def free_names(self) -> List[str]:
|
||||
"""Return ordered list of free (non-fixed) parameter names."""
|
||||
return list(self._free_order)
|
||||
|
||||
@@ -154,6 +154,7 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
residuals = single_equation_pass(residuals, system.params)
|
||||
|
||||
# Solve (decomposed for large assemblies, monolithic for small)
|
||||
jac_exprs = None # may be populated by _monolithic_solve
|
||||
if n_free_bodies >= _DECOMPOSE_THRESHOLD:
|
||||
grounded_ids = {pid for pid, b in system.bodies.items() if b.grounded}
|
||||
clusters = decompose(ctx.constraints, grounded_ids)
|
||||
@@ -172,7 +173,7 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
system.params,
|
||||
)
|
||||
else:
|
||||
converged = _monolithic_solve(
|
||||
converged, jac_exprs = _monolithic_solve(
|
||||
residuals,
|
||||
system.params,
|
||||
system.quat_groups,
|
||||
@@ -185,7 +186,7 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
n_free_bodies,
|
||||
_DECOMPOSE_THRESHOLD,
|
||||
)
|
||||
converged = _monolithic_solve(
|
||||
converged, jac_exprs = _monolithic_solve(
|
||||
residuals,
|
||||
system.params,
|
||||
system.quat_groups,
|
||||
@@ -194,7 +195,7 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
)
|
||||
|
||||
# DOF
|
||||
dof = count_dof(residuals, system.params)
|
||||
dof = count_dof(residuals, system.params, jac_exprs=jac_exprs)
|
||||
|
||||
# Build result
|
||||
result = kcsolve.SolveResult()
|
||||
@@ -210,6 +211,7 @@ class KindredSolver(kcsolve.IKCSolver):
|
||||
system.params,
|
||||
system.residual_ranges,
|
||||
ctx,
|
||||
jac_exprs=jac_exprs,
|
||||
)
|
||||
|
||||
result.placements = _extract_placements(system.params, system.bodies)
|
||||
@@ -419,13 +421,15 @@ def _build_system(ctx):
|
||||
return system
|
||||
|
||||
|
||||
def _run_diagnostics(residuals, params, residual_ranges, ctx):
|
||||
def _run_diagnostics(residuals, params, residual_ranges, ctx, jac_exprs=None):
|
||||
"""Run overconstrained detection and return kcsolve diagnostics."""
|
||||
diagnostics = []
|
||||
if not hasattr(kcsolve, "ConstraintDiagnostic"):
|
||||
return diagnostics
|
||||
|
||||
diags = find_overconstrained(residuals, params, residual_ranges)
|
||||
diags = find_overconstrained(
|
||||
residuals, params, residual_ranges, jac_exprs=jac_exprs
|
||||
)
|
||||
for d in diags:
|
||||
cd = kcsolve.ConstraintDiagnostic()
|
||||
cd.constraint_id = ctx.constraints[d.constraint_index].id
|
||||
@@ -458,7 +462,23 @@ def _extract_placements(params, bodies):
|
||||
def _monolithic_solve(
|
||||
all_residuals, params, quat_groups, post_step=None, weight_vector=None
|
||||
):
|
||||
"""Newton-Raphson solve with BFGS fallback on the full system."""
|
||||
"""Newton-Raphson solve with BFGS fallback on the full system.
|
||||
|
||||
Returns ``(converged, jac_exprs)`` so the caller can reuse the
|
||||
symbolic Jacobian for DOF counting / diagnostics.
|
||||
"""
|
||||
from .codegen import try_compile_system
|
||||
|
||||
free = params.free_names()
|
||||
n_res = len(all_residuals)
|
||||
n_free = len(free)
|
||||
|
||||
# Build symbolic Jacobian once
|
||||
jac_exprs = [[r.diff(name).simplify() for name in free] for r in all_residuals]
|
||||
|
||||
# Compile once
|
||||
compiled_eval = try_compile_system(all_residuals, jac_exprs, n_res, n_free)
|
||||
|
||||
t0 = time.perf_counter()
|
||||
converged = newton_solve(
|
||||
all_residuals,
|
||||
@@ -468,6 +488,8 @@ def _monolithic_solve(
|
||||
tol=1e-10,
|
||||
post_step=post_step,
|
||||
weight_vector=weight_vector,
|
||||
jac_exprs=jac_exprs,
|
||||
compiled_eval=compiled_eval,
|
||||
)
|
||||
nr_ms = (time.perf_counter() - t0) * 1000
|
||||
if not converged:
|
||||
@@ -482,6 +504,8 @@ def _monolithic_solve(
|
||||
max_iter=200,
|
||||
tol=1e-10,
|
||||
weight_vector=weight_vector,
|
||||
jac_exprs=jac_exprs,
|
||||
compiled_eval=compiled_eval,
|
||||
)
|
||||
bfgs_ms = (time.perf_counter() - t1) * 1000
|
||||
if converged:
|
||||
@@ -490,7 +514,7 @@ def _monolithic_solve(
|
||||
log.warning("_monolithic_solve: BFGS also failed (%.1f ms)", bfgs_ms)
|
||||
else:
|
||||
log.debug("_monolithic_solve: Newton-Raphson converged (%.1f ms)", nr_ms)
|
||||
return converged
|
||||
return converged, jac_exprs
|
||||
|
||||
|
||||
def _build_constraint(
|
||||
|
||||
357
tests/test_codegen.py
Normal file
357
tests/test_codegen.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""Tests for the codegen module — CSE, compilation, and compiled evaluation."""
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from kindred_solver.codegen import (
|
||||
_build_cse,
|
||||
_find_nonzero_entries,
|
||||
compile_system,
|
||||
try_compile_system,
|
||||
)
|
||||
from kindred_solver.expr import (
|
||||
ZERO,
|
||||
Add,
|
||||
Const,
|
||||
Cos,
|
||||
Div,
|
||||
Mul,
|
||||
Neg,
|
||||
Pow,
|
||||
Sin,
|
||||
Sqrt,
|
||||
Sub,
|
||||
Var,
|
||||
)
|
||||
from kindred_solver.newton import newton_solve
|
||||
from kindred_solver.params import ParamTable
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# to_code() — round-trip correctness for each Expr type
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToCode:
|
||||
"""Verify that eval(expr.to_code()) == expr.eval(env) for each node."""
|
||||
|
||||
NS = {"_sin": math.sin, "_cos": math.cos, "_sqrt": math.sqrt}
|
||||
|
||||
def _check(self, expr, env):
|
||||
code = expr.to_code()
|
||||
ns = dict(self.NS)
|
||||
ns["env"] = env
|
||||
compiled = eval(code, ns)
|
||||
expected = expr.eval(env)
|
||||
assert abs(compiled - expected) < 1e-15, (
|
||||
f"{code} = {compiled}, expected {expected}"
|
||||
)
|
||||
|
||||
def test_const(self):
|
||||
self._check(Const(3.14), {})
|
||||
|
||||
def test_const_negative(self):
|
||||
self._check(Const(-2.5), {})
|
||||
|
||||
def test_const_zero(self):
|
||||
self._check(Const(0.0), {})
|
||||
|
||||
def test_var(self):
|
||||
self._check(Var("x"), {"x": 7.0})
|
||||
|
||||
def test_neg(self):
|
||||
self._check(Neg(Var("x")), {"x": 3.0})
|
||||
|
||||
def test_add(self):
|
||||
self._check(Add(Var("x"), Const(2.0)), {"x": 5.0})
|
||||
|
||||
def test_sub(self):
|
||||
self._check(Sub(Var("x"), Var("y")), {"x": 5.0, "y": 3.0})
|
||||
|
||||
def test_mul(self):
|
||||
self._check(Mul(Var("x"), Const(3.0)), {"x": 4.0})
|
||||
|
||||
def test_div(self):
|
||||
self._check(Div(Var("x"), Const(2.0)), {"x": 6.0})
|
||||
|
||||
def test_pow(self):
|
||||
self._check(Pow(Var("x"), Const(3.0)), {"x": 2.0})
|
||||
|
||||
def test_sin(self):
|
||||
self._check(Sin(Var("x")), {"x": 1.0})
|
||||
|
||||
def test_cos(self):
|
||||
self._check(Cos(Var("x")), {"x": 1.0})
|
||||
|
||||
def test_sqrt(self):
|
||||
self._check(Sqrt(Var("x")), {"x": 9.0})
|
||||
|
||||
def test_nested(self):
|
||||
"""Complex nested expression."""
|
||||
x, y = Var("x"), Var("y")
|
||||
expr = Add(Mul(Sin(x), Cos(y)), Sqrt(Sub(x, Neg(y))))
|
||||
self._check(expr, {"x": 2.0, "y": 1.0})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CSE
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCSE:
|
||||
def test_no_sharing(self):
|
||||
"""Distinct expressions produce no CSE temps."""
|
||||
a = Var("x") + Const(1.0)
|
||||
b = Var("y") + Const(2.0)
|
||||
id_to_temp, temps = _build_cse([a, b])
|
||||
assert len(temps) == 0
|
||||
|
||||
def test_shared_subtree(self):
|
||||
"""Same node object used in two places is extracted."""
|
||||
x = Var("x")
|
||||
shared = x * Const(2.0) # single Mul node
|
||||
a = shared + Const(1.0)
|
||||
b = shared + Const(3.0)
|
||||
id_to_temp, temps = _build_cse([a, b])
|
||||
assert len(temps) >= 1
|
||||
# The shared Mul node should be a temp
|
||||
assert id(shared) in id_to_temp
|
||||
|
||||
def test_leaf_nodes_not_extracted(self):
|
||||
"""Const and Var nodes are never extracted as temps."""
|
||||
x = Var("x")
|
||||
c = Const(5.0)
|
||||
a = x + c
|
||||
b = x + c
|
||||
id_to_temp, temps = _build_cse([a, b])
|
||||
for _, expr in temps:
|
||||
assert not isinstance(expr, (Const, Var))
|
||||
|
||||
def test_dependency_order(self):
|
||||
"""Temps are in dependency order (dependencies first)."""
|
||||
x = Var("x")
|
||||
inner = x * Const(2.0)
|
||||
outer = inner + inner # uses inner twice
|
||||
wrapper_a = outer * Const(3.0)
|
||||
wrapper_b = outer * Const(4.0)
|
||||
id_to_temp, temps = _build_cse([wrapper_a, wrapper_b])
|
||||
# If both inner and outer are temps, inner must come first
|
||||
temp_names = [name for name, _ in temps]
|
||||
temp_ids = [id(expr) for _, expr in temps]
|
||||
if id(inner) in set(id_to_temp) and id(outer) in set(id_to_temp):
|
||||
inner_idx = temp_ids.index(id(inner))
|
||||
outer_idx = temp_ids.index(id(outer))
|
||||
assert inner_idx < outer_idx
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Sparsity detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSparsity:
|
||||
def test_zero_entries_skipped(self):
|
||||
nz = _find_nonzero_entries(
|
||||
[
|
||||
[Const(0.0), Var("x"), Const(0.0)],
|
||||
[Const(1.0), Const(0.0), Var("y")],
|
||||
]
|
||||
)
|
||||
assert nz == [(0, 1), (1, 0), (1, 2)]
|
||||
|
||||
def test_all_nonzero(self):
|
||||
nz = _find_nonzero_entries(
|
||||
[
|
||||
[Var("x"), Const(1.0)],
|
||||
]
|
||||
)
|
||||
assert nz == [(0, 0), (0, 1)]
|
||||
|
||||
def test_all_zero(self):
|
||||
nz = _find_nonzero_entries(
|
||||
[
|
||||
[Const(0.0), Const(0.0)],
|
||||
]
|
||||
)
|
||||
assert nz == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full compilation pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompileSystem:
|
||||
def test_simple_linear(self):
|
||||
"""Compile and evaluate a trivial system: r = x - 3, J = [[1]]."""
|
||||
x = Var("x")
|
||||
residuals = [x - Const(3.0)]
|
||||
jac_exprs = [[Const(1.0)]] # d(x-3)/dx = 1
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 1, 1)
|
||||
|
||||
env = {"x": 5.0}
|
||||
r_vec = np.empty(1)
|
||||
J = np.zeros((1, 1))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(r_vec[0] - 2.0) < 1e-15 # 5 - 3 = 2
|
||||
assert abs(J[0, 0] - 1.0) < 1e-15
|
||||
|
||||
def test_two_variable_system(self):
|
||||
"""Compile: r0 = x + y - 5, r1 = x - y - 1."""
|
||||
x, y = Var("x"), Var("y")
|
||||
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
|
||||
jac_exprs = [
|
||||
[Const(1.0), Const(1.0)], # d(r0)/dx, d(r0)/dy
|
||||
[Const(1.0), Const(-1.0)], # d(r1)/dx, d(r1)/dy
|
||||
]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 2, 2)
|
||||
|
||||
env = {"x": 3.0, "y": 2.0}
|
||||
r_vec = np.empty(2)
|
||||
J = np.zeros((2, 2))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(r_vec[0] - 0.0) < 1e-15
|
||||
assert abs(r_vec[1] - 0.0) < 1e-15
|
||||
assert abs(J[0, 0] - 1.0) < 1e-15
|
||||
assert abs(J[0, 1] - 1.0) < 1e-15
|
||||
assert abs(J[1, 0] - 1.0) < 1e-15
|
||||
assert abs(J[1, 1] - (-1.0)) < 1e-15
|
||||
|
||||
def test_sparse_jacobian(self):
|
||||
"""Zero Jacobian entries remain zero after compiled evaluation."""
|
||||
x = Var("x")
|
||||
y = Var("y")
|
||||
# r0 depends on x only, r1 depends on y only
|
||||
residuals = [x - Const(1.0), y - Const(2.0)]
|
||||
jac_exprs = [
|
||||
[Const(1.0), Const(0.0)],
|
||||
[Const(0.0), Const(1.0)],
|
||||
]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 2, 2)
|
||||
|
||||
env = {"x": 3.0, "y": 4.0}
|
||||
r_vec = np.empty(2)
|
||||
J = np.zeros((2, 2))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(J[0, 1]) < 1e-15 # should remain zero
|
||||
assert abs(J[1, 0]) < 1e-15 # should remain zero
|
||||
assert abs(J[0, 0] - 1.0) < 1e-15
|
||||
assert abs(J[1, 1] - 1.0) < 1e-15
|
||||
|
||||
def test_trig_functions(self):
|
||||
"""Compiled evaluation handles Sin/Cos/Sqrt."""
|
||||
x = Var("x")
|
||||
residuals = [Sin(x), Cos(x), Sqrt(x)]
|
||||
jac_exprs = [
|
||||
[Cos(x)],
|
||||
[Neg(Sin(x))],
|
||||
[Div(Const(1.0), Mul(Const(2.0), Sqrt(x)))],
|
||||
]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 3, 1)
|
||||
|
||||
env = {"x": 1.0}
|
||||
r_vec = np.empty(3)
|
||||
J = np.zeros((3, 1))
|
||||
fn(env, r_vec, J)
|
||||
|
||||
assert abs(r_vec[0] - math.sin(1.0)) < 1e-15
|
||||
assert abs(r_vec[1] - math.cos(1.0)) < 1e-15
|
||||
assert abs(r_vec[2] - math.sqrt(1.0)) < 1e-15
|
||||
assert abs(J[0, 0] - math.cos(1.0)) < 1e-15
|
||||
assert abs(J[1, 0] - (-math.sin(1.0))) < 1e-15
|
||||
assert abs(J[2, 0] - (1.0 / (2.0 * math.sqrt(1.0)))) < 1e-15
|
||||
|
||||
def test_matches_tree_walk(self):
|
||||
"""Compiled eval produces identical results to tree-walk eval."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 2.0)
|
||||
y = pt.add("y", 3.0)
|
||||
|
||||
residuals = [x * y - Const(6.0), x * x + y - Const(7.0)]
|
||||
free = pt.free_names()
|
||||
|
||||
jac_exprs = [[r.diff(name).simplify() for name in free] for r in residuals]
|
||||
|
||||
fn = compile_system(residuals, jac_exprs, 2, 2)
|
||||
|
||||
# Tree-walk eval
|
||||
env = pt.get_env()
|
||||
r_tree = np.array([r.eval(env) for r in residuals])
|
||||
J_tree = np.empty((2, 2))
|
||||
for i in range(2):
|
||||
for j in range(2):
|
||||
J_tree[i, j] = jac_exprs[i][j].eval(env)
|
||||
|
||||
# Compiled eval
|
||||
r_comp = np.empty(2)
|
||||
J_comp = np.zeros((2, 2))
|
||||
fn(pt.env_ref(), r_comp, J_comp)
|
||||
|
||||
np.testing.assert_allclose(r_comp, r_tree, atol=1e-15)
|
||||
np.testing.assert_allclose(J_comp, J_tree, atol=1e-15)
|
||||
|
||||
|
||||
class TestTryCompile:
|
||||
def test_returns_callable(self):
|
||||
x = Var("x")
|
||||
fn = try_compile_system([x], [[Const(1.0)]], 1, 1)
|
||||
assert fn is not None
|
||||
|
||||
def test_empty_system(self):
|
||||
"""Empty system returns None (nothing to compile)."""
|
||||
fn = try_compile_system([], [], 0, 0)
|
||||
# Empty system is handled by the solver before codegen is reached,
|
||||
# so returning None is acceptable.
|
||||
assert fn is None or callable(fn)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: Newton with compiled eval matches tree-walk
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCompiledNewton:
|
||||
def test_single_linear(self):
|
||||
"""Solve x - 3 = 0 with compiled eval."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 0.0)
|
||||
residuals = [x - Const(3.0)]
|
||||
assert newton_solve(residuals, pt) is True
|
||||
assert abs(pt.get_value("x") - 3.0) < 1e-10
|
||||
|
||||
def test_two_variables(self):
|
||||
"""Solve x + y = 5, x - y = 1 with compiled eval."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 0.0)
|
||||
y = pt.add("y", 0.0)
|
||||
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
|
||||
assert newton_solve(residuals, pt) is True
|
||||
assert abs(pt.get_value("x") - 3.0) < 1e-10
|
||||
assert abs(pt.get_value("y") - 2.0) < 1e-10
|
||||
|
||||
def test_quadratic(self):
|
||||
"""Solve x^2 - 4 = 0 starting from x=1."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 1.0)
|
||||
residuals = [x * x - Const(4.0)]
|
||||
assert newton_solve(residuals, pt) is True
|
||||
assert abs(pt.get_value("x") - 2.0) < 1e-10
|
||||
|
||||
def test_nonlinear_system(self):
|
||||
"""Compiled eval converges for a nonlinear system: xy=6, x+y=5."""
|
||||
pt = ParamTable()
|
||||
x = pt.add("x", 2.0)
|
||||
y = pt.add("y", 3.5)
|
||||
residuals = [x * y - Const(6.0), x + y - Const(5.0)]
|
||||
assert newton_solve(residuals, pt, max_iter=100) is True
|
||||
# Solutions are (2, 3) or (3, 2) — check they satisfy both equations
|
||||
xv, yv = pt.get_value("x"), pt.get_value("y")
|
||||
assert abs(xv * yv - 6.0) < 1e-10
|
||||
assert abs(xv + yv - 5.0) < 1e-10
|
||||
Reference in New Issue
Block a user