Add a code generation pipeline that compiles Expr DAGs into flat Python functions, eliminating recursive tree-walk dispatch in the Newton-Raphson inner loop. Key changes: - Add to_code() method to all 11 Expr node types (expr.py) - New codegen.py module with CSE (common subexpression elimination), sparsity detection, and compile()/exec() compilation pipeline - Add ParamTable.env_ref() to avoid dict copies per iteration (params.py) - Newton and BFGS solvers accept pre-built jac_exprs and compiled_eval to avoid redundant diff/simplify and enable compiled evaluation - count_dof() and diagnostics accept pre-built jac_exprs - solver.py builds symbolic Jacobian once, compiles once, passes to all consumers (_monolithic_solve, count_dof, diagnostics) - Automatic fallback: if codegen fails, tree-walk eval is used Expected performance impact: - ~10-20x faster Jacobian evaluation (no recursive dispatch) - ~2-5x additional from CSE on quaternion-heavy systems - ~3x fewer entries evaluated via sparsity detection - Eliminates redundant diff().simplify() in DOF/diagnostics
358 lines
11 KiB
Python
358 lines
11 KiB
Python
"""Tests for the codegen module — CSE, compilation, and compiled evaluation."""
|
|
|
|
import math
|
|
|
|
import numpy as np
|
|
import pytest
|
|
from kindred_solver.codegen import (
|
|
_build_cse,
|
|
_find_nonzero_entries,
|
|
compile_system,
|
|
try_compile_system,
|
|
)
|
|
from kindred_solver.expr import (
|
|
ZERO,
|
|
Add,
|
|
Const,
|
|
Cos,
|
|
Div,
|
|
Mul,
|
|
Neg,
|
|
Pow,
|
|
Sin,
|
|
Sqrt,
|
|
Sub,
|
|
Var,
|
|
)
|
|
from kindred_solver.newton import newton_solve
|
|
from kindred_solver.params import ParamTable
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# to_code() — round-trip correctness for each Expr type
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestToCode:
|
|
"""Verify that eval(expr.to_code()) == expr.eval(env) for each node."""
|
|
|
|
NS = {"_sin": math.sin, "_cos": math.cos, "_sqrt": math.sqrt}
|
|
|
|
def _check(self, expr, env):
|
|
code = expr.to_code()
|
|
ns = dict(self.NS)
|
|
ns["env"] = env
|
|
compiled = eval(code, ns)
|
|
expected = expr.eval(env)
|
|
assert abs(compiled - expected) < 1e-15, (
|
|
f"{code} = {compiled}, expected {expected}"
|
|
)
|
|
|
|
def test_const(self):
|
|
self._check(Const(3.14), {})
|
|
|
|
def test_const_negative(self):
|
|
self._check(Const(-2.5), {})
|
|
|
|
def test_const_zero(self):
|
|
self._check(Const(0.0), {})
|
|
|
|
def test_var(self):
|
|
self._check(Var("x"), {"x": 7.0})
|
|
|
|
def test_neg(self):
|
|
self._check(Neg(Var("x")), {"x": 3.0})
|
|
|
|
def test_add(self):
|
|
self._check(Add(Var("x"), Const(2.0)), {"x": 5.0})
|
|
|
|
def test_sub(self):
|
|
self._check(Sub(Var("x"), Var("y")), {"x": 5.0, "y": 3.0})
|
|
|
|
def test_mul(self):
|
|
self._check(Mul(Var("x"), Const(3.0)), {"x": 4.0})
|
|
|
|
def test_div(self):
|
|
self._check(Div(Var("x"), Const(2.0)), {"x": 6.0})
|
|
|
|
def test_pow(self):
|
|
self._check(Pow(Var("x"), Const(3.0)), {"x": 2.0})
|
|
|
|
def test_sin(self):
|
|
self._check(Sin(Var("x")), {"x": 1.0})
|
|
|
|
def test_cos(self):
|
|
self._check(Cos(Var("x")), {"x": 1.0})
|
|
|
|
def test_sqrt(self):
|
|
self._check(Sqrt(Var("x")), {"x": 9.0})
|
|
|
|
def test_nested(self):
|
|
"""Complex nested expression."""
|
|
x, y = Var("x"), Var("y")
|
|
expr = Add(Mul(Sin(x), Cos(y)), Sqrt(Sub(x, Neg(y))))
|
|
self._check(expr, {"x": 2.0, "y": 1.0})
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CSE
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCSE:
|
|
def test_no_sharing(self):
|
|
"""Distinct expressions produce no CSE temps."""
|
|
a = Var("x") + Const(1.0)
|
|
b = Var("y") + Const(2.0)
|
|
id_to_temp, temps = _build_cse([a, b])
|
|
assert len(temps) == 0
|
|
|
|
def test_shared_subtree(self):
|
|
"""Same node object used in two places is extracted."""
|
|
x = Var("x")
|
|
shared = x * Const(2.0) # single Mul node
|
|
a = shared + Const(1.0)
|
|
b = shared + Const(3.0)
|
|
id_to_temp, temps = _build_cse([a, b])
|
|
assert len(temps) >= 1
|
|
# The shared Mul node should be a temp
|
|
assert id(shared) in id_to_temp
|
|
|
|
def test_leaf_nodes_not_extracted(self):
|
|
"""Const and Var nodes are never extracted as temps."""
|
|
x = Var("x")
|
|
c = Const(5.0)
|
|
a = x + c
|
|
b = x + c
|
|
id_to_temp, temps = _build_cse([a, b])
|
|
for _, expr in temps:
|
|
assert not isinstance(expr, (Const, Var))
|
|
|
|
def test_dependency_order(self):
|
|
"""Temps are in dependency order (dependencies first)."""
|
|
x = Var("x")
|
|
inner = x * Const(2.0)
|
|
outer = inner + inner # uses inner twice
|
|
wrapper_a = outer * Const(3.0)
|
|
wrapper_b = outer * Const(4.0)
|
|
id_to_temp, temps = _build_cse([wrapper_a, wrapper_b])
|
|
# If both inner and outer are temps, inner must come first
|
|
temp_names = [name for name, _ in temps]
|
|
temp_ids = [id(expr) for _, expr in temps]
|
|
if id(inner) in set(id_to_temp) and id(outer) in set(id_to_temp):
|
|
inner_idx = temp_ids.index(id(inner))
|
|
outer_idx = temp_ids.index(id(outer))
|
|
assert inner_idx < outer_idx
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Sparsity detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSparsity:
|
|
def test_zero_entries_skipped(self):
|
|
nz = _find_nonzero_entries(
|
|
[
|
|
[Const(0.0), Var("x"), Const(0.0)],
|
|
[Const(1.0), Const(0.0), Var("y")],
|
|
]
|
|
)
|
|
assert nz == [(0, 1), (1, 0), (1, 2)]
|
|
|
|
def test_all_nonzero(self):
|
|
nz = _find_nonzero_entries(
|
|
[
|
|
[Var("x"), Const(1.0)],
|
|
]
|
|
)
|
|
assert nz == [(0, 0), (0, 1)]
|
|
|
|
def test_all_zero(self):
|
|
nz = _find_nonzero_entries(
|
|
[
|
|
[Const(0.0), Const(0.0)],
|
|
]
|
|
)
|
|
assert nz == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Full compilation pipeline
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCompileSystem:
|
|
def test_simple_linear(self):
|
|
"""Compile and evaluate a trivial system: r = x - 3, J = [[1]]."""
|
|
x = Var("x")
|
|
residuals = [x - Const(3.0)]
|
|
jac_exprs = [[Const(1.0)]] # d(x-3)/dx = 1
|
|
|
|
fn = compile_system(residuals, jac_exprs, 1, 1)
|
|
|
|
env = {"x": 5.0}
|
|
r_vec = np.empty(1)
|
|
J = np.zeros((1, 1))
|
|
fn(env, r_vec, J)
|
|
|
|
assert abs(r_vec[0] - 2.0) < 1e-15 # 5 - 3 = 2
|
|
assert abs(J[0, 0] - 1.0) < 1e-15
|
|
|
|
def test_two_variable_system(self):
|
|
"""Compile: r0 = x + y - 5, r1 = x - y - 1."""
|
|
x, y = Var("x"), Var("y")
|
|
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
|
|
jac_exprs = [
|
|
[Const(1.0), Const(1.0)], # d(r0)/dx, d(r0)/dy
|
|
[Const(1.0), Const(-1.0)], # d(r1)/dx, d(r1)/dy
|
|
]
|
|
|
|
fn = compile_system(residuals, jac_exprs, 2, 2)
|
|
|
|
env = {"x": 3.0, "y": 2.0}
|
|
r_vec = np.empty(2)
|
|
J = np.zeros((2, 2))
|
|
fn(env, r_vec, J)
|
|
|
|
assert abs(r_vec[0] - 0.0) < 1e-15
|
|
assert abs(r_vec[1] - 0.0) < 1e-15
|
|
assert abs(J[0, 0] - 1.0) < 1e-15
|
|
assert abs(J[0, 1] - 1.0) < 1e-15
|
|
assert abs(J[1, 0] - 1.0) < 1e-15
|
|
assert abs(J[1, 1] - (-1.0)) < 1e-15
|
|
|
|
def test_sparse_jacobian(self):
|
|
"""Zero Jacobian entries remain zero after compiled evaluation."""
|
|
x = Var("x")
|
|
y = Var("y")
|
|
# r0 depends on x only, r1 depends on y only
|
|
residuals = [x - Const(1.0), y - Const(2.0)]
|
|
jac_exprs = [
|
|
[Const(1.0), Const(0.0)],
|
|
[Const(0.0), Const(1.0)],
|
|
]
|
|
|
|
fn = compile_system(residuals, jac_exprs, 2, 2)
|
|
|
|
env = {"x": 3.0, "y": 4.0}
|
|
r_vec = np.empty(2)
|
|
J = np.zeros((2, 2))
|
|
fn(env, r_vec, J)
|
|
|
|
assert abs(J[0, 1]) < 1e-15 # should remain zero
|
|
assert abs(J[1, 0]) < 1e-15 # should remain zero
|
|
assert abs(J[0, 0] - 1.0) < 1e-15
|
|
assert abs(J[1, 1] - 1.0) < 1e-15
|
|
|
|
def test_trig_functions(self):
|
|
"""Compiled evaluation handles Sin/Cos/Sqrt."""
|
|
x = Var("x")
|
|
residuals = [Sin(x), Cos(x), Sqrt(x)]
|
|
jac_exprs = [
|
|
[Cos(x)],
|
|
[Neg(Sin(x))],
|
|
[Div(Const(1.0), Mul(Const(2.0), Sqrt(x)))],
|
|
]
|
|
|
|
fn = compile_system(residuals, jac_exprs, 3, 1)
|
|
|
|
env = {"x": 1.0}
|
|
r_vec = np.empty(3)
|
|
J = np.zeros((3, 1))
|
|
fn(env, r_vec, J)
|
|
|
|
assert abs(r_vec[0] - math.sin(1.0)) < 1e-15
|
|
assert abs(r_vec[1] - math.cos(1.0)) < 1e-15
|
|
assert abs(r_vec[2] - math.sqrt(1.0)) < 1e-15
|
|
assert abs(J[0, 0] - math.cos(1.0)) < 1e-15
|
|
assert abs(J[1, 0] - (-math.sin(1.0))) < 1e-15
|
|
assert abs(J[2, 0] - (1.0 / (2.0 * math.sqrt(1.0)))) < 1e-15
|
|
|
|
def test_matches_tree_walk(self):
|
|
"""Compiled eval produces identical results to tree-walk eval."""
|
|
pt = ParamTable()
|
|
x = pt.add("x", 2.0)
|
|
y = pt.add("y", 3.0)
|
|
|
|
residuals = [x * y - Const(6.0), x * x + y - Const(7.0)]
|
|
free = pt.free_names()
|
|
|
|
jac_exprs = [[r.diff(name).simplify() for name in free] for r in residuals]
|
|
|
|
fn = compile_system(residuals, jac_exprs, 2, 2)
|
|
|
|
# Tree-walk eval
|
|
env = pt.get_env()
|
|
r_tree = np.array([r.eval(env) for r in residuals])
|
|
J_tree = np.empty((2, 2))
|
|
for i in range(2):
|
|
for j in range(2):
|
|
J_tree[i, j] = jac_exprs[i][j].eval(env)
|
|
|
|
# Compiled eval
|
|
r_comp = np.empty(2)
|
|
J_comp = np.zeros((2, 2))
|
|
fn(pt.env_ref(), r_comp, J_comp)
|
|
|
|
np.testing.assert_allclose(r_comp, r_tree, atol=1e-15)
|
|
np.testing.assert_allclose(J_comp, J_tree, atol=1e-15)
|
|
|
|
|
|
class TestTryCompile:
|
|
def test_returns_callable(self):
|
|
x = Var("x")
|
|
fn = try_compile_system([x], [[Const(1.0)]], 1, 1)
|
|
assert fn is not None
|
|
|
|
def test_empty_system(self):
|
|
"""Empty system returns None (nothing to compile)."""
|
|
fn = try_compile_system([], [], 0, 0)
|
|
# Empty system is handled by the solver before codegen is reached,
|
|
# so returning None is acceptable.
|
|
assert fn is None or callable(fn)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration: Newton with compiled eval matches tree-walk
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCompiledNewton:
|
|
def test_single_linear(self):
|
|
"""Solve x - 3 = 0 with compiled eval."""
|
|
pt = ParamTable()
|
|
x = pt.add("x", 0.0)
|
|
residuals = [x - Const(3.0)]
|
|
assert newton_solve(residuals, pt) is True
|
|
assert abs(pt.get_value("x") - 3.0) < 1e-10
|
|
|
|
def test_two_variables(self):
|
|
"""Solve x + y = 5, x - y = 1 with compiled eval."""
|
|
pt = ParamTable()
|
|
x = pt.add("x", 0.0)
|
|
y = pt.add("y", 0.0)
|
|
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
|
|
assert newton_solve(residuals, pt) is True
|
|
assert abs(pt.get_value("x") - 3.0) < 1e-10
|
|
assert abs(pt.get_value("y") - 2.0) < 1e-10
|
|
|
|
def test_quadratic(self):
|
|
"""Solve x^2 - 4 = 0 starting from x=1."""
|
|
pt = ParamTable()
|
|
x = pt.add("x", 1.0)
|
|
residuals = [x * x - Const(4.0)]
|
|
assert newton_solve(residuals, pt) is True
|
|
assert abs(pt.get_value("x") - 2.0) < 1e-10
|
|
|
|
def test_nonlinear_system(self):
|
|
"""Compiled eval converges for a nonlinear system: xy=6, x+y=5."""
|
|
pt = ParamTable()
|
|
x = pt.add("x", 2.0)
|
|
y = pt.add("y", 3.5)
|
|
residuals = [x * y - Const(6.0), x + y - Const(5.0)]
|
|
assert newton_solve(residuals, pt, max_iter=100) is True
|
|
# Solutions are (2, 3) or (3, 2) — check they satisfy both equations
|
|
xv, yv = pt.get_value("x"), pt.get_value("y")
|
|
assert abs(xv * yv - 6.0) < 1e-10
|
|
assert abs(xv + yv - 5.0) < 1e-10
|