"""Tests for solution preference: half-space tracking and corrections.""" import math import numpy as np import pytest from kindred_solver.constraints import ( AngleConstraint, DistancePointPointConstraint, ParallelConstraint, PerpendicularConstraint, ) from kindred_solver.entities import RigidBody from kindred_solver.newton import newton_solve from kindred_solver.params import ParamTable from kindred_solver.preference import ( apply_half_space_correction, compute_half_spaces, ) def _make_two_bodies( params, pos_a=(0, 0, 0), pos_b=(5, 0, 0), quat_a=(1, 0, 0, 0), quat_b=(1, 0, 0, 0), ground_a=True, ground_b=False, ): """Create two bodies with given positions/orientations.""" body_a = RigidBody( "a", params, position=pos_a, quaternion=quat_a, grounded=ground_a ) body_b = RigidBody( "b", params, position=pos_b, quaternion=quat_b, grounded=ground_b ) return body_a, body_b class TestDistanceHalfSpace: """Half-space tracking for DistancePointPoint constraint.""" def test_positive_x_stays_positive(self): """Body starting at +X should stay at +X after solve.""" params = ParamTable() body_a, body_b = _make_two_bodies(params, pos_b=(3, 0, 0)) c = DistancePointPointConstraint( body_a, (0, 0, 0), body_b, (0, 0, 0), distance=5.0, ) hs = compute_half_spaces([c], [0], params) assert len(hs) == 1 # Solve with half-space correction residuals = c.residuals() residuals.append(body_b.quat_norm_residual()) quat_groups = [body_b.quat_param_names()] def post_step(p): apply_half_space_correction(p, hs) converged = newton_solve( residuals, params, quat_groups=quat_groups, post_step=post_step, ) assert converged env = params.get_env() # Body b should be at +X (x > 0), not -X bx = env["b/tx"] assert bx > 0, f"Expected positive X, got {bx}" # Distance should be 5 dist = math.sqrt(bx**2 + env["b/ty"] ** 2 + env["b/tz"] ** 2) assert dist == pytest.approx(5.0, abs=1e-8) def test_negative_x_stays_negative(self): """Body starting at -X should stay at -X after solve.""" params = ParamTable() body_a, body_b = _make_two_bodies(params, pos_b=(-3, 0, 0)) c = DistancePointPointConstraint( body_a, (0, 0, 0), body_b, (0, 0, 0), distance=5.0, ) hs = compute_half_spaces([c], [0], params) assert len(hs) == 1 residuals = c.residuals() residuals.append(body_b.quat_norm_residual()) quat_groups = [body_b.quat_param_names()] def post_step(p): apply_half_space_correction(p, hs) converged = newton_solve( residuals, params, quat_groups=quat_groups, post_step=post_step, ) assert converged env = params.get_env() bx = env["b/tx"] assert bx < 0, f"Expected negative X, got {bx}" def test_zero_distance_no_halfspace(self): """Zero distance constraint has no branch ambiguity.""" params = ParamTable() body_a, body_b = _make_two_bodies(params, pos_b=(3, 0, 0)) c = DistancePointPointConstraint( body_a, (0, 0, 0), body_b, (0, 0, 0), distance=0.0, ) hs = compute_half_spaces([c], [0], params) assert len(hs) == 0 class TestParallelHalfSpace: """Half-space tracking for Parallel constraint.""" def test_same_direction_tracked(self): """Same-direction parallel: positive reference sign.""" params = ParamTable() body_a, body_b = _make_two_bodies(params) c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0)) hs = compute_half_spaces([c], [0], params) assert len(hs) == 1 assert hs[0].reference_sign == 1.0 def test_opposite_direction_tracked(self): """Opposite-direction parallel: negative reference sign.""" params = ParamTable() # Rotate body_b by 180 degrees about X: Z-axis flips q_flip = (0, 1, 0, 0) # 180 deg about X body_a, body_b = _make_two_bodies(params, quat_b=q_flip) c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0)) hs = compute_half_spaces([c], [0], params) assert len(hs) == 1 assert hs[0].reference_sign == -1.0 class TestAngleHalfSpace: """Half-space tracking for Angle constraint.""" def test_90_degree_angle(self): """90-degree angle constraint creates a half-space.""" params = ParamTable() # Rotate body_b by 90 degrees about X q_90x = (math.cos(math.pi / 4), math.sin(math.pi / 4), 0, 0) body_a, body_b = _make_two_bodies(params, quat_b=q_90x) c = AngleConstraint( body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0), angle=math.pi / 2, ) hs = compute_half_spaces([c], [0], params) assert len(hs) == 1 def test_zero_angle_no_halfspace(self): """0-degree angle has no branch ambiguity.""" params = ParamTable() body_a, body_b = _make_two_bodies(params) c = AngleConstraint( body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0), angle=0.0, ) hs = compute_half_spaces([c], [0], params) assert len(hs) == 0 def test_180_angle_no_halfspace(self): """180-degree angle has no branch ambiguity.""" params = ParamTable() body_a, body_b = _make_two_bodies(params) c = AngleConstraint( body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0), angle=math.pi, ) hs = compute_half_spaces([c], [0], params) assert len(hs) == 0 class TestPerpendicularHalfSpace: """Half-space tracking for Perpendicular constraint.""" def test_perpendicular_tracked(self): """Perpendicular constraint creates a half-space.""" params = ParamTable() # Rotate body_b by 90 degrees about X q_90x = (math.cos(math.pi / 4), math.sin(math.pi / 4), 0, 0) body_a, body_b = _make_two_bodies(params, quat_b=q_90x) c = PerpendicularConstraint( body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0), ) hs = compute_half_spaces([c], [0], params) assert len(hs) == 1 class TestNewtonPostStep: """Verify Newton post_step callback works correctly.""" def test_callback_fires(self): """post_step callback is invoked during Newton iterations.""" params = ParamTable() x = params.add("x", 2.0) from kindred_solver.expr import Const residuals = [x - Const(5.0)] call_count = [0] def counter(p): call_count[0] += 1 converged = newton_solve(residuals, params, post_step=counter) assert converged assert call_count[0] >= 1 def test_callback_does_not_break_convergence(self): """A no-op callback doesn't prevent convergence.""" params = ParamTable() x = params.add("x", 1.0) y = params.add("y", 1.0) from kindred_solver.expr import Const residuals = [x - Const(3.0), y - Const(7.0)] converged = newton_solve(residuals, params, post_step=lambda p: None) assert converged assert params.get_value("x") == pytest.approx(3.0) assert params.get_value("y") == pytest.approx(7.0) class TestMixedHalfSpaces: """Multiple branching constraints in one system.""" def test_multiple_constraints(self): """compute_half_spaces handles mixed constraint types.""" params = ParamTable() body_a, body_b = _make_two_bodies(params, pos_b=(5, 0, 0)) dist_c = DistancePointPointConstraint( body_a, (0, 0, 0), body_b, (0, 0, 0), distance=5.0, ) par_c = ParallelConstraint(body_a, (1, 0, 0, 0), body_b, (1, 0, 0, 0)) hs = compute_half_spaces([dist_c, par_c], [0, 1], params) assert len(hs) == 2 class TestBuildWeightVector: """Weight vector construction.""" def test_translation_weight_one(self): """Translation params get weight 1.0.""" from kindred_solver.preference import build_weight_vector params = ParamTable() params.add("body/tx", 0.0) params.add("body/ty", 0.0) params.add("body/tz", 0.0) w = build_weight_vector(params) np.testing.assert_array_equal(w, [1.0, 1.0, 1.0]) def test_quaternion_weight_high(self): """Quaternion params get QUAT_WEIGHT.""" from kindred_solver.preference import QUAT_WEIGHT, build_weight_vector params = ParamTable() params.add("body/qw", 1.0) params.add("body/qx", 0.0) params.add("body/qy", 0.0) params.add("body/qz", 0.0) w = build_weight_vector(params) np.testing.assert_array_equal(w, [QUAT_WEIGHT] * 4) def test_mixed_params(self): """Mixed translation and quaternion params get correct weights.""" from kindred_solver.preference import QUAT_WEIGHT, build_weight_vector params = ParamTable() params.add("b/tx", 0.0) params.add("b/qw", 1.0) params.add("b/ty", 0.0) params.add("b/qx", 0.0) w = build_weight_vector(params) assert w[0] == pytest.approx(1.0) assert w[1] == pytest.approx(QUAT_WEIGHT) assert w[2] == pytest.approx(1.0) assert w[3] == pytest.approx(QUAT_WEIGHT) def test_fixed_params_excluded(self): """Fixed params are not in free list, so not in weight vector.""" from kindred_solver.preference import build_weight_vector params = ParamTable() params.add("b/tx", 0.0, fixed=True) params.add("b/ty", 0.0) w = build_weight_vector(params) assert len(w) == 1 assert w[0] == pytest.approx(1.0) class TestWeightedNewton: """Weighted minimum-norm Newton solve.""" def test_well_constrained_same_result(self): """Weighted and unweighted produce identical results for unique solution.""" from kindred_solver.expr import Const # Fully determined system: x = 3, y = 7 params1 = ParamTable() x1 = params1.add("x", 1.0) y1 = params1.add("y", 1.0) r1 = [x1 - Const(3.0), y1 - Const(7.0)] params2 = ParamTable() x2 = params2.add("x", 1.0) y2 = params2.add("y", 1.0) r2 = [x2 - Const(3.0), y2 - Const(7.0)] newton_solve(r1, params1) newton_solve(r2, params2, weight_vector=np.array([1.0, 100.0])) assert params1.get_value("x") == pytest.approx( params2.get_value("x"), abs=1e-10 ) assert params1.get_value("y") == pytest.approx( params2.get_value("y"), abs=1e-10 ) def test_underconstrained_prefers_low_weight(self): """Under-constrained: weighted solve moves high-weight params less.""" from kindred_solver.expr import Const # 1 equation, 2 unknowns: x + y = 10 (from x=0, y=0) params_unw = ParamTable() xu = params_unw.add("x", 0.0) yu = params_unw.add("y", 0.0) ru = [xu + yu - Const(10.0)] params_w = ParamTable() xw = params_w.add("x", 0.0) yw = params_w.add("y", 0.0) rw = [xw + yw - Const(10.0)] # Unweighted: lstsq gives equal movement newton_solve(ru, params_unw) # Weighted: y is 100x more expensive to move newton_solve(rw, params_w, weight_vector=np.array([1.0, 100.0])) # Both should satisfy x + y = 10 assert params_unw.get_value("x") + params_unw.get_value("y") == pytest.approx( 10.0 ) assert params_w.get_value("x") + params_w.get_value("y") == pytest.approx(10.0) # Weighted solve should move y less than x assert abs(params_w.get_value("y")) < abs(params_w.get_value("x"))