diff --git a/kindred_solver/bfgs.py b/kindred_solver/bfgs.py new file mode 100644 index 0000000..44e65ce --- /dev/null +++ b/kindred_solver/bfgs.py @@ -0,0 +1,127 @@ +"""L-BFGS-B fallback solver for when Newton-Raphson fails to converge. + +Minimizes f(x) = 0.5 * sum(r_i(x)^2) using scipy's L-BFGS-B with +analytic gradient from the Expr DAG's symbolic differentiation. +""" + +from __future__ import annotations + +import math +from typing import List + +import numpy as np + +from .expr import Expr +from .params import ParamTable + +try: + from scipy.optimize import minimize as _scipy_minimize + + _HAS_SCIPY = True +except ImportError: + _HAS_SCIPY = False + + +def bfgs_solve( + residuals: List[Expr], + params: ParamTable, + quat_groups: List[tuple[str, str, str, str]] | None = None, + max_iter: int = 200, + tol: float = 1e-10, +) -> bool: + """Solve ``residuals == 0`` by minimizing sum of squared residuals. + + Falls back gracefully to False if scipy is not available. + + Returns True if converged (||r|| < tol). + """ + if not _HAS_SCIPY: + return False + + free = params.free_names() + n_free = len(free) + n_res = len(residuals) + + if n_free == 0 or n_res == 0: + return True + + # Build symbolic gradient expressions once: d(r_i)/d(x_j) + jac_exprs: List[List[Expr]] = [] + for r in residuals: + row = [] + for name in free: + row.append(r.diff(name).simplify()) + jac_exprs.append(row) + + def objective_and_grad(x_vec): + # Update params + params.set_free_vector(x_vec) + if quat_groups: + _renormalize_quats(params, quat_groups) + + env = params.get_env() + + # Evaluate residuals + r_vals = np.array([r.eval(env) for r in residuals]) + f = 0.5 * np.dot(r_vals, r_vals) + + # Evaluate Jacobian + J = np.empty((n_res, n_free)) + for i in range(n_res): + for j in range(n_free): + J[i, j] = jac_exprs[i][j].eval(env) + + # Gradient of f = sum(r_i * dr_i/dx_j) = J^T @ r + grad = J.T @ r_vals + + return f, grad + + x0 = params.get_free_vector().copy() + + result = _scipy_minimize( + objective_and_grad, + x0, + method="L-BFGS-B", + jac=True, + options={"maxiter": max_iter, "ftol": tol * tol, "gtol": tol}, + ) + + # Apply final result + params.set_free_vector(result.x) + if quat_groups: + _renormalize_quats(params, quat_groups) + + # Check convergence on actual residual norm + env = params.get_env() + r_vals = np.array([r.eval(env) for r in residuals]) + return bool(np.linalg.norm(r_vals) < tol) + + +def _renormalize_quats( + params: ParamTable, + groups: List[tuple[str, str, str, str]], +): + """Project quaternion params back onto the unit sphere.""" + for qw_name, qx_name, qy_name, qz_name in groups: + if ( + params.is_fixed(qw_name) + and params.is_fixed(qx_name) + and params.is_fixed(qy_name) + and params.is_fixed(qz_name) + ): + continue + w = params.get_value(qw_name) + x = params.get_value(qx_name) + y = params.get_value(qy_name) + z = params.get_value(qz_name) + norm = math.sqrt(w * w + x * x + y * y + z * z) + if norm < 1e-15: + params.set_value(qw_name, 1.0) + params.set_value(qx_name, 0.0) + params.set_value(qy_name, 0.0) + params.set_value(qz_name, 0.0) + else: + params.set_value(qw_name, w / norm) + params.set_value(qx_name, x / norm) + params.set_value(qy_name, y / norm) + params.set_value(qz_name, z / norm) diff --git a/kindred_solver/constraints.py b/kindred_solver/constraints.py index aeec351..47eacb1 100644 --- a/kindred_solver/constraints.py +++ b/kindred_solver/constraints.py @@ -2,14 +2,28 @@ Each constraint takes two RigidBody entities and marker transforms, then generates residual expressions that equal zero when satisfied. + +Phase 1 constraints: Coincident, DistancePointPoint, Fixed +Phase 2 constraints: all remaining BaseJointKind types from Types.h """ from __future__ import annotations +import math from typing import List from .entities import RigidBody from .expr import Const, Expr +from .geometry import ( + cross3, + dot3, + marker_x_axis, + marker_y_axis, + marker_z_axis, + point_line_perp_components, + point_plane_distance, + sub3, +) class ConstraintBase: @@ -145,6 +159,703 @@ class FixedConstraint(ConstraintBase): return pos_res + ori_res +# ============================================================================ +# Phase 2: Point constraints +# ============================================================================ + + +class PointOnLineConstraint(ConstraintBase): + """Point constrained to a line — 2 DOF removed. + + marker_i origin lies on the line through marker_j origin along + marker_j Z-axis. 2 residuals: perpendicular distance components. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy = point_line_perp_components(p_i, p_j, z_j) + return [cx, cy] + + +class PointInPlaneConstraint(ConstraintBase): + """Point constrained to a plane — 1 DOF removed. + + marker_i origin lies in the plane through marker_j origin with + normal = marker_j Z-axis. Optional offset via params[0]. + 1 residual: signed distance to plane. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + offset: float = 0.0, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + self.offset = offset + + def residuals(self) -> List[Expr]: + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + n_j = marker_z_axis(self.body_j, self.marker_j_quat) + d = point_plane_distance(p_i, p_j, n_j) + if self.offset != 0.0: + d = d - Const(self.offset) + return [d] + + +# ============================================================================ +# Phase 2: Axis orientation constraints +# ============================================================================ + + +class ParallelConstraint(ConstraintBase): + """Parallel axes — 2 DOF removed. + + marker Z-axes are parallel: z_i x z_j = 0. + 2 residuals from the cross product (only 2 of 3 components are + independent for unit vectors). + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_quat = marker_i_quat + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy, cz = cross3(z_i, z_j) + return [cx, cy] + + +class PerpendicularConstraint(ConstraintBase): + """Perpendicular axes — 1 DOF removed. + + marker Z-axes are perpendicular: z_i . z_j = 0. + 1 residual. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_quat = marker_i_quat + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + return [dot3(z_i, z_j)] + + +class AngleConstraint(ConstraintBase): + """Angle between axes — 1 DOF removed. + + z_i . z_j = cos(angle). + 1 residual. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_quat: tuple[float, float, float, float], + angle: float, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_quat = marker_i_quat + self.marker_j_quat = marker_j_quat + self.angle = angle + + def residuals(self) -> List[Expr]: + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + return [dot3(z_i, z_j) - Const(math.cos(self.angle))] + + +# ============================================================================ +# Phase 2: Axis/surface constraints +# ============================================================================ + + +class ConcentricConstraint(ConstraintBase): + """Coaxial / concentric — 4 DOF removed. + + Axes are collinear: parallel Z-axes (2) + point-on-line (2). + Optional distance offset along axis via params. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + distance: float = 0.0, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + self.distance = distance + + def residuals(self) -> List[Expr]: + # Parallel axes (2 residuals) + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy, _cz = cross3(z_i, z_j) + + # Point-on-line: marker_i origin on line through marker_j along z_j + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + lx, ly = point_line_perp_components(p_i, p_j, z_j) + + return [cx, cy, lx, ly] + + +class TangentConstraint(ConstraintBase): + """Face-on-face tangency — 1 DOF removed. + + Signed distance between marker origins along marker_j normal = 0. + 1 residual: (p_i - p_j) . z_j + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + n_j = marker_z_axis(self.body_j, self.marker_j_quat) + return [point_plane_distance(p_i, p_j, n_j)] + + +class PlanarConstraint(ConstraintBase): + """Coplanar faces — 3 DOF removed. + + Parallel normals (2) + point-in-plane (1). Optional offset. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + offset: float = 0.0, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + self.offset = offset + + def residuals(self) -> List[Expr]: + # Parallel normals + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy, _cz = cross3(z_i, z_j) + + # Point-in-plane + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + d = point_plane_distance(p_i, p_j, z_j) + if self.offset != 0.0: + d = d - Const(self.offset) + + return [cx, cy, d] + + +class LineInPlaneConstraint(ConstraintBase): + """Line constrained to a plane — 2 DOF removed. + + Line defined by marker_i Z-axis lies in plane defined by marker_j normal. + 2 residuals: point-in-plane (1) + line direction perpendicular to normal (1). + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + offset: float = 0.0, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + self.offset = offset + + def residuals(self) -> List[Expr]: + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + n_j = marker_z_axis(self.body_j, self.marker_j_quat) + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + + # Point in plane + d = point_plane_distance(p_i, p_j, n_j) + if self.offset != 0.0: + d = d - Const(self.offset) + + # Line direction perpendicular to plane normal + dir_dot = dot3(z_i, n_j) + + return [d, dir_dot] + + +# ============================================================================ +# Phase 2: Kinematic joints +# ============================================================================ + + +class BallConstraint(ConstraintBase): + """Spherical joint — 3 DOF removed. + + Coincident marker origins. Same as CoincidentConstraint. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + ): + self._inner = CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos) + + def residuals(self) -> List[Expr]: + return self._inner.residuals() + + +class RevoluteConstraint(ConstraintBase): + """Hinge joint — 5 DOF removed. + + Coincident origins (3) + parallel Z-axes (2). + 1 rotational DOF remains (about the common Z-axis). + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + # Coincident origins + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]] + + # Parallel Z-axes + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy, _cz = cross3(z_i, z_j) + + return pos + [cx, cy] + + +class CylindricalConstraint(ConstraintBase): + """Cylindrical joint — 4 DOF removed. + + Parallel Z-axes (2) + point-on-line (2). + 2 DOF remain: rotation about and translation along the common axis. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + # Parallel Z-axes + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy, _cz = cross3(z_i, z_j) + + # Point-on-line + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + lx, ly = point_line_perp_components(p_i, p_j, z_j) + + return [cx, cy, lx, ly] + + +class SliderConstraint(ConstraintBase): + """Prismatic / slider joint — 5 DOF removed. + + Parallel Z-axes (2) + point-on-line (2) + rotation lock (1). + 1 DOF remains: translation along the common Z-axis. + + Rotation lock: x_i . y_j = 0 (prevents twist about Z). + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + # Parallel Z-axes + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy, _cz = cross3(z_i, z_j) + + # Point-on-line + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + lx, ly = point_line_perp_components(p_i, p_j, z_j) + + # Rotation lock: x_i . y_j = 0 + x_i = marker_x_axis(self.body_i, self.marker_i_quat) + y_j = marker_y_axis(self.body_j, self.marker_j_quat) + twist = dot3(x_i, y_j) + + return [cx, cy, lx, ly, twist] + + +class ScrewConstraint(ConstraintBase): + """Helical / screw joint — 5 DOF removed. + + Cylindrical (4) + coupled rotation-translation via pitch (1). + 1 DOF remains: screw motion (rotation + proportional translation). + + The coupling residual uses the relative quaternion's Z-component + (proportional to the rotation angle for small angles) and the axial + displacement: axial_disp - pitch * (2 * qz_rel / qw_rel) / (2*pi) = 0. + For the Newton solver operating near the solution, the linear + approximation angle ≈ 2 * qz_rel is adequate. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + pitch: float = 1.0, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + self.pitch = pitch + + def residuals(self) -> List[Expr]: + # Cylindrical residuals (4) + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + cx, cy, _cz = cross3(z_i, z_j) + + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + lx, ly = point_line_perp_components(p_i, p_j, z_j) + + # Pitch coupling: axial_disp = pitch * angle / (2*pi) + # Axial displacement + d = sub3(p_i, p_j) + axial = dot3(d, z_j) + + # Relative rotation about Z via quaternion + # q_rel = conj(q_i_total) * q_j_total + qi = _quat_mul_const( + self.body_i.qw, + self.body_i.qx, + self.body_i.qy, + self.body_i.qz, + *self.marker_i_quat, + ) + qj = _quat_mul_const( + self.body_j.qw, + self.body_j.qx, + self.body_j.qy, + self.body_j.qz, + *self.marker_j_quat, + ) + rel = _quat_mul_expr(qi[0], -qi[1], -qi[2], -qi[3], qj[0], qj[1], qj[2], qj[3]) + # For small angles: angle ≈ 2 * qz_rel, but qw_rel ≈ 1 + # Use sin(angle/2) form: residual = axial - pitch * 2*qz / (2*pi) + # = axial - pitch * qz / pi + coupling = axial - Const(self.pitch / math.pi) * rel[3] + + return [cx, cy, lx, ly, coupling] + + +class UniversalConstraint(ConstraintBase): + """Universal / Cardan joint — 4 DOF removed. + + Coincident origins (3) + perpendicular Z-axes (1). + 2 DOF remain: rotation about each body's Z-axis. + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + + def residuals(self) -> List[Expr]: + # Coincident origins + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]] + + # Perpendicular Z-axes + z_i = marker_z_axis(self.body_i, self.marker_i_quat) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + + return pos + [dot3(z_i, z_j)] + + +# ============================================================================ +# Phase 2: Mechanical element constraints +# ============================================================================ + + +class GearConstraint(ConstraintBase): + """Gear pair or belt — 1 DOF removed. + + Couples rotation angles: r_i * theta_i + r_j * theta_j = 0. + For belts (same-direction rotation), r_j is passed as negative. + + Uses the Z-component of the relative quaternion as a proxy for + rotation angle (linear for small angles, which is the regime + where Newton operates). + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_quat: tuple[float, float, float, float], + radius_i: float, + radius_j: float, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_quat = marker_i_quat + self.marker_j_quat = marker_j_quat + self.radius_i = radius_i + self.radius_j = radius_j + + def residuals(self) -> List[Expr]: + # Rotation angle proxy via relative quaternion Z-component + # For body_i: q_rel_i = conj(q_marker_i) * q_body_i * q_marker_i + # Simplified: use 2*qz of (conj(marker) * body * marker) as angle proxy + qz_i = _rotation_z_component(self.body_i, self.marker_i_quat) + qz_j = _rotation_z_component(self.body_j, self.marker_j_quat) + + # r_i * theta_i + r_j * theta_j = 0 + # Using qz as proportional to theta/2: + # r_i * qz_i + r_j * qz_j = 0 + return [Const(self.radius_i) * qz_i + Const(self.radius_j) * qz_j] + + +class RackPinionConstraint(ConstraintBase): + """Rack-and-pinion — 1 DOF removed. + + Couples rotation of body_i to translation of body_j along marker_j Z-axis. + translation = pitch_radius * theta + """ + + def __init__( + self, + body_i: RigidBody, + marker_i_pos: tuple[float, float, float], + marker_i_quat: tuple[float, float, float, float], + body_j: RigidBody, + marker_j_pos: tuple[float, float, float], + marker_j_quat: tuple[float, float, float, float], + pitch_radius: float, + ): + self.body_i = body_i + self.body_j = body_j + self.marker_i_pos = marker_i_pos + self.marker_i_quat = marker_i_quat + self.marker_j_pos = marker_j_pos + self.marker_j_quat = marker_j_quat + self.pitch_radius = pitch_radius + + def residuals(self) -> List[Expr]: + # Translation of j along its Z-axis + p_i = self.body_i.world_point(*self.marker_i_pos) + p_j = self.body_j.world_point(*self.marker_j_pos) + z_j = marker_z_axis(self.body_j, self.marker_j_quat) + d = sub3(p_j, p_i) + translation = dot3(d, z_j) + + # Rotation angle of i about its Z-axis + qz_i = _rotation_z_component(self.body_i, self.marker_i_quat) + + # translation - pitch_radius * angle = 0 + # angle ≈ 2 * qz, so: translation - pitch_radius * 2 * qz = 0 + return [translation - Const(2.0 * self.pitch_radius) * qz_i] + + +class CamConstraint(ConstraintBase): + """Cam-follower constraint — future, stub.""" + + def residuals(self) -> List[Expr]: + return [] + + +class SlotConstraint(ConstraintBase): + """Slot constraint — future, stub.""" + + def residuals(self) -> List[Expr]: + return [] + + +class DistanceCylSphConstraint(ConstraintBase): + """Cylinder-sphere distance — stub. + + Semantics depend on geometry classification; placeholder for now. + """ + + def residuals(self) -> List[Expr]: + return [] + + +# -- rotation helpers for mechanical constraints ------------------------------ + + +def _rotation_z_component( + body: RigidBody, + marker_quat: tuple[float, float, float, float], +) -> Expr: + """Extract the Z-component of the relative quaternion about a marker axis. + + Returns the qz component of conj(q_marker) * q_body * q_marker, + which is proportional to sin(theta/2) where theta is the rotation + angle about the marker Z-axis. + """ + mw, mx, my, mz = marker_quat + # q_local = conj(marker) * q_body * marker + # Step 1: temp = conj(marker) * q_body + cmw, cmx, cmy, cmz = Const(mw), Const(-mx), Const(-my), Const(-mz) + # temp = conj(marker) * q_body + tw = cmw * body.qw - cmx * body.qx - cmy * body.qy - cmz * body.qz + tx = cmw * body.qx + cmx * body.qw + cmy * body.qz - cmz * body.qy + ty = cmw * body.qy - cmx * body.qz + cmy * body.qw + cmz * body.qx + tz = cmw * body.qz + cmx * body.qy - cmy * body.qx + cmz * body.qw + # q_local = temp * marker + mmw, mmx, mmy, mmz = Const(mw), Const(mx), Const(my), Const(mz) + # rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw + rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw + return rz + + # -- quaternion multiplication helpers ---------------------------------------- diff --git a/kindred_solver/geometry.py b/kindred_solver/geometry.py new file mode 100644 index 0000000..0365921 --- /dev/null +++ b/kindred_solver/geometry.py @@ -0,0 +1,131 @@ +"""Geometric helper functions for constraint equations. + +Provides Expr-level vector operations and marker axis extraction. +All functions work with Expr triples (tuples of 3 Expr nodes) +representing 3D vectors in world coordinates. + +Marker convention (from Types.h): the marker's Z-axis defines the +constraint direction (hinge axis, face normal, line direction, etc.). +""" + +from __future__ import annotations + +from .entities import RigidBody +from .expr import Const, Expr +from .quat import quat_rotate + +# Type alias for an Expr triple (3D vector) +Vec3 = tuple[Expr, Expr, Expr] + + +# -- Marker axis extraction --------------------------------------------------- + + +def _composed_quat( + body: RigidBody, + marker_quat: tuple[float, float, float, float], +) -> tuple[Expr, Expr, Expr, Expr]: + """Compute q_total = q_body * q_marker as Expr quaternion. + + q_body comes from the body's Var params; q_marker is constant. + """ + bw, bx, by, bz = body.qw, body.qx, body.qy, body.qz + mw, mx, my, mz = (Const(v) for v in marker_quat) + # Hamilton product: body * marker + rw = bw * mw - bx * mx - by * my - bz * mz + rx = bw * mx + bx * mw + by * mz - bz * my + ry = bw * my - bx * mz + by * mw + bz * mx + rz = bw * mz + bx * my - by * mx + bz * mw + return rw, rx, ry, rz + + +def marker_z_axis( + body: RigidBody, + marker_quat: tuple[float, float, float, float], +) -> Vec3: + """World-frame Z-axis of a marker on a body. + + Computes rotate(q_body * q_marker, [0, 0, 1]). + """ + qw, qx, qy, qz = _composed_quat(body, marker_quat) + return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(0.0), Const(1.0)) + + +def marker_x_axis( + body: RigidBody, + marker_quat: tuple[float, float, float, float], +) -> Vec3: + """World-frame X-axis of a marker on a body. + + Computes rotate(q_body * q_marker, [1, 0, 0]). + """ + qw, qx, qy, qz = _composed_quat(body, marker_quat) + return quat_rotate(qw, qx, qy, qz, Const(1.0), Const(0.0), Const(0.0)) + + +def marker_y_axis( + body: RigidBody, + marker_quat: tuple[float, float, float, float], +) -> Vec3: + """World-frame Y-axis of a marker on a body. + + Computes rotate(q_body * q_marker, [0, 1, 0]). + """ + qw, qx, qy, qz = _composed_quat(body, marker_quat) + return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(1.0), Const(0.0)) + + +# -- Vector operations on Expr triples ---------------------------------------- + + +def dot3(a: Vec3, b: Vec3) -> Expr: + """Dot product of two Expr triples.""" + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + + +def cross3(a: Vec3, b: Vec3) -> Vec3: + """Cross product of two Expr triples.""" + return ( + a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0], + ) + + +def sub3(a: Vec3, b: Vec3) -> Vec3: + """Vector subtraction a - b.""" + return (a[0] - b[0], a[1] - b[1], a[2] - b[2]) + + +# -- Geometric primitives ----------------------------------------------------- + + +def point_plane_distance( + point: Vec3, + plane_origin: Vec3, + normal: Vec3, +) -> Expr: + """Signed distance from point to plane defined by origin + normal. + + Returns (point - plane_origin) . normal + """ + d = sub3(point, plane_origin) + return dot3(d, normal) + + +def point_line_perp_components( + point: Vec3, + line_origin: Vec3, + line_dir: Vec3, +) -> tuple[Expr, Expr]: + """Two independent perpendicular-distance components from point to line. + + The line passes through line_origin along line_dir. + Returns the x and y components of (point - line_origin) x line_dir, + which are zero when the point lies on the line. + """ + d = sub3(point, line_origin) + cx, cy, cz = cross3(d, line_dir) + # All three components of d x line_dir are zero when d is parallel + # to line_dir, but only 2 are independent. We return x and y. + return cx, cy diff --git a/kindred_solver/solver.py b/kindred_solver/solver.py index 53d1749..bab8c8a 100644 --- a/kindred_solver/solver.py +++ b/kindred_solver/solver.py @@ -5,11 +5,32 @@ from __future__ import annotations import kcsolve +from .bfgs import bfgs_solve from .constraints import ( + AngleConstraint, + BallConstraint, + CamConstraint, CoincidentConstraint, + ConcentricConstraint, ConstraintBase, + CylindricalConstraint, + DistanceCylSphConstraint, DistancePointPointConstraint, FixedConstraint, + GearConstraint, + LineInPlaneConstraint, + ParallelConstraint, + PerpendicularConstraint, + PlanarConstraint, + PointInPlaneConstraint, + PointOnLineConstraint, + RackPinionConstraint, + RevoluteConstraint, + ScrewConstraint, + SliderConstraint, + SlotConstraint, + TangentConstraint, + UniversalConstraint, ) from .dof import count_dof from .entities import RigidBody @@ -17,11 +38,34 @@ from .newton import newton_solve from .params import ParamTable from .prepass import single_equation_pass, substitution_pass -# Map BaseJointKind enum values to handler names +# All BaseJointKind values this solver can handle _SUPPORTED = { + # Phase 1 kcsolve.BaseJointKind.Coincident, kcsolve.BaseJointKind.DistancePointPoint, kcsolve.BaseJointKind.Fixed, + # Phase 2: point constraints + kcsolve.BaseJointKind.PointOnLine, + kcsolve.BaseJointKind.PointInPlane, + # Phase 2: orientation + kcsolve.BaseJointKind.Parallel, + kcsolve.BaseJointKind.Perpendicular, + kcsolve.BaseJointKind.Angle, + # Phase 2: axis/surface + kcsolve.BaseJointKind.Concentric, + kcsolve.BaseJointKind.Tangent, + kcsolve.BaseJointKind.Planar, + kcsolve.BaseJointKind.LineInPlane, + # Phase 2: kinematic joints + kcsolve.BaseJointKind.Ball, + kcsolve.BaseJointKind.Revolute, + kcsolve.BaseJointKind.Cylindrical, + kcsolve.BaseJointKind.Slider, + kcsolve.BaseJointKind.Screw, + kcsolve.BaseJointKind.Universal, + # Phase 2: mechanical + kcsolve.BaseJointKind.Gear, + kcsolve.BaseJointKind.RackPinion, } @@ -92,7 +136,7 @@ class KindredSolver(kcsolve.IKCSolver): all_residuals = substitution_pass(all_residuals, params) all_residuals = single_equation_pass(all_residuals, params) - # 5. Newton-Raphson + # 5. Newton-Raphson (with BFGS fallback) converged = newton_solve( all_residuals, params, @@ -100,6 +144,14 @@ class KindredSolver(kcsolve.IKCSolver): max_iter=100, tol=1e-10, ) + if not converged: + converged = bfgs_solve( + all_residuals, + params, + quat_groups=quat_groups, + max_iter=200, + tol=1e-10, + ) # 6. DOF dof = count_dof(all_residuals, params) @@ -141,6 +193,11 @@ def _build_constraint( c_params, ) -> ConstraintBase | None: """Create the appropriate constraint object from a BaseJointKind.""" + marker_i_quat = tuple(marker_i.quaternion) + marker_j_quat = tuple(marker_j.quaternion) + + # -- Phase 1 constraints -------------------------------------------------- + if kind == kcsolve.BaseJointKind.Coincident: return CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos) @@ -155,8 +212,6 @@ def _build_constraint( ) if kind == kcsolve.BaseJointKind.Fixed: - marker_i_quat = tuple(marker_i.quaternion) - marker_j_quat = tuple(marker_j.quaternion) return FixedConstraint( body_i, marker_i_pos, @@ -166,4 +221,182 @@ def _build_constraint( marker_j_quat, ) + # -- Phase 2: point constraints ------------------------------------------- + + if kind == kcsolve.BaseJointKind.PointOnLine: + return PointOnLineConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + ) + + if kind == kcsolve.BaseJointKind.PointInPlane: + offset = c_params[0] if c_params else 0.0 + return PointInPlaneConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + offset=offset, + ) + + # -- Phase 2: orientation constraints ------------------------------------- + + if kind == kcsolve.BaseJointKind.Parallel: + return ParallelConstraint(body_i, marker_i_quat, body_j, marker_j_quat) + + if kind == kcsolve.BaseJointKind.Perpendicular: + return PerpendicularConstraint(body_i, marker_i_quat, body_j, marker_j_quat) + + if kind == kcsolve.BaseJointKind.Angle: + angle = c_params[0] if c_params else 0.0 + return AngleConstraint(body_i, marker_i_quat, body_j, marker_j_quat, angle) + + # -- Phase 2: axis/surface constraints ------------------------------------ + + if kind == kcsolve.BaseJointKind.Concentric: + distance = c_params[0] if c_params else 0.0 + return ConcentricConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + distance=distance, + ) + + if kind == kcsolve.BaseJointKind.Tangent: + return TangentConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + ) + + if kind == kcsolve.BaseJointKind.Planar: + offset = c_params[0] if c_params else 0.0 + return PlanarConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + offset=offset, + ) + + if kind == kcsolve.BaseJointKind.LineInPlane: + offset = c_params[0] if c_params else 0.0 + return LineInPlaneConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + offset=offset, + ) + + # -- Phase 2: kinematic joints -------------------------------------------- + + if kind == kcsolve.BaseJointKind.Ball: + return BallConstraint(body_i, marker_i_pos, body_j, marker_j_pos) + + if kind == kcsolve.BaseJointKind.Revolute: + return RevoluteConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + ) + + if kind == kcsolve.BaseJointKind.Cylindrical: + return CylindricalConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + ) + + if kind == kcsolve.BaseJointKind.Slider: + return SliderConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + ) + + if kind == kcsolve.BaseJointKind.Screw: + pitch = c_params[0] if c_params else 1.0 + return ScrewConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + pitch=pitch, + ) + + if kind == kcsolve.BaseJointKind.Universal: + return UniversalConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + ) + + # -- Phase 2: mechanical constraints -------------------------------------- + + if kind == kcsolve.BaseJointKind.Gear: + radius_i = c_params[0] if len(c_params) > 0 else 1.0 + radius_j = c_params[1] if len(c_params) > 1 else 1.0 + return GearConstraint( + body_i, + marker_i_quat, + body_j, + marker_j_quat, + radius_i, + radius_j, + ) + + if kind == kcsolve.BaseJointKind.RackPinion: + pitch_radius = c_params[0] if c_params else 1.0 + return RackPinionConstraint( + body_i, + marker_i_pos, + marker_i_quat, + body_j, + marker_j_pos, + marker_j_quat, + pitch_radius=pitch_radius, + ) + + # -- Stubs (accepted but produce no residuals) ---------------------------- + + if kind == kcsolve.BaseJointKind.Cam: + return CamConstraint() + + if kind == kcsolve.BaseJointKind.Slot: + return SlotConstraint() + + if kind == kcsolve.BaseJointKind.DistanceCylSph: + return DistanceCylSphConstraint() + return None diff --git a/tests/test_bfgs.py b/tests/test_bfgs.py new file mode 100644 index 0000000..f434151 --- /dev/null +++ b/tests/test_bfgs.py @@ -0,0 +1,70 @@ +"""Tests for the BFGS fallback solver.""" + +import math + +import pytest +from kindred_solver.bfgs import bfgs_solve +from kindred_solver.expr import Const, Var +from kindred_solver.params import ParamTable + + +class TestBFGSBasic: + def test_single_linear(self): + """Solve x - 3 = 0.""" + pt = ParamTable() + x = pt.add("x", 0.0) + assert bfgs_solve([x - Const(3.0)], pt) is True + assert abs(pt.get_value("x") - 3.0) < 1e-8 + + def test_single_quadratic(self): + """Solve x^2 - 4 = 0 from x=1 → x=2.""" + pt = ParamTable() + x = pt.add("x", 1.0) + assert bfgs_solve([x * x - Const(4.0)], pt) is True + assert abs(pt.get_value("x") - 2.0) < 1e-8 + + def test_two_variables(self): + """Solve x + y = 5, x - y = 1.""" + pt = ParamTable() + x = pt.add("x", 0.0) + y = pt.add("y", 0.0) + assert bfgs_solve([x + y - Const(5.0), x - y - Const(1.0)], pt) is True + assert abs(pt.get_value("x") - 3.0) < 1e-8 + assert abs(pt.get_value("y") - 2.0) < 1e-8 + + def test_empty_system(self): + pt = ParamTable() + assert bfgs_solve([], pt) is True + + def test_with_quat_renorm(self): + """Quaternion re-normalization during BFGS.""" + pt = ParamTable() + qw = pt.add("qw", 0.9) + qx = pt.add("qx", 0.1) + qy = pt.add("qy", 0.1) + qz = pt.add("qz", 0.1) + r = qw * qw + qx * qx + qy * qy + qz * qz - Const(1.0) + groups = [("qw", "qx", "qy", "qz")] + assert bfgs_solve([r], pt, quat_groups=groups) is True + w, x, y, z = (pt.get_value(n) for n in ["qw", "qx", "qy", "qz"]) + norm = math.sqrt(w**2 + x**2 + y**2 + z**2) + assert abs(norm - 1.0) < 1e-8 + + +class TestBFGSGeometric: + def test_distance_constraint(self): + """x^2 - 25 = 0 from x=3 → x=5.""" + pt = ParamTable() + x = pt.add("x", 3.0) + assert bfgs_solve([x * x - Const(25.0)], pt) is True + assert abs(pt.get_value("x") - 5.0) < 1e-8 + + def test_difficult_initial_guess(self): + """BFGS should handle worse initial guesses than Newton.""" + pt = ParamTable() + x = pt.add("x", 100.0) + y = pt.add("y", -50.0) + residuals = [x + y - Const(5.0), x - y - Const(1.0)] + assert bfgs_solve(residuals, pt) is True + assert abs(pt.get_value("x") - 3.0) < 1e-6 + assert abs(pt.get_value("y") - 2.0) < 1e-6 diff --git a/tests/test_constraints_phase2.py b/tests/test_constraints_phase2.py new file mode 100644 index 0000000..210df3b --- /dev/null +++ b/tests/test_constraints_phase2.py @@ -0,0 +1,481 @@ +"""Tests for Phase 2 constraint residual generation.""" + +import math + +import pytest +from kindred_solver.constraints import ( + AngleConstraint, + BallConstraint, + CamConstraint, + ConcentricConstraint, + CylindricalConstraint, + DistanceCylSphConstraint, + GearConstraint, + LineInPlaneConstraint, + ParallelConstraint, + PerpendicularConstraint, + PlanarConstraint, + PointInPlaneConstraint, + PointOnLineConstraint, + RackPinionConstraint, + RevoluteConstraint, + ScrewConstraint, + SliderConstraint, + SlotConstraint, + TangentConstraint, + UniversalConstraint, +) +from kindred_solver.entities import RigidBody +from kindred_solver.params import ParamTable + +ID_QUAT = (1.0, 0.0, 0.0, 0.0) +# 90-deg about Y: Z-axis of body rotates to point along X +_c = math.cos(math.pi / 4) +_s = math.sin(math.pi / 4) +ROT_90Y = (_c, 0.0, _s, 0.0) +ROT_90Z = (_c, 0.0, 0.0, _s) + + +# ── Point constraints ──────────────────────────────────────────────── + + +class TestPointOnLine: + def test_on_line(self): + """Point at (0,0,5) is on Z-axis line through origin.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0)) + c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_off_line(self): + """Point at (3,0,5) is NOT on Z-axis line through origin.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (3, 0, 5), (1, 0, 0, 0)) + c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT) + env = pt.get_env() + vals = [r.eval(env) for r in c.residuals()] + assert any(abs(v) > 0.1 for v in vals) + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = PointOnLineConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 2 + + +class TestPointInPlane: + def test_in_plane(self): + """Point at (3,4,0) is in XY plane through origin.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (3, 4, 0), (1, 0, 0, 0)) + c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_above_plane(self): + """Point at (0,0,7) is 7 above XY plane.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 7), (1, 0, 0, 0)) + c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env) - 7.0) < 1e-10 + + def test_with_offset(self): + """Point at (0,0,5) with offset=5 → residual 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0)) + c = PointInPlaneConstraint( + b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0 + ) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = PointInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 1 + + +# ── Orientation constraints ────────────────────────────────────────── + + +class TestParallel: + def test_parallel_same(self): + """Both bodies with identity rotation → Z-axes parallel → residuals 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0)) + c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_not_parallel(self): + """One body rotated 90-deg about Y → Z-axes perpendicular.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y) + c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT) + env = pt.get_env() + vals = [r.eval(env) for r in c.residuals()] + assert any(abs(v) > 0.1 for v in vals) + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT) + assert len(c.residuals()) == 2 + + +class TestPerpendicular: + def test_perpendicular(self): + """One body rotated 90-deg about Y → Z-axes perpendicular.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y) + c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_not_perpendicular(self): + """Same orientation → not perpendicular.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT) + env = pt.get_env() + # dot(z,z) = 1 ≠ 0 + assert abs(c.residuals()[0].eval(env) - 1.0) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT) + assert len(c.residuals()) == 1 + + +class TestAngle: + def test_90_degrees(self): + """90-deg angle between Z-axes rotated 90-deg about Y.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y) + c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, math.pi / 2) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_0_degrees(self): + """0-deg angle, same orientation → cos(0)=1, dot=1 → residual 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 0.0) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0) + assert len(c.residuals()) == 1 + + +# ── Axis/surface constraints ───────────────────────────────────────── + + +class TestConcentric: + def test_coaxial(self): + """Both on Z-axis → coaxial → residuals 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0)) + c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_not_coaxial(self): + """Offset in X → not coaxial.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0)) + c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + vals = [r.eval(env) for r in c.residuals()] + assert any(abs(v) > 0.1 for v in vals) + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 4 + + +class TestTangent: + def test_touching(self): + """Marker origins at same point → tangent.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_separated(self): + """Separated along normal → non-zero residual.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 5), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env) - 5.0) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 1 + + +class TestPlanar: + def test_coplanar(self): + """Same plane, same orientation → all residuals 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (5, 3, 0), (1, 0, 0, 0)) + c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_with_offset(self): + """b_i at z=5, b_j at origin, normal=Z, offset=5. + Signed distance = (p_i - p_j).n = 5, offset=5 → 5-5 = 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0)) + c = PlanarConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 3 + + +class TestLineInPlane: + def test_in_plane(self): + """Line along X in XY plane → residuals 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + # b2 has Z-axis = (1,0,0) via 90-deg rotation about Y + b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y) + # Line = b2's Z-axis (which is world X), plane = b1's XY plane (normal=Z) + c = LineInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = LineInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 2 + + +# ── Kinematic joints ───────────────────────────────────────────────── + + +class TestBall: + def test_same_as_coincident(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = BallConstraint(b1, (0, 0, 0), b2, (0, 0, 0)) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + assert len(c.residuals()) == 3 + + +class TestRevolute: + def test_satisfied(self): + """Same position, same Z-axis → satisfied.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z) # rotated about Z — still parallel + c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 5 + + +class TestCylindrical: + def test_on_axis(self): + """Same axis, displaced along Z → satisfied.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0)) + c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 4 + + +class TestSlider: + def test_aligned(self): + """Same axis, no twist, displaced along Z → satisfied.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0)) + c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_twisted(self): + """Rotated about Z → twist residual non-zero.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z) + c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + vals = [r.eval(env) for r in c.residuals()] + # First 4 should be ~0 (parallel + on-line), but twist residual should be ~1 + assert abs(vals[4]) > 0.5 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 5 + + +class TestUniversal: + def test_satisfied(self): + """Same origin, perpendicular Z-axes.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y) + c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT) + assert len(c.residuals()) == 4 + + +class TestScrew: + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0) + assert len(c.residuals()) == 5 + + def test_zero_displacement_zero_rotation(self): + """Both at origin with identity rotation → all residuals 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0) + env = pt.get_env() + for r in c.residuals(): + assert abs(r.eval(env)) < 1e-10 + + +# ── Mechanical constraints ─────────────────────────────────────────── + + +class TestGear: + def test_both_at_rest(self): + """Both at identity rotation → residual 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 1.0) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 2.0) + assert len(c.residuals()) == 1 + + +class TestRackPinion: + def test_at_rest(self): + """Both at rest → residual 0.""" + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = RackPinionConstraint( + b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=5.0 + ) + env = pt.get_env() + assert abs(c.residuals()[0].eval(env)) < 1e-10 + + def test_residual_count(self): + pt = ParamTable() + b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0)) + b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0)) + c = RackPinionConstraint( + b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=1.0 + ) + assert len(c.residuals()) == 1 + + +# ── Stubs ──────────────────────────────────────────────────────────── + + +class TestStubs: + def test_cam(self): + assert CamConstraint().residuals() == [] + + def test_slot(self): + assert SlotConstraint().residuals() == [] + + def test_distance_cyl_sph(self): + assert DistanceCylSphConstraint().residuals() == [] diff --git a/tests/test_geometry.py b/tests/test_geometry.py new file mode 100644 index 0000000..45b90ba --- /dev/null +++ b/tests/test_geometry.py @@ -0,0 +1,187 @@ +"""Tests for geometry helpers.""" + +import math + +import pytest +from kindred_solver.entities import RigidBody +from kindred_solver.expr import Const, Var +from kindred_solver.geometry import ( + cross3, + dot3, + marker_x_axis, + marker_y_axis, + marker_z_axis, + point_line_perp_components, + point_plane_distance, + sub3, +) +from kindred_solver.params import ParamTable + +IDENTITY_QUAT = (1.0, 0.0, 0.0, 0.0) +# 90-deg about Z: (cos45, 0, 0, sin45) +_c = math.cos(math.pi / 4) +_s = math.sin(math.pi / 4) +ROT_90Z_QUAT = (_c, 0.0, 0.0, _s) + + +class TestDot3: + def test_parallel(self): + a = (Const(1.0), Const(0.0), Const(0.0)) + b = (Const(1.0), Const(0.0), Const(0.0)) + assert abs(dot3(a, b).eval({}) - 1.0) < 1e-10 + + def test_perpendicular(self): + a = (Const(1.0), Const(0.0), Const(0.0)) + b = (Const(0.0), Const(1.0), Const(0.0)) + assert abs(dot3(a, b).eval({})) < 1e-10 + + def test_general(self): + a = (Const(1.0), Const(2.0), Const(3.0)) + b = (Const(4.0), Const(5.0), Const(6.0)) + # 1*4 + 2*5 + 3*6 = 32 + assert abs(dot3(a, b).eval({}) - 32.0) < 1e-10 + + +class TestCross3: + def test_x_cross_y(self): + x = (Const(1.0), Const(0.0), Const(0.0)) + y = (Const(0.0), Const(1.0), Const(0.0)) + cx, cy, cz = cross3(x, y) + assert abs(cx.eval({})) < 1e-10 + assert abs(cy.eval({})) < 1e-10 + assert abs(cz.eval({}) - 1.0) < 1e-10 + + def test_parallel_is_zero(self): + a = (Const(2.0), Const(3.0), Const(4.0)) + b = (Const(4.0), Const(6.0), Const(8.0)) + cx, cy, cz = cross3(a, b) + assert abs(cx.eval({})) < 1e-10 + assert abs(cy.eval({})) < 1e-10 + assert abs(cz.eval({})) < 1e-10 + + +class TestSub3: + def test_basic(self): + a = (Const(5.0), Const(3.0), Const(1.0)) + b = (Const(1.0), Const(2.0), Const(3.0)) + dx, dy, dz = sub3(a, b) + assert abs(dx.eval({}) - 4.0) < 1e-10 + assert abs(dy.eval({}) - 1.0) < 1e-10 + assert abs(dz.eval({}) - (-2.0)) < 1e-10 + + +class TestMarkerAxes: + def test_identity_z(self): + """Identity body + identity marker → Z = (0,0,1).""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) + zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT) + env = pt.get_env() + assert abs(zx.eval(env)) < 1e-10 + assert abs(zy.eval(env)) < 1e-10 + assert abs(zz.eval(env) - 1.0) < 1e-10 + + def test_identity_x(self): + """Identity body + identity marker → X = (1,0,0).""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) + xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT) + env = pt.get_env() + assert abs(xx.eval(env) - 1.0) < 1e-10 + assert abs(xy.eval(env)) < 1e-10 + assert abs(xz.eval(env)) < 1e-10 + + def test_identity_y(self): + """Identity body + identity marker → Y = (0,1,0).""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) + yx, yy, yz = marker_y_axis(body, IDENTITY_QUAT) + env = pt.get_env() + assert abs(yx.eval(env)) < 1e-10 + assert abs(yy.eval(env) - 1.0) < 1e-10 + assert abs(yz.eval(env)) < 1e-10 + + def test_rotated_body_z(self): + """Body rotated 90-deg about Z → Z-axis still (0,0,1).""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT) + zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT) + env = pt.get_env() + assert abs(zx.eval(env)) < 1e-10 + assert abs(zy.eval(env)) < 1e-10 + assert abs(zz.eval(env) - 1.0) < 1e-10 + + def test_rotated_body_x(self): + """Body rotated 90-deg about Z → X-axis becomes (0,1,0).""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT) + xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT) + env = pt.get_env() + assert abs(xx.eval(env)) < 1e-10 + assert abs(xy.eval(env) - 1.0) < 1e-10 + assert abs(xz.eval(env)) < 1e-10 + + def test_marker_rotation(self): + """Identity body + marker rotated 90-deg about Z → Z still (0,0,1).""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) + zx, zy, zz = marker_z_axis(body, ROT_90Z_QUAT) + env = pt.get_env() + assert abs(zx.eval(env)) < 1e-10 + assert abs(zy.eval(env)) < 1e-10 + assert abs(zz.eval(env) - 1.0) < 1e-10 + + def test_marker_rotation_x_axis(self): + """Identity body + marker rotated 90-deg about Z → X becomes (0,1,0).""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) + xx, xy, xz = marker_x_axis(body, ROT_90Z_QUAT) + env = pt.get_env() + assert abs(xx.eval(env)) < 1e-10 + assert abs(xy.eval(env) - 1.0) < 1e-10 + assert abs(xz.eval(env)) < 1e-10 + + def test_differentiable(self): + """Marker axes are differentiable w.r.t. body quat params.""" + pt = ParamTable() + body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0)) + zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT) + # Should not raise + dzx = zx.diff("p/qz").simplify() + env = pt.get_env() + dzx.eval(env) # Should be evaluable + + +class TestPointPlaneDistance: + def test_on_plane(self): + pt = (Const(1.0), Const(2.0), Const(0.0)) + origin = (Const(0.0), Const(0.0), Const(0.0)) + normal = (Const(0.0), Const(0.0), Const(1.0)) + d = point_plane_distance(pt, origin, normal) + assert abs(d.eval({})) < 1e-10 + + def test_above_plane(self): + pt = (Const(1.0), Const(2.0), Const(5.0)) + origin = (Const(0.0), Const(0.0), Const(0.0)) + normal = (Const(0.0), Const(0.0), Const(1.0)) + d = point_plane_distance(pt, origin, normal) + assert abs(d.eval({}) - 5.0) < 1e-10 + + +class TestPointLinePerp: + def test_on_line(self): + pt = (Const(0.0), Const(0.0), Const(5.0)) + origin = (Const(0.0), Const(0.0), Const(0.0)) + direction = (Const(0.0), Const(0.0), Const(1.0)) + cx, cy = point_line_perp_components(pt, origin, direction) + assert abs(cx.eval({})) < 1e-10 + assert abs(cy.eval({})) < 1e-10 + + def test_off_line(self): + pt = (Const(3.0), Const(0.0), Const(0.0)) + origin = (Const(0.0), Const(0.0), Const(0.0)) + direction = (Const(0.0), Const(0.0), Const(1.0)) + cx, cy = point_line_perp_components(pt, origin, direction) + # d = (3,0,0), dir = (0,0,1), d x dir = (0*1-0*0, 0*0-3*1, 3*0-0*0) = (0,-3,0) + assert abs(cx.eval({})) < 1e-10 + assert abs(cy.eval({}) - (-3.0)) < 1e-10 diff --git a/tests/test_joints.py b/tests/test_joints.py new file mode 100644 index 0000000..9b96757 --- /dev/null +++ b/tests/test_joints.py @@ -0,0 +1,612 @@ +"""Integration tests for kinematic joint constraints. + +These tests exercise the full solve pipeline (constraint → residuals → +pre-pass → Newton / BFGS) for multi-body systems with various joint types. +""" + +import math + +import pytest +from kindred_solver.constraints import ( + BallConstraint, + CoincidentConstraint, + CylindricalConstraint, + GearConstraint, + ParallelConstraint, + PerpendicularConstraint, + PlanarConstraint, + PointInPlaneConstraint, + PointOnLineConstraint, + RackPinionConstraint, + RevoluteConstraint, + ScrewConstraint, + SliderConstraint, + UniversalConstraint, +) +from kindred_solver.dof import count_dof +from kindred_solver.entities import RigidBody +from kindred_solver.newton import newton_solve +from kindred_solver.params import ParamTable +from kindred_solver.prepass import single_equation_pass, substitution_pass + +ID_QUAT = (1, 0, 0, 0) +# 90° about Z: (cos(45°), 0, 0, sin(45°)) +c45 = math.cos(math.pi / 4) +s45 = math.sin(math.pi / 4) +ROT_90Z = (c45, 0, 0, s45) +# 90° about Y +ROT_90Y = (c45, 0, s45, 0) +# 90° about X +ROT_90X = (c45, s45, 0, 0) + + +def _solve(bodies, constraint_objs): + """Run the full solve pipeline. Returns (converged, params, bodies).""" + pt = bodies[0].tx # all bodies share the same ParamTable via Var._name + # Actually, we need the ParamTable object. Get it from the first body. + # The Var objects store names, but we need the table. We'll reconstruct. + # Better approach: caller passes pt. + + raise NotImplementedError("Use _solve_with_pt instead") + + +def _solve_with_pt(pt, bodies, constraint_objs): + """Run the full solve pipeline with explicit ParamTable.""" + all_residuals = [] + for c in constraint_objs: + all_residuals.extend(c.residuals()) + + quat_groups = [] + for body in bodies: + if not body.grounded: + all_residuals.append(body.quat_norm_residual()) + quat_groups.append(body.quat_param_names()) + + all_residuals = substitution_pass(all_residuals, pt) + all_residuals = single_equation_pass(all_residuals, pt) + + converged = newton_solve( + all_residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10 + ) + return converged, all_residuals + + +def _dof(pt, bodies, constraint_objs): + """Count DOF for a system.""" + all_residuals = [] + for c in constraint_objs: + all_residuals.extend(c.residuals()) + for body in bodies: + if not body.grounded: + all_residuals.append(body.quat_norm_residual()) + all_residuals = substitution_pass(all_residuals, pt) + return count_dof(all_residuals, pt) + + +# ============================================================================ +# Single-joint DOF counting tests +# ============================================================================ + + +class TestJointDOF: + """Verify each joint type removes the expected number of DOF. + + Setup: ground body + 1 free body (6 DOF) with a single joint. + """ + + def _setup(self, pos_b=(0, 0, 0), quat_b=ID_QUAT): + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, pos_b, quat_b) + return pt, a, b + + def test_ball_3dof(self): + """Ball joint: 6 - 3 = 3 DOF (3 rotation).""" + pt, a, b = self._setup() + constraints = [BallConstraint(a, (0, 0, 0), b, (0, 0, 0))] + assert _dof(pt, [a, b], constraints) == 3 + + def test_revolute_1dof(self): + """Revolute: 6 - 5 = 1 DOF (rotation about Z).""" + pt, a, b = self._setup() + constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] + assert _dof(pt, [a, b], constraints) == 1 + + def test_cylindrical_2dof(self): + """Cylindrical: 6 - 4 = 2 DOF (rotation + translation along Z).""" + pt, a, b = self._setup() + constraints = [ + CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + assert _dof(pt, [a, b], constraints) == 2 + + def test_slider_1dof(self): + """Slider: 6 - 5 = 1 DOF (translation along Z).""" + pt, a, b = self._setup() + constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] + assert _dof(pt, [a, b], constraints) == 1 + + def test_universal_2dof(self): + """Universal: 6 - 4 = 2 DOF (rotation about each body's Z).""" + pt, a, b = self._setup(quat_b=ROT_90X) + constraints = [ + UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + assert _dof(pt, [a, b], constraints) == 2 + + def test_screw_1dof(self): + """Screw: 6 - 5 = 1 DOF (helical motion).""" + pt, a, b = self._setup() + constraints = [ + ScrewConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT, pitch=10.0) + ] + assert _dof(pt, [a, b], constraints) == 1 + + def test_parallel_4dof(self): + """Parallel: 6 - 2 = 4 DOF.""" + pt, a, b = self._setup() + constraints = [ParallelConstraint(a, ID_QUAT, b, ID_QUAT)] + assert _dof(pt, [a, b], constraints) == 4 + + def test_perpendicular_5dof(self): + """Perpendicular: 6 - 1 = 5 DOF.""" + pt, a, b = self._setup(quat_b=ROT_90X) + constraints = [PerpendicularConstraint(a, ID_QUAT, b, ID_QUAT)] + assert _dof(pt, [a, b], constraints) == 5 + + def test_point_on_line_4dof(self): + """PointOnLine: 6 - 2 = 4 DOF.""" + pt, a, b = self._setup() + constraints = [ + PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + assert _dof(pt, [a, b], constraints) == 4 + + def test_point_in_plane_5dof(self): + """PointInPlane: 6 - 1 = 5 DOF.""" + pt, a, b = self._setup() + constraints = [ + PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + assert _dof(pt, [a, b], constraints) == 5 + + def test_planar_3dof(self): + """Planar: 6 - 3 = 3 DOF (2 translation in plane + 1 rotation about normal).""" + pt, a, b = self._setup() + constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] + assert _dof(pt, [a, b], constraints) == 3 + + +# ============================================================================ +# Solve convergence tests — single joints from displaced initial conditions +# ============================================================================ + + +class TestJointSolve: + """Newton converges to a valid configuration from displaced starting points.""" + + def test_revolute_displaced(self): + """Revolute joint: body B starts displaced, should converge to hinge position.""" + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, (3, 4, 5), ID_QUAT) # displaced + + constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] + converged, _ = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + pos = b.extract_position(env) + # Coincident origins → position should be at origin + assert abs(pos[0]) < 1e-8 + assert abs(pos[1]) < 1e-8 + assert abs(pos[2]) < 1e-8 + + def test_cylindrical_displaced(self): + """Cylindrical joint: body B can slide along Z but must be on axis.""" + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, (3, 4, 7), ID_QUAT) # off-axis + + constraints = [ + CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + converged, _ = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + pos = b.extract_position(env) + # X and Y should be zero (on axis), Z can be anything + assert abs(pos[0]) < 1e-8 + assert abs(pos[1]) < 1e-8 + + def test_slider_displaced(self): + """Slider: body B can translate along Z only.""" + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, (2, 3, 5), ID_QUAT) # displaced + + constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] + converged, _ = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + pos = b.extract_position(env) + # X and Y should be zero (on axis), Z free + assert abs(pos[0]) < 1e-8 + assert abs(pos[1]) < 1e-8 + + def test_ball_displaced(self): + """Ball joint: body B moves so marker origins coincide. + + Ball has 3 rotation DOF free, so we can only verify the + world-frame marker points match, not the body position directly. + """ + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, (5, 5, 5), ID_QUAT) + + constraints = [BallConstraint(a, (1, 0, 0), b, (-1, 0, 0))] + converged, _ = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + # Verify marker world points match + wp_a = a.world_point(1, 0, 0) + wp_b = b.world_point(-1, 0, 0) + for ea, eb in zip(wp_a, wp_b): + assert abs(ea.eval(env) - eb.eval(env)) < 1e-8 + + def test_universal_displaced(self): + """Universal joint: coincident origins + perpendicular Z-axes.""" + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + # Start B with Z-axis along X (90° about Y) — perpendicular to A's Z + b = RigidBody("b", pt, (3, 4, 5), ROT_90Y) + + constraints = [ + UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + converged, _ = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + pos = b.extract_position(env) + assert abs(pos[0]) < 1e-8 + assert abs(pos[1]) < 1e-8 + assert abs(pos[2]) < 1e-8 + + def test_point_on_line_solve(self): + """Point on line: body B's marker origin constrained to line along Z. + + Under-constrained system (4 DOF remain), so we verify the constraint + residuals are satisfied rather than expecting specific positions. + """ + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, (5, 3, 7), ID_QUAT) + + constraints = [ + PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + converged, residuals = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + for r in residuals: + assert abs(r.eval(env)) < 1e-8 + + def test_point_in_plane_solve(self): + """Point in plane: body B's marker origin at z=0 plane. + + Under-constrained (5 DOF remain), so verify residuals. + """ + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, (3, 4, 8), ID_QUAT) + + constraints = [ + PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT) + ] + converged, residuals = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + for r in residuals: + assert abs(r.eval(env)) < 1e-8 + + def test_planar_solve(self): + """Planar: coplanar faces — parallel normals + point in plane.""" + pt = ParamTable() + a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True) + # Start B tilted and displaced + b = RigidBody("b", pt, (3, 4, 8), ID_QUAT) + + constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)] + converged, _ = _solve_with_pt(pt, [a, b], constraints) + assert converged + + env = pt.get_env() + pos = b.extract_position(env) + # Z must be zero (in plane), X and Y free + assert abs(pos[2]) < 1e-8 + + +# ============================================================================ +# Multi-body integration tests +# ============================================================================ + + +class TestFourBarLinkage: + """Four-bar linkage: 4 bodies, 4 revolute joints. + + In 3D with Z-axis revolutes, this yields 2 DOF: the expected planar + motion plus an out-of-plane fold. A truly planar mechanism would + add Planar constraints on each link to eliminate the fold DOF. + """ + + def test_four_bar_dof(self): + """Four-bar linkage in 3D has 2 DOF (planar + fold).""" + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + link1 = RigidBody("l1", pt, (2, 0, 0), ID_QUAT) + link2 = RigidBody("l2", pt, (5, 3, 0), ID_QUAT) + link3 = RigidBody("l3", pt, (8, 0, 0), ID_QUAT) + + constraints = [ + RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT), + RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT), + RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT), + RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT), + ] + + bodies = [ground, link1, link2, link3] + dof = _dof(pt, bodies, constraints) + assert dof == 2 + + def test_four_bar_solves(self): + """Four-bar linkage converges from displaced initial conditions.""" + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + # Initial positions slightly displaced from valid config + link1 = RigidBody("l1", pt, (2, 1, 0), ID_QUAT) + link2 = RigidBody("l2", pt, (5, 4, 0), ID_QUAT) + link3 = RigidBody("l3", pt, (8, 1, 0), ID_QUAT) + + constraints = [ + RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT), + RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT), + RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT), + RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT), + ] + + bodies = [ground, link1, link2, link3] + converged, residuals = _solve_with_pt(pt, bodies, constraints) + assert converged + + # Verify all revolute constraints are satisfied + env = pt.get_env() + for r in residuals: + assert abs(r.eval(env)) < 1e-8 + + +class TestSliderCrank: + """Slider-crank mechanism: crank + connecting rod + piston. + + ground --[Revolute]-- crank --[Revolute]-- rod --[Revolute]-- piston --[Slider]-- ground + + Using Slider (not Cylindrical) for the piston to also lock rotation, + making it a true prismatic joint. In 3D, out-of-plane folding adds + extra DOF beyond the planar 1-DOF. + + 3 free bodies × 6 = 18 DOF + Revolute(5) + Revolute(5) + Revolute(5) + Slider(5) = 20 + But many constraints share bodies, so effective rank < 20. + In 3D: 3 DOF (planar crank + 2 fold modes). + """ + + def test_slider_crank_dof(self): + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + crank = RigidBody("crank", pt, (1, 0, 0), ID_QUAT) + rod = RigidBody("rod", pt, (3, 0, 0), ID_QUAT) + piston = RigidBody("piston", pt, (5, 0, 0), ID_QUAT) + + constraints = [ + RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT), + RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT), + RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT), + SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y), + ] + + bodies = [ground, crank, rod, piston] + dof = _dof(pt, bodies, constraints) + # 3D slider-crank: planar motion + out-of-plane fold modes + assert dof == 3 + + def test_slider_crank_solves(self): + """Slider-crank converges from displaced state.""" + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + crank = RigidBody("crank", pt, (1, 0.5, 0), ID_QUAT) + rod = RigidBody("rod", pt, (3, 1, 0), ID_QUAT) + piston = RigidBody("piston", pt, (5, 0.5, 0), ID_QUAT) + + constraints = [ + RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT), + RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT), + RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT), + SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y), + ] + + bodies = [ground, crank, rod, piston] + converged, residuals = _solve_with_pt(pt, bodies, constraints) + assert converged + + env = pt.get_env() + for r in residuals: + assert abs(r.eval(env)) < 1e-8 + + +class TestRevoluteChain: + """Chain of revolute joints: ground → body1 → body2. + + Each revolute removes 5 DOF. Two free bodies = 12 DOF. + 2 revolutes = 10 constraints + 2 quat norms = 12. + Expected: 2 DOF (one rotation per hinge). + """ + + def test_chain_dof(self): + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + b1 = RigidBody("b1", pt, (3, 0, 0), ID_QUAT) + b2 = RigidBody("b2", pt, (6, 0, 0), ID_QUAT) + + constraints = [ + RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT), + RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT), + ] + + assert _dof(pt, [ground, b1, b2], constraints) == 2 + + def test_chain_solves(self): + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + b1 = RigidBody("b1", pt, (3, 2, 0), ID_QUAT) + b2 = RigidBody("b2", pt, (6, 3, 0), ID_QUAT) + + constraints = [ + RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT), + RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT), + ] + + converged, residuals = _solve_with_pt(pt, [ground, b1, b2], constraints) + assert converged + + env = pt.get_env() + # b1 origin at ground hinge point (0,0,0) + pos1 = b1.extract_position(env) + assert abs(pos1[0]) < 1e-8 + assert abs(pos1[1]) < 1e-8 + assert abs(pos1[2]) < 1e-8 + + +class TestSliderOnRail: + """Slider constraint: body translates along ground Z-axis only. + + 1 free body, 1 slider = 6 - 5 = 1 DOF. + """ + + def test_slider_on_rail(self): + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + block = RigidBody("block", pt, (3, 4, 5), ID_QUAT) + + constraints = [ + SliderConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT) + ] + + converged, _ = _solve_with_pt(pt, [ground, block], constraints) + assert converged + + env = pt.get_env() + pos = block.extract_position(env) + # X, Y must be zero; Z is free + assert abs(pos[0]) < 1e-8 + assert abs(pos[1]) < 1e-8 + # Z should remain near initial value (minimum-norm solution) + + # Check orientation unchanged (no twist) + quat = block.extract_quaternion(env) + assert abs(quat[0] - 1.0) < 1e-6 + assert abs(quat[1]) < 1e-6 + assert abs(quat[2]) < 1e-6 + assert abs(quat[3]) < 1e-6 + + +class TestPlanarOnTable: + """Planar constraint: body slides on XY plane. + + 1 free body, 1 planar = 6 - 3 = 3 DOF. + """ + + def test_planar_on_table(self): + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + block = RigidBody("block", pt, (3, 4, 5), ID_QUAT) + + constraints = [ + PlanarConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT) + ] + + converged, _ = _solve_with_pt(pt, [ground, block], constraints) + assert converged + + env = pt.get_env() + pos = block.extract_position(env) + # Z must be zero, X and Y are free + assert abs(pos[2]) < 1e-8 + + +class TestPlanarWithOffset: + """Planar with offset: body floats at z=3 above ground.""" + + def test_planar_offset(self): + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + block = RigidBody("block", pt, (1, 2, 5), ID_QUAT) + + # PlanarConstraint residual: (p_i - p_j) . z_j - offset = 0 + # body_i=block, body_j=ground: (block_z - 0) * 1 - offset = 0 + # For block at z=3: offset = 3 + constraints = [ + PlanarConstraint( + block, (0, 0, 0), ID_QUAT, ground, (0, 0, 0), ID_QUAT, offset=3.0 + ) + ] + + converged, _ = _solve_with_pt(pt, [ground, block], constraints) + assert converged + + env = pt.get_env() + pos = block.extract_position(env) + assert abs(pos[2] - 3.0) < 1e-8 + + +class TestMixedConstraints: + """System with mixed constraint types.""" + + def test_revolute_plus_parallel(self): + """Two free bodies: revolute between ground and b1, parallel between b1 and b2. + + b1: 6 DOF - 5 (revolute) = 1 DOF + b2: 6 DOF - 2 (parallel) = 4 DOF + Total: 5 DOF + """ + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + b1 = RigidBody("b1", pt, (0, 0, 0), ID_QUAT) + b2 = RigidBody("b2", pt, (5, 0, 0), ID_QUAT) + + constraints = [ + RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT), + ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT), + ] + + assert _dof(pt, [ground, b1, b2], constraints) == 5 + + def test_coincident_plus_perpendicular(self): + """Coincident + perpendicular = ball + 1 angle constraint. + + 6 - 3 (coincident) - 1 (perpendicular) = 2 DOF. + """ + pt = ParamTable() + ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True) + b = RigidBody("b", pt, (0, 0, 0), ROT_90X) + + constraints = [ + CoincidentConstraint(ground, (0, 0, 0), b, (0, 0, 0)), + PerpendicularConstraint(ground, ID_QUAT, b, ID_QUAT), + ] + + assert _dof(pt, [ground, b], constraints) == 2