"""Integration tests for kinematic joint constraints. These tests exercise the full solve pipeline (constraint → residuals → pre-pass → Newton / BFGS) for multi-body systems with various joint types. """ import math import pytest from kindred_solver.constraints import ( BallConstraint, CoincidentConstraint, CylindricalConstraint, GearConstraint, ParallelConstraint, PerpendicularConstraint, PlanarConstraint, PointInPlaneConstraint, PointOnLineConstraint, RackPinionConstraint, RevoluteConstraint, ScrewConstraint, SliderConstraint, UniversalConstraint, ) from kindred_solver.dof import count_dof from kindred_solver.entities import RigidBody from kindred_solver.newton import newton_solve from kindred_solver.params import ParamTable from kindred_solver.prepass import single_equation_pass, substitution_pass ID_QUAT = (1, 0, 0, 0) # 90° about Z: (cos(45°), 0, 0, sin(45°)) c45 = math.cos(math.pi / 4) s45 = math.sin(math.pi / 4) ROT_90Z = (c45, 0, 0, s45) # 90° about Y ROT_90Y = (c45, 0, s45, 0) # 90° about X ROT_90X = (c45, s45, 0, 0) def _solve(bodies, constraint_objs): """Run the full solve pipeline. Returns (converged, params, bodies).""" pt = bodies[0].tx # all bodies share the same ParamTable via Var._name # Actually, we need the ParamTable object. Get it from the first body. # The Var objects store names, but we need the table. We'll reconstruct. # Better approach: caller passes pt. raise NotImplementedError("Use _solve_with_pt instead") def _solve_with_pt(pt, bodies, constraint_objs): """Run the full solve pipeline with explicit ParamTable.""" all_residuals = [] for c in constraint_objs: all_residuals.extend(c.residuals()) quat_groups = [] for body in bodies: if not body.grounded: all_residuals.append(body.quat_norm_residual()) quat_groups.append(body.quat_param_names()) all_residuals = substitution_pass(all_residuals, pt) all_residuals = single_equation_pass(all_residuals, pt) converged = newton_solve( all_residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10 ) return converged, all_residuals def _dof(pt, bodies, constraint_objs): """Count DOF for a system.""" all_residuals = [] for c in constraint_objs: all_residuals.extend(c.residuals()) for body in bodies: if not body.grounded: all_residuals.append(body.quat_norm_residual()) all_residuals = substitution_pass(all_residuals, pt) return count_dof(all_residuals, pt) # ============================================================================ # Single-joint DOF counting tests # ============================================================================ class TestJointDOF: """Verify each joint type removes the expected number of DOF. Setup: ground body + 1 free body (6 DOF) with a single joint. """ def _setup(self, pos_b=(0, 0, 0), quat_b=ID_QUAT): pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, pos_b, quat_b) return pt, a, b def test_ball_3dof(self): """Ball joint: 6 - 3 = 3 DOF (3 rotation).""" pt, a, b = self._setup() constraints = [BallConstraint(a, (0, 0, 0), b, (0, 0, 0))] assert _dof(pt, [a, b], constraints) == 3 def test_revolute_1dof(self): """Revolute: 6 - 5 = 1 DOF (rotation about Z).""" pt, a, b = self._setup() constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] assert _dof(pt, [a, b], constraints) == 1 def test_cylindrical_2dof(self): """Cylindrical: 6 - 4 = 2 DOF (rotation + translation along Z).""" pt, a, b = self._setup() constraints = [ CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] assert _dof(pt, [a, b], constraints) == 2 def test_slider_1dof(self): """Slider: 6 - 5 = 1 DOF (translation along Z).""" pt, a, b = self._setup() constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] assert _dof(pt, [a, b], constraints) == 1 def test_universal_2dof(self): """Universal: 6 - 4 = 2 DOF (rotation about each body's Z).""" pt, a, b = self._setup(quat_b=ROT_90X) constraints = [ UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] assert _dof(pt, [a, b], constraints) == 2 def test_screw_1dof(self): """Screw: 6 - 5 = 1 DOF (helical motion).""" pt, a, b = self._setup() constraints = [ ScrewConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT, pitch=10.0) ] assert _dof(pt, [a, b], constraints) == 1 def test_parallel_4dof(self): """Parallel: 6 - 2 = 4 DOF.""" pt, a, b = self._setup() constraints = [ParallelConstraint(a, ID_QUAT, b, ID_QUAT)] assert _dof(pt, [a, b], constraints) == 4 def test_perpendicular_5dof(self): """Perpendicular: 6 - 1 = 5 DOF.""" pt, a, b = self._setup(quat_b=ROT_90X) constraints = [PerpendicularConstraint(a, ID_QUAT, b, ID_QUAT)] assert _dof(pt, [a, b], constraints) == 5 def test_point_on_line_4dof(self): """PointOnLine: 6 - 2 = 4 DOF.""" pt, a, b = self._setup() constraints = [ PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] assert _dof(pt, [a, b], constraints) == 4 def test_point_in_plane_5dof(self): """PointInPlane: 6 - 1 = 5 DOF.""" pt, a, b = self._setup() constraints = [ PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] assert _dof(pt, [a, b], constraints) == 5 def test_planar_3dof(self): """Planar: 6 - 3 = 3 DOF (2 translation in plane + 1 rotation about normal).""" pt, a, b = self._setup() constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] assert _dof(pt, [a, b], constraints) == 3 # ============================================================================ # Solve convergence tests — single joints from displaced initial conditions # ============================================================================ class TestJointSolve: """Newton converges to a valid configuration from displaced starting points.""" def test_revolute_displaced(self): """Revolute joint: body B starts displaced, should converge to hinge position.""" pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, (3, 4, 5), ID_QUAT) # displaced constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] converged, _ = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() pos = b.extract_position(env) # Coincident origins → position should be at origin assert abs(pos[0]) < 1e-8 assert abs(pos[1]) < 1e-8 assert abs(pos[2]) < 1e-8 def test_cylindrical_displaced(self): """Cylindrical joint: body B can slide along Z but must be on axis.""" pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, (3, 4, 7), ID_QUAT) # off-axis constraints = [ CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] converged, _ = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() pos = b.extract_position(env) # X and Y should be zero (on axis), Z can be anything assert abs(pos[0]) < 1e-8 assert abs(pos[1]) < 1e-8 def test_slider_displaced(self): """Slider: body B can translate along Z only.""" pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, (2, 3, 5), ID_QUAT) # displaced constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] converged, _ = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() pos = b.extract_position(env) # X and Y should be zero (on axis), Z free assert abs(pos[0]) < 1e-8 assert abs(pos[1]) < 1e-8 def test_ball_displaced(self): """Ball joint: body B moves so marker origins coincide. Ball has 3 rotation DOF free, so we can only verify the world-frame marker points match, not the body position directly. """ pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, (5, 5, 5), ID_QUAT) constraints = [BallConstraint(a, (1, 0, 0), b, (-1, 0, 0))] converged, _ = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() # Verify marker world points match wp_a = a.world_point(1, 0, 0) wp_b = b.world_point(-1, 0, 0) for ea, eb in zip(wp_a, wp_b): assert abs(ea.eval(env) - eb.eval(env)) < 1e-8 def test_universal_displaced(self): """Universal joint: coincident origins + perpendicular Z-axes.""" pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) # Start B with Z-axis along X (90° about Y) — perpendicular to A's Z b = RigidBody("b", pt, (3, 4, 5), ROT_90Y) constraints = [ UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] converged, _ = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() pos = b.extract_position(env) assert abs(pos[0]) < 1e-8 assert abs(pos[1]) < 1e-8 assert abs(pos[2]) < 1e-8 def test_point_on_line_solve(self): """Point on line: body B's marker origin constrained to line along Z. Under-constrained system (4 DOF remain), so we verify the constraint residuals are satisfied rather than expecting specific positions. """ pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, (5, 3, 7), ID_QUAT) constraints = [ PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] converged, residuals = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() for r in residuals: assert abs(r.eval(env)) < 1e-8 def test_point_in_plane_solve(self): """Point in plane: body B's marker origin at z=0 plane. Under-constrained (5 DOF remain), so verify residuals. """ pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, (3, 4, 8), ID_QUAT) constraints = [ PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) ] converged, residuals = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() for r in residuals: assert abs(r.eval(env)) < 1e-8 def test_planar_solve(self): """Planar: coplanar faces — parallel normals + point in plane.""" pt = ParamTable() a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) # Start B tilted and displaced b = RigidBody("b", pt, (3, 4, 8), ID_QUAT) constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] converged, _ = _solve_with_pt(pt, [a, b], constraints) assert converged env = pt.get_env() pos = b.extract_position(env) # Z must be zero (in plane), X and Y free assert abs(pos[2]) < 1e-8 # ============================================================================ # Multi-body integration tests # ============================================================================ class TestFourBarLinkage: """Four-bar linkage: 4 bodies, 4 revolute joints. In 3D with Z-axis revolutes, this yields 2 DOF: the expected planar motion plus an out-of-plane fold. A truly planar mechanism would add Planar constraints on each link to eliminate the fold DOF. """ def test_four_bar_dof(self): """Four-bar linkage in 3D has 2 DOF (planar + fold).""" pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) link1 = RigidBody("l1", pt, (2, 0, 0), ID_QUAT) link2 = RigidBody("l2", pt, (5, 3, 0), ID_QUAT) link3 = RigidBody("l3", pt, (8, 0, 0), ID_QUAT) constraints = [ RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT), RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT), RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT), RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT), ] bodies = [ground, link1, link2, link3] dof = _dof(pt, bodies, constraints) assert dof == 2 def test_four_bar_solves(self): """Four-bar linkage converges from displaced initial conditions.""" pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) # Initial positions slightly displaced from valid config link1 = RigidBody("l1", pt, (2, 1, 0), ID_QUAT) link2 = RigidBody("l2", pt, (5, 4, 0), ID_QUAT) link3 = RigidBody("l3", pt, (8, 1, 0), ID_QUAT) constraints = [ RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT), RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT), RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT), RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT), ] bodies = [ground, link1, link2, link3] converged, residuals = _solve_with_pt(pt, bodies, constraints) assert converged # Verify all revolute constraints are satisfied env = pt.get_env() for r in residuals: assert abs(r.eval(env)) < 1e-8 class TestSliderCrank: """Slider-crank mechanism: crank + connecting rod + piston. ground --[Revolute]-- crank --[Revolute]-- rod --[Revolute]-- piston --[Slider]-- ground Using Slider (not Cylindrical) for the piston to also lock rotation, making it a true prismatic joint. In 3D, out-of-plane folding adds extra DOF beyond the planar 1-DOF. 3 free bodies × 6 = 18 DOF Revolute(5) + Revolute(5) + Revolute(5) + Slider(5) = 20 But many constraints share bodies, so effective rank < 20. In 3D: 3 DOF (planar crank + 2 fold modes). """ def test_slider_crank_dof(self): pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) crank = RigidBody("crank", pt, (1, 0, 0), ID_QUAT) rod = RigidBody("rod", pt, (3, 0, 0), ID_QUAT) piston = RigidBody("piston", pt, (5, 0, 0), ID_QUAT) constraints = [ RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT), RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT), RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT), SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y), ] bodies = [ground, crank, rod, piston] dof = _dof(pt, bodies, constraints) # With full 3-component cross products, the redundant constraint rows # eliminate the out-of-plane fold modes, giving the correct 1 DOF # (crank rotation only). assert dof == 1 def test_slider_crank_solves(self): """Slider-crank converges from displaced state.""" pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) crank = RigidBody("crank", pt, (1, 0.5, 0), ID_QUAT) rod = RigidBody("rod", pt, (3, 1, 0), ID_QUAT) piston = RigidBody("piston", pt, (5, 0.5, 0), ID_QUAT) constraints = [ RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT), RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT), RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT), SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y), ] bodies = [ground, crank, rod, piston] converged, residuals = _solve_with_pt(pt, bodies, constraints) assert converged env = pt.get_env() for r in residuals: assert abs(r.eval(env)) < 1e-8 class TestRevoluteChain: """Chain of revolute joints: ground → body1 → body2. Each revolute removes 5 DOF. Two free bodies = 12 DOF. 2 revolutes = 10 constraints + 2 quat norms = 12. Expected: 2 DOF (one rotation per hinge). """ def test_chain_dof(self): pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) b1 = RigidBody("b1", pt, (3, 0, 0), ID_QUAT) b2 = RigidBody("b2", pt, (6, 0, 0), ID_QUAT) constraints = [ RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT), RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT), ] assert _dof(pt, [ground, b1, b2], constraints) == 2 def test_chain_solves(self): pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) b1 = RigidBody("b1", pt, (3, 2, 0), ID_QUAT) b2 = RigidBody("b2", pt, (6, 3, 0), ID_QUAT) constraints = [ RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT), RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT), ] converged, residuals = _solve_with_pt(pt, [ground, b1, b2], constraints) assert converged env = pt.get_env() # b1 origin at ground hinge point (0,0,0) pos1 = b1.extract_position(env) assert abs(pos1[0]) < 1e-8 assert abs(pos1[1]) < 1e-8 assert abs(pos1[2]) < 1e-8 class TestSliderOnRail: """Slider constraint: body translates along ground Z-axis only. 1 free body, 1 slider = 6 - 5 = 1 DOF. """ def test_slider_on_rail(self): pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) block = RigidBody("block", pt, (3, 4, 5), ID_QUAT) constraints = [ SliderConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT) ] converged, _ = _solve_with_pt(pt, [ground, block], constraints) assert converged env = pt.get_env() pos = block.extract_position(env) # X, Y must be zero; Z is free assert abs(pos[0]) < 1e-8 assert abs(pos[1]) < 1e-8 # Z should remain near initial value (minimum-norm solution) # Check orientation unchanged (no twist) quat = block.extract_quaternion(env) assert abs(quat[0] - 1.0) < 1e-6 assert abs(quat[1]) < 1e-6 assert abs(quat[2]) < 1e-6 assert abs(quat[3]) < 1e-6 class TestPlanarOnTable: """Planar constraint: body slides on XY plane. 1 free body, 1 planar = 6 - 3 = 3 DOF. """ def test_planar_on_table(self): pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) block = RigidBody("block", pt, (3, 4, 5), ID_QUAT) constraints = [ PlanarConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT) ] converged, _ = _solve_with_pt(pt, [ground, block], constraints) assert converged env = pt.get_env() pos = block.extract_position(env) # Z must be zero, X and Y are free assert abs(pos[2]) < 1e-8 class TestPlanarWithOffset: """Planar with offset: body floats at z=3 above ground.""" def test_planar_offset(self): pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) block = RigidBody("block", pt, (1, 2, 5), ID_QUAT) # PlanarConstraint residual: (p_i - p_j) . z_j - offset = 0 # body_i=block, body_j=ground: (block_z - 0) * 1 - offset = 0 # For block at z=3: offset = 3 constraints = [ PlanarConstraint( block, (0, 0, 0), ID_QUAT, ground, (0, 0, 0), ID_QUAT, offset=3.0 ) ] converged, _ = _solve_with_pt(pt, [ground, block], constraints) assert converged env = pt.get_env() pos = block.extract_position(env) assert abs(pos[2] - 3.0) < 1e-8 class TestMixedConstraints: """System with mixed constraint types.""" def test_revolute_plus_parallel(self): """Two free bodies: revolute between ground and b1, parallel between b1 and b2. b1: 6 DOF - 5 (revolute) = 1 DOF b2: 6 DOF - 2 (parallel) = 4 DOF Total: 5 DOF """ pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) b1 = RigidBody("b1", pt, (0, 0, 0), ID_QUAT) b2 = RigidBody("b2", pt, (5, 0, 0), ID_QUAT) constraints = [ RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT), ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT), ] assert _dof(pt, [ground, b1, b2], constraints) == 5 def test_coincident_plus_perpendicular(self): """Coincident + perpendicular = ball + 1 angle constraint. 6 - 3 (coincident) - 1 (perpendicular) = 2 DOF. """ pt = ParamTable() ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) b = RigidBody("b", pt, (0, 0, 0), ROT_90X) constraints = [ CoincidentConstraint(ground, (0, 0, 0), b, (0, 0, 0)), PerpendicularConstraint(ground, ID_QUAT, b, ID_QUAT), ] assert _dof(pt, [ground, b], constraints) == 2