"""Tests for the codegen module — CSE, compilation, and compiled evaluation.""" import math import numpy as np import pytest from kindred_solver.codegen import ( _build_cse, _find_nonzero_entries, compile_system, try_compile_system, ) from kindred_solver.expr import ( ZERO, Add, Const, Cos, Div, Mul, Neg, Pow, Sin, Sqrt, Sub, Var, ) from kindred_solver.newton import newton_solve from kindred_solver.params import ParamTable # --------------------------------------------------------------------------- # to_code() — round-trip correctness for each Expr type # --------------------------------------------------------------------------- class TestToCode: """Verify that eval(expr.to_code()) == expr.eval(env) for each node.""" NS = {"_sin": math.sin, "_cos": math.cos, "_sqrt": math.sqrt} def _check(self, expr, env): code = expr.to_code() ns = dict(self.NS) ns["env"] = env compiled = eval(code, ns) expected = expr.eval(env) assert abs(compiled - expected) < 1e-15, ( f"{code} = {compiled}, expected {expected}" ) def test_const(self): self._check(Const(3.14), {}) def test_const_negative(self): self._check(Const(-2.5), {}) def test_const_zero(self): self._check(Const(0.0), {}) def test_var(self): self._check(Var("x"), {"x": 7.0}) def test_neg(self): self._check(Neg(Var("x")), {"x": 3.0}) def test_add(self): self._check(Add(Var("x"), Const(2.0)), {"x": 5.0}) def test_sub(self): self._check(Sub(Var("x"), Var("y")), {"x": 5.0, "y": 3.0}) def test_mul(self): self._check(Mul(Var("x"), Const(3.0)), {"x": 4.0}) def test_div(self): self._check(Div(Var("x"), Const(2.0)), {"x": 6.0}) def test_pow(self): self._check(Pow(Var("x"), Const(3.0)), {"x": 2.0}) def test_sin(self): self._check(Sin(Var("x")), {"x": 1.0}) def test_cos(self): self._check(Cos(Var("x")), {"x": 1.0}) def test_sqrt(self): self._check(Sqrt(Var("x")), {"x": 9.0}) def test_nested(self): """Complex nested expression.""" x, y = Var("x"), Var("y") expr = Add(Mul(Sin(x), Cos(y)), Sqrt(Sub(x, Neg(y)))) self._check(expr, {"x": 2.0, "y": 1.0}) # --------------------------------------------------------------------------- # CSE # --------------------------------------------------------------------------- class TestCSE: def test_no_sharing(self): """Distinct expressions produce no CSE temps.""" a = Var("x") + Const(1.0) b = Var("y") + Const(2.0) id_to_temp, temps = _build_cse([a, b]) assert len(temps) == 0 def test_shared_subtree(self): """Same node object used in two places is extracted.""" x = Var("x") shared = x * Const(2.0) # single Mul node a = shared + Const(1.0) b = shared + Const(3.0) id_to_temp, temps = _build_cse([a, b]) assert len(temps) >= 1 # The shared Mul node should be a temp assert id(shared) in id_to_temp def test_leaf_nodes_not_extracted(self): """Const and Var nodes are never extracted as temps.""" x = Var("x") c = Const(5.0) a = x + c b = x + c id_to_temp, temps = _build_cse([a, b]) for _, expr in temps: assert not isinstance(expr, (Const, Var)) def test_dependency_order(self): """Temps are in dependency order (dependencies first).""" x = Var("x") inner = x * Const(2.0) outer = inner + inner # uses inner twice wrapper_a = outer * Const(3.0) wrapper_b = outer * Const(4.0) id_to_temp, temps = _build_cse([wrapper_a, wrapper_b]) # If both inner and outer are temps, inner must come first temp_names = [name for name, _ in temps] temp_ids = [id(expr) for _, expr in temps] if id(inner) in set(id_to_temp) and id(outer) in set(id_to_temp): inner_idx = temp_ids.index(id(inner)) outer_idx = temp_ids.index(id(outer)) assert inner_idx < outer_idx # --------------------------------------------------------------------------- # Sparsity detection # --------------------------------------------------------------------------- class TestSparsity: def test_zero_entries_skipped(self): nz = _find_nonzero_entries( [ [Const(0.0), Var("x"), Const(0.0)], [Const(1.0), Const(0.0), Var("y")], ] ) assert nz == [(0, 1), (1, 0), (1, 2)] def test_all_nonzero(self): nz = _find_nonzero_entries( [ [Var("x"), Const(1.0)], ] ) assert nz == [(0, 0), (0, 1)] def test_all_zero(self): nz = _find_nonzero_entries( [ [Const(0.0), Const(0.0)], ] ) assert nz == [] # --------------------------------------------------------------------------- # Full compilation pipeline # --------------------------------------------------------------------------- class TestCompileSystem: def test_simple_linear(self): """Compile and evaluate a trivial system: r = x - 3, J = [[1]].""" x = Var("x") residuals = [x - Const(3.0)] jac_exprs = [[Const(1.0)]] # d(x-3)/dx = 1 fn = compile_system(residuals, jac_exprs, 1, 1) env = {"x": 5.0} r_vec = np.empty(1) J = np.zeros((1, 1)) fn(env, r_vec, J) assert abs(r_vec[0] - 2.0) < 1e-15 # 5 - 3 = 2 assert abs(J[0, 0] - 1.0) < 1e-15 def test_two_variable_system(self): """Compile: r0 = x + y - 5, r1 = x - y - 1.""" x, y = Var("x"), Var("y") residuals = [x + y - Const(5.0), x - y - Const(1.0)] jac_exprs = [ [Const(1.0), Const(1.0)], # d(r0)/dx, d(r0)/dy [Const(1.0), Const(-1.0)], # d(r1)/dx, d(r1)/dy ] fn = compile_system(residuals, jac_exprs, 2, 2) env = {"x": 3.0, "y": 2.0} r_vec = np.empty(2) J = np.zeros((2, 2)) fn(env, r_vec, J) assert abs(r_vec[0] - 0.0) < 1e-15 assert abs(r_vec[1] - 0.0) < 1e-15 assert abs(J[0, 0] - 1.0) < 1e-15 assert abs(J[0, 1] - 1.0) < 1e-15 assert abs(J[1, 0] - 1.0) < 1e-15 assert abs(J[1, 1] - (-1.0)) < 1e-15 def test_sparse_jacobian(self): """Zero Jacobian entries remain zero after compiled evaluation.""" x = Var("x") y = Var("y") # r0 depends on x only, r1 depends on y only residuals = [x - Const(1.0), y - Const(2.0)] jac_exprs = [ [Const(1.0), Const(0.0)], [Const(0.0), Const(1.0)], ] fn = compile_system(residuals, jac_exprs, 2, 2) env = {"x": 3.0, "y": 4.0} r_vec = np.empty(2) J = np.zeros((2, 2)) fn(env, r_vec, J) assert abs(J[0, 1]) < 1e-15 # should remain zero assert abs(J[1, 0]) < 1e-15 # should remain zero assert abs(J[0, 0] - 1.0) < 1e-15 assert abs(J[1, 1] - 1.0) < 1e-15 def test_trig_functions(self): """Compiled evaluation handles Sin/Cos/Sqrt.""" x = Var("x") residuals = [Sin(x), Cos(x), Sqrt(x)] jac_exprs = [ [Cos(x)], [Neg(Sin(x))], [Div(Const(1.0), Mul(Const(2.0), Sqrt(x)))], ] fn = compile_system(residuals, jac_exprs, 3, 1) env = {"x": 1.0} r_vec = np.empty(3) J = np.zeros((3, 1)) fn(env, r_vec, J) assert abs(r_vec[0] - math.sin(1.0)) < 1e-15 assert abs(r_vec[1] - math.cos(1.0)) < 1e-15 assert abs(r_vec[2] - math.sqrt(1.0)) < 1e-15 assert abs(J[0, 0] - math.cos(1.0)) < 1e-15 assert abs(J[1, 0] - (-math.sin(1.0))) < 1e-15 assert abs(J[2, 0] - (1.0 / (2.0 * math.sqrt(1.0)))) < 1e-15 def test_matches_tree_walk(self): """Compiled eval produces identical results to tree-walk eval.""" pt = ParamTable() x = pt.add("x", 2.0) y = pt.add("y", 3.0) residuals = [x * y - Const(6.0), x * x + y - Const(7.0)] free = pt.free_names() jac_exprs = [[r.diff(name).simplify() for name in free] for r in residuals] fn = compile_system(residuals, jac_exprs, 2, 2) # Tree-walk eval env = pt.get_env() r_tree = np.array([r.eval(env) for r in residuals]) J_tree = np.empty((2, 2)) for i in range(2): for j in range(2): J_tree[i, j] = jac_exprs[i][j].eval(env) # Compiled eval r_comp = np.empty(2) J_comp = np.zeros((2, 2)) fn(pt.env_ref(), r_comp, J_comp) np.testing.assert_allclose(r_comp, r_tree, atol=1e-15) np.testing.assert_allclose(J_comp, J_tree, atol=1e-15) class TestTryCompile: def test_returns_callable(self): x = Var("x") fn = try_compile_system([x], [[Const(1.0)]], 1, 1) assert fn is not None def test_empty_system(self): """Empty system returns None (nothing to compile).""" fn = try_compile_system([], [], 0, 0) # Empty system is handled by the solver before codegen is reached, # so returning None is acceptable. assert fn is None or callable(fn) # --------------------------------------------------------------------------- # Integration: Newton with compiled eval matches tree-walk # --------------------------------------------------------------------------- class TestCompiledNewton: def test_single_linear(self): """Solve x - 3 = 0 with compiled eval.""" pt = ParamTable() x = pt.add("x", 0.0) residuals = [x - Const(3.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 3.0) < 1e-10 def test_two_variables(self): """Solve x + y = 5, x - y = 1 with compiled eval.""" pt = ParamTable() x = pt.add("x", 0.0) y = pt.add("y", 0.0) residuals = [x + y - Const(5.0), x - y - Const(1.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 3.0) < 1e-10 assert abs(pt.get_value("y") - 2.0) < 1e-10 def test_quadratic(self): """Solve x^2 - 4 = 0 starting from x=1.""" pt = ParamTable() x = pt.add("x", 1.0) residuals = [x * x - Const(4.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 2.0) < 1e-10 def test_nonlinear_system(self): """Compiled eval converges for a nonlinear system: xy=6, x+y=5.""" pt = ParamTable() x = pt.add("x", 2.0) y = pt.add("y", 3.5) residuals = [x * y - Const(6.0), x + y - Const(5.0)] assert newton_solve(residuals, pt, max_iter=100) is True # Solutions are (2, 3) or (3, 2) — check they satisfy both equations xv, yv = pt.get_value("x"), pt.get_value("y") assert abs(xv * yv - 6.0) < 1e-10 assert abs(xv + yv - 5.0) < 1e-10