"""Tests for pre-solve passes.""" import pytest from kindred_solver.expr import Const, Var from kindred_solver.params import ParamTable from kindred_solver.prepass import single_equation_pass, substitution_pass class TestSubstitutionPass: def test_fixed_replaced(self): pt = ParamTable() x = pt.add("x", 3.0, fixed=True) y = pt.add("y", 0.0) residuals = [x + y - Const(5.0)] result = substitution_pass(residuals, pt) # After substitution, x is replaced with 3.0 # So residual becomes 3.0 + y - 5.0 = y - 2.0 env = pt.get_env() assert abs(result[0].eval({"y": 2.0}) - 0.0) < 1e-10 def test_no_fixed(self): pt = ParamTable() x = pt.add("x", 1.0) residuals = [x - Const(1.0)] result = substitution_pass(residuals, pt) assert len(result) == 1 class TestSingleEquationPass: def test_solve_linear(self): """x - 3 = 0 with only x free → solves x=3 and removes residual.""" pt = ParamTable() x = pt.add("x", 0.0) residuals = [x - Const(3.0)] result = single_equation_pass(residuals, pt) assert len(result) == 0 assert abs(pt.get_value("x") - 3.0) < 1e-10 assert pt.is_fixed("x") def test_two_residuals_chain(self): """x - 3 = 0, y - x = 0 → solves x=3, then y=3.""" pt = ParamTable() x = pt.add("x", 0.0) y = pt.add("y", 0.0) residuals = [x - Const(3.0), y - x] result = single_equation_pass(residuals, pt) assert len(result) == 0 assert abs(pt.get_value("x") - 3.0) < 1e-10 assert abs(pt.get_value("y") - 3.0) < 1e-10 def test_multi_var_not_solved(self): """x + y - 5 = 0 with both free → not solved.""" pt = ParamTable() x = pt.add("x", 0.0) y = pt.add("y", 0.0) residuals = [x + y - Const(5.0)] result = single_equation_pass(residuals, pt) assert len(result) == 1 # still unsolved