"""Tests for RigidBody entities.""" import math import pytest from kindred_solver.entities import RigidBody from kindred_solver.params import ParamTable class TestRigidBody: def test_identity_world_point(self): """Body at origin with identity rotation: local == world.""" pt = ParamTable() body = RigidBody("p1", pt, (0, 0, 0), (1, 0, 0, 0)) wx, wy, wz = body.world_point(1.0, 2.0, 3.0) env = pt.get_env() assert abs(wx.eval(env) - 1.0) < 1e-10 assert abs(wy.eval(env) - 2.0) < 1e-10 assert abs(wz.eval(env) - 3.0) < 1e-10 def test_translated(self): """Body translated by (10, 20, 30) with identity rotation.""" pt = ParamTable() body = RigidBody("p1", pt, (10, 20, 30), (1, 0, 0, 0)) wx, wy, wz = body.world_point(1.0, 2.0, 3.0) env = pt.get_env() assert abs(wx.eval(env) - 11.0) < 1e-10 assert abs(wy.eval(env) - 22.0) < 1e-10 assert abs(wz.eval(env) - 33.0) < 1e-10 def test_rotated_90_z(self): """90-degree rotation about Z: (1,0,0) -> (0,1,0).""" pt = ParamTable() # 90-deg about Z: q = (cos(45), 0, 0, sin(45)) c = math.cos(math.pi / 4) s = math.sin(math.pi / 4) body = RigidBody("p1", pt, (0, 0, 0), (c, 0, 0, s)) wx, wy, wz = body.world_point(1.0, 0.0, 0.0) env = pt.get_env() assert abs(wx.eval(env) - 0.0) < 1e-10 assert abs(wy.eval(env) - 1.0) < 1e-10 assert abs(wz.eval(env) - 0.0) < 1e-10 def test_quat_norm_residual(self): """Normalization residual is zero for unit quaternion.""" pt = ParamTable() body = RigidBody("p1", pt, (0, 0, 0), (1, 0, 0, 0)) r = body.quat_norm_residual() env = pt.get_env() assert abs(r.eval(env)) < 1e-10 def test_grounded(self): """Grounded body has all params fixed.""" pt = ParamTable() body = RigidBody("p1", pt, (1, 2, 3), (1, 0, 0, 0), grounded=True) assert pt.n_free() == 0 assert body.grounded def test_extract(self): """Extract position and quaternion from env.""" pt = ParamTable() body = RigidBody("p1", pt, (1, 2, 3), (0.5, 0.5, 0.5, 0.5)) env = pt.get_env() pos = body.extract_position(env) quat = body.extract_quaternion(env) assert pos == (1.0, 2.0, 3.0) assert quat == (0.5, 0.5, 0.5, 0.5) def test_differentiation(self): """World point expressions are differentiable w.r.t. body params.""" pt = ParamTable() body = RigidBody("p1", pt, (0, 0, 0), (1, 0, 0, 0)) wx, wy, wz = body.world_point(1.0, 0.0, 0.0) # d(wx)/d(tx) should be 1 dtx = wx.diff("p1/tx").simplify() env = pt.get_env() assert abs(dtx.eval(env) - 1.0) < 1e-10