"""Tests for geometry helpers.""" import math import pytest from kindred_solver.entities import RigidBody from kindred_solver.expr import Const, Var from kindred_solver.geometry import ( cross3, dot3, marker_x_axis, marker_y_axis, marker_z_axis, point_line_perp_components, point_plane_distance, sub3, ) from kindred_solver.params import ParamTable IDENTITY_QUAT = (1.0, 0.0, 0.0, 0.0) # 90-deg about Z: (cos45, 0, 0, sin45) _c = math.cos(math.pi / 4) _s = math.sin(math.pi / 4) ROT_90Z_QUAT = (_c, 0.0, 0.0, _s) class TestDot3: def test_parallel(self): a = (Const(1.0), Const(0.0), Const(0.0)) b = (Const(1.0), Const(0.0), Const(0.0)) assert abs(dot3(a, b).eval({}) - 1.0) < 1e-10 def test_perpendicular(self): a = (Const(1.0), Const(0.0), Const(0.0)) b = (Const(0.0), Const(1.0), Const(0.0)) assert abs(dot3(a, b).eval({})) < 1e-10 def test_general(self): a = (Const(1.0), Const(2.0), Const(3.0)) b = (Const(4.0), Const(5.0), Const(6.0)) # 1*4 + 2*5 + 3*6 = 32 assert abs(dot3(a, b).eval({}) - 32.0) < 1e-10 class TestCross3: def test_x_cross_y(self): x = (Const(1.0), Const(0.0), Const(0.0)) y = (Const(0.0), Const(1.0), Const(0.0)) cx, cy, cz = cross3(x, y) assert abs(cx.eval({})) < 1e-10 assert abs(cy.eval({})) < 1e-10 assert abs(cz.eval({}) - 1.0) < 1e-10 def test_parallel_is_zero(self): a = (Const(2.0), Const(3.0), Const(4.0)) b = (Const(4.0), Const(6.0), Const(8.0)) cx, cy, cz = cross3(a, b) assert abs(cx.eval({})) < 1e-10 assert abs(cy.eval({})) < 1e-10 assert abs(cz.eval({})) < 1e-10 class TestSub3: def test_basic(self): a = (Const(5.0), Const(3.0), Const(1.0)) b = (Const(1.0), Const(2.0), Const(3.0)) dx, dy, dz = sub3(a, b) assert abs(dx.eval({}) - 4.0) < 1e-10 assert abs(dy.eval({}) - 1.0) < 1e-10 assert abs(dz.eval({}) - (-2.0)) < 1e-10 class TestMarkerAxes: def test_identity_z(self): """Identity body + identity marker → Z = (0,0,1).""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT) env = pt.get_env() assert abs(zx.eval(env)) < 1e-10 assert abs(zy.eval(env)) < 1e-10 assert abs(zz.eval(env) - 1.0) < 1e-10 def test_identity_x(self): """Identity body + identity marker → X = (1,0,0).""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT) env = pt.get_env() assert abs(xx.eval(env) - 1.0) < 1e-10 assert abs(xy.eval(env)) < 1e-10 assert abs(xz.eval(env)) < 1e-10 def test_identity_y(self): """Identity body + identity marker → Y = (0,1,0).""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) yx, yy, yz = marker_y_axis(body, IDENTITY_QUAT) env = pt.get_env() assert abs(yx.eval(env)) < 1e-10 assert abs(yy.eval(env) - 1.0) < 1e-10 assert abs(yz.eval(env)) < 1e-10 def test_rotated_body_z(self): """Body rotated 90-deg about Z → Z-axis still (0,0,1).""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT) zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT) env = pt.get_env() assert abs(zx.eval(env)) < 1e-10 assert abs(zy.eval(env)) < 1e-10 assert abs(zz.eval(env) - 1.0) < 1e-10 def test_rotated_body_x(self): """Body rotated 90-deg about Z → X-axis becomes (0,1,0).""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT) xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT) env = pt.get_env() assert abs(xx.eval(env)) < 1e-10 assert abs(xy.eval(env) - 1.0) < 1e-10 assert abs(xz.eval(env)) < 1e-10 def test_marker_rotation(self): """Identity body + marker rotated 90-deg about Z → Z still (0,0,1).""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) zx, zy, zz = marker_z_axis(body, ROT_90Z_QUAT) env = pt.get_env() assert abs(zx.eval(env)) < 1e-10 assert abs(zy.eval(env)) < 1e-10 assert abs(zz.eval(env) - 1.0) < 1e-10 def test_marker_rotation_x_axis(self): """Identity body + marker rotated 90-deg about Z → X becomes (0,1,0).""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) xx, xy, xz = marker_x_axis(body, ROT_90Z_QUAT) env = pt.get_env() assert abs(xx.eval(env)) < 1e-10 assert abs(xy.eval(env) - 1.0) < 1e-10 assert abs(xz.eval(env)) < 1e-10 def test_differentiable(self): """Marker axes are differentiable w.r.t. body quat params.""" pt = ParamTable() body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT) # Should not raise dzx = zx.diff("p/qz").simplify() env = pt.get_env() dzx.eval(env) # Should be evaluable class TestPointPlaneDistance: def test_on_plane(self): pt = (Const(1.0), Const(2.0), Const(0.0)) origin = (Const(0.0), Const(0.0), Const(0.0)) normal = (Const(0.0), Const(0.0), Const(1.0)) d = point_plane_distance(pt, origin, normal) assert abs(d.eval({})) < 1e-10 def test_above_plane(self): pt = (Const(1.0), Const(2.0), Const(5.0)) origin = (Const(0.0), Const(0.0), Const(0.0)) normal = (Const(0.0), Const(0.0), Const(1.0)) d = point_plane_distance(pt, origin, normal) assert abs(d.eval({}) - 5.0) < 1e-10 class TestPointLinePerp: def test_on_line(self): pt = (Const(0.0), Const(0.0), Const(5.0)) origin = (Const(0.0), Const(0.0), Const(0.0)) direction = (Const(0.0), Const(0.0), Const(1.0)) cx, cy, cz = point_line_perp_components(pt, origin, direction) assert abs(cx.eval({})) < 1e-10 assert abs(cy.eval({})) < 1e-10 assert abs(cz.eval({})) < 1e-10 def test_off_line(self): pt = (Const(3.0), Const(0.0), Const(0.0)) origin = (Const(0.0), Const(0.0), Const(0.0)) direction = (Const(0.0), Const(0.0), Const(1.0)) cx, cy, cz = point_line_perp_components(pt, origin, direction) # d = (3,0,0), dir = (0,0,1), d x dir = (0*1-0*0, 0*0-3*1, 3*0-0*0) = (0,-3,0) assert abs(cx.eval({})) < 1e-10 assert abs(cy.eval({}) - (-3.0)) < 1e-10 assert abs(cz.eval({})) < 1e-10