"""Tests for the Newton-Raphson solver.""" import math import pytest from kindred_solver.expr import Const, Var from kindred_solver.newton import newton_solve from kindred_solver.params import ParamTable class TestNewtonBasic: def test_single_linear(self): """Solve x - 3 = 0 → x = 3.""" pt = ParamTable() x = pt.add("x", 0.0) residuals = [x - Const(3.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 3.0) < 1e-10 def test_single_quadratic(self): """Solve x^2 - 4 = 0 starting from x=1 → x = 2.""" pt = ParamTable() x = pt.add("x", 1.0) residuals = [x * x - Const(4.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 2.0) < 1e-10 def test_two_variables(self): """Solve x + y = 5, x - y = 1 → x=3, y=2.""" pt = ParamTable() x = pt.add("x", 0.0) y = pt.add("y", 0.0) residuals = [x + y - Const(5.0), x - y - Const(1.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 3.0) < 1e-10 assert abs(pt.get_value("y") - 2.0) < 1e-10 def test_with_fixed(self): """Fixed parameter is not updated.""" pt = ParamTable() x = pt.add("x", 0.0) y = pt.add("y", 5.0, fixed=True) residuals = [x + y - Const(10.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 5.0) < 1e-10 assert pt.get_value("y") == 5.0 def test_empty_system(self): """Empty residual list converges trivially.""" pt = ParamTable() assert newton_solve([], pt) is True class TestNewtonQuat: def test_quat_renorm(self): """Quaternion re-normalization keeps unit length.""" pt = ParamTable() qw = pt.add("qw", 0.9) qx = pt.add("qx", 0.1) qy = pt.add("qy", 0.1) qz = pt.add("qz", 0.1) # Residual: qw^2 + qx^2 + qy^2 + qz^2 - 1 = 0 r = qw * qw + qx * qx + qy * qy + qz * qz - Const(1.0) quat_groups = [("qw", "qx", "qy", "qz")] assert newton_solve([r], pt, quat_groups=quat_groups) is True # Check unit length w, x, y, z = (pt.get_value(n) for n in ["qw", "qx", "qy", "qz"]) norm = math.sqrt(w**2 + x**2 + y**2 + z**2) assert abs(norm - 1.0) < 1e-10 class TestNewtonGeometric: def test_point_coincidence(self): """Two points that should meet: (x,0,0) == (3,0,0).""" pt = ParamTable() x = pt.add("x", 0.0) residuals = [x - Const(3.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 3.0) < 1e-10 def test_distance_constraint(self): """Point at (x,0,0) should be distance 5 from origin. Squared: x^2 - 25 = 0, starting from x=3 → x=5.""" pt = ParamTable() x = pt.add("x", 3.0) residuals = [x * x - Const(25.0)] assert newton_solve(residuals, pt) is True assert abs(pt.get_value("x") - 5.0) < 1e-10