2 Commits

Author SHA1 Message Date
forbes-0023
92ae57751f feat(solver): graph decomposition for cluster-by-cluster solving (phase 3)
Add a Python decomposition layer using NetworkX that partitions the
constraint graph into biconnected components (rigid clusters), orders
them via a block-cut tree, and solves each cluster independently.
Articulation-point bodies propagate as boundary conditions between
clusters.

New module kindred_solver/decompose.py:
- DOF table mapping BaseJointKind to residual counts
- Constraint graph construction (nx.MultiGraph)
- Biconnected component detection + articulation points
- Block-cut tree solve ordering (root-first from grounded cluster)
- Cluster-by-cluster solver with boundary body fix/unfix cycling
- Pebble game integration for per-cluster rigidity classification

Changes to existing modules:
- params.py: add unfix() for boundary body cycling
- solver.py: extract _monolithic_solve(), add decomposition branch
  for assemblies with >= 8 free bodies

Performance: for k clusters of ~n/k params each, total cost drops
from O(n^3) to O(n^3/k^2).

220 tests passing (up from 207).
2026-02-20 22:19:35 -06:00
forbes-0023
533ca91774 feat(solver): full constraint vocabulary — all 24 BaseJointKind types (phase 2)
Add 18 new constraint classes covering all BaseJointKind types from Types.h:
- Point: PointOnLine (2r), PointInPlane (1r)
- Orientation: Parallel (2r), Perpendicular (1r), Angle (1r)
- Surface: Concentric (4r), Tangent (1r), Planar (3r), LineInPlane (2r)
- Kinematic: Ball (3r), Revolute (5r), Cylindrical (4r), Slider (5r),
  Screw (5r), Universal (4r)
- Mechanical: Gear (1r), RackPinion (1r)
- Stubs: Cam, Slot, DistanceCylSph

New modules:
- geometry.py: marker axis extraction, vector ops (dot3, cross3, sub3),
  geometric primitives (point_plane_distance, point_line_perp_components)
- bfgs.py: L-BFGS-B fallback solver via scipy for when Newton fails

solver.py changes:
- Wire all 20 supported types in _build_constraint()
- BFGS fallback after Newton-Raphson in solve()

183 tests passing (up from 82), including:
- DOF counting for every joint type
- Solve convergence from displaced initial conditions
- Multi-body mechanisms (four-bar linkage, slider-crank, revolute chain)
2026-02-20 21:15:15 -06:00
12 changed files with 4351 additions and 14 deletions

127
kindred_solver/bfgs.py Normal file
View File

@@ -0,0 +1,127 @@
"""L-BFGS-B fallback solver for when Newton-Raphson fails to converge.
Minimizes f(x) = 0.5 * sum(r_i(x)^2) using scipy's L-BFGS-B with
analytic gradient from the Expr DAG's symbolic differentiation.
"""
from __future__ import annotations
import math
from typing import List
import numpy as np
from .expr import Expr
from .params import ParamTable
try:
from scipy.optimize import minimize as _scipy_minimize
_HAS_SCIPY = True
except ImportError:
_HAS_SCIPY = False
def bfgs_solve(
residuals: List[Expr],
params: ParamTable,
quat_groups: List[tuple[str, str, str, str]] | None = None,
max_iter: int = 200,
tol: float = 1e-10,
) -> bool:
"""Solve ``residuals == 0`` by minimizing sum of squared residuals.
Falls back gracefully to False if scipy is not available.
Returns True if converged (||r|| < tol).
"""
if not _HAS_SCIPY:
return False
free = params.free_names()
n_free = len(free)
n_res = len(residuals)
if n_free == 0 or n_res == 0:
return True
# Build symbolic gradient expressions once: d(r_i)/d(x_j)
jac_exprs: List[List[Expr]] = []
for r in residuals:
row = []
for name in free:
row.append(r.diff(name).simplify())
jac_exprs.append(row)
def objective_and_grad(x_vec):
# Update params
params.set_free_vector(x_vec)
if quat_groups:
_renormalize_quats(params, quat_groups)
env = params.get_env()
# Evaluate residuals
r_vals = np.array([r.eval(env) for r in residuals])
f = 0.5 * np.dot(r_vals, r_vals)
# Evaluate Jacobian
J = np.empty((n_res, n_free))
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
# Gradient of f = sum(r_i * dr_i/dx_j) = J^T @ r
grad = J.T @ r_vals
return f, grad
x0 = params.get_free_vector().copy()
result = _scipy_minimize(
objective_and_grad,
x0,
method="L-BFGS-B",
jac=True,
options={"maxiter": max_iter, "ftol": tol * tol, "gtol": tol},
)
# Apply final result
params.set_free_vector(result.x)
if quat_groups:
_renormalize_quats(params, quat_groups)
# Check convergence on actual residual norm
env = params.get_env()
r_vals = np.array([r.eval(env) for r in residuals])
return bool(np.linalg.norm(r_vals) < tol)
def _renormalize_quats(
params: ParamTable,
groups: List[tuple[str, str, str, str]],
):
"""Project quaternion params back onto the unit sphere."""
for qw_name, qx_name, qy_name, qz_name in groups:
if (
params.is_fixed(qw_name)
and params.is_fixed(qx_name)
and params.is_fixed(qy_name)
and params.is_fixed(qz_name)
):
continue
w = params.get_value(qw_name)
x = params.get_value(qx_name)
y = params.get_value(qy_name)
z = params.get_value(qz_name)
norm = math.sqrt(w * w + x * x + y * y + z * z)
if norm < 1e-15:
params.set_value(qw_name, 1.0)
params.set_value(qx_name, 0.0)
params.set_value(qy_name, 0.0)
params.set_value(qz_name, 0.0)
else:
params.set_value(qw_name, w / norm)
params.set_value(qx_name, x / norm)
params.set_value(qy_name, y / norm)
params.set_value(qz_name, z / norm)

View File

@@ -2,14 +2,28 @@
Each constraint takes two RigidBody entities and marker transforms,
then generates residual expressions that equal zero when satisfied.
Phase 1 constraints: Coincident, DistancePointPoint, Fixed
Phase 2 constraints: all remaining BaseJointKind types from Types.h
"""
from __future__ import annotations
import math
from typing import List
from .entities import RigidBody
from .expr import Const, Expr
from .geometry import (
cross3,
dot3,
marker_x_axis,
marker_y_axis,
marker_z_axis,
point_line_perp_components,
point_plane_distance,
sub3,
)
class ConstraintBase:
@@ -145,6 +159,703 @@ class FixedConstraint(ConstraintBase):
return pos_res + ori_res
# ============================================================================
# Phase 2: Point constraints
# ============================================================================
class PointOnLineConstraint(ConstraintBase):
"""Point constrained to a line — 2 DOF removed.
marker_i origin lies on the line through marker_j origin along
marker_j Z-axis. 2 residuals: perpendicular distance components.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy = point_line_perp_components(p_i, p_j, z_j)
return [cx, cy]
class PointInPlaneConstraint(ConstraintBase):
"""Point constrained to a plane — 1 DOF removed.
marker_i origin lies in the plane through marker_j origin with
normal = marker_j Z-axis. Optional offset via params[0].
1 residual: signed distance to plane.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
offset: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.offset = offset
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
d = point_plane_distance(p_i, p_j, n_j)
if self.offset != 0.0:
d = d - Const(self.offset)
return [d]
# ============================================================================
# Phase 2: Axis orientation constraints
# ============================================================================
class ParallelConstraint(ConstraintBase):
"""Parallel axes — 2 DOF removed.
marker Z-axes are parallel: z_i x z_j = 0.
2 residuals from the cross product (only 2 of 3 components are
independent for unit vectors).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, cz = cross3(z_i, z_j)
return [cx, cy]
class PerpendicularConstraint(ConstraintBase):
"""Perpendicular axes — 1 DOF removed.
marker Z-axes are perpendicular: z_i . z_j = 0.
1 residual.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
return [dot3(z_i, z_j)]
class AngleConstraint(ConstraintBase):
"""Angle between axes — 1 DOF removed.
z_i . z_j = cos(angle).
1 residual.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
angle: float,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
self.angle = angle
def residuals(self) -> List[Expr]:
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
return [dot3(z_i, z_j) - Const(math.cos(self.angle))]
# ============================================================================
# Phase 2: Axis/surface constraints
# ============================================================================
class ConcentricConstraint(ConstraintBase):
"""Coaxial / concentric — 4 DOF removed.
Axes are collinear: parallel Z-axes (2) + point-on-line (2).
Optional distance offset along axis via params.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
distance: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.distance = distance
def residuals(self) -> List[Expr]:
# Parallel axes (2 residuals)
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-on-line: marker_i origin on line through marker_j along z_j
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
return [cx, cy, lx, ly]
class TangentConstraint(ConstraintBase):
"""Face-on-face tangency — 1 DOF removed.
Signed distance between marker origins along marker_j normal = 0.
1 residual: (p_i - p_j) . z_j
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
return [point_plane_distance(p_i, p_j, n_j)]
class PlanarConstraint(ConstraintBase):
"""Coplanar faces — 3 DOF removed.
Parallel normals (2) + point-in-plane (1). Optional offset.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
offset: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.offset = offset
def residuals(self) -> List[Expr]:
# Parallel normals
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-in-plane
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
d = point_plane_distance(p_i, p_j, z_j)
if self.offset != 0.0:
d = d - Const(self.offset)
return [cx, cy, d]
class LineInPlaneConstraint(ConstraintBase):
"""Line constrained to a plane — 2 DOF removed.
Line defined by marker_i Z-axis lies in plane defined by marker_j normal.
2 residuals: point-in-plane (1) + line direction perpendicular to normal (1).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
offset: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.offset = offset
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
# Point in plane
d = point_plane_distance(p_i, p_j, n_j)
if self.offset != 0.0:
d = d - Const(self.offset)
# Line direction perpendicular to plane normal
dir_dot = dot3(z_i, n_j)
return [d, dir_dot]
# ============================================================================
# Phase 2: Kinematic joints
# ============================================================================
class BallConstraint(ConstraintBase):
"""Spherical joint — 3 DOF removed.
Coincident marker origins. Same as CoincidentConstraint.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
):
self._inner = CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
def residuals(self) -> List[Expr]:
return self._inner.residuals()
class RevoluteConstraint(ConstraintBase):
"""Hinge joint — 5 DOF removed.
Coincident origins (3) + parallel Z-axes (2).
1 rotational DOF remains (about the common Z-axis).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Coincident origins
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]]
# Parallel Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
return pos + [cx, cy]
class CylindricalConstraint(ConstraintBase):
"""Cylindrical joint — 4 DOF removed.
Parallel Z-axes (2) + point-on-line (2).
2 DOF remain: rotation about and translation along the common axis.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Parallel Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-on-line
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
return [cx, cy, lx, ly]
class SliderConstraint(ConstraintBase):
"""Prismatic / slider joint — 5 DOF removed.
Parallel Z-axes (2) + point-on-line (2) + rotation lock (1).
1 DOF remains: translation along the common Z-axis.
Rotation lock: x_i . y_j = 0 (prevents twist about Z).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Parallel Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-on-line
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
# Rotation lock: x_i . y_j = 0
x_i = marker_x_axis(self.body_i, self.marker_i_quat)
y_j = marker_y_axis(self.body_j, self.marker_j_quat)
twist = dot3(x_i, y_j)
return [cx, cy, lx, ly, twist]
class ScrewConstraint(ConstraintBase):
"""Helical / screw joint — 5 DOF removed.
Cylindrical (4) + coupled rotation-translation via pitch (1).
1 DOF remains: screw motion (rotation + proportional translation).
The coupling residual uses the relative quaternion's Z-component
(proportional to the rotation angle for small angles) and the axial
displacement: axial_disp - pitch * (2 * qz_rel / qw_rel) / (2*pi) = 0.
For the Newton solver operating near the solution, the linear
approximation angle ≈ 2 * qz_rel is adequate.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
pitch: float = 1.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.pitch = pitch
def residuals(self) -> List[Expr]:
# Cylindrical residuals (4)
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
# Pitch coupling: axial_disp = pitch * angle / (2*pi)
# Axial displacement
d = sub3(p_i, p_j)
axial = dot3(d, z_j)
# Relative rotation about Z via quaternion
# q_rel = conj(q_i_total) * q_j_total
qi = _quat_mul_const(
self.body_i.qw,
self.body_i.qx,
self.body_i.qy,
self.body_i.qz,
*self.marker_i_quat,
)
qj = _quat_mul_const(
self.body_j.qw,
self.body_j.qx,
self.body_j.qy,
self.body_j.qz,
*self.marker_j_quat,
)
rel = _quat_mul_expr(qi[0], -qi[1], -qi[2], -qi[3], qj[0], qj[1], qj[2], qj[3])
# For small angles: angle ≈ 2 * qz_rel, but qw_rel ≈ 1
# Use sin(angle/2) form: residual = axial - pitch * 2*qz / (2*pi)
# = axial - pitch * qz / pi
coupling = axial - Const(self.pitch / math.pi) * rel[3]
return [cx, cy, lx, ly, coupling]
class UniversalConstraint(ConstraintBase):
"""Universal / Cardan joint — 4 DOF removed.
Coincident origins (3) + perpendicular Z-axes (1).
2 DOF remain: rotation about each body's Z-axis.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Coincident origins
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]]
# Perpendicular Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
return pos + [dot3(z_i, z_j)]
# ============================================================================
# Phase 2: Mechanical element constraints
# ============================================================================
class GearConstraint(ConstraintBase):
"""Gear pair or belt — 1 DOF removed.
Couples rotation angles: r_i * theta_i + r_j * theta_j = 0.
For belts (same-direction rotation), r_j is passed as negative.
Uses the Z-component of the relative quaternion as a proxy for
rotation angle (linear for small angles, which is the regime
where Newton operates).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
radius_i: float,
radius_j: float,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
self.radius_i = radius_i
self.radius_j = radius_j
def residuals(self) -> List[Expr]:
# Rotation angle proxy via relative quaternion Z-component
# For body_i: q_rel_i = conj(q_marker_i) * q_body_i * q_marker_i
# Simplified: use 2*qz of (conj(marker) * body * marker) as angle proxy
qz_i = _rotation_z_component(self.body_i, self.marker_i_quat)
qz_j = _rotation_z_component(self.body_j, self.marker_j_quat)
# r_i * theta_i + r_j * theta_j = 0
# Using qz as proportional to theta/2:
# r_i * qz_i + r_j * qz_j = 0
return [Const(self.radius_i) * qz_i + Const(self.radius_j) * qz_j]
class RackPinionConstraint(ConstraintBase):
"""Rack-and-pinion — 1 DOF removed.
Couples rotation of body_i to translation of body_j along marker_j Z-axis.
translation = pitch_radius * theta
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
pitch_radius: float,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.pitch_radius = pitch_radius
def residuals(self) -> List[Expr]:
# Translation of j along its Z-axis
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
d = sub3(p_j, p_i)
translation = dot3(d, z_j)
# Rotation angle of i about its Z-axis
qz_i = _rotation_z_component(self.body_i, self.marker_i_quat)
# translation - pitch_radius * angle = 0
# angle ≈ 2 * qz, so: translation - pitch_radius * 2 * qz = 0
return [translation - Const(2.0 * self.pitch_radius) * qz_i]
class CamConstraint(ConstraintBase):
"""Cam-follower constraint — future, stub."""
def residuals(self) -> List[Expr]:
return []
class SlotConstraint(ConstraintBase):
"""Slot constraint — future, stub."""
def residuals(self) -> List[Expr]:
return []
class DistanceCylSphConstraint(ConstraintBase):
"""Cylinder-sphere distance — stub.
Semantics depend on geometry classification; placeholder for now.
"""
def residuals(self) -> List[Expr]:
return []
# -- rotation helpers for mechanical constraints ------------------------------
def _rotation_z_component(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Expr:
"""Extract the Z-component of the relative quaternion about a marker axis.
Returns the qz component of conj(q_marker) * q_body * q_marker,
which is proportional to sin(theta/2) where theta is the rotation
angle about the marker Z-axis.
"""
mw, mx, my, mz = marker_quat
# q_local = conj(marker) * q_body * marker
# Step 1: temp = conj(marker) * q_body
cmw, cmx, cmy, cmz = Const(mw), Const(-mx), Const(-my), Const(-mz)
# temp = conj(marker) * q_body
tw = cmw * body.qw - cmx * body.qx - cmy * body.qy - cmz * body.qz
tx = cmw * body.qx + cmx * body.qw + cmy * body.qz - cmz * body.qy
ty = cmw * body.qy - cmx * body.qz + cmy * body.qw + cmz * body.qx
tz = cmw * body.qz + cmx * body.qy - cmy * body.qx + cmz * body.qw
# q_local = temp * marker
mmw, mmx, mmy, mmz = Const(mw), Const(mx), Const(my), Const(mz)
# rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw
rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw
return rz
# -- quaternion multiplication helpers ----------------------------------------

661
kindred_solver/decompose.py Normal file
View File

@@ -0,0 +1,661 @@
"""Graph decomposition for cluster-by-cluster constraint solving.
Builds a constraint graph from the SolveContext, decomposes it into
biconnected components (rigid clusters), orders them via a block-cut
tree, and solves each cluster independently. Articulation-point bodies
are temporarily fixed when solving adjacent clusters so their solved
values propagate as boundary conditions.
Requires: networkx
"""
from __future__ import annotations
import importlib.util
import logging
import sys
import types as stdlib_types
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, List
import networkx as nx
from .bfgs import bfgs_solve
from .newton import newton_solve
from .prepass import substitution_pass
if TYPE_CHECKING:
from .constraints import ConstraintBase
from .entities import RigidBody
from .params import ParamTable
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DOF table: BaseJointKind → number of residuals (= DOF removed)
# ---------------------------------------------------------------------------
# Imported lazily to avoid hard kcsolve dependency in tests.
# Use residual_count() accessor instead of this dict directly.
_RESIDUAL_COUNT: dict[str, int] | None = None
def _ensure_residual_count() -> dict:
"""Build the residual count table on first use."""
global _RESIDUAL_COUNT
if _RESIDUAL_COUNT is not None:
return _RESIDUAL_COUNT
import kcsolve
_RESIDUAL_COUNT = {
kcsolve.BaseJointKind.Fixed: 6,
kcsolve.BaseJointKind.Coincident: 3,
kcsolve.BaseJointKind.Ball: 3,
kcsolve.BaseJointKind.Revolute: 5,
kcsolve.BaseJointKind.Cylindrical: 4,
kcsolve.BaseJointKind.Slider: 5,
kcsolve.BaseJointKind.Screw: 5,
kcsolve.BaseJointKind.Universal: 4,
kcsolve.BaseJointKind.Parallel: 2,
kcsolve.BaseJointKind.Perpendicular: 1,
kcsolve.BaseJointKind.Angle: 1,
kcsolve.BaseJointKind.Concentric: 4,
kcsolve.BaseJointKind.Tangent: 1,
kcsolve.BaseJointKind.Planar: 3,
kcsolve.BaseJointKind.LineInPlane: 2,
kcsolve.BaseJointKind.PointOnLine: 2,
kcsolve.BaseJointKind.PointInPlane: 1,
kcsolve.BaseJointKind.DistancePointPoint: 1,
kcsolve.BaseJointKind.Gear: 1,
kcsolve.BaseJointKind.RackPinion: 1,
kcsolve.BaseJointKind.Cam: 0,
kcsolve.BaseJointKind.Slot: 0,
kcsolve.BaseJointKind.DistanceCylSph: 0,
}
return _RESIDUAL_COUNT
def residual_count(kind) -> int:
"""Number of residuals a constraint type produces."""
return _ensure_residual_count().get(kind, 0)
# ---------------------------------------------------------------------------
# Standalone residual-count table (no kcsolve dependency, string-keyed)
# Used by tests that don't have kcsolve available.
# ---------------------------------------------------------------------------
_RESIDUAL_COUNT_BY_NAME: dict[str, int] = {
"Fixed": 6,
"Coincident": 3,
"Ball": 3,
"Revolute": 5,
"Cylindrical": 4,
"Slider": 5,
"Screw": 5,
"Universal": 4,
"Parallel": 2,
"Perpendicular": 1,
"Angle": 1,
"Concentric": 4,
"Tangent": 1,
"Planar": 3,
"LineInPlane": 2,
"PointOnLine": 2,
"PointInPlane": 1,
"DistancePointPoint": 1,
"Gear": 1,
"RackPinion": 1,
"Cam": 0,
"Slot": 0,
"DistanceCylSph": 0,
}
def residual_count_by_name(kind_name: str) -> int:
"""Number of residuals by constraint type name (no kcsolve needed)."""
return _RESIDUAL_COUNT_BY_NAME.get(kind_name, 0)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class SolveCluster:
"""A cluster of bodies to solve together."""
bodies: set[str] # Body IDs in this cluster
constraint_indices: list[int] # Indices into the constraint list
boundary_bodies: set[str] # Articulation points shared with other clusters
has_ground: bool # Whether any body in the cluster is grounded
# ---------------------------------------------------------------------------
# Graph construction
# ---------------------------------------------------------------------------
def build_constraint_graph(
constraints: list,
grounded_bodies: set[str],
) -> nx.MultiGraph:
"""Build a body-level constraint multigraph.
Nodes: part_id strings (one per body referenced by constraints).
Edges: one per active constraint with attributes:
- constraint_index: position in the constraints list
- weight: number of residuals
Grounded bodies are tagged with ``grounded=True``.
Constraints with 0 residuals (stubs) are excluded.
"""
G = nx.MultiGraph()
for idx, c in enumerate(constraints):
if not c.activated:
continue
weight = residual_count(c.type)
if weight == 0:
continue
part_i = c.part_i
part_j = c.part_j
# Ensure nodes exist
if part_i not in G:
G.add_node(part_i, grounded=(part_i in grounded_bodies))
if part_j not in G:
G.add_node(part_j, grounded=(part_j in grounded_bodies))
# Store kind_name for pebble game integration
kind_name = c.type.name if hasattr(c.type, "name") else str(c.type)
G.add_edge(
part_i, part_j, constraint_index=idx, weight=weight, kind_name=kind_name
)
return G
def build_constraint_graph_simple(
edges: list[tuple[str, str, str, int]],
grounded: set[str] | None = None,
) -> nx.MultiGraph:
"""Build a constraint graph from simple edge tuples (for testing).
Each edge is ``(body_i, body_j, kind_name, constraint_index)``.
"""
grounded = grounded or set()
G = nx.MultiGraph()
for body_i, body_j, kind_name, idx in edges:
weight = residual_count_by_name(kind_name)
if weight == 0:
continue
if body_i not in G:
G.add_node(body_i, grounded=(body_i in grounded))
if body_j not in G:
G.add_node(body_j, grounded=(body_j in grounded))
G.add_edge(
body_i, body_j, constraint_index=idx, weight=weight, kind_name=kind_name
)
return G
# ---------------------------------------------------------------------------
# Decomposition
# ---------------------------------------------------------------------------
def find_clusters(
G: nx.MultiGraph,
) -> tuple[list[set[str]], set[str]]:
"""Find biconnected components and articulation points.
Returns:
clusters: list of body-ID sets (one per biconnected component)
articulation_points: body-IDs shared between clusters
"""
# biconnected_components requires a simple Graph
simple = nx.Graph(G)
clusters = [set(c) for c in nx.biconnected_components(simple)]
artic = set(nx.articulation_points(simple))
return clusters, artic
def build_solve_order(
G: nx.MultiGraph,
clusters: list[set[str]],
articulation_points: set[str],
grounded_bodies: set[str],
) -> list[SolveCluster]:
"""Order clusters for solving via the block-cut tree.
Builds the block-cut tree (bipartite graph of clusters and
articulation points), roots it at a grounded cluster, and returns
clusters in root-to-leaf order (grounded first, outward to leaves).
This ensures boundary bodies are solved before clusters that
depend on them.
"""
if not clusters:
return []
# Single cluster — no ordering needed
if len(clusters) == 1:
bodies = clusters[0]
indices = _constraints_for_bodies(G, bodies)
has_ground = bool(bodies & grounded_bodies)
return [
SolveCluster(
bodies=bodies,
constraint_indices=indices,
boundary_bodies=set(),
has_ground=has_ground,
)
]
# Build block-cut tree
# Nodes: ("C", i) for cluster i, ("A", body_id) for articulation points
bct = nx.Graph()
for i, cluster in enumerate(clusters):
bct.add_node(("C", i))
for ap in articulation_points:
if ap in cluster:
bct.add_edge(("C", i), ("A", ap))
# Find root: prefer a cluster containing a grounded body
root = ("C", 0)
for i, cluster in enumerate(clusters):
if cluster & grounded_bodies:
root = ("C", i)
break
# BFS from root: grounded cluster first, outward to leaves
visited = set()
order = []
queue = deque([root])
visited.add(root)
while queue:
node = queue.popleft()
if node[0] == "C":
order.append(node[1])
for neighbor in bct.neighbors(node):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# Build SolveCluster objects
solve_clusters = []
for i in order:
bodies = clusters[i]
indices = _constraints_for_bodies(G, bodies)
boundary = bodies & articulation_points
has_ground = bool(bodies & grounded_bodies)
solve_clusters.append(
SolveCluster(
bodies=bodies,
constraint_indices=indices,
boundary_bodies=boundary,
has_ground=has_ground,
)
)
return solve_clusters
def _constraints_for_bodies(G: nx.MultiGraph, bodies: set[str]) -> list[int]:
"""Collect constraint indices for edges where both endpoints are in bodies."""
indices = []
seen = set()
for u, v, data in G.edges(data=True):
idx = data["constraint_index"]
if idx in seen:
continue
if u in bodies and v in bodies:
seen.add(idx)
indices.append(idx)
return sorted(indices)
# ---------------------------------------------------------------------------
# Top-level decompose entry point
# ---------------------------------------------------------------------------
def decompose(
constraints: list,
grounded_bodies: set[str],
) -> list[SolveCluster]:
"""Full decomposition pipeline: graph → clusters → solve order.
Returns a list of SolveCluster in solve order (leaves first).
If the system is a single cluster, returns a 1-element list.
"""
G = build_constraint_graph(constraints, grounded_bodies)
# Handle disconnected sub-assemblies
all_clusters = []
for component_nodes in nx.connected_components(G):
sub = G.subgraph(component_nodes).copy()
clusters, artic = find_clusters(sub)
if len(clusters) <= 1:
# Single cluster in this component
bodies = component_nodes if not clusters else clusters[0]
indices = _constraints_for_bodies(sub, bodies)
has_ground = bool(bodies & grounded_bodies)
all_clusters.append(
SolveCluster(
bodies=set(bodies),
constraint_indices=indices,
boundary_bodies=set(),
has_ground=has_ground,
)
)
else:
ordered = build_solve_order(sub, clusters, artic, grounded_bodies)
all_clusters.extend(ordered)
return all_clusters
# ---------------------------------------------------------------------------
# Cluster solver
# ---------------------------------------------------------------------------
def solve_decomposed(
clusters: list[SolveCluster],
bodies: dict[str, "RigidBody"],
constraint_objs: list["ConstraintBase"],
constraint_indices_map: list[int],
params: "ParamTable",
) -> bool:
"""Solve clusters in order, fixing boundary bodies between solves.
Args:
clusters: SolveCluster list in solve order (from decompose()).
bodies: part_id → RigidBody mapping.
constraint_objs: constraint objects (parallel to constraint_indices_map).
constraint_indices_map: for each constraint_obj, its index in ctx.constraints.
params: shared ParamTable.
Returns True if all clusters converged.
"""
# Build reverse map: constraint_index → position in constraint_objs list
idx_to_obj: dict[int, "ConstraintBase"] = {}
for pos, ci in enumerate(constraint_indices_map):
idx_to_obj[ci] = constraint_objs[pos]
solved_bodies: set[str] = set()
all_converged = True
for cluster in clusters:
# 1. Fix boundary bodies that were already solved
fixed_boundary_params: list[str] = []
for body_id in cluster.boundary_bodies:
if body_id in solved_bodies:
body = bodies[body_id]
for pname in body._param_names:
if not params.is_fixed(pname):
params.fix(pname)
fixed_boundary_params.append(pname)
# 2. Collect residuals for this cluster
cluster_residuals = []
for ci in cluster.constraint_indices:
obj = idx_to_obj.get(ci)
if obj is not None:
cluster_residuals.extend(obj.residuals())
# 3. Add quat norm residuals for free, non-grounded bodies in this cluster
quat_groups = []
for body_id in cluster.bodies:
body = bodies[body_id]
if body.grounded:
continue
if body_id in cluster.boundary_bodies and body_id in solved_bodies:
continue # Already fixed as boundary
cluster_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
# 4. Substitution pass (compiles fixed boundary params to constants)
cluster_residuals = substitution_pass(cluster_residuals, params)
# 5. Newton solve (+ BFGS fallback)
if cluster_residuals:
converged = newton_solve(
cluster_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
if not converged:
converged = bfgs_solve(
cluster_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
)
if not converged:
all_converged = False
# 6. Mark this cluster's bodies as solved
solved_bodies.update(cluster.bodies)
# 7. Unfix boundary params
for pname in fixed_boundary_params:
params.unfix(pname)
return all_converged
# ---------------------------------------------------------------------------
# Pebble game integration (rigidity classification)
# ---------------------------------------------------------------------------
_PEBBLE_MODULES_LOADED = False
_PebbleGame3D = None
_PebbleJointType = None
_PebbleJoint = None
def _load_pebble_modules():
"""Lazily load PebbleGame3D and related types from GNN/solver/datagen/.
The GNN package has its own import structure (``from solver.datagen.types
import ...``) that conflicts with the top-level module layout, so we
register shim modules in ``sys.modules`` to make it work.
"""
global _PEBBLE_MODULES_LOADED, _PebbleGame3D, _PebbleJointType, _PebbleJoint
if _PEBBLE_MODULES_LOADED:
return
# Find GNN/solver/datagen relative to this package
pkg_dir = Path(__file__).resolve().parent.parent # mods/solver/
datagen_dir = pkg_dir / "GNN" / "solver" / "datagen"
if not datagen_dir.exists():
log.warning("GNN/solver/datagen/ not found; pebble game unavailable")
_PEBBLE_MODULES_LOADED = True
return
# Register shim modules so ``from solver.datagen.types import ...`` works
if "solver" not in sys.modules:
sys.modules["solver"] = stdlib_types.ModuleType("solver")
if "solver.datagen" not in sys.modules:
dg = stdlib_types.ModuleType("solver.datagen")
sys.modules["solver.datagen"] = dg
sys.modules["solver"].datagen = dg # type: ignore[attr-defined]
# Load types.py
types_path = datagen_dir / "types.py"
spec_t = importlib.util.spec_from_file_location(
"solver.datagen.types", str(types_path)
)
types_mod = importlib.util.module_from_spec(spec_t)
sys.modules["solver.datagen.types"] = types_mod
spec_t.loader.exec_module(types_mod)
# Load pebble_game.py
pg_path = datagen_dir / "pebble_game.py"
spec_p = importlib.util.spec_from_file_location(
"solver.datagen.pebble_game", str(pg_path)
)
pg_mod = importlib.util.module_from_spec(spec_p)
sys.modules["solver.datagen.pebble_game"] = pg_mod
spec_p.loader.exec_module(pg_mod)
_PebbleGame3D = pg_mod.PebbleGame3D
_PebbleJointType = types_mod.JointType
_PebbleJoint = types_mod.Joint
_PEBBLE_MODULES_LOADED = True
# BaseJointKind name → PebbleGame JointType name.
# Types not listed here use manual edge insertion with the residual count.
_KIND_NAME_TO_PEBBLE_NAME: dict[str, str] = {
"Fixed": "FIXED",
"Coincident": "BALL", # Same DOF count (3)
"Ball": "BALL",
"Revolute": "REVOLUTE",
"Cylindrical": "CYLINDRICAL",
"Slider": "SLIDER",
"Screw": "SCREW",
"Universal": "UNIVERSAL",
"Planar": "PLANAR",
"Perpendicular": "PERPENDICULAR",
"DistancePointPoint": "DISTANCE",
}
# Parallel: pebble game uses 3 DOF, but our solver uses 2.
# We handle it with manual edge insertion.
# Types that need manual edge insertion (no direct JointType mapping,
# or DOF mismatch like Parallel).
_MANUAL_EDGE_TYPES: set[str] = {
"Parallel", # 2 residuals, but JointType.PARALLEL = 3
"Angle", # 1 residual, no JointType
"Concentric", # 4 residuals, no JointType
"Tangent", # 1 residual, no JointType
"LineInPlane", # 2 residuals, no JointType
"PointOnLine", # 2 residuals, no JointType
"PointInPlane", # 1 residual, no JointType
"Gear", # 1 residual, no JointType
"RackPinion", # 1 residual, no JointType
}
_GROUND_BODY_ID = -1
def classify_cluster_rigidity(
cluster: SolveCluster,
constraint_graph: nx.MultiGraph,
grounded_bodies: set[str],
) -> str | None:
"""Run pebble game on a cluster and return rigidity classification.
Returns one of: "well-constrained", "underconstrained",
"overconstrained", "mixed", or None if pebble game unavailable.
"""
import numpy as np
_load_pebble_modules()
if _PebbleGame3D is None:
return None
pg = _PebbleGame3D()
# Map string body IDs → integer IDs for pebble game
body_list = sorted(cluster.bodies)
body_to_int: dict[str, int] = {b: i for i, b in enumerate(body_list)}
for b in body_list:
pg.add_body(body_to_int[b])
# Add virtual ground body if cluster has grounded bodies
has_ground = bool(cluster.bodies & grounded_bodies)
if has_ground:
pg.add_body(_GROUND_BODY_ID)
for b in cluster.bodies & grounded_bodies:
ground_joint = _PebbleJoint(
joint_id=-1,
body_a=body_to_int[b],
body_b=_GROUND_BODY_ID,
joint_type=_PebbleJointType["FIXED"],
anchor_a=np.zeros(3),
anchor_b=np.zeros(3),
)
pg.add_joint(ground_joint)
# Add constraint edges
joint_counter = 0
zero = np.zeros(3)
for u, v, data in constraint_graph.edges(data=True):
if u not in cluster.bodies or v not in cluster.bodies:
continue
ci = data["constraint_index"]
if ci not in cluster.constraint_indices:
continue
# Determine the constraint kind name from the graph edge
kind_name = data.get("kind_name", "")
n_residuals = data.get("weight", 0)
if not kind_name or n_residuals == 0:
continue
int_u = body_to_int[u]
int_v = body_to_int[v]
pebble_name = _KIND_NAME_TO_PEBBLE_NAME.get(kind_name)
if pebble_name and kind_name not in _MANUAL_EDGE_TYPES:
# Direct JointType mapping
jt = _PebbleJointType[pebble_name]
joint = _PebbleJoint(
joint_id=joint_counter,
body_a=int_u,
body_b=int_v,
joint_type=jt,
anchor_a=zero,
anchor_b=zero,
)
pg.add_joint(joint)
joint_counter += 1
else:
# Manual edge insertion: one DISTANCE edge per residual
for _ in range(n_residuals):
joint = _PebbleJoint(
joint_id=joint_counter,
body_a=int_u,
body_b=int_v,
joint_type=_PebbleJointType["DISTANCE"],
anchor_a=zero,
anchor_b=zero,
)
pg.add_joint(joint)
joint_counter += 1
# Classify using raw pebble counts (adjusting for virtual ground)
total_dof = pg.get_dof()
redundant = pg.get_redundant_count()
# The virtual ground body contributes 6 pebbles that are never consumed.
# Subtract them to get the effective DOF.
if has_ground:
total_dof -= 6 # virtual ground's unconstrained pebbles
baseline = 0
else:
baseline = 6 # trivial rigid-body motion
if redundant > 0 and total_dof > baseline:
return "mixed"
elif redundant > 0:
return "overconstrained"
elif total_dof > baseline:
return "underconstrained"
elif total_dof == baseline:
return "well-constrained"
else:
return "overconstrained"

131
kindred_solver/geometry.py Normal file
View File

@@ -0,0 +1,131 @@
"""Geometric helper functions for constraint equations.
Provides Expr-level vector operations and marker axis extraction.
All functions work with Expr triples (tuples of 3 Expr nodes)
representing 3D vectors in world coordinates.
Marker convention (from Types.h): the marker's Z-axis defines the
constraint direction (hinge axis, face normal, line direction, etc.).
"""
from __future__ import annotations
from .entities import RigidBody
from .expr import Const, Expr
from .quat import quat_rotate
# Type alias for an Expr triple (3D vector)
Vec3 = tuple[Expr, Expr, Expr]
# -- Marker axis extraction ---------------------------------------------------
def _composed_quat(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> tuple[Expr, Expr, Expr, Expr]:
"""Compute q_total = q_body * q_marker as Expr quaternion.
q_body comes from the body's Var params; q_marker is constant.
"""
bw, bx, by, bz = body.qw, body.qx, body.qy, body.qz
mw, mx, my, mz = (Const(v) for v in marker_quat)
# Hamilton product: body * marker
rw = bw * mw - bx * mx - by * my - bz * mz
rx = bw * mx + bx * mw + by * mz - bz * my
ry = bw * my - bx * mz + by * mw + bz * mx
rz = bw * mz + bx * my - by * mx + bz * mw
return rw, rx, ry, rz
def marker_z_axis(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Vec3:
"""World-frame Z-axis of a marker on a body.
Computes rotate(q_body * q_marker, [0, 0, 1]).
"""
qw, qx, qy, qz = _composed_quat(body, marker_quat)
return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(0.0), Const(1.0))
def marker_x_axis(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Vec3:
"""World-frame X-axis of a marker on a body.
Computes rotate(q_body * q_marker, [1, 0, 0]).
"""
qw, qx, qy, qz = _composed_quat(body, marker_quat)
return quat_rotate(qw, qx, qy, qz, Const(1.0), Const(0.0), Const(0.0))
def marker_y_axis(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Vec3:
"""World-frame Y-axis of a marker on a body.
Computes rotate(q_body * q_marker, [0, 1, 0]).
"""
qw, qx, qy, qz = _composed_quat(body, marker_quat)
return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(1.0), Const(0.0))
# -- Vector operations on Expr triples ----------------------------------------
def dot3(a: Vec3, b: Vec3) -> Expr:
"""Dot product of two Expr triples."""
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
def cross3(a: Vec3, b: Vec3) -> Vec3:
"""Cross product of two Expr triples."""
return (
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
)
def sub3(a: Vec3, b: Vec3) -> Vec3:
"""Vector subtraction a - b."""
return (a[0] - b[0], a[1] - b[1], a[2] - b[2])
# -- Geometric primitives -----------------------------------------------------
def point_plane_distance(
point: Vec3,
plane_origin: Vec3,
normal: Vec3,
) -> Expr:
"""Signed distance from point to plane defined by origin + normal.
Returns (point - plane_origin) . normal
"""
d = sub3(point, plane_origin)
return dot3(d, normal)
def point_line_perp_components(
point: Vec3,
line_origin: Vec3,
line_dir: Vec3,
) -> tuple[Expr, Expr]:
"""Two independent perpendicular-distance components from point to line.
The line passes through line_origin along line_dir.
Returns the x and y components of (point - line_origin) x line_dir,
which are zero when the point lies on the line.
"""
d = sub3(point, line_origin)
cx, cy, cz = cross3(d, line_dir)
# All three components of d x line_dir are zero when d is parallel
# to line_dir, but only 2 are independent. We return x and y.
return cx, cy

View File

@@ -49,6 +49,13 @@ class ParamTable:
if name in self._free_order:
self._free_order.remove(name)
def unfix(self, name: str):
"""Restore a fixed parameter to free status."""
if name in self._fixed:
self._fixed.discard(name)
if name not in self._free_order:
self._free_order.append(name)
def get_env(self) -> Dict[str, float]:
"""Return a snapshot of all current values (for Expr.eval)."""
return dict(self._values)

View File

@@ -5,23 +5,71 @@ from __future__ import annotations
import kcsolve
from .bfgs import bfgs_solve
from .constraints import (
AngleConstraint,
BallConstraint,
CamConstraint,
CoincidentConstraint,
ConcentricConstraint,
ConstraintBase,
CylindricalConstraint,
DistanceCylSphConstraint,
DistancePointPointConstraint,
FixedConstraint,
GearConstraint,
LineInPlaneConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
PointOnLineConstraint,
RackPinionConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
SlotConstraint,
TangentConstraint,
UniversalConstraint,
)
from .decompose import decompose, solve_decomposed
from .dof import count_dof
from .entities import RigidBody
from .newton import newton_solve
from .params import ParamTable
from .prepass import single_equation_pass, substitution_pass
# Map BaseJointKind enum values to handler names
# Assemblies with fewer free bodies than this use the monolithic path.
_DECOMPOSE_THRESHOLD = 8
# All BaseJointKind values this solver can handle
_SUPPORTED = {
# Phase 1
kcsolve.BaseJointKind.Coincident,
kcsolve.BaseJointKind.DistancePointPoint,
kcsolve.BaseJointKind.Fixed,
# Phase 2: point constraints
kcsolve.BaseJointKind.PointOnLine,
kcsolve.BaseJointKind.PointInPlane,
# Phase 2: orientation
kcsolve.BaseJointKind.Parallel,
kcsolve.BaseJointKind.Perpendicular,
kcsolve.BaseJointKind.Angle,
# Phase 2: axis/surface
kcsolve.BaseJointKind.Concentric,
kcsolve.BaseJointKind.Tangent,
kcsolve.BaseJointKind.Planar,
kcsolve.BaseJointKind.LineInPlane,
# Phase 2: kinematic joints
kcsolve.BaseJointKind.Ball,
kcsolve.BaseJointKind.Revolute,
kcsolve.BaseJointKind.Cylindrical,
kcsolve.BaseJointKind.Slider,
kcsolve.BaseJointKind.Screw,
kcsolve.BaseJointKind.Universal,
# Phase 2: mechanical
kcsolve.BaseJointKind.Gear,
kcsolve.BaseJointKind.RackPinion,
}
@@ -51,11 +99,12 @@ class KindredSolver(kcsolve.IKCSolver):
)
bodies[part.id] = body
# 2. Build constraint residuals
# 2. Build constraint residuals (track index mapping for decomposition)
all_residuals = []
constraint_objs = []
constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints
for c in ctx.constraints:
for idx, c in enumerate(ctx.constraints):
if not c.activated:
continue
body_i = bodies.get(c.part_i)
@@ -79,6 +128,7 @@ class KindredSolver(kcsolve.IKCSolver):
if obj is None:
continue
constraint_objs.append(obj)
constraint_indices.append(idx)
all_residuals.extend(obj.residuals())
# 3. Add quaternion normalization residuals for non-grounded bodies
@@ -88,18 +138,31 @@ class KindredSolver(kcsolve.IKCSolver):
all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
# 4. Pre-passes
# 4. Pre-passes on full system
all_residuals = substitution_pass(all_residuals, params)
all_residuals = single_equation_pass(all_residuals, params)
# 5. Newton-Raphson
converged = newton_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
# 5. Solve (decomposed for large assemblies, monolithic for small)
n_free_bodies = sum(1 for b in bodies.values() if not b.grounded)
if n_free_bodies >= _DECOMPOSE_THRESHOLD:
grounded_ids = {pid for pid, b in bodies.items() if b.grounded}
clusters = decompose(ctx.constraints, grounded_ids)
if len(clusters) > 1:
converged = solve_decomposed(
clusters,
bodies,
constraint_objs,
constraint_indices,
params,
)
else:
converged = _monolithic_solve(
all_residuals,
params,
quat_groups,
)
else:
converged = _monolithic_solve(all_residuals, params, quat_groups)
# 6. DOF
dof = count_dof(all_residuals, params)
@@ -130,6 +193,26 @@ class KindredSolver(kcsolve.IKCSolver):
return True
def _monolithic_solve(all_residuals, params, quat_groups):
"""Newton-Raphson solve with BFGS fallback on the full system."""
converged = newton_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
if not converged:
converged = bfgs_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
)
return converged
def _build_constraint(
kind,
body_i,
@@ -141,6 +224,11 @@ def _build_constraint(
c_params,
) -> ConstraintBase | None:
"""Create the appropriate constraint object from a BaseJointKind."""
marker_i_quat = tuple(marker_i.quaternion)
marker_j_quat = tuple(marker_j.quaternion)
# -- Phase 1 constraints --------------------------------------------------
if kind == kcsolve.BaseJointKind.Coincident:
return CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
@@ -155,8 +243,6 @@ def _build_constraint(
)
if kind == kcsolve.BaseJointKind.Fixed:
marker_i_quat = tuple(marker_i.quaternion)
marker_j_quat = tuple(marker_j.quaternion)
return FixedConstraint(
body_i,
marker_i_pos,
@@ -166,4 +252,182 @@ def _build_constraint(
marker_j_quat,
)
# -- Phase 2: point constraints -------------------------------------------
if kind == kcsolve.BaseJointKind.PointOnLine:
return PointOnLineConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.PointInPlane:
offset = c_params[0] if c_params else 0.0
return PointInPlaneConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
# -- Phase 2: orientation constraints -------------------------------------
if kind == kcsolve.BaseJointKind.Parallel:
return ParallelConstraint(body_i, marker_i_quat, body_j, marker_j_quat)
if kind == kcsolve.BaseJointKind.Perpendicular:
return PerpendicularConstraint(body_i, marker_i_quat, body_j, marker_j_quat)
if kind == kcsolve.BaseJointKind.Angle:
angle = c_params[0] if c_params else 0.0
return AngleConstraint(body_i, marker_i_quat, body_j, marker_j_quat, angle)
# -- Phase 2: axis/surface constraints ------------------------------------
if kind == kcsolve.BaseJointKind.Concentric:
distance = c_params[0] if c_params else 0.0
return ConcentricConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
distance=distance,
)
if kind == kcsolve.BaseJointKind.Tangent:
return TangentConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Planar:
offset = c_params[0] if c_params else 0.0
return PlanarConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
if kind == kcsolve.BaseJointKind.LineInPlane:
offset = c_params[0] if c_params else 0.0
return LineInPlaneConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
# -- Phase 2: kinematic joints --------------------------------------------
if kind == kcsolve.BaseJointKind.Ball:
return BallConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
if kind == kcsolve.BaseJointKind.Revolute:
return RevoluteConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Cylindrical:
return CylindricalConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Slider:
return SliderConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Screw:
pitch = c_params[0] if c_params else 1.0
return ScrewConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
pitch=pitch,
)
if kind == kcsolve.BaseJointKind.Universal:
return UniversalConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
# -- Phase 2: mechanical constraints --------------------------------------
if kind == kcsolve.BaseJointKind.Gear:
radius_i = c_params[0] if len(c_params) > 0 else 1.0
radius_j = c_params[1] if len(c_params) > 1 else 1.0
return GearConstraint(
body_i,
marker_i_quat,
body_j,
marker_j_quat,
radius_i,
radius_j,
)
if kind == kcsolve.BaseJointKind.RackPinion:
pitch_radius = c_params[0] if c_params else 1.0
return RackPinionConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
pitch_radius=pitch_radius,
)
# -- Stubs (accepted but produce no residuals) ----------------------------
if kind == kcsolve.BaseJointKind.Cam:
return CamConstraint()
if kind == kcsolve.BaseJointKind.Slot:
return SlotConstraint()
if kind == kcsolve.BaseJointKind.DistanceCylSph:
return DistanceCylSphConstraint()
return None

70
tests/test_bfgs.py Normal file
View File

@@ -0,0 +1,70 @@
"""Tests for the BFGS fallback solver."""
import math
import pytest
from kindred_solver.bfgs import bfgs_solve
from kindred_solver.expr import Const, Var
from kindred_solver.params import ParamTable
class TestBFGSBasic:
def test_single_linear(self):
"""Solve x - 3 = 0."""
pt = ParamTable()
x = pt.add("x", 0.0)
assert bfgs_solve([x - Const(3.0)], pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-8
def test_single_quadratic(self):
"""Solve x^2 - 4 = 0 from x=1 → x=2."""
pt = ParamTable()
x = pt.add("x", 1.0)
assert bfgs_solve([x * x - Const(4.0)], pt) is True
assert abs(pt.get_value("x") - 2.0) < 1e-8
def test_two_variables(self):
"""Solve x + y = 5, x - y = 1."""
pt = ParamTable()
x = pt.add("x", 0.0)
y = pt.add("y", 0.0)
assert bfgs_solve([x + y - Const(5.0), x - y - Const(1.0)], pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-8
assert abs(pt.get_value("y") - 2.0) < 1e-8
def test_empty_system(self):
pt = ParamTable()
assert bfgs_solve([], pt) is True
def test_with_quat_renorm(self):
"""Quaternion re-normalization during BFGS."""
pt = ParamTable()
qw = pt.add("qw", 0.9)
qx = pt.add("qx", 0.1)
qy = pt.add("qy", 0.1)
qz = pt.add("qz", 0.1)
r = qw * qw + qx * qx + qy * qy + qz * qz - Const(1.0)
groups = [("qw", "qx", "qy", "qz")]
assert bfgs_solve([r], pt, quat_groups=groups) is True
w, x, y, z = (pt.get_value(n) for n in ["qw", "qx", "qy", "qz"])
norm = math.sqrt(w**2 + x**2 + y**2 + z**2)
assert abs(norm - 1.0) < 1e-8
class TestBFGSGeometric:
def test_distance_constraint(self):
"""x^2 - 25 = 0 from x=3 → x=5."""
pt = ParamTable()
x = pt.add("x", 3.0)
assert bfgs_solve([x * x - Const(25.0)], pt) is True
assert abs(pt.get_value("x") - 5.0) < 1e-8
def test_difficult_initial_guess(self):
"""BFGS should handle worse initial guesses than Newton."""
pt = ParamTable()
x = pt.add("x", 100.0)
y = pt.add("y", -50.0)
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
assert bfgs_solve(residuals, pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-6
assert abs(pt.get_value("y") - 2.0) < 1e-6

View File

@@ -0,0 +1,481 @@
"""Tests for Phase 2 constraint residual generation."""
import math
import pytest
from kindred_solver.constraints import (
AngleConstraint,
BallConstraint,
CamConstraint,
ConcentricConstraint,
CylindricalConstraint,
DistanceCylSphConstraint,
GearConstraint,
LineInPlaneConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
PointOnLineConstraint,
RackPinionConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
SlotConstraint,
TangentConstraint,
UniversalConstraint,
)
from kindred_solver.entities import RigidBody
from kindred_solver.params import ParamTable
ID_QUAT = (1.0, 0.0, 0.0, 0.0)
# 90-deg about Y: Z-axis of body rotates to point along X
_c = math.cos(math.pi / 4)
_s = math.sin(math.pi / 4)
ROT_90Y = (_c, 0.0, _s, 0.0)
ROT_90Z = (_c, 0.0, 0.0, _s)
# ── Point constraints ────────────────────────────────────────────────
class TestPointOnLine:
def test_on_line(self):
"""Point at (0,0,5) is on Z-axis line through origin."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_off_line(self):
"""Point at (3,0,5) is NOT on Z-axis line through origin."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (3, 0, 5), (1, 0, 0, 0))
c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
assert any(abs(v) > 0.1 for v in vals)
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PointOnLineConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 2
class TestPointInPlane:
def test_in_plane(self):
"""Point at (3,4,0) is in XY plane through origin."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (3, 4, 0), (1, 0, 0, 0))
c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_above_plane(self):
"""Point at (0,0,7) is 7 above XY plane."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 7), (1, 0, 0, 0))
c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env) - 7.0) < 1e-10
def test_with_offset(self):
"""Point at (0,0,5) with offset=5 → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = PointInPlaneConstraint(
b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0
)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PointInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 1
# ── Orientation constraints ──────────────────────────────────────────
class TestParallel:
def test_parallel_same(self):
"""Both bodies with identity rotation → Z-axes parallel → residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0))
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_not_parallel(self):
"""One body rotated 90-deg about Y → Z-axes perpendicular."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y)
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
assert any(abs(v) > 0.1 for v in vals)
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
assert len(c.residuals()) == 2
class TestPerpendicular:
def test_perpendicular(self):
"""One body rotated 90-deg about Y → Z-axes perpendicular."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_not_perpendicular(self):
"""Same orientation → not perpendicular."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
# dot(z,z) = 1 ≠ 0
assert abs(c.residuals()[0].eval(env) - 1.0) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
assert len(c.residuals()) == 1
class TestAngle:
def test_90_degrees(self):
"""90-deg angle between Z-axes rotated 90-deg about Y."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, math.pi / 2)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_0_degrees(self):
"""0-deg angle, same orientation → cos(0)=1, dot=1 → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 0.0)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0)
assert len(c.residuals()) == 1
# ── Axis/surface constraints ─────────────────────────────────────────
class TestConcentric:
def test_coaxial(self):
"""Both on Z-axis → coaxial → residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_not_coaxial(self):
"""Offset in X → not coaxial."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0))
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
assert any(abs(v) > 0.1 for v in vals)
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 4
class TestTangent:
def test_touching(self):
"""Marker origins at same point → tangent."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_separated(self):
"""Separated along normal → non-zero residual."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 5), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env) - 5.0) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 1
class TestPlanar:
def test_coplanar(self):
"""Same plane, same orientation → all residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 3, 0), (1, 0, 0, 0))
c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_with_offset(self):
"""b_i at z=5, b_j at origin, normal=Z, offset=5.
Signed distance = (p_i - p_j).n = 5, offset=5 → 5-5 = 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = PlanarConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 3
class TestLineInPlane:
def test_in_plane(self):
"""Line along X in XY plane → residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
# b2 has Z-axis = (1,0,0) via 90-deg rotation about Y
b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y)
# Line = b2's Z-axis (which is world X), plane = b1's XY plane (normal=Z)
c = LineInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = LineInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 2
# ── Kinematic joints ─────────────────────────────────────────────────
class TestBall:
def test_same_as_coincident(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = BallConstraint(b1, (0, 0, 0), b2, (0, 0, 0))
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
assert len(c.residuals()) == 3
class TestRevolute:
def test_satisfied(self):
"""Same position, same Z-axis → satisfied."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z) # rotated about Z — still parallel
c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 5
class TestCylindrical:
def test_on_axis(self):
"""Same axis, displaced along Z → satisfied."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0))
c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 4
class TestSlider:
def test_aligned(self):
"""Same axis, no twist, displaced along Z → satisfied."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0))
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_twisted(self):
"""Rotated about Z → twist residual non-zero."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z)
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
# First 4 should be ~0 (parallel + on-line), but twist residual should be ~1
assert abs(vals[4]) > 0.5
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 5
class TestUniversal:
def test_satisfied(self):
"""Same origin, perpendicular Z-axes."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 4
class TestScrew:
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0)
assert len(c.residuals()) == 5
def test_zero_displacement_zero_rotation(self):
"""Both at origin with identity rotation → all residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
# ── Mechanical constraints ───────────────────────────────────────────
class TestGear:
def test_both_at_rest(self):
"""Both at identity rotation → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 1.0)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 2.0)
assert len(c.residuals()) == 1
class TestRackPinion:
def test_at_rest(self):
"""Both at rest → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = RackPinionConstraint(
b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=5.0
)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = RackPinionConstraint(
b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=1.0
)
assert len(c.residuals()) == 1
# ── Stubs ────────────────────────────────────────────────────────────
class TestStubs:
def test_cam(self):
assert CamConstraint().residuals() == []
def test_slot(self):
assert SlotConstraint().residuals() == []
def test_distance_cyl_sph(self):
assert DistanceCylSphConstraint().residuals() == []

1052
tests/test_decompose.py Normal file

File diff suppressed because it is too large Load Diff

187
tests/test_geometry.py Normal file
View File

@@ -0,0 +1,187 @@
"""Tests for geometry helpers."""
import math
import pytest
from kindred_solver.entities import RigidBody
from kindred_solver.expr import Const, Var
from kindred_solver.geometry import (
cross3,
dot3,
marker_x_axis,
marker_y_axis,
marker_z_axis,
point_line_perp_components,
point_plane_distance,
sub3,
)
from kindred_solver.params import ParamTable
IDENTITY_QUAT = (1.0, 0.0, 0.0, 0.0)
# 90-deg about Z: (cos45, 0, 0, sin45)
_c = math.cos(math.pi / 4)
_s = math.sin(math.pi / 4)
ROT_90Z_QUAT = (_c, 0.0, 0.0, _s)
class TestDot3:
def test_parallel(self):
a = (Const(1.0), Const(0.0), Const(0.0))
b = (Const(1.0), Const(0.0), Const(0.0))
assert abs(dot3(a, b).eval({}) - 1.0) < 1e-10
def test_perpendicular(self):
a = (Const(1.0), Const(0.0), Const(0.0))
b = (Const(0.0), Const(1.0), Const(0.0))
assert abs(dot3(a, b).eval({})) < 1e-10
def test_general(self):
a = (Const(1.0), Const(2.0), Const(3.0))
b = (Const(4.0), Const(5.0), Const(6.0))
# 1*4 + 2*5 + 3*6 = 32
assert abs(dot3(a, b).eval({}) - 32.0) < 1e-10
class TestCross3:
def test_x_cross_y(self):
x = (Const(1.0), Const(0.0), Const(0.0))
y = (Const(0.0), Const(1.0), Const(0.0))
cx, cy, cz = cross3(x, y)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({})) < 1e-10
assert abs(cz.eval({}) - 1.0) < 1e-10
def test_parallel_is_zero(self):
a = (Const(2.0), Const(3.0), Const(4.0))
b = (Const(4.0), Const(6.0), Const(8.0))
cx, cy, cz = cross3(a, b)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({})) < 1e-10
assert abs(cz.eval({})) < 1e-10
class TestSub3:
def test_basic(self):
a = (Const(5.0), Const(3.0), Const(1.0))
b = (Const(1.0), Const(2.0), Const(3.0))
dx, dy, dz = sub3(a, b)
assert abs(dx.eval({}) - 4.0) < 1e-10
assert abs(dy.eval({}) - 1.0) < 1e-10
assert abs(dz.eval({}) - (-2.0)) < 1e-10
class TestMarkerAxes:
def test_identity_z(self):
"""Identity body + identity marker → Z = (0,0,1)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(zx.eval(env)) < 1e-10
assert abs(zy.eval(env)) < 1e-10
assert abs(zz.eval(env) - 1.0) < 1e-10
def test_identity_x(self):
"""Identity body + identity marker → X = (1,0,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(xx.eval(env) - 1.0) < 1e-10
assert abs(xy.eval(env)) < 1e-10
assert abs(xz.eval(env)) < 1e-10
def test_identity_y(self):
"""Identity body + identity marker → Y = (0,1,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
yx, yy, yz = marker_y_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(yx.eval(env)) < 1e-10
assert abs(yy.eval(env) - 1.0) < 1e-10
assert abs(yz.eval(env)) < 1e-10
def test_rotated_body_z(self):
"""Body rotated 90-deg about Z → Z-axis still (0,0,1)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT)
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(zx.eval(env)) < 1e-10
assert abs(zy.eval(env)) < 1e-10
assert abs(zz.eval(env) - 1.0) < 1e-10
def test_rotated_body_x(self):
"""Body rotated 90-deg about Z → X-axis becomes (0,1,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT)
xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(xx.eval(env)) < 1e-10
assert abs(xy.eval(env) - 1.0) < 1e-10
assert abs(xz.eval(env)) < 1e-10
def test_marker_rotation(self):
"""Identity body + marker rotated 90-deg about Z → Z still (0,0,1)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
zx, zy, zz = marker_z_axis(body, ROT_90Z_QUAT)
env = pt.get_env()
assert abs(zx.eval(env)) < 1e-10
assert abs(zy.eval(env)) < 1e-10
assert abs(zz.eval(env) - 1.0) < 1e-10
def test_marker_rotation_x_axis(self):
"""Identity body + marker rotated 90-deg about Z → X becomes (0,1,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
xx, xy, xz = marker_x_axis(body, ROT_90Z_QUAT)
env = pt.get_env()
assert abs(xx.eval(env)) < 1e-10
assert abs(xy.eval(env) - 1.0) < 1e-10
assert abs(xz.eval(env)) < 1e-10
def test_differentiable(self):
"""Marker axes are differentiable w.r.t. body quat params."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
# Should not raise
dzx = zx.diff("p/qz").simplify()
env = pt.get_env()
dzx.eval(env) # Should be evaluable
class TestPointPlaneDistance:
def test_on_plane(self):
pt = (Const(1.0), Const(2.0), Const(0.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
normal = (Const(0.0), Const(0.0), Const(1.0))
d = point_plane_distance(pt, origin, normal)
assert abs(d.eval({})) < 1e-10
def test_above_plane(self):
pt = (Const(1.0), Const(2.0), Const(5.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
normal = (Const(0.0), Const(0.0), Const(1.0))
d = point_plane_distance(pt, origin, normal)
assert abs(d.eval({}) - 5.0) < 1e-10
class TestPointLinePerp:
def test_on_line(self):
pt = (Const(0.0), Const(0.0), Const(5.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
direction = (Const(0.0), Const(0.0), Const(1.0))
cx, cy = point_line_perp_components(pt, origin, direction)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({})) < 1e-10
def test_off_line(self):
pt = (Const(3.0), Const(0.0), Const(0.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
direction = (Const(0.0), Const(0.0), Const(1.0))
cx, cy = point_line_perp_components(pt, origin, direction)
# d = (3,0,0), dir = (0,0,1), d x dir = (0*1-0*0, 0*0-3*1, 3*0-0*0) = (0,-3,0)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({}) - (-3.0)) < 1e-10

612
tests/test_joints.py Normal file
View File

@@ -0,0 +1,612 @@
"""Integration tests for kinematic joint constraints.
These tests exercise the full solve pipeline (constraint → residuals →
pre-pass → Newton / BFGS) for multi-body systems with various joint types.
"""
import math
import pytest
from kindred_solver.constraints import (
BallConstraint,
CoincidentConstraint,
CylindricalConstraint,
GearConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
PointOnLineConstraint,
RackPinionConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
UniversalConstraint,
)
from kindred_solver.dof import count_dof
from kindred_solver.entities import RigidBody
from kindred_solver.newton import newton_solve
from kindred_solver.params import ParamTable
from kindred_solver.prepass import single_equation_pass, substitution_pass
ID_QUAT = (1, 0, 0, 0)
# 90° about Z: (cos(45°), 0, 0, sin(45°))
c45 = math.cos(math.pi / 4)
s45 = math.sin(math.pi / 4)
ROT_90Z = (c45, 0, 0, s45)
# 90° about Y
ROT_90Y = (c45, 0, s45, 0)
# 90° about X
ROT_90X = (c45, s45, 0, 0)
def _solve(bodies, constraint_objs):
"""Run the full solve pipeline. Returns (converged, params, bodies)."""
pt = bodies[0].tx # all bodies share the same ParamTable via Var._name
# Actually, we need the ParamTable object. Get it from the first body.
# The Var objects store names, but we need the table. We'll reconstruct.
# Better approach: caller passes pt.
raise NotImplementedError("Use _solve_with_pt instead")
def _solve_with_pt(pt, bodies, constraint_objs):
"""Run the full solve pipeline with explicit ParamTable."""
all_residuals = []
for c in constraint_objs:
all_residuals.extend(c.residuals())
quat_groups = []
for body in bodies:
if not body.grounded:
all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
all_residuals = substitution_pass(all_residuals, pt)
all_residuals = single_equation_pass(all_residuals, pt)
converged = newton_solve(
all_residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10
)
return converged, all_residuals
def _dof(pt, bodies, constraint_objs):
"""Count DOF for a system."""
all_residuals = []
for c in constraint_objs:
all_residuals.extend(c.residuals())
for body in bodies:
if not body.grounded:
all_residuals.append(body.quat_norm_residual())
all_residuals = substitution_pass(all_residuals, pt)
return count_dof(all_residuals, pt)
# ============================================================================
# Single-joint DOF counting tests
# ============================================================================
class TestJointDOF:
"""Verify each joint type removes the expected number of DOF.
Setup: ground body + 1 free body (6 DOF) with a single joint.
"""
def _setup(self, pos_b=(0, 0, 0), quat_b=ID_QUAT):
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, pos_b, quat_b)
return pt, a, b
def test_ball_3dof(self):
"""Ball joint: 6 - 3 = 3 DOF (3 rotation)."""
pt, a, b = self._setup()
constraints = [BallConstraint(a, (0, 0, 0), b, (0, 0, 0))]
assert _dof(pt, [a, b], constraints) == 3
def test_revolute_1dof(self):
"""Revolute: 6 - 5 = 1 DOF (rotation about Z)."""
pt, a, b = self._setup()
constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 1
def test_cylindrical_2dof(self):
"""Cylindrical: 6 - 4 = 2 DOF (rotation + translation along Z)."""
pt, a, b = self._setup()
constraints = [
CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 2
def test_slider_1dof(self):
"""Slider: 6 - 5 = 1 DOF (translation along Z)."""
pt, a, b = self._setup()
constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 1
def test_universal_2dof(self):
"""Universal: 6 - 4 = 2 DOF (rotation about each body's Z)."""
pt, a, b = self._setup(quat_b=ROT_90X)
constraints = [
UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 2
def test_screw_1dof(self):
"""Screw: 6 - 5 = 1 DOF (helical motion)."""
pt, a, b = self._setup()
constraints = [
ScrewConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT, pitch=10.0)
]
assert _dof(pt, [a, b], constraints) == 1
def test_parallel_4dof(self):
"""Parallel: 6 - 2 = 4 DOF."""
pt, a, b = self._setup()
constraints = [ParallelConstraint(a, ID_QUAT, b, ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 4
def test_perpendicular_5dof(self):
"""Perpendicular: 6 - 1 = 5 DOF."""
pt, a, b = self._setup(quat_b=ROT_90X)
constraints = [PerpendicularConstraint(a, ID_QUAT, b, ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 5
def test_point_on_line_4dof(self):
"""PointOnLine: 6 - 2 = 4 DOF."""
pt, a, b = self._setup()
constraints = [
PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 4
def test_point_in_plane_5dof(self):
"""PointInPlane: 6 - 1 = 5 DOF."""
pt, a, b = self._setup()
constraints = [
PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 5
def test_planar_3dof(self):
"""Planar: 6 - 3 = 3 DOF (2 translation in plane + 1 rotation about normal)."""
pt, a, b = self._setup()
constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 3
# ============================================================================
# Solve convergence tests — single joints from displaced initial conditions
# ============================================================================
class TestJointSolve:
"""Newton converges to a valid configuration from displaced starting points."""
def test_revolute_displaced(self):
"""Revolute joint: body B starts displaced, should converge to hinge position."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (3, 4, 5), ID_QUAT) # displaced
constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# Coincident origins → position should be at origin
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
assert abs(pos[2]) < 1e-8
def test_cylindrical_displaced(self):
"""Cylindrical joint: body B can slide along Z but must be on axis."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (3, 4, 7), ID_QUAT) # off-axis
constraints = [
CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# X and Y should be zero (on axis), Z can be anything
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
def test_slider_displaced(self):
"""Slider: body B can translate along Z only."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (2, 3, 5), ID_QUAT) # displaced
constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# X and Y should be zero (on axis), Z free
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
def test_ball_displaced(self):
"""Ball joint: body B moves so marker origins coincide.
Ball has 3 rotation DOF free, so we can only verify the
world-frame marker points match, not the body position directly.
"""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (5, 5, 5), ID_QUAT)
constraints = [BallConstraint(a, (1, 0, 0), b, (-1, 0, 0))]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
# Verify marker world points match
wp_a = a.world_point(1, 0, 0)
wp_b = b.world_point(-1, 0, 0)
for ea, eb in zip(wp_a, wp_b):
assert abs(ea.eval(env) - eb.eval(env)) < 1e-8
def test_universal_displaced(self):
"""Universal joint: coincident origins + perpendicular Z-axes."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
# Start B with Z-axis along X (90° about Y) — perpendicular to A's Z
b = RigidBody("b", pt, (3, 4, 5), ROT_90Y)
constraints = [
UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
assert abs(pos[2]) < 1e-8
def test_point_on_line_solve(self):
"""Point on line: body B's marker origin constrained to line along Z.
Under-constrained system (4 DOF remain), so we verify the constraint
residuals are satisfied rather than expecting specific positions.
"""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (5, 3, 7), ID_QUAT)
constraints = [
PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, residuals = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
def test_point_in_plane_solve(self):
"""Point in plane: body B's marker origin at z=0 plane.
Under-constrained (5 DOF remain), so verify residuals.
"""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (3, 4, 8), ID_QUAT)
constraints = [
PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, residuals = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
def test_planar_solve(self):
"""Planar: coplanar faces — parallel normals + point in plane."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
# Start B tilted and displaced
b = RigidBody("b", pt, (3, 4, 8), ID_QUAT)
constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# Z must be zero (in plane), X and Y free
assert abs(pos[2]) < 1e-8
# ============================================================================
# Multi-body integration tests
# ============================================================================
class TestFourBarLinkage:
"""Four-bar linkage: 4 bodies, 4 revolute joints.
In 3D with Z-axis revolutes, this yields 2 DOF: the expected planar
motion plus an out-of-plane fold. A truly planar mechanism would
add Planar constraints on each link to eliminate the fold DOF.
"""
def test_four_bar_dof(self):
"""Four-bar linkage in 3D has 2 DOF (planar + fold)."""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
link1 = RigidBody("l1", pt, (2, 0, 0), ID_QUAT)
link2 = RigidBody("l2", pt, (5, 3, 0), ID_QUAT)
link3 = RigidBody("l3", pt, (8, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT),
]
bodies = [ground, link1, link2, link3]
dof = _dof(pt, bodies, constraints)
assert dof == 2
def test_four_bar_solves(self):
"""Four-bar linkage converges from displaced initial conditions."""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
# Initial positions slightly displaced from valid config
link1 = RigidBody("l1", pt, (2, 1, 0), ID_QUAT)
link2 = RigidBody("l2", pt, (5, 4, 0), ID_QUAT)
link3 = RigidBody("l3", pt, (8, 1, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT),
]
bodies = [ground, link1, link2, link3]
converged, residuals = _solve_with_pt(pt, bodies, constraints)
assert converged
# Verify all revolute constraints are satisfied
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
class TestSliderCrank:
"""Slider-crank mechanism: crank + connecting rod + piston.
ground --[Revolute]-- crank --[Revolute]-- rod --[Revolute]-- piston --[Slider]-- ground
Using Slider (not Cylindrical) for the piston to also lock rotation,
making it a true prismatic joint. In 3D, out-of-plane folding adds
extra DOF beyond the planar 1-DOF.
3 free bodies × 6 = 18 DOF
Revolute(5) + Revolute(5) + Revolute(5) + Slider(5) = 20
But many constraints share bodies, so effective rank < 20.
In 3D: 3 DOF (planar crank + 2 fold modes).
"""
def test_slider_crank_dof(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
crank = RigidBody("crank", pt, (1, 0, 0), ID_QUAT)
rod = RigidBody("rod", pt, (3, 0, 0), ID_QUAT)
piston = RigidBody("piston", pt, (5, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT),
RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT),
RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT),
SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y),
]
bodies = [ground, crank, rod, piston]
dof = _dof(pt, bodies, constraints)
# 3D slider-crank: planar motion + out-of-plane fold modes
assert dof == 3
def test_slider_crank_solves(self):
"""Slider-crank converges from displaced state."""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
crank = RigidBody("crank", pt, (1, 0.5, 0), ID_QUAT)
rod = RigidBody("rod", pt, (3, 1, 0), ID_QUAT)
piston = RigidBody("piston", pt, (5, 0.5, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT),
RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT),
RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT),
SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y),
]
bodies = [ground, crank, rod, piston]
converged, residuals = _solve_with_pt(pt, bodies, constraints)
assert converged
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
class TestRevoluteChain:
"""Chain of revolute joints: ground → body1 → body2.
Each revolute removes 5 DOF. Two free bodies = 12 DOF.
2 revolutes = 10 constraints + 2 quat norms = 12.
Expected: 2 DOF (one rotation per hinge).
"""
def test_chain_dof(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b1 = RigidBody("b1", pt, (3, 0, 0), ID_QUAT)
b2 = RigidBody("b2", pt, (6, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT),
]
assert _dof(pt, [ground, b1, b2], constraints) == 2
def test_chain_solves(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b1 = RigidBody("b1", pt, (3, 2, 0), ID_QUAT)
b2 = RigidBody("b2", pt, (6, 3, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT),
]
converged, residuals = _solve_with_pt(pt, [ground, b1, b2], constraints)
assert converged
env = pt.get_env()
# b1 origin at ground hinge point (0,0,0)
pos1 = b1.extract_position(env)
assert abs(pos1[0]) < 1e-8
assert abs(pos1[1]) < 1e-8
assert abs(pos1[2]) < 1e-8
class TestSliderOnRail:
"""Slider constraint: body translates along ground Z-axis only.
1 free body, 1 slider = 6 - 5 = 1 DOF.
"""
def test_slider_on_rail(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
block = RigidBody("block", pt, (3, 4, 5), ID_QUAT)
constraints = [
SliderConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
assert converged
env = pt.get_env()
pos = block.extract_position(env)
# X, Y must be zero; Z is free
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
# Z should remain near initial value (minimum-norm solution)
# Check orientation unchanged (no twist)
quat = block.extract_quaternion(env)
assert abs(quat[0] - 1.0) < 1e-6
assert abs(quat[1]) < 1e-6
assert abs(quat[2]) < 1e-6
assert abs(quat[3]) < 1e-6
class TestPlanarOnTable:
"""Planar constraint: body slides on XY plane.
1 free body, 1 planar = 6 - 3 = 3 DOF.
"""
def test_planar_on_table(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
block = RigidBody("block", pt, (3, 4, 5), ID_QUAT)
constraints = [
PlanarConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
assert converged
env = pt.get_env()
pos = block.extract_position(env)
# Z must be zero, X and Y are free
assert abs(pos[2]) < 1e-8
class TestPlanarWithOffset:
"""Planar with offset: body floats at z=3 above ground."""
def test_planar_offset(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
block = RigidBody("block", pt, (1, 2, 5), ID_QUAT)
# PlanarConstraint residual: (p_i - p_j) . z_j - offset = 0
# body_i=block, body_j=ground: (block_z - 0) * 1 - offset = 0
# For block at z=3: offset = 3
constraints = [
PlanarConstraint(
block, (0, 0, 0), ID_QUAT, ground, (0, 0, 0), ID_QUAT, offset=3.0
)
]
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
assert converged
env = pt.get_env()
pos = block.extract_position(env)
assert abs(pos[2] - 3.0) < 1e-8
class TestMixedConstraints:
"""System with mixed constraint types."""
def test_revolute_plus_parallel(self):
"""Two free bodies: revolute between ground and b1, parallel between b1 and b2.
b1: 6 DOF - 5 (revolute) = 1 DOF
b2: 6 DOF - 2 (parallel) = 4 DOF
Total: 5 DOF
"""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b1 = RigidBody("b1", pt, (0, 0, 0), ID_QUAT)
b2 = RigidBody("b2", pt, (5, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT),
]
assert _dof(pt, [ground, b1, b2], constraints) == 5
def test_coincident_plus_perpendicular(self):
"""Coincident + perpendicular = ball + 1 angle constraint.
6 - 3 (coincident) - 1 (perpendicular) = 2 DOF.
"""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (0, 0, 0), ROT_90X)
constraints = [
CoincidentConstraint(ground, (0, 0, 0), b, (0, 0, 0)),
PerpendicularConstraint(ground, ID_QUAT, b, ID_QUAT),
]
assert _dof(pt, [ground, b], constraints) == 2

View File

@@ -65,3 +65,37 @@ class TestParamTable:
pt.add("b", 0.0, fixed=True)
pt.add("c", 0.0)
assert pt.n_free() == 2
def test_unfix(self):
pt = ParamTable()
pt.add("a", 1.0)
pt.add("b", 2.0)
pt.fix("a")
assert pt.is_fixed("a")
assert "a" not in pt.free_names()
pt.unfix("a")
assert not pt.is_fixed("a")
assert "a" in pt.free_names()
assert pt.n_free() == 2
def test_fix_unfix_roundtrip(self):
"""Fix then unfix preserves value and makes param free again."""
pt = ParamTable()
pt.add("x", 5.0)
pt.add("y", 3.0)
pt.fix("x")
pt.set_value("x", 10.0)
pt.unfix("x")
assert pt.get_value("x") == 10.0
assert "x" in pt.free_names()
# x moves to end of free list
assert pt.free_names() == ["y", "x"]
def test_unfix_noop_if_already_free(self):
"""Unfixing a free parameter is a no-op."""
pt = ParamTable()
pt.add("a", 1.0)
pt.unfix("a")
assert pt.free_names() == ["a"]
assert pt.n_free() == 1