2 Commits

Author SHA1 Message Date
forbes-0023
92ae57751f feat(solver): graph decomposition for cluster-by-cluster solving (phase 3)
Add a Python decomposition layer using NetworkX that partitions the
constraint graph into biconnected components (rigid clusters), orders
them via a block-cut tree, and solves each cluster independently.
Articulation-point bodies propagate as boundary conditions between
clusters.

New module kindred_solver/decompose.py:
- DOF table mapping BaseJointKind to residual counts
- Constraint graph construction (nx.MultiGraph)
- Biconnected component detection + articulation points
- Block-cut tree solve ordering (root-first from grounded cluster)
- Cluster-by-cluster solver with boundary body fix/unfix cycling
- Pebble game integration for per-cluster rigidity classification

Changes to existing modules:
- params.py: add unfix() for boundary body cycling
- solver.py: extract _monolithic_solve(), add decomposition branch
  for assemblies with >= 8 free bodies

Performance: for k clusters of ~n/k params each, total cost drops
from O(n^3) to O(n^3/k^2).

220 tests passing (up from 207).
2026-02-20 22:19:35 -06:00
forbes-0023
533ca91774 feat(solver): full constraint vocabulary — all 24 BaseJointKind types (phase 2)
Add 18 new constraint classes covering all BaseJointKind types from Types.h:
- Point: PointOnLine (2r), PointInPlane (1r)
- Orientation: Parallel (2r), Perpendicular (1r), Angle (1r)
- Surface: Concentric (4r), Tangent (1r), Planar (3r), LineInPlane (2r)
- Kinematic: Ball (3r), Revolute (5r), Cylindrical (4r), Slider (5r),
  Screw (5r), Universal (4r)
- Mechanical: Gear (1r), RackPinion (1r)
- Stubs: Cam, Slot, DistanceCylSph

New modules:
- geometry.py: marker axis extraction, vector ops (dot3, cross3, sub3),
  geometric primitives (point_plane_distance, point_line_perp_components)
- bfgs.py: L-BFGS-B fallback solver via scipy for when Newton fails

solver.py changes:
- Wire all 20 supported types in _build_constraint()
- BFGS fallback after Newton-Raphson in solve()

183 tests passing (up from 82), including:
- DOF counting for every joint type
- Solve convergence from displaced initial conditions
- Multi-body mechanisms (four-bar linkage, slider-crank, revolute chain)
2026-02-20 21:15:15 -06:00
12 changed files with 4351 additions and 14 deletions

127
kindred_solver/bfgs.py Normal file
View File

@@ -0,0 +1,127 @@
"""L-BFGS-B fallback solver for when Newton-Raphson fails to converge.
Minimizes f(x) = 0.5 * sum(r_i(x)^2) using scipy's L-BFGS-B with
analytic gradient from the Expr DAG's symbolic differentiation.
"""
from __future__ import annotations
import math
from typing import List
import numpy as np
from .expr import Expr
from .params import ParamTable
try:
from scipy.optimize import minimize as _scipy_minimize
_HAS_SCIPY = True
except ImportError:
_HAS_SCIPY = False
def bfgs_solve(
residuals: List[Expr],
params: ParamTable,
quat_groups: List[tuple[str, str, str, str]] | None = None,
max_iter: int = 200,
tol: float = 1e-10,
) -> bool:
"""Solve ``residuals == 0`` by minimizing sum of squared residuals.
Falls back gracefully to False if scipy is not available.
Returns True if converged (||r|| < tol).
"""
if not _HAS_SCIPY:
return False
free = params.free_names()
n_free = len(free)
n_res = len(residuals)
if n_free == 0 or n_res == 0:
return True
# Build symbolic gradient expressions once: d(r_i)/d(x_j)
jac_exprs: List[List[Expr]] = []
for r in residuals:
row = []
for name in free:
row.append(r.diff(name).simplify())
jac_exprs.append(row)
def objective_and_grad(x_vec):
# Update params
params.set_free_vector(x_vec)
if quat_groups:
_renormalize_quats(params, quat_groups)
env = params.get_env()
# Evaluate residuals
r_vals = np.array([r.eval(env) for r in residuals])
f = 0.5 * np.dot(r_vals, r_vals)
# Evaluate Jacobian
J = np.empty((n_res, n_free))
for i in range(n_res):
for j in range(n_free):
J[i, j] = jac_exprs[i][j].eval(env)
# Gradient of f = sum(r_i * dr_i/dx_j) = J^T @ r
grad = J.T @ r_vals
return f, grad
x0 = params.get_free_vector().copy()
result = _scipy_minimize(
objective_and_grad,
x0,
method="L-BFGS-B",
jac=True,
options={"maxiter": max_iter, "ftol": tol * tol, "gtol": tol},
)
# Apply final result
params.set_free_vector(result.x)
if quat_groups:
_renormalize_quats(params, quat_groups)
# Check convergence on actual residual norm
env = params.get_env()
r_vals = np.array([r.eval(env) for r in residuals])
return bool(np.linalg.norm(r_vals) < tol)
def _renormalize_quats(
params: ParamTable,
groups: List[tuple[str, str, str, str]],
):
"""Project quaternion params back onto the unit sphere."""
for qw_name, qx_name, qy_name, qz_name in groups:
if (
params.is_fixed(qw_name)
and params.is_fixed(qx_name)
and params.is_fixed(qy_name)
and params.is_fixed(qz_name)
):
continue
w = params.get_value(qw_name)
x = params.get_value(qx_name)
y = params.get_value(qy_name)
z = params.get_value(qz_name)
norm = math.sqrt(w * w + x * x + y * y + z * z)
if norm < 1e-15:
params.set_value(qw_name, 1.0)
params.set_value(qx_name, 0.0)
params.set_value(qy_name, 0.0)
params.set_value(qz_name, 0.0)
else:
params.set_value(qw_name, w / norm)
params.set_value(qx_name, x / norm)
params.set_value(qy_name, y / norm)
params.set_value(qz_name, z / norm)

View File

@@ -2,14 +2,28 @@
Each constraint takes two RigidBody entities and marker transforms, Each constraint takes two RigidBody entities and marker transforms,
then generates residual expressions that equal zero when satisfied. then generates residual expressions that equal zero when satisfied.
Phase 1 constraints: Coincident, DistancePointPoint, Fixed
Phase 2 constraints: all remaining BaseJointKind types from Types.h
""" """
from __future__ import annotations from __future__ import annotations
import math
from typing import List from typing import List
from .entities import RigidBody from .entities import RigidBody
from .expr import Const, Expr from .expr import Const, Expr
from .geometry import (
cross3,
dot3,
marker_x_axis,
marker_y_axis,
marker_z_axis,
point_line_perp_components,
point_plane_distance,
sub3,
)
class ConstraintBase: class ConstraintBase:
@@ -145,6 +159,703 @@ class FixedConstraint(ConstraintBase):
return pos_res + ori_res return pos_res + ori_res
# ============================================================================
# Phase 2: Point constraints
# ============================================================================
class PointOnLineConstraint(ConstraintBase):
"""Point constrained to a line — 2 DOF removed.
marker_i origin lies on the line through marker_j origin along
marker_j Z-axis. 2 residuals: perpendicular distance components.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy = point_line_perp_components(p_i, p_j, z_j)
return [cx, cy]
class PointInPlaneConstraint(ConstraintBase):
"""Point constrained to a plane — 1 DOF removed.
marker_i origin lies in the plane through marker_j origin with
normal = marker_j Z-axis. Optional offset via params[0].
1 residual: signed distance to plane.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
offset: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.offset = offset
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
d = point_plane_distance(p_i, p_j, n_j)
if self.offset != 0.0:
d = d - Const(self.offset)
return [d]
# ============================================================================
# Phase 2: Axis orientation constraints
# ============================================================================
class ParallelConstraint(ConstraintBase):
"""Parallel axes — 2 DOF removed.
marker Z-axes are parallel: z_i x z_j = 0.
2 residuals from the cross product (only 2 of 3 components are
independent for unit vectors).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, cz = cross3(z_i, z_j)
return [cx, cy]
class PerpendicularConstraint(ConstraintBase):
"""Perpendicular axes — 1 DOF removed.
marker Z-axes are perpendicular: z_i . z_j = 0.
1 residual.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
return [dot3(z_i, z_j)]
class AngleConstraint(ConstraintBase):
"""Angle between axes — 1 DOF removed.
z_i . z_j = cos(angle).
1 residual.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
angle: float,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
self.angle = angle
def residuals(self) -> List[Expr]:
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
return [dot3(z_i, z_j) - Const(math.cos(self.angle))]
# ============================================================================
# Phase 2: Axis/surface constraints
# ============================================================================
class ConcentricConstraint(ConstraintBase):
"""Coaxial / concentric — 4 DOF removed.
Axes are collinear: parallel Z-axes (2) + point-on-line (2).
Optional distance offset along axis via params.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
distance: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.distance = distance
def residuals(self) -> List[Expr]:
# Parallel axes (2 residuals)
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-on-line: marker_i origin on line through marker_j along z_j
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
return [cx, cy, lx, ly]
class TangentConstraint(ConstraintBase):
"""Face-on-face tangency — 1 DOF removed.
Signed distance between marker origins along marker_j normal = 0.
1 residual: (p_i - p_j) . z_j
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
return [point_plane_distance(p_i, p_j, n_j)]
class PlanarConstraint(ConstraintBase):
"""Coplanar faces — 3 DOF removed.
Parallel normals (2) + point-in-plane (1). Optional offset.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
offset: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.offset = offset
def residuals(self) -> List[Expr]:
# Parallel normals
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-in-plane
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
d = point_plane_distance(p_i, p_j, z_j)
if self.offset != 0.0:
d = d - Const(self.offset)
return [cx, cy, d]
class LineInPlaneConstraint(ConstraintBase):
"""Line constrained to a plane — 2 DOF removed.
Line defined by marker_i Z-axis lies in plane defined by marker_j normal.
2 residuals: point-in-plane (1) + line direction perpendicular to normal (1).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
offset: float = 0.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.offset = offset
def residuals(self) -> List[Expr]:
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
n_j = marker_z_axis(self.body_j, self.marker_j_quat)
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
# Point in plane
d = point_plane_distance(p_i, p_j, n_j)
if self.offset != 0.0:
d = d - Const(self.offset)
# Line direction perpendicular to plane normal
dir_dot = dot3(z_i, n_j)
return [d, dir_dot]
# ============================================================================
# Phase 2: Kinematic joints
# ============================================================================
class BallConstraint(ConstraintBase):
"""Spherical joint — 3 DOF removed.
Coincident marker origins. Same as CoincidentConstraint.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
):
self._inner = CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
def residuals(self) -> List[Expr]:
return self._inner.residuals()
class RevoluteConstraint(ConstraintBase):
"""Hinge joint — 5 DOF removed.
Coincident origins (3) + parallel Z-axes (2).
1 rotational DOF remains (about the common Z-axis).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Coincident origins
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]]
# Parallel Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
return pos + [cx, cy]
class CylindricalConstraint(ConstraintBase):
"""Cylindrical joint — 4 DOF removed.
Parallel Z-axes (2) + point-on-line (2).
2 DOF remain: rotation about and translation along the common axis.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Parallel Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-on-line
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
return [cx, cy, lx, ly]
class SliderConstraint(ConstraintBase):
"""Prismatic / slider joint — 5 DOF removed.
Parallel Z-axes (2) + point-on-line (2) + rotation lock (1).
1 DOF remains: translation along the common Z-axis.
Rotation lock: x_i . y_j = 0 (prevents twist about Z).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Parallel Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
# Point-on-line
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
# Rotation lock: x_i . y_j = 0
x_i = marker_x_axis(self.body_i, self.marker_i_quat)
y_j = marker_y_axis(self.body_j, self.marker_j_quat)
twist = dot3(x_i, y_j)
return [cx, cy, lx, ly, twist]
class ScrewConstraint(ConstraintBase):
"""Helical / screw joint — 5 DOF removed.
Cylindrical (4) + coupled rotation-translation via pitch (1).
1 DOF remains: screw motion (rotation + proportional translation).
The coupling residual uses the relative quaternion's Z-component
(proportional to the rotation angle for small angles) and the axial
displacement: axial_disp - pitch * (2 * qz_rel / qw_rel) / (2*pi) = 0.
For the Newton solver operating near the solution, the linear
approximation angle ≈ 2 * qz_rel is adequate.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
pitch: float = 1.0,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.pitch = pitch
def residuals(self) -> List[Expr]:
# Cylindrical residuals (4)
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
cx, cy, _cz = cross3(z_i, z_j)
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
lx, ly = point_line_perp_components(p_i, p_j, z_j)
# Pitch coupling: axial_disp = pitch * angle / (2*pi)
# Axial displacement
d = sub3(p_i, p_j)
axial = dot3(d, z_j)
# Relative rotation about Z via quaternion
# q_rel = conj(q_i_total) * q_j_total
qi = _quat_mul_const(
self.body_i.qw,
self.body_i.qx,
self.body_i.qy,
self.body_i.qz,
*self.marker_i_quat,
)
qj = _quat_mul_const(
self.body_j.qw,
self.body_j.qx,
self.body_j.qy,
self.body_j.qz,
*self.marker_j_quat,
)
rel = _quat_mul_expr(qi[0], -qi[1], -qi[2], -qi[3], qj[0], qj[1], qj[2], qj[3])
# For small angles: angle ≈ 2 * qz_rel, but qw_rel ≈ 1
# Use sin(angle/2) form: residual = axial - pitch * 2*qz / (2*pi)
# = axial - pitch * qz / pi
coupling = axial - Const(self.pitch / math.pi) * rel[3]
return [cx, cy, lx, ly, coupling]
class UniversalConstraint(ConstraintBase):
"""Universal / Cardan joint — 4 DOF removed.
Coincident origins (3) + perpendicular Z-axes (1).
2 DOF remain: rotation about each body's Z-axis.
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
def residuals(self) -> List[Expr]:
# Coincident origins
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
pos = [p_i[0] - p_j[0], p_i[1] - p_j[1], p_i[2] - p_j[2]]
# Perpendicular Z-axes
z_i = marker_z_axis(self.body_i, self.marker_i_quat)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
return pos + [dot3(z_i, z_j)]
# ============================================================================
# Phase 2: Mechanical element constraints
# ============================================================================
class GearConstraint(ConstraintBase):
"""Gear pair or belt — 1 DOF removed.
Couples rotation angles: r_i * theta_i + r_j * theta_j = 0.
For belts (same-direction rotation), r_j is passed as negative.
Uses the Z-component of the relative quaternion as a proxy for
rotation angle (linear for small angles, which is the regime
where Newton operates).
"""
def __init__(
self,
body_i: RigidBody,
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_quat: tuple[float, float, float, float],
radius_i: float,
radius_j: float,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_quat = marker_i_quat
self.marker_j_quat = marker_j_quat
self.radius_i = radius_i
self.radius_j = radius_j
def residuals(self) -> List[Expr]:
# Rotation angle proxy via relative quaternion Z-component
# For body_i: q_rel_i = conj(q_marker_i) * q_body_i * q_marker_i
# Simplified: use 2*qz of (conj(marker) * body * marker) as angle proxy
qz_i = _rotation_z_component(self.body_i, self.marker_i_quat)
qz_j = _rotation_z_component(self.body_j, self.marker_j_quat)
# r_i * theta_i + r_j * theta_j = 0
# Using qz as proportional to theta/2:
# r_i * qz_i + r_j * qz_j = 0
return [Const(self.radius_i) * qz_i + Const(self.radius_j) * qz_j]
class RackPinionConstraint(ConstraintBase):
"""Rack-and-pinion — 1 DOF removed.
Couples rotation of body_i to translation of body_j along marker_j Z-axis.
translation = pitch_radius * theta
"""
def __init__(
self,
body_i: RigidBody,
marker_i_pos: tuple[float, float, float],
marker_i_quat: tuple[float, float, float, float],
body_j: RigidBody,
marker_j_pos: tuple[float, float, float],
marker_j_quat: tuple[float, float, float, float],
pitch_radius: float,
):
self.body_i = body_i
self.body_j = body_j
self.marker_i_pos = marker_i_pos
self.marker_i_quat = marker_i_quat
self.marker_j_pos = marker_j_pos
self.marker_j_quat = marker_j_quat
self.pitch_radius = pitch_radius
def residuals(self) -> List[Expr]:
# Translation of j along its Z-axis
p_i = self.body_i.world_point(*self.marker_i_pos)
p_j = self.body_j.world_point(*self.marker_j_pos)
z_j = marker_z_axis(self.body_j, self.marker_j_quat)
d = sub3(p_j, p_i)
translation = dot3(d, z_j)
# Rotation angle of i about its Z-axis
qz_i = _rotation_z_component(self.body_i, self.marker_i_quat)
# translation - pitch_radius * angle = 0
# angle ≈ 2 * qz, so: translation - pitch_radius * 2 * qz = 0
return [translation - Const(2.0 * self.pitch_radius) * qz_i]
class CamConstraint(ConstraintBase):
"""Cam-follower constraint — future, stub."""
def residuals(self) -> List[Expr]:
return []
class SlotConstraint(ConstraintBase):
"""Slot constraint — future, stub."""
def residuals(self) -> List[Expr]:
return []
class DistanceCylSphConstraint(ConstraintBase):
"""Cylinder-sphere distance — stub.
Semantics depend on geometry classification; placeholder for now.
"""
def residuals(self) -> List[Expr]:
return []
# -- rotation helpers for mechanical constraints ------------------------------
def _rotation_z_component(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Expr:
"""Extract the Z-component of the relative quaternion about a marker axis.
Returns the qz component of conj(q_marker) * q_body * q_marker,
which is proportional to sin(theta/2) where theta is the rotation
angle about the marker Z-axis.
"""
mw, mx, my, mz = marker_quat
# q_local = conj(marker) * q_body * marker
# Step 1: temp = conj(marker) * q_body
cmw, cmx, cmy, cmz = Const(mw), Const(-mx), Const(-my), Const(-mz)
# temp = conj(marker) * q_body
tw = cmw * body.qw - cmx * body.qx - cmy * body.qy - cmz * body.qz
tx = cmw * body.qx + cmx * body.qw + cmy * body.qz - cmz * body.qy
ty = cmw * body.qy - cmx * body.qz + cmy * body.qw + cmz * body.qx
tz = cmw * body.qz + cmx * body.qy - cmy * body.qx + cmz * body.qw
# q_local = temp * marker
mmw, mmx, mmy, mmz = Const(mw), Const(mx), Const(my), Const(mz)
# rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw
rz = tw * mmz + tx * mmy - ty * mmx + tz * mmw
return rz
# -- quaternion multiplication helpers ---------------------------------------- # -- quaternion multiplication helpers ----------------------------------------

661
kindred_solver/decompose.py Normal file
View File

@@ -0,0 +1,661 @@
"""Graph decomposition for cluster-by-cluster constraint solving.
Builds a constraint graph from the SolveContext, decomposes it into
biconnected components (rigid clusters), orders them via a block-cut
tree, and solves each cluster independently. Articulation-point bodies
are temporarily fixed when solving adjacent clusters so their solved
values propagate as boundary conditions.
Requires: networkx
"""
from __future__ import annotations
import importlib.util
import logging
import sys
import types as stdlib_types
from collections import deque
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, List
import networkx as nx
from .bfgs import bfgs_solve
from .newton import newton_solve
from .prepass import substitution_pass
if TYPE_CHECKING:
from .constraints import ConstraintBase
from .entities import RigidBody
from .params import ParamTable
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# DOF table: BaseJointKind → number of residuals (= DOF removed)
# ---------------------------------------------------------------------------
# Imported lazily to avoid hard kcsolve dependency in tests.
# Use residual_count() accessor instead of this dict directly.
_RESIDUAL_COUNT: dict[str, int] | None = None
def _ensure_residual_count() -> dict:
"""Build the residual count table on first use."""
global _RESIDUAL_COUNT
if _RESIDUAL_COUNT is not None:
return _RESIDUAL_COUNT
import kcsolve
_RESIDUAL_COUNT = {
kcsolve.BaseJointKind.Fixed: 6,
kcsolve.BaseJointKind.Coincident: 3,
kcsolve.BaseJointKind.Ball: 3,
kcsolve.BaseJointKind.Revolute: 5,
kcsolve.BaseJointKind.Cylindrical: 4,
kcsolve.BaseJointKind.Slider: 5,
kcsolve.BaseJointKind.Screw: 5,
kcsolve.BaseJointKind.Universal: 4,
kcsolve.BaseJointKind.Parallel: 2,
kcsolve.BaseJointKind.Perpendicular: 1,
kcsolve.BaseJointKind.Angle: 1,
kcsolve.BaseJointKind.Concentric: 4,
kcsolve.BaseJointKind.Tangent: 1,
kcsolve.BaseJointKind.Planar: 3,
kcsolve.BaseJointKind.LineInPlane: 2,
kcsolve.BaseJointKind.PointOnLine: 2,
kcsolve.BaseJointKind.PointInPlane: 1,
kcsolve.BaseJointKind.DistancePointPoint: 1,
kcsolve.BaseJointKind.Gear: 1,
kcsolve.BaseJointKind.RackPinion: 1,
kcsolve.BaseJointKind.Cam: 0,
kcsolve.BaseJointKind.Slot: 0,
kcsolve.BaseJointKind.DistanceCylSph: 0,
}
return _RESIDUAL_COUNT
def residual_count(kind) -> int:
"""Number of residuals a constraint type produces."""
return _ensure_residual_count().get(kind, 0)
# ---------------------------------------------------------------------------
# Standalone residual-count table (no kcsolve dependency, string-keyed)
# Used by tests that don't have kcsolve available.
# ---------------------------------------------------------------------------
_RESIDUAL_COUNT_BY_NAME: dict[str, int] = {
"Fixed": 6,
"Coincident": 3,
"Ball": 3,
"Revolute": 5,
"Cylindrical": 4,
"Slider": 5,
"Screw": 5,
"Universal": 4,
"Parallel": 2,
"Perpendicular": 1,
"Angle": 1,
"Concentric": 4,
"Tangent": 1,
"Planar": 3,
"LineInPlane": 2,
"PointOnLine": 2,
"PointInPlane": 1,
"DistancePointPoint": 1,
"Gear": 1,
"RackPinion": 1,
"Cam": 0,
"Slot": 0,
"DistanceCylSph": 0,
}
def residual_count_by_name(kind_name: str) -> int:
"""Number of residuals by constraint type name (no kcsolve needed)."""
return _RESIDUAL_COUNT_BY_NAME.get(kind_name, 0)
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class SolveCluster:
"""A cluster of bodies to solve together."""
bodies: set[str] # Body IDs in this cluster
constraint_indices: list[int] # Indices into the constraint list
boundary_bodies: set[str] # Articulation points shared with other clusters
has_ground: bool # Whether any body in the cluster is grounded
# ---------------------------------------------------------------------------
# Graph construction
# ---------------------------------------------------------------------------
def build_constraint_graph(
constraints: list,
grounded_bodies: set[str],
) -> nx.MultiGraph:
"""Build a body-level constraint multigraph.
Nodes: part_id strings (one per body referenced by constraints).
Edges: one per active constraint with attributes:
- constraint_index: position in the constraints list
- weight: number of residuals
Grounded bodies are tagged with ``grounded=True``.
Constraints with 0 residuals (stubs) are excluded.
"""
G = nx.MultiGraph()
for idx, c in enumerate(constraints):
if not c.activated:
continue
weight = residual_count(c.type)
if weight == 0:
continue
part_i = c.part_i
part_j = c.part_j
# Ensure nodes exist
if part_i not in G:
G.add_node(part_i, grounded=(part_i in grounded_bodies))
if part_j not in G:
G.add_node(part_j, grounded=(part_j in grounded_bodies))
# Store kind_name for pebble game integration
kind_name = c.type.name if hasattr(c.type, "name") else str(c.type)
G.add_edge(
part_i, part_j, constraint_index=idx, weight=weight, kind_name=kind_name
)
return G
def build_constraint_graph_simple(
edges: list[tuple[str, str, str, int]],
grounded: set[str] | None = None,
) -> nx.MultiGraph:
"""Build a constraint graph from simple edge tuples (for testing).
Each edge is ``(body_i, body_j, kind_name, constraint_index)``.
"""
grounded = grounded or set()
G = nx.MultiGraph()
for body_i, body_j, kind_name, idx in edges:
weight = residual_count_by_name(kind_name)
if weight == 0:
continue
if body_i not in G:
G.add_node(body_i, grounded=(body_i in grounded))
if body_j not in G:
G.add_node(body_j, grounded=(body_j in grounded))
G.add_edge(
body_i, body_j, constraint_index=idx, weight=weight, kind_name=kind_name
)
return G
# ---------------------------------------------------------------------------
# Decomposition
# ---------------------------------------------------------------------------
def find_clusters(
G: nx.MultiGraph,
) -> tuple[list[set[str]], set[str]]:
"""Find biconnected components and articulation points.
Returns:
clusters: list of body-ID sets (one per biconnected component)
articulation_points: body-IDs shared between clusters
"""
# biconnected_components requires a simple Graph
simple = nx.Graph(G)
clusters = [set(c) for c in nx.biconnected_components(simple)]
artic = set(nx.articulation_points(simple))
return clusters, artic
def build_solve_order(
G: nx.MultiGraph,
clusters: list[set[str]],
articulation_points: set[str],
grounded_bodies: set[str],
) -> list[SolveCluster]:
"""Order clusters for solving via the block-cut tree.
Builds the block-cut tree (bipartite graph of clusters and
articulation points), roots it at a grounded cluster, and returns
clusters in root-to-leaf order (grounded first, outward to leaves).
This ensures boundary bodies are solved before clusters that
depend on them.
"""
if not clusters:
return []
# Single cluster — no ordering needed
if len(clusters) == 1:
bodies = clusters[0]
indices = _constraints_for_bodies(G, bodies)
has_ground = bool(bodies & grounded_bodies)
return [
SolveCluster(
bodies=bodies,
constraint_indices=indices,
boundary_bodies=set(),
has_ground=has_ground,
)
]
# Build block-cut tree
# Nodes: ("C", i) for cluster i, ("A", body_id) for articulation points
bct = nx.Graph()
for i, cluster in enumerate(clusters):
bct.add_node(("C", i))
for ap in articulation_points:
if ap in cluster:
bct.add_edge(("C", i), ("A", ap))
# Find root: prefer a cluster containing a grounded body
root = ("C", 0)
for i, cluster in enumerate(clusters):
if cluster & grounded_bodies:
root = ("C", i)
break
# BFS from root: grounded cluster first, outward to leaves
visited = set()
order = []
queue = deque([root])
visited.add(root)
while queue:
node = queue.popleft()
if node[0] == "C":
order.append(node[1])
for neighbor in bct.neighbors(node):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
# Build SolveCluster objects
solve_clusters = []
for i in order:
bodies = clusters[i]
indices = _constraints_for_bodies(G, bodies)
boundary = bodies & articulation_points
has_ground = bool(bodies & grounded_bodies)
solve_clusters.append(
SolveCluster(
bodies=bodies,
constraint_indices=indices,
boundary_bodies=boundary,
has_ground=has_ground,
)
)
return solve_clusters
def _constraints_for_bodies(G: nx.MultiGraph, bodies: set[str]) -> list[int]:
"""Collect constraint indices for edges where both endpoints are in bodies."""
indices = []
seen = set()
for u, v, data in G.edges(data=True):
idx = data["constraint_index"]
if idx in seen:
continue
if u in bodies and v in bodies:
seen.add(idx)
indices.append(idx)
return sorted(indices)
# ---------------------------------------------------------------------------
# Top-level decompose entry point
# ---------------------------------------------------------------------------
def decompose(
constraints: list,
grounded_bodies: set[str],
) -> list[SolveCluster]:
"""Full decomposition pipeline: graph → clusters → solve order.
Returns a list of SolveCluster in solve order (leaves first).
If the system is a single cluster, returns a 1-element list.
"""
G = build_constraint_graph(constraints, grounded_bodies)
# Handle disconnected sub-assemblies
all_clusters = []
for component_nodes in nx.connected_components(G):
sub = G.subgraph(component_nodes).copy()
clusters, artic = find_clusters(sub)
if len(clusters) <= 1:
# Single cluster in this component
bodies = component_nodes if not clusters else clusters[0]
indices = _constraints_for_bodies(sub, bodies)
has_ground = bool(bodies & grounded_bodies)
all_clusters.append(
SolveCluster(
bodies=set(bodies),
constraint_indices=indices,
boundary_bodies=set(),
has_ground=has_ground,
)
)
else:
ordered = build_solve_order(sub, clusters, artic, grounded_bodies)
all_clusters.extend(ordered)
return all_clusters
# ---------------------------------------------------------------------------
# Cluster solver
# ---------------------------------------------------------------------------
def solve_decomposed(
clusters: list[SolveCluster],
bodies: dict[str, "RigidBody"],
constraint_objs: list["ConstraintBase"],
constraint_indices_map: list[int],
params: "ParamTable",
) -> bool:
"""Solve clusters in order, fixing boundary bodies between solves.
Args:
clusters: SolveCluster list in solve order (from decompose()).
bodies: part_id → RigidBody mapping.
constraint_objs: constraint objects (parallel to constraint_indices_map).
constraint_indices_map: for each constraint_obj, its index in ctx.constraints.
params: shared ParamTable.
Returns True if all clusters converged.
"""
# Build reverse map: constraint_index → position in constraint_objs list
idx_to_obj: dict[int, "ConstraintBase"] = {}
for pos, ci in enumerate(constraint_indices_map):
idx_to_obj[ci] = constraint_objs[pos]
solved_bodies: set[str] = set()
all_converged = True
for cluster in clusters:
# 1. Fix boundary bodies that were already solved
fixed_boundary_params: list[str] = []
for body_id in cluster.boundary_bodies:
if body_id in solved_bodies:
body = bodies[body_id]
for pname in body._param_names:
if not params.is_fixed(pname):
params.fix(pname)
fixed_boundary_params.append(pname)
# 2. Collect residuals for this cluster
cluster_residuals = []
for ci in cluster.constraint_indices:
obj = idx_to_obj.get(ci)
if obj is not None:
cluster_residuals.extend(obj.residuals())
# 3. Add quat norm residuals for free, non-grounded bodies in this cluster
quat_groups = []
for body_id in cluster.bodies:
body = bodies[body_id]
if body.grounded:
continue
if body_id in cluster.boundary_bodies and body_id in solved_bodies:
continue # Already fixed as boundary
cluster_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
# 4. Substitution pass (compiles fixed boundary params to constants)
cluster_residuals = substitution_pass(cluster_residuals, params)
# 5. Newton solve (+ BFGS fallback)
if cluster_residuals:
converged = newton_solve(
cluster_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
if not converged:
converged = bfgs_solve(
cluster_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
)
if not converged:
all_converged = False
# 6. Mark this cluster's bodies as solved
solved_bodies.update(cluster.bodies)
# 7. Unfix boundary params
for pname in fixed_boundary_params:
params.unfix(pname)
return all_converged
# ---------------------------------------------------------------------------
# Pebble game integration (rigidity classification)
# ---------------------------------------------------------------------------
_PEBBLE_MODULES_LOADED = False
_PebbleGame3D = None
_PebbleJointType = None
_PebbleJoint = None
def _load_pebble_modules():
"""Lazily load PebbleGame3D and related types from GNN/solver/datagen/.
The GNN package has its own import structure (``from solver.datagen.types
import ...``) that conflicts with the top-level module layout, so we
register shim modules in ``sys.modules`` to make it work.
"""
global _PEBBLE_MODULES_LOADED, _PebbleGame3D, _PebbleJointType, _PebbleJoint
if _PEBBLE_MODULES_LOADED:
return
# Find GNN/solver/datagen relative to this package
pkg_dir = Path(__file__).resolve().parent.parent # mods/solver/
datagen_dir = pkg_dir / "GNN" / "solver" / "datagen"
if not datagen_dir.exists():
log.warning("GNN/solver/datagen/ not found; pebble game unavailable")
_PEBBLE_MODULES_LOADED = True
return
# Register shim modules so ``from solver.datagen.types import ...`` works
if "solver" not in sys.modules:
sys.modules["solver"] = stdlib_types.ModuleType("solver")
if "solver.datagen" not in sys.modules:
dg = stdlib_types.ModuleType("solver.datagen")
sys.modules["solver.datagen"] = dg
sys.modules["solver"].datagen = dg # type: ignore[attr-defined]
# Load types.py
types_path = datagen_dir / "types.py"
spec_t = importlib.util.spec_from_file_location(
"solver.datagen.types", str(types_path)
)
types_mod = importlib.util.module_from_spec(spec_t)
sys.modules["solver.datagen.types"] = types_mod
spec_t.loader.exec_module(types_mod)
# Load pebble_game.py
pg_path = datagen_dir / "pebble_game.py"
spec_p = importlib.util.spec_from_file_location(
"solver.datagen.pebble_game", str(pg_path)
)
pg_mod = importlib.util.module_from_spec(spec_p)
sys.modules["solver.datagen.pebble_game"] = pg_mod
spec_p.loader.exec_module(pg_mod)
_PebbleGame3D = pg_mod.PebbleGame3D
_PebbleJointType = types_mod.JointType
_PebbleJoint = types_mod.Joint
_PEBBLE_MODULES_LOADED = True
# BaseJointKind name → PebbleGame JointType name.
# Types not listed here use manual edge insertion with the residual count.
_KIND_NAME_TO_PEBBLE_NAME: dict[str, str] = {
"Fixed": "FIXED",
"Coincident": "BALL", # Same DOF count (3)
"Ball": "BALL",
"Revolute": "REVOLUTE",
"Cylindrical": "CYLINDRICAL",
"Slider": "SLIDER",
"Screw": "SCREW",
"Universal": "UNIVERSAL",
"Planar": "PLANAR",
"Perpendicular": "PERPENDICULAR",
"DistancePointPoint": "DISTANCE",
}
# Parallel: pebble game uses 3 DOF, but our solver uses 2.
# We handle it with manual edge insertion.
# Types that need manual edge insertion (no direct JointType mapping,
# or DOF mismatch like Parallel).
_MANUAL_EDGE_TYPES: set[str] = {
"Parallel", # 2 residuals, but JointType.PARALLEL = 3
"Angle", # 1 residual, no JointType
"Concentric", # 4 residuals, no JointType
"Tangent", # 1 residual, no JointType
"LineInPlane", # 2 residuals, no JointType
"PointOnLine", # 2 residuals, no JointType
"PointInPlane", # 1 residual, no JointType
"Gear", # 1 residual, no JointType
"RackPinion", # 1 residual, no JointType
}
_GROUND_BODY_ID = -1
def classify_cluster_rigidity(
cluster: SolveCluster,
constraint_graph: nx.MultiGraph,
grounded_bodies: set[str],
) -> str | None:
"""Run pebble game on a cluster and return rigidity classification.
Returns one of: "well-constrained", "underconstrained",
"overconstrained", "mixed", or None if pebble game unavailable.
"""
import numpy as np
_load_pebble_modules()
if _PebbleGame3D is None:
return None
pg = _PebbleGame3D()
# Map string body IDs → integer IDs for pebble game
body_list = sorted(cluster.bodies)
body_to_int: dict[str, int] = {b: i for i, b in enumerate(body_list)}
for b in body_list:
pg.add_body(body_to_int[b])
# Add virtual ground body if cluster has grounded bodies
has_ground = bool(cluster.bodies & grounded_bodies)
if has_ground:
pg.add_body(_GROUND_BODY_ID)
for b in cluster.bodies & grounded_bodies:
ground_joint = _PebbleJoint(
joint_id=-1,
body_a=body_to_int[b],
body_b=_GROUND_BODY_ID,
joint_type=_PebbleJointType["FIXED"],
anchor_a=np.zeros(3),
anchor_b=np.zeros(3),
)
pg.add_joint(ground_joint)
# Add constraint edges
joint_counter = 0
zero = np.zeros(3)
for u, v, data in constraint_graph.edges(data=True):
if u not in cluster.bodies or v not in cluster.bodies:
continue
ci = data["constraint_index"]
if ci not in cluster.constraint_indices:
continue
# Determine the constraint kind name from the graph edge
kind_name = data.get("kind_name", "")
n_residuals = data.get("weight", 0)
if not kind_name or n_residuals == 0:
continue
int_u = body_to_int[u]
int_v = body_to_int[v]
pebble_name = _KIND_NAME_TO_PEBBLE_NAME.get(kind_name)
if pebble_name and kind_name not in _MANUAL_EDGE_TYPES:
# Direct JointType mapping
jt = _PebbleJointType[pebble_name]
joint = _PebbleJoint(
joint_id=joint_counter,
body_a=int_u,
body_b=int_v,
joint_type=jt,
anchor_a=zero,
anchor_b=zero,
)
pg.add_joint(joint)
joint_counter += 1
else:
# Manual edge insertion: one DISTANCE edge per residual
for _ in range(n_residuals):
joint = _PebbleJoint(
joint_id=joint_counter,
body_a=int_u,
body_b=int_v,
joint_type=_PebbleJointType["DISTANCE"],
anchor_a=zero,
anchor_b=zero,
)
pg.add_joint(joint)
joint_counter += 1
# Classify using raw pebble counts (adjusting for virtual ground)
total_dof = pg.get_dof()
redundant = pg.get_redundant_count()
# The virtual ground body contributes 6 pebbles that are never consumed.
# Subtract them to get the effective DOF.
if has_ground:
total_dof -= 6 # virtual ground's unconstrained pebbles
baseline = 0
else:
baseline = 6 # trivial rigid-body motion
if redundant > 0 and total_dof > baseline:
return "mixed"
elif redundant > 0:
return "overconstrained"
elif total_dof > baseline:
return "underconstrained"
elif total_dof == baseline:
return "well-constrained"
else:
return "overconstrained"

131
kindred_solver/geometry.py Normal file
View File

@@ -0,0 +1,131 @@
"""Geometric helper functions for constraint equations.
Provides Expr-level vector operations and marker axis extraction.
All functions work with Expr triples (tuples of 3 Expr nodes)
representing 3D vectors in world coordinates.
Marker convention (from Types.h): the marker's Z-axis defines the
constraint direction (hinge axis, face normal, line direction, etc.).
"""
from __future__ import annotations
from .entities import RigidBody
from .expr import Const, Expr
from .quat import quat_rotate
# Type alias for an Expr triple (3D vector)
Vec3 = tuple[Expr, Expr, Expr]
# -- Marker axis extraction ---------------------------------------------------
def _composed_quat(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> tuple[Expr, Expr, Expr, Expr]:
"""Compute q_total = q_body * q_marker as Expr quaternion.
q_body comes from the body's Var params; q_marker is constant.
"""
bw, bx, by, bz = body.qw, body.qx, body.qy, body.qz
mw, mx, my, mz = (Const(v) for v in marker_quat)
# Hamilton product: body * marker
rw = bw * mw - bx * mx - by * my - bz * mz
rx = bw * mx + bx * mw + by * mz - bz * my
ry = bw * my - bx * mz + by * mw + bz * mx
rz = bw * mz + bx * my - by * mx + bz * mw
return rw, rx, ry, rz
def marker_z_axis(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Vec3:
"""World-frame Z-axis of a marker on a body.
Computes rotate(q_body * q_marker, [0, 0, 1]).
"""
qw, qx, qy, qz = _composed_quat(body, marker_quat)
return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(0.0), Const(1.0))
def marker_x_axis(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Vec3:
"""World-frame X-axis of a marker on a body.
Computes rotate(q_body * q_marker, [1, 0, 0]).
"""
qw, qx, qy, qz = _composed_quat(body, marker_quat)
return quat_rotate(qw, qx, qy, qz, Const(1.0), Const(0.0), Const(0.0))
def marker_y_axis(
body: RigidBody,
marker_quat: tuple[float, float, float, float],
) -> Vec3:
"""World-frame Y-axis of a marker on a body.
Computes rotate(q_body * q_marker, [0, 1, 0]).
"""
qw, qx, qy, qz = _composed_quat(body, marker_quat)
return quat_rotate(qw, qx, qy, qz, Const(0.0), Const(1.0), Const(0.0))
# -- Vector operations on Expr triples ----------------------------------------
def dot3(a: Vec3, b: Vec3) -> Expr:
"""Dot product of two Expr triples."""
return a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
def cross3(a: Vec3, b: Vec3) -> Vec3:
"""Cross product of two Expr triples."""
return (
a[1] * b[2] - a[2] * b[1],
a[2] * b[0] - a[0] * b[2],
a[0] * b[1] - a[1] * b[0],
)
def sub3(a: Vec3, b: Vec3) -> Vec3:
"""Vector subtraction a - b."""
return (a[0] - b[0], a[1] - b[1], a[2] - b[2])
# -- Geometric primitives -----------------------------------------------------
def point_plane_distance(
point: Vec3,
plane_origin: Vec3,
normal: Vec3,
) -> Expr:
"""Signed distance from point to plane defined by origin + normal.
Returns (point - plane_origin) . normal
"""
d = sub3(point, plane_origin)
return dot3(d, normal)
def point_line_perp_components(
point: Vec3,
line_origin: Vec3,
line_dir: Vec3,
) -> tuple[Expr, Expr]:
"""Two independent perpendicular-distance components from point to line.
The line passes through line_origin along line_dir.
Returns the x and y components of (point - line_origin) x line_dir,
which are zero when the point lies on the line.
"""
d = sub3(point, line_origin)
cx, cy, cz = cross3(d, line_dir)
# All three components of d x line_dir are zero when d is parallel
# to line_dir, but only 2 are independent. We return x and y.
return cx, cy

View File

@@ -49,6 +49,13 @@ class ParamTable:
if name in self._free_order: if name in self._free_order:
self._free_order.remove(name) self._free_order.remove(name)
def unfix(self, name: str):
"""Restore a fixed parameter to free status."""
if name in self._fixed:
self._fixed.discard(name)
if name not in self._free_order:
self._free_order.append(name)
def get_env(self) -> Dict[str, float]: def get_env(self) -> Dict[str, float]:
"""Return a snapshot of all current values (for Expr.eval).""" """Return a snapshot of all current values (for Expr.eval)."""
return dict(self._values) return dict(self._values)

View File

@@ -5,23 +5,71 @@ from __future__ import annotations
import kcsolve import kcsolve
from .bfgs import bfgs_solve
from .constraints import ( from .constraints import (
AngleConstraint,
BallConstraint,
CamConstraint,
CoincidentConstraint, CoincidentConstraint,
ConcentricConstraint,
ConstraintBase, ConstraintBase,
CylindricalConstraint,
DistanceCylSphConstraint,
DistancePointPointConstraint, DistancePointPointConstraint,
FixedConstraint, FixedConstraint,
GearConstraint,
LineInPlaneConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
PointOnLineConstraint,
RackPinionConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
SlotConstraint,
TangentConstraint,
UniversalConstraint,
) )
from .decompose import decompose, solve_decomposed
from .dof import count_dof from .dof import count_dof
from .entities import RigidBody from .entities import RigidBody
from .newton import newton_solve from .newton import newton_solve
from .params import ParamTable from .params import ParamTable
from .prepass import single_equation_pass, substitution_pass from .prepass import single_equation_pass, substitution_pass
# Map BaseJointKind enum values to handler names # Assemblies with fewer free bodies than this use the monolithic path.
_DECOMPOSE_THRESHOLD = 8
# All BaseJointKind values this solver can handle
_SUPPORTED = { _SUPPORTED = {
# Phase 1
kcsolve.BaseJointKind.Coincident, kcsolve.BaseJointKind.Coincident,
kcsolve.BaseJointKind.DistancePointPoint, kcsolve.BaseJointKind.DistancePointPoint,
kcsolve.BaseJointKind.Fixed, kcsolve.BaseJointKind.Fixed,
# Phase 2: point constraints
kcsolve.BaseJointKind.PointOnLine,
kcsolve.BaseJointKind.PointInPlane,
# Phase 2: orientation
kcsolve.BaseJointKind.Parallel,
kcsolve.BaseJointKind.Perpendicular,
kcsolve.BaseJointKind.Angle,
# Phase 2: axis/surface
kcsolve.BaseJointKind.Concentric,
kcsolve.BaseJointKind.Tangent,
kcsolve.BaseJointKind.Planar,
kcsolve.BaseJointKind.LineInPlane,
# Phase 2: kinematic joints
kcsolve.BaseJointKind.Ball,
kcsolve.BaseJointKind.Revolute,
kcsolve.BaseJointKind.Cylindrical,
kcsolve.BaseJointKind.Slider,
kcsolve.BaseJointKind.Screw,
kcsolve.BaseJointKind.Universal,
# Phase 2: mechanical
kcsolve.BaseJointKind.Gear,
kcsolve.BaseJointKind.RackPinion,
} }
@@ -51,11 +99,12 @@ class KindredSolver(kcsolve.IKCSolver):
) )
bodies[part.id] = body bodies[part.id] = body
# 2. Build constraint residuals # 2. Build constraint residuals (track index mapping for decomposition)
all_residuals = [] all_residuals = []
constraint_objs = [] constraint_objs = []
constraint_indices = [] # parallel to constraint_objs: index in ctx.constraints
for c in ctx.constraints: for idx, c in enumerate(ctx.constraints):
if not c.activated: if not c.activated:
continue continue
body_i = bodies.get(c.part_i) body_i = bodies.get(c.part_i)
@@ -79,6 +128,7 @@ class KindredSolver(kcsolve.IKCSolver):
if obj is None: if obj is None:
continue continue
constraint_objs.append(obj) constraint_objs.append(obj)
constraint_indices.append(idx)
all_residuals.extend(obj.residuals()) all_residuals.extend(obj.residuals())
# 3. Add quaternion normalization residuals for non-grounded bodies # 3. Add quaternion normalization residuals for non-grounded bodies
@@ -88,18 +138,31 @@ class KindredSolver(kcsolve.IKCSolver):
all_residuals.append(body.quat_norm_residual()) all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names()) quat_groups.append(body.quat_param_names())
# 4. Pre-passes # 4. Pre-passes on full system
all_residuals = substitution_pass(all_residuals, params) all_residuals = substitution_pass(all_residuals, params)
all_residuals = single_equation_pass(all_residuals, params) all_residuals = single_equation_pass(all_residuals, params)
# 5. Newton-Raphson # 5. Solve (decomposed for large assemblies, monolithic for small)
converged = newton_solve( n_free_bodies = sum(1 for b in bodies.values() if not b.grounded)
all_residuals, if n_free_bodies >= _DECOMPOSE_THRESHOLD:
params, grounded_ids = {pid for pid, b in bodies.items() if b.grounded}
quat_groups=quat_groups, clusters = decompose(ctx.constraints, grounded_ids)
max_iter=100, if len(clusters) > 1:
tol=1e-10, converged = solve_decomposed(
) clusters,
bodies,
constraint_objs,
constraint_indices,
params,
)
else:
converged = _monolithic_solve(
all_residuals,
params,
quat_groups,
)
else:
converged = _monolithic_solve(all_residuals, params, quat_groups)
# 6. DOF # 6. DOF
dof = count_dof(all_residuals, params) dof = count_dof(all_residuals, params)
@@ -130,6 +193,26 @@ class KindredSolver(kcsolve.IKCSolver):
return True return True
def _monolithic_solve(all_residuals, params, quat_groups):
"""Newton-Raphson solve with BFGS fallback on the full system."""
converged = newton_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=100,
tol=1e-10,
)
if not converged:
converged = bfgs_solve(
all_residuals,
params,
quat_groups=quat_groups,
max_iter=200,
tol=1e-10,
)
return converged
def _build_constraint( def _build_constraint(
kind, kind,
body_i, body_i,
@@ -141,6 +224,11 @@ def _build_constraint(
c_params, c_params,
) -> ConstraintBase | None: ) -> ConstraintBase | None:
"""Create the appropriate constraint object from a BaseJointKind.""" """Create the appropriate constraint object from a BaseJointKind."""
marker_i_quat = tuple(marker_i.quaternion)
marker_j_quat = tuple(marker_j.quaternion)
# -- Phase 1 constraints --------------------------------------------------
if kind == kcsolve.BaseJointKind.Coincident: if kind == kcsolve.BaseJointKind.Coincident:
return CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos) return CoincidentConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
@@ -155,8 +243,6 @@ def _build_constraint(
) )
if kind == kcsolve.BaseJointKind.Fixed: if kind == kcsolve.BaseJointKind.Fixed:
marker_i_quat = tuple(marker_i.quaternion)
marker_j_quat = tuple(marker_j.quaternion)
return FixedConstraint( return FixedConstraint(
body_i, body_i,
marker_i_pos, marker_i_pos,
@@ -166,4 +252,182 @@ def _build_constraint(
marker_j_quat, marker_j_quat,
) )
# -- Phase 2: point constraints -------------------------------------------
if kind == kcsolve.BaseJointKind.PointOnLine:
return PointOnLineConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.PointInPlane:
offset = c_params[0] if c_params else 0.0
return PointInPlaneConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
# -- Phase 2: orientation constraints -------------------------------------
if kind == kcsolve.BaseJointKind.Parallel:
return ParallelConstraint(body_i, marker_i_quat, body_j, marker_j_quat)
if kind == kcsolve.BaseJointKind.Perpendicular:
return PerpendicularConstraint(body_i, marker_i_quat, body_j, marker_j_quat)
if kind == kcsolve.BaseJointKind.Angle:
angle = c_params[0] if c_params else 0.0
return AngleConstraint(body_i, marker_i_quat, body_j, marker_j_quat, angle)
# -- Phase 2: axis/surface constraints ------------------------------------
if kind == kcsolve.BaseJointKind.Concentric:
distance = c_params[0] if c_params else 0.0
return ConcentricConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
distance=distance,
)
if kind == kcsolve.BaseJointKind.Tangent:
return TangentConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Planar:
offset = c_params[0] if c_params else 0.0
return PlanarConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
if kind == kcsolve.BaseJointKind.LineInPlane:
offset = c_params[0] if c_params else 0.0
return LineInPlaneConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
offset=offset,
)
# -- Phase 2: kinematic joints --------------------------------------------
if kind == kcsolve.BaseJointKind.Ball:
return BallConstraint(body_i, marker_i_pos, body_j, marker_j_pos)
if kind == kcsolve.BaseJointKind.Revolute:
return RevoluteConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Cylindrical:
return CylindricalConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Slider:
return SliderConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
if kind == kcsolve.BaseJointKind.Screw:
pitch = c_params[0] if c_params else 1.0
return ScrewConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
pitch=pitch,
)
if kind == kcsolve.BaseJointKind.Universal:
return UniversalConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
)
# -- Phase 2: mechanical constraints --------------------------------------
if kind == kcsolve.BaseJointKind.Gear:
radius_i = c_params[0] if len(c_params) > 0 else 1.0
radius_j = c_params[1] if len(c_params) > 1 else 1.0
return GearConstraint(
body_i,
marker_i_quat,
body_j,
marker_j_quat,
radius_i,
radius_j,
)
if kind == kcsolve.BaseJointKind.RackPinion:
pitch_radius = c_params[0] if c_params else 1.0
return RackPinionConstraint(
body_i,
marker_i_pos,
marker_i_quat,
body_j,
marker_j_pos,
marker_j_quat,
pitch_radius=pitch_radius,
)
# -- Stubs (accepted but produce no residuals) ----------------------------
if kind == kcsolve.BaseJointKind.Cam:
return CamConstraint()
if kind == kcsolve.BaseJointKind.Slot:
return SlotConstraint()
if kind == kcsolve.BaseJointKind.DistanceCylSph:
return DistanceCylSphConstraint()
return None return None

70
tests/test_bfgs.py Normal file
View File

@@ -0,0 +1,70 @@
"""Tests for the BFGS fallback solver."""
import math
import pytest
from kindred_solver.bfgs import bfgs_solve
from kindred_solver.expr import Const, Var
from kindred_solver.params import ParamTable
class TestBFGSBasic:
def test_single_linear(self):
"""Solve x - 3 = 0."""
pt = ParamTable()
x = pt.add("x", 0.0)
assert bfgs_solve([x - Const(3.0)], pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-8
def test_single_quadratic(self):
"""Solve x^2 - 4 = 0 from x=1 → x=2."""
pt = ParamTable()
x = pt.add("x", 1.0)
assert bfgs_solve([x * x - Const(4.0)], pt) is True
assert abs(pt.get_value("x") - 2.0) < 1e-8
def test_two_variables(self):
"""Solve x + y = 5, x - y = 1."""
pt = ParamTable()
x = pt.add("x", 0.0)
y = pt.add("y", 0.0)
assert bfgs_solve([x + y - Const(5.0), x - y - Const(1.0)], pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-8
assert abs(pt.get_value("y") - 2.0) < 1e-8
def test_empty_system(self):
pt = ParamTable()
assert bfgs_solve([], pt) is True
def test_with_quat_renorm(self):
"""Quaternion re-normalization during BFGS."""
pt = ParamTable()
qw = pt.add("qw", 0.9)
qx = pt.add("qx", 0.1)
qy = pt.add("qy", 0.1)
qz = pt.add("qz", 0.1)
r = qw * qw + qx * qx + qy * qy + qz * qz - Const(1.0)
groups = [("qw", "qx", "qy", "qz")]
assert bfgs_solve([r], pt, quat_groups=groups) is True
w, x, y, z = (pt.get_value(n) for n in ["qw", "qx", "qy", "qz"])
norm = math.sqrt(w**2 + x**2 + y**2 + z**2)
assert abs(norm - 1.0) < 1e-8
class TestBFGSGeometric:
def test_distance_constraint(self):
"""x^2 - 25 = 0 from x=3 → x=5."""
pt = ParamTable()
x = pt.add("x", 3.0)
assert bfgs_solve([x * x - Const(25.0)], pt) is True
assert abs(pt.get_value("x") - 5.0) < 1e-8
def test_difficult_initial_guess(self):
"""BFGS should handle worse initial guesses than Newton."""
pt = ParamTable()
x = pt.add("x", 100.0)
y = pt.add("y", -50.0)
residuals = [x + y - Const(5.0), x - y - Const(1.0)]
assert bfgs_solve(residuals, pt) is True
assert abs(pt.get_value("x") - 3.0) < 1e-6
assert abs(pt.get_value("y") - 2.0) < 1e-6

View File

@@ -0,0 +1,481 @@
"""Tests for Phase 2 constraint residual generation."""
import math
import pytest
from kindred_solver.constraints import (
AngleConstraint,
BallConstraint,
CamConstraint,
ConcentricConstraint,
CylindricalConstraint,
DistanceCylSphConstraint,
GearConstraint,
LineInPlaneConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
PointOnLineConstraint,
RackPinionConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
SlotConstraint,
TangentConstraint,
UniversalConstraint,
)
from kindred_solver.entities import RigidBody
from kindred_solver.params import ParamTable
ID_QUAT = (1.0, 0.0, 0.0, 0.0)
# 90-deg about Y: Z-axis of body rotates to point along X
_c = math.cos(math.pi / 4)
_s = math.sin(math.pi / 4)
ROT_90Y = (_c, 0.0, _s, 0.0)
ROT_90Z = (_c, 0.0, 0.0, _s)
# ── Point constraints ────────────────────────────────────────────────
class TestPointOnLine:
def test_on_line(self):
"""Point at (0,0,5) is on Z-axis line through origin."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_off_line(self):
"""Point at (3,0,5) is NOT on Z-axis line through origin."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (3, 0, 5), (1, 0, 0, 0))
c = PointOnLineConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
assert any(abs(v) > 0.1 for v in vals)
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PointOnLineConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 2
class TestPointInPlane:
def test_in_plane(self):
"""Point at (3,4,0) is in XY plane through origin."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (3, 4, 0), (1, 0, 0, 0))
c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_above_plane(self):
"""Point at (0,0,7) is 7 above XY plane."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 7), (1, 0, 0, 0))
c = PointInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env) - 7.0) < 1e-10
def test_with_offset(self):
"""Point at (0,0,5) with offset=5 → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = PointInPlaneConstraint(
b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0
)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PointInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 1
# ── Orientation constraints ──────────────────────────────────────────
class TestParallel:
def test_parallel_same(self):
"""Both bodies with identity rotation → Z-axes parallel → residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0))
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_not_parallel(self):
"""One body rotated 90-deg about Y → Z-axes perpendicular."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y)
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
assert any(abs(v) > 0.1 for v in vals)
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT)
assert len(c.residuals()) == 2
class TestPerpendicular:
def test_perpendicular(self):
"""One body rotated 90-deg about Y → Z-axes perpendicular."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_not_perpendicular(self):
"""Same orientation → not perpendicular."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
env = pt.get_env()
# dot(z,z) = 1 ≠ 0
assert abs(c.residuals()[0].eval(env) - 1.0) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PerpendicularConstraint(b1, ID_QUAT, b2, ID_QUAT)
assert len(c.residuals()) == 1
class TestAngle:
def test_90_degrees(self):
"""90-deg angle between Z-axes rotated 90-deg about Y."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, math.pi / 2)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_0_degrees(self):
"""0-deg angle, same orientation → cos(0)=1, dot=1 → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 0.0)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = AngleConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0)
assert len(c.residuals()) == 1
# ── Axis/surface constraints ─────────────────────────────────────────
class TestConcentric:
def test_coaxial(self):
"""Both on Z-axis → coaxial → residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_not_coaxial(self):
"""Offset in X → not coaxial."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 0, 0), (1, 0, 0, 0))
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
assert any(abs(v) > 0.1 for v in vals)
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ConcentricConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 4
class TestTangent:
def test_touching(self):
"""Marker origins at same point → tangent."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_separated(self):
"""Separated along normal → non-zero residual."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 5), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env) - 5.0) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = TangentConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 1
class TestPlanar:
def test_coplanar(self):
"""Same plane, same orientation → all residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (5, 3, 0), (1, 0, 0, 0))
c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_with_offset(self):
"""b_i at z=5, b_j at origin, normal=Z, offset=5.
Signed distance = (p_i - p_j).n = 5, offset=5 → 5-5 = 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 5), (1, 0, 0, 0))
c = PlanarConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT, offset=5.0)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = PlanarConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 3
class TestLineInPlane:
def test_in_plane(self):
"""Line along X in XY plane → residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
# b2 has Z-axis = (1,0,0) via 90-deg rotation about Y
b2 = RigidBody("b", pt, (5, 0, 0), ROT_90Y)
# Line = b2's Z-axis (which is world X), plane = b1's XY plane (normal=Z)
c = LineInPlaneConstraint(b2, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = LineInPlaneConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 2
# ── Kinematic joints ─────────────────────────────────────────────────
class TestBall:
def test_same_as_coincident(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = BallConstraint(b1, (0, 0, 0), b2, (0, 0, 0))
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
assert len(c.residuals()) == 3
class TestRevolute:
def test_satisfied(self):
"""Same position, same Z-axis → satisfied."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z) # rotated about Z — still parallel
c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = RevoluteConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 5
class TestCylindrical:
def test_on_axis(self):
"""Same axis, displaced along Z → satisfied."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0))
c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = CylindricalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 4
class TestSlider:
def test_aligned(self):
"""Same axis, no twist, displaced along Z → satisfied."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 10), (1, 0, 0, 0))
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_twisted(self):
"""Rotated about Z → twist residual non-zero."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Z)
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
vals = [r.eval(env) for r in c.residuals()]
# First 4 should be ~0 (parallel + on-line), but twist residual should be ~1
assert abs(vals[4]) > 0.5
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = SliderConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 5
class TestUniversal:
def test_satisfied(self):
"""Same origin, perpendicular Z-axes."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), ROT_90Y)
c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = UniversalConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT)
assert len(c.residuals()) == 4
class TestScrew:
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0)
assert len(c.residuals()) == 5
def test_zero_displacement_zero_rotation(self):
"""Both at origin with identity rotation → all residuals 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = ScrewConstraint(b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch=10.0)
env = pt.get_env()
for r in c.residuals():
assert abs(r.eval(env)) < 1e-10
# ── Mechanical constraints ───────────────────────────────────────────
class TestGear:
def test_both_at_rest(self):
"""Both at identity rotation → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 1.0)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = GearConstraint(b1, ID_QUAT, b2, ID_QUAT, 1.0, 2.0)
assert len(c.residuals()) == 1
class TestRackPinion:
def test_at_rest(self):
"""Both at rest → residual 0."""
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0), grounded=True)
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = RackPinionConstraint(
b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=5.0
)
env = pt.get_env()
assert abs(c.residuals()[0].eval(env)) < 1e-10
def test_residual_count(self):
pt = ParamTable()
b1 = RigidBody("a", pt, (0, 0, 0), (1, 0, 0, 0))
b2 = RigidBody("b", pt, (0, 0, 0), (1, 0, 0, 0))
c = RackPinionConstraint(
b1, (0, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT, pitch_radius=1.0
)
assert len(c.residuals()) == 1
# ── Stubs ────────────────────────────────────────────────────────────
class TestStubs:
def test_cam(self):
assert CamConstraint().residuals() == []
def test_slot(self):
assert SlotConstraint().residuals() == []
def test_distance_cyl_sph(self):
assert DistanceCylSphConstraint().residuals() == []

1052
tests/test_decompose.py Normal file

File diff suppressed because it is too large Load Diff

187
tests/test_geometry.py Normal file
View File

@@ -0,0 +1,187 @@
"""Tests for geometry helpers."""
import math
import pytest
from kindred_solver.entities import RigidBody
from kindred_solver.expr import Const, Var
from kindred_solver.geometry import (
cross3,
dot3,
marker_x_axis,
marker_y_axis,
marker_z_axis,
point_line_perp_components,
point_plane_distance,
sub3,
)
from kindred_solver.params import ParamTable
IDENTITY_QUAT = (1.0, 0.0, 0.0, 0.0)
# 90-deg about Z: (cos45, 0, 0, sin45)
_c = math.cos(math.pi / 4)
_s = math.sin(math.pi / 4)
ROT_90Z_QUAT = (_c, 0.0, 0.0, _s)
class TestDot3:
def test_parallel(self):
a = (Const(1.0), Const(0.0), Const(0.0))
b = (Const(1.0), Const(0.0), Const(0.0))
assert abs(dot3(a, b).eval({}) - 1.0) < 1e-10
def test_perpendicular(self):
a = (Const(1.0), Const(0.0), Const(0.0))
b = (Const(0.0), Const(1.0), Const(0.0))
assert abs(dot3(a, b).eval({})) < 1e-10
def test_general(self):
a = (Const(1.0), Const(2.0), Const(3.0))
b = (Const(4.0), Const(5.0), Const(6.0))
# 1*4 + 2*5 + 3*6 = 32
assert abs(dot3(a, b).eval({}) - 32.0) < 1e-10
class TestCross3:
def test_x_cross_y(self):
x = (Const(1.0), Const(0.0), Const(0.0))
y = (Const(0.0), Const(1.0), Const(0.0))
cx, cy, cz = cross3(x, y)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({})) < 1e-10
assert abs(cz.eval({}) - 1.0) < 1e-10
def test_parallel_is_zero(self):
a = (Const(2.0), Const(3.0), Const(4.0))
b = (Const(4.0), Const(6.0), Const(8.0))
cx, cy, cz = cross3(a, b)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({})) < 1e-10
assert abs(cz.eval({})) < 1e-10
class TestSub3:
def test_basic(self):
a = (Const(5.0), Const(3.0), Const(1.0))
b = (Const(1.0), Const(2.0), Const(3.0))
dx, dy, dz = sub3(a, b)
assert abs(dx.eval({}) - 4.0) < 1e-10
assert abs(dy.eval({}) - 1.0) < 1e-10
assert abs(dz.eval({}) - (-2.0)) < 1e-10
class TestMarkerAxes:
def test_identity_z(self):
"""Identity body + identity marker → Z = (0,0,1)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(zx.eval(env)) < 1e-10
assert abs(zy.eval(env)) < 1e-10
assert abs(zz.eval(env) - 1.0) < 1e-10
def test_identity_x(self):
"""Identity body + identity marker → X = (1,0,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(xx.eval(env) - 1.0) < 1e-10
assert abs(xy.eval(env)) < 1e-10
assert abs(xz.eval(env)) < 1e-10
def test_identity_y(self):
"""Identity body + identity marker → Y = (0,1,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
yx, yy, yz = marker_y_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(yx.eval(env)) < 1e-10
assert abs(yy.eval(env) - 1.0) < 1e-10
assert abs(yz.eval(env)) < 1e-10
def test_rotated_body_z(self):
"""Body rotated 90-deg about Z → Z-axis still (0,0,1)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT)
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(zx.eval(env)) < 1e-10
assert abs(zy.eval(env)) < 1e-10
assert abs(zz.eval(env) - 1.0) < 1e-10
def test_rotated_body_x(self):
"""Body rotated 90-deg about Z → X-axis becomes (0,1,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), ROT_90Z_QUAT)
xx, xy, xz = marker_x_axis(body, IDENTITY_QUAT)
env = pt.get_env()
assert abs(xx.eval(env)) < 1e-10
assert abs(xy.eval(env) - 1.0) < 1e-10
assert abs(xz.eval(env)) < 1e-10
def test_marker_rotation(self):
"""Identity body + marker rotated 90-deg about Z → Z still (0,0,1)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
zx, zy, zz = marker_z_axis(body, ROT_90Z_QUAT)
env = pt.get_env()
assert abs(zx.eval(env)) < 1e-10
assert abs(zy.eval(env)) < 1e-10
assert abs(zz.eval(env) - 1.0) < 1e-10
def test_marker_rotation_x_axis(self):
"""Identity body + marker rotated 90-deg about Z → X becomes (0,1,0)."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
xx, xy, xz = marker_x_axis(body, ROT_90Z_QUAT)
env = pt.get_env()
assert abs(xx.eval(env)) < 1e-10
assert abs(xy.eval(env) - 1.0) < 1e-10
assert abs(xz.eval(env)) < 1e-10
def test_differentiable(self):
"""Marker axes are differentiable w.r.t. body quat params."""
pt = ParamTable()
body = RigidBody("p", pt, (0, 0, 0), (1, 0, 0, 0))
zx, zy, zz = marker_z_axis(body, IDENTITY_QUAT)
# Should not raise
dzx = zx.diff("p/qz").simplify()
env = pt.get_env()
dzx.eval(env) # Should be evaluable
class TestPointPlaneDistance:
def test_on_plane(self):
pt = (Const(1.0), Const(2.0), Const(0.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
normal = (Const(0.0), Const(0.0), Const(1.0))
d = point_plane_distance(pt, origin, normal)
assert abs(d.eval({})) < 1e-10
def test_above_plane(self):
pt = (Const(1.0), Const(2.0), Const(5.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
normal = (Const(0.0), Const(0.0), Const(1.0))
d = point_plane_distance(pt, origin, normal)
assert abs(d.eval({}) - 5.0) < 1e-10
class TestPointLinePerp:
def test_on_line(self):
pt = (Const(0.0), Const(0.0), Const(5.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
direction = (Const(0.0), Const(0.0), Const(1.0))
cx, cy = point_line_perp_components(pt, origin, direction)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({})) < 1e-10
def test_off_line(self):
pt = (Const(3.0), Const(0.0), Const(0.0))
origin = (Const(0.0), Const(0.0), Const(0.0))
direction = (Const(0.0), Const(0.0), Const(1.0))
cx, cy = point_line_perp_components(pt, origin, direction)
# d = (3,0,0), dir = (0,0,1), d x dir = (0*1-0*0, 0*0-3*1, 3*0-0*0) = (0,-3,0)
assert abs(cx.eval({})) < 1e-10
assert abs(cy.eval({}) - (-3.0)) < 1e-10

612
tests/test_joints.py Normal file
View File

@@ -0,0 +1,612 @@
"""Integration tests for kinematic joint constraints.
These tests exercise the full solve pipeline (constraint → residuals →
pre-pass → Newton / BFGS) for multi-body systems with various joint types.
"""
import math
import pytest
from kindred_solver.constraints import (
BallConstraint,
CoincidentConstraint,
CylindricalConstraint,
GearConstraint,
ParallelConstraint,
PerpendicularConstraint,
PlanarConstraint,
PointInPlaneConstraint,
PointOnLineConstraint,
RackPinionConstraint,
RevoluteConstraint,
ScrewConstraint,
SliderConstraint,
UniversalConstraint,
)
from kindred_solver.dof import count_dof
from kindred_solver.entities import RigidBody
from kindred_solver.newton import newton_solve
from kindred_solver.params import ParamTable
from kindred_solver.prepass import single_equation_pass, substitution_pass
ID_QUAT = (1, 0, 0, 0)
# 90° about Z: (cos(45°), 0, 0, sin(45°))
c45 = math.cos(math.pi / 4)
s45 = math.sin(math.pi / 4)
ROT_90Z = (c45, 0, 0, s45)
# 90° about Y
ROT_90Y = (c45, 0, s45, 0)
# 90° about X
ROT_90X = (c45, s45, 0, 0)
def _solve(bodies, constraint_objs):
"""Run the full solve pipeline. Returns (converged, params, bodies)."""
pt = bodies[0].tx # all bodies share the same ParamTable via Var._name
# Actually, we need the ParamTable object. Get it from the first body.
# The Var objects store names, but we need the table. We'll reconstruct.
# Better approach: caller passes pt.
raise NotImplementedError("Use _solve_with_pt instead")
def _solve_with_pt(pt, bodies, constraint_objs):
"""Run the full solve pipeline with explicit ParamTable."""
all_residuals = []
for c in constraint_objs:
all_residuals.extend(c.residuals())
quat_groups = []
for body in bodies:
if not body.grounded:
all_residuals.append(body.quat_norm_residual())
quat_groups.append(body.quat_param_names())
all_residuals = substitution_pass(all_residuals, pt)
all_residuals = single_equation_pass(all_residuals, pt)
converged = newton_solve(
all_residuals, pt, quat_groups=quat_groups, max_iter=100, tol=1e-10
)
return converged, all_residuals
def _dof(pt, bodies, constraint_objs):
"""Count DOF for a system."""
all_residuals = []
for c in constraint_objs:
all_residuals.extend(c.residuals())
for body in bodies:
if not body.grounded:
all_residuals.append(body.quat_norm_residual())
all_residuals = substitution_pass(all_residuals, pt)
return count_dof(all_residuals, pt)
# ============================================================================
# Single-joint DOF counting tests
# ============================================================================
class TestJointDOF:
"""Verify each joint type removes the expected number of DOF.
Setup: ground body + 1 free body (6 DOF) with a single joint.
"""
def _setup(self, pos_b=(0, 0, 0), quat_b=ID_QUAT):
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, pos_b, quat_b)
return pt, a, b
def test_ball_3dof(self):
"""Ball joint: 6 - 3 = 3 DOF (3 rotation)."""
pt, a, b = self._setup()
constraints = [BallConstraint(a, (0, 0, 0), b, (0, 0, 0))]
assert _dof(pt, [a, b], constraints) == 3
def test_revolute_1dof(self):
"""Revolute: 6 - 5 = 1 DOF (rotation about Z)."""
pt, a, b = self._setup()
constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 1
def test_cylindrical_2dof(self):
"""Cylindrical: 6 - 4 = 2 DOF (rotation + translation along Z)."""
pt, a, b = self._setup()
constraints = [
CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 2
def test_slider_1dof(self):
"""Slider: 6 - 5 = 1 DOF (translation along Z)."""
pt, a, b = self._setup()
constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 1
def test_universal_2dof(self):
"""Universal: 6 - 4 = 2 DOF (rotation about each body's Z)."""
pt, a, b = self._setup(quat_b=ROT_90X)
constraints = [
UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 2
def test_screw_1dof(self):
"""Screw: 6 - 5 = 1 DOF (helical motion)."""
pt, a, b = self._setup()
constraints = [
ScrewConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT, pitch=10.0)
]
assert _dof(pt, [a, b], constraints) == 1
def test_parallel_4dof(self):
"""Parallel: 6 - 2 = 4 DOF."""
pt, a, b = self._setup()
constraints = [ParallelConstraint(a, ID_QUAT, b, ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 4
def test_perpendicular_5dof(self):
"""Perpendicular: 6 - 1 = 5 DOF."""
pt, a, b = self._setup(quat_b=ROT_90X)
constraints = [PerpendicularConstraint(a, ID_QUAT, b, ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 5
def test_point_on_line_4dof(self):
"""PointOnLine: 6 - 2 = 4 DOF."""
pt, a, b = self._setup()
constraints = [
PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 4
def test_point_in_plane_5dof(self):
"""PointInPlane: 6 - 1 = 5 DOF."""
pt, a, b = self._setup()
constraints = [
PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
assert _dof(pt, [a, b], constraints) == 5
def test_planar_3dof(self):
"""Planar: 6 - 3 = 3 DOF (2 translation in plane + 1 rotation about normal)."""
pt, a, b = self._setup()
constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
assert _dof(pt, [a, b], constraints) == 3
# ============================================================================
# Solve convergence tests — single joints from displaced initial conditions
# ============================================================================
class TestJointSolve:
"""Newton converges to a valid configuration from displaced starting points."""
def test_revolute_displaced(self):
"""Revolute joint: body B starts displaced, should converge to hinge position."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (3, 4, 5), ID_QUAT) # displaced
constraints = [RevoluteConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# Coincident origins → position should be at origin
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
assert abs(pos[2]) < 1e-8
def test_cylindrical_displaced(self):
"""Cylindrical joint: body B can slide along Z but must be on axis."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (3, 4, 7), ID_QUAT) # off-axis
constraints = [
CylindricalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# X and Y should be zero (on axis), Z can be anything
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
def test_slider_displaced(self):
"""Slider: body B can translate along Z only."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (2, 3, 5), ID_QUAT) # displaced
constraints = [SliderConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# X and Y should be zero (on axis), Z free
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
def test_ball_displaced(self):
"""Ball joint: body B moves so marker origins coincide.
Ball has 3 rotation DOF free, so we can only verify the
world-frame marker points match, not the body position directly.
"""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (5, 5, 5), ID_QUAT)
constraints = [BallConstraint(a, (1, 0, 0), b, (-1, 0, 0))]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
# Verify marker world points match
wp_a = a.world_point(1, 0, 0)
wp_b = b.world_point(-1, 0, 0)
for ea, eb in zip(wp_a, wp_b):
assert abs(ea.eval(env) - eb.eval(env)) < 1e-8
def test_universal_displaced(self):
"""Universal joint: coincident origins + perpendicular Z-axes."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
# Start B with Z-axis along X (90° about Y) — perpendicular to A's Z
b = RigidBody("b", pt, (3, 4, 5), ROT_90Y)
constraints = [
UniversalConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
assert abs(pos[2]) < 1e-8
def test_point_on_line_solve(self):
"""Point on line: body B's marker origin constrained to line along Z.
Under-constrained system (4 DOF remain), so we verify the constraint
residuals are satisfied rather than expecting specific positions.
"""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (5, 3, 7), ID_QUAT)
constraints = [
PointOnLineConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, residuals = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
def test_point_in_plane_solve(self):
"""Point in plane: body B's marker origin at z=0 plane.
Under-constrained (5 DOF remain), so verify residuals.
"""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (3, 4, 8), ID_QUAT)
constraints = [
PointInPlaneConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)
]
converged, residuals = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
def test_planar_solve(self):
"""Planar: coplanar faces — parallel normals + point in plane."""
pt = ParamTable()
a = RigidBody("a", pt, (0, 0, 0), ID_QUAT, grounded=True)
# Start B tilted and displaced
b = RigidBody("b", pt, (3, 4, 8), ID_QUAT)
constraints = [PlanarConstraint(a, (0, 0, 0), ID_QUAT, b, (0, 0, 0), ID_QUAT)]
converged, _ = _solve_with_pt(pt, [a, b], constraints)
assert converged
env = pt.get_env()
pos = b.extract_position(env)
# Z must be zero (in plane), X and Y free
assert abs(pos[2]) < 1e-8
# ============================================================================
# Multi-body integration tests
# ============================================================================
class TestFourBarLinkage:
"""Four-bar linkage: 4 bodies, 4 revolute joints.
In 3D with Z-axis revolutes, this yields 2 DOF: the expected planar
motion plus an out-of-plane fold. A truly planar mechanism would
add Planar constraints on each link to eliminate the fold DOF.
"""
def test_four_bar_dof(self):
"""Four-bar linkage in 3D has 2 DOF (planar + fold)."""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
link1 = RigidBody("l1", pt, (2, 0, 0), ID_QUAT)
link2 = RigidBody("l2", pt, (5, 3, 0), ID_QUAT)
link3 = RigidBody("l3", pt, (8, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT),
]
bodies = [ground, link1, link2, link3]
dof = _dof(pt, bodies, constraints)
assert dof == 2
def test_four_bar_solves(self):
"""Four-bar linkage converges from displaced initial conditions."""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
# Initial positions slightly displaced from valid config
link1 = RigidBody("l1", pt, (2, 1, 0), ID_QUAT)
link2 = RigidBody("l2", pt, (5, 4, 0), ID_QUAT)
link3 = RigidBody("l3", pt, (8, 1, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, link1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link1, (4, 0, 0), ID_QUAT, link2, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link2, (6, 0, 0), ID_QUAT, link3, (0, 0, 0), ID_QUAT),
RevoluteConstraint(link3, (4, 0, 0), ID_QUAT, ground, (10, 0, 0), ID_QUAT),
]
bodies = [ground, link1, link2, link3]
converged, residuals = _solve_with_pt(pt, bodies, constraints)
assert converged
# Verify all revolute constraints are satisfied
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
class TestSliderCrank:
"""Slider-crank mechanism: crank + connecting rod + piston.
ground --[Revolute]-- crank --[Revolute]-- rod --[Revolute]-- piston --[Slider]-- ground
Using Slider (not Cylindrical) for the piston to also lock rotation,
making it a true prismatic joint. In 3D, out-of-plane folding adds
extra DOF beyond the planar 1-DOF.
3 free bodies × 6 = 18 DOF
Revolute(5) + Revolute(5) + Revolute(5) + Slider(5) = 20
But many constraints share bodies, so effective rank < 20.
In 3D: 3 DOF (planar crank + 2 fold modes).
"""
def test_slider_crank_dof(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
crank = RigidBody("crank", pt, (1, 0, 0), ID_QUAT)
rod = RigidBody("rod", pt, (3, 0, 0), ID_QUAT)
piston = RigidBody("piston", pt, (5, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT),
RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT),
RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT),
SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y),
]
bodies = [ground, crank, rod, piston]
dof = _dof(pt, bodies, constraints)
# 3D slider-crank: planar motion + out-of-plane fold modes
assert dof == 3
def test_slider_crank_solves(self):
"""Slider-crank converges from displaced state."""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
crank = RigidBody("crank", pt, (1, 0.5, 0), ID_QUAT)
rod = RigidBody("rod", pt, (3, 1, 0), ID_QUAT)
piston = RigidBody("piston", pt, (5, 0.5, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, crank, (0, 0, 0), ID_QUAT),
RevoluteConstraint(crank, (2, 0, 0), ID_QUAT, rod, (0, 0, 0), ID_QUAT),
RevoluteConstraint(rod, (4, 0, 0), ID_QUAT, piston, (0, 0, 0), ID_QUAT),
SliderConstraint(piston, (0, 0, 0), ROT_90Y, ground, (0, 0, 0), ROT_90Y),
]
bodies = [ground, crank, rod, piston]
converged, residuals = _solve_with_pt(pt, bodies, constraints)
assert converged
env = pt.get_env()
for r in residuals:
assert abs(r.eval(env)) < 1e-8
class TestRevoluteChain:
"""Chain of revolute joints: ground → body1 → body2.
Each revolute removes 5 DOF. Two free bodies = 12 DOF.
2 revolutes = 10 constraints + 2 quat norms = 12.
Expected: 2 DOF (one rotation per hinge).
"""
def test_chain_dof(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b1 = RigidBody("b1", pt, (3, 0, 0), ID_QUAT)
b2 = RigidBody("b2", pt, (6, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT),
]
assert _dof(pt, [ground, b1, b2], constraints) == 2
def test_chain_solves(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b1 = RigidBody("b1", pt, (3, 2, 0), ID_QUAT)
b2 = RigidBody("b2", pt, (6, 3, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
RevoluteConstraint(b1, (3, 0, 0), ID_QUAT, b2, (0, 0, 0), ID_QUAT),
]
converged, residuals = _solve_with_pt(pt, [ground, b1, b2], constraints)
assert converged
env = pt.get_env()
# b1 origin at ground hinge point (0,0,0)
pos1 = b1.extract_position(env)
assert abs(pos1[0]) < 1e-8
assert abs(pos1[1]) < 1e-8
assert abs(pos1[2]) < 1e-8
class TestSliderOnRail:
"""Slider constraint: body translates along ground Z-axis only.
1 free body, 1 slider = 6 - 5 = 1 DOF.
"""
def test_slider_on_rail(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
block = RigidBody("block", pt, (3, 4, 5), ID_QUAT)
constraints = [
SliderConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
assert converged
env = pt.get_env()
pos = block.extract_position(env)
# X, Y must be zero; Z is free
assert abs(pos[0]) < 1e-8
assert abs(pos[1]) < 1e-8
# Z should remain near initial value (minimum-norm solution)
# Check orientation unchanged (no twist)
quat = block.extract_quaternion(env)
assert abs(quat[0] - 1.0) < 1e-6
assert abs(quat[1]) < 1e-6
assert abs(quat[2]) < 1e-6
assert abs(quat[3]) < 1e-6
class TestPlanarOnTable:
"""Planar constraint: body slides on XY plane.
1 free body, 1 planar = 6 - 3 = 3 DOF.
"""
def test_planar_on_table(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
block = RigidBody("block", pt, (3, 4, 5), ID_QUAT)
constraints = [
PlanarConstraint(ground, (0, 0, 0), ID_QUAT, block, (0, 0, 0), ID_QUAT)
]
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
assert converged
env = pt.get_env()
pos = block.extract_position(env)
# Z must be zero, X and Y are free
assert abs(pos[2]) < 1e-8
class TestPlanarWithOffset:
"""Planar with offset: body floats at z=3 above ground."""
def test_planar_offset(self):
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
block = RigidBody("block", pt, (1, 2, 5), ID_QUAT)
# PlanarConstraint residual: (p_i - p_j) . z_j - offset = 0
# body_i=block, body_j=ground: (block_z - 0) * 1 - offset = 0
# For block at z=3: offset = 3
constraints = [
PlanarConstraint(
block, (0, 0, 0), ID_QUAT, ground, (0, 0, 0), ID_QUAT, offset=3.0
)
]
converged, _ = _solve_with_pt(pt, [ground, block], constraints)
assert converged
env = pt.get_env()
pos = block.extract_position(env)
assert abs(pos[2] - 3.0) < 1e-8
class TestMixedConstraints:
"""System with mixed constraint types."""
def test_revolute_plus_parallel(self):
"""Two free bodies: revolute between ground and b1, parallel between b1 and b2.
b1: 6 DOF - 5 (revolute) = 1 DOF
b2: 6 DOF - 2 (parallel) = 4 DOF
Total: 5 DOF
"""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b1 = RigidBody("b1", pt, (0, 0, 0), ID_QUAT)
b2 = RigidBody("b2", pt, (5, 0, 0), ID_QUAT)
constraints = [
RevoluteConstraint(ground, (0, 0, 0), ID_QUAT, b1, (0, 0, 0), ID_QUAT),
ParallelConstraint(b1, ID_QUAT, b2, ID_QUAT),
]
assert _dof(pt, [ground, b1, b2], constraints) == 5
def test_coincident_plus_perpendicular(self):
"""Coincident + perpendicular = ball + 1 angle constraint.
6 - 3 (coincident) - 1 (perpendicular) = 2 DOF.
"""
pt = ParamTable()
ground = RigidBody("g", pt, (0, 0, 0), ID_QUAT, grounded=True)
b = RigidBody("b", pt, (0, 0, 0), ROT_90X)
constraints = [
CoincidentConstraint(ground, (0, 0, 0), b, (0, 0, 0)),
PerpendicularConstraint(ground, ID_QUAT, b, ID_QUAT),
]
assert _dof(pt, [ground, b], constraints) == 2

View File

@@ -65,3 +65,37 @@ class TestParamTable:
pt.add("b", 0.0, fixed=True) pt.add("b", 0.0, fixed=True)
pt.add("c", 0.0) pt.add("c", 0.0)
assert pt.n_free() == 2 assert pt.n_free() == 2
def test_unfix(self):
pt = ParamTable()
pt.add("a", 1.0)
pt.add("b", 2.0)
pt.fix("a")
assert pt.is_fixed("a")
assert "a" not in pt.free_names()
pt.unfix("a")
assert not pt.is_fixed("a")
assert "a" in pt.free_names()
assert pt.n_free() == 2
def test_fix_unfix_roundtrip(self):
"""Fix then unfix preserves value and makes param free again."""
pt = ParamTable()
pt.add("x", 5.0)
pt.add("y", 3.0)
pt.fix("x")
pt.set_value("x", 10.0)
pt.unfix("x")
assert pt.get_value("x") == 10.0
assert "x" in pt.free_names()
# x moves to end of free list
assert pt.free_names() == ["y", "x"]
def test_unfix_noop_if_already_free(self):
"""Unfixing a free parameter is a no-op."""
pt = ParamTable()
pt.add("a", 1.0)
pt.unfix("a")
assert pt.free_names() == ["a"]
assert pt.n_free() == 1