From 9f53fdb15428e922b61b108fc6df896e4955d64b Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Tue, 3 Feb 2026 12:55:37 -0600 Subject: [PATCH 1/5] feat(mates): add mate type definitions and geometry references MateType enum (8 types), GeometryType enum (5 types), GeometryRef and Mate dataclasses with validation, serialization, and context-dependent DOF removal via dof_removed(). Closes #11 --- solver/mates/__init__.py | 17 ++ solver/mates/primitives.py | 279 ++++++++++++++++++++++++++++ tests/mates/__init__.py | 0 tests/mates/test_primitives.py | 329 +++++++++++++++++++++++++++++++++ 4 files changed, 625 insertions(+) create mode 100644 solver/mates/__init__.py create mode 100644 solver/mates/primitives.py create mode 100644 tests/mates/__init__.py create mode 100644 tests/mates/test_primitives.py diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py new file mode 100644 index 0000000..994e51d --- /dev/null +++ b/solver/mates/__init__.py @@ -0,0 +1,17 @@ +"""Mate-level constraint types for assembly analysis.""" + +from solver.mates.primitives import ( + GeometryRef, + GeometryType, + Mate, + MateType, + dof_removed, +) + +__all__ = [ + "GeometryRef", + "GeometryType", + "Mate", + "MateType", + "dof_removed", +] diff --git a/solver/mates/primitives.py b/solver/mates/primitives.py new file mode 100644 index 0000000..9a3b0d2 --- /dev/null +++ b/solver/mates/primitives.py @@ -0,0 +1,279 @@ +"""Mate type definitions and geometry references for assembly constraints. + +Mates are the user-facing constraint primitives in CAD (e.g. SolidWorks-style +Coincident, Concentric, Parallel). Each mate references geometry on two bodies +and removes a context-dependent number of degrees of freedom. +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "GeometryRef", + "GeometryType", + "Mate", + "MateType", + "dof_removed", +] + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class MateType(enum.Enum): + """CAD mate types with default DOF-removal counts. + + Values are ``(ordinal, default_dof)`` tuples so that mate types + sharing the same DOF count remain distinct enum members. Use the + :attr:`default_dof` property to get the scalar constraint count. + + The actual DOF removed can be context-dependent (e.g. COINCIDENT + removes 3 DOF for face-face but only 1 for face-point). Use + :func:`dof_removed` for the context-aware count. + """ + + COINCIDENT = (0, 3) + CONCENTRIC = (1, 2) + PARALLEL = (2, 2) + PERPENDICULAR = (3, 1) + TANGENT = (4, 1) + DISTANCE = (5, 1) + ANGLE = (6, 1) + LOCK = (7, 6) + + @property + def default_dof(self) -> int: + """Default number of DOF removed by this mate type.""" + return self.value[1] + + +class GeometryType(enum.Enum): + """Types of geometric references used by mates.""" + + FACE = "face" + EDGE = "edge" + POINT = "point" + AXIS = "axis" + PLANE = "plane" + + +# Geometry types that require a direction vector. +_DIRECTIONAL_TYPES = frozenset( + { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + } +) + + +# --------------------------------------------------------------------------- +# Dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class GeometryRef: + """A reference to a specific geometric entity on a body. + + Attributes: + body_id: Index of the body this geometry belongs to. + geometry_type: What kind of geometry (face, edge, etc.). + geometry_id: CAD identifier string (e.g. ``"Face001"``). + origin: 3D position of the geometry reference point. + direction: Unit direction vector. Required for FACE, AXIS, PLANE; + ``None`` for POINT. + """ + + body_id: int + geometry_type: GeometryType + geometry_id: str + origin: np.ndarray = field(default_factory=lambda: np.zeros(3)) + direction: np.ndarray | None = None + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "body_id": self.body_id, + "geometry_type": self.geometry_type.value, + "geometry_id": self.geometry_id, + "origin": self.origin.tolist(), + "direction": self.direction.tolist() if self.direction is not None else None, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> GeometryRef: + """Construct from a dict produced by :meth:`to_dict`.""" + direction_raw = data.get("direction") + return cls( + body_id=data["body_id"], + geometry_type=GeometryType(data["geometry_type"]), + geometry_id=data["geometry_id"], + origin=np.asarray(data["origin"], dtype=np.float64), + direction=( + np.asarray(direction_raw, dtype=np.float64) if direction_raw is not None else None + ), + ) + + +@dataclass +class Mate: + """A mate constraint between geometry on two bodies. + + Attributes: + mate_id: Unique identifier for this mate. + mate_type: The type of constraint (Coincident, Concentric, etc.). + ref_a: Geometry reference on the first body. + ref_b: Geometry reference on the second body. + value: Scalar parameter for DISTANCE and ANGLE mates (0 otherwise). + tolerance: Numeric tolerance for constraint satisfaction. + """ + + mate_id: int + mate_type: MateType + ref_a: GeometryRef + ref_b: GeometryRef + value: float = 0.0 + tolerance: float = 1e-6 + + def validate(self) -> None: + """Raise ``ValueError`` if this mate has incompatible geometry. + + Checks: + - Self-mate (both refs on same body) + - CONCENTRIC requires AXIS geometry on both refs + - PARALLEL requires directional geometry (not POINT) + - TANGENT requires surface geometry (FACE or EDGE) + - Directional geometry types must have a direction vector + """ + if self.ref_a.body_id == self.ref_b.body_id: + msg = f"Self-mate: ref_a and ref_b both reference body {self.ref_a.body_id}" + raise ValueError(msg) + + for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]: + if ref.geometry_type in _DIRECTIONAL_TYPES and ref.direction is None: + msg = ( + f"{label}: geometry type {ref.geometry_type.value} requires a direction vector" + ) + raise ValueError(msg) + + if self.mate_type is MateType.CONCENTRIC: + for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]: + if ref.geometry_type is not GeometryType.AXIS: + msg = ( + f"CONCENTRIC mate requires AXIS geometry, " + f"got {ref.geometry_type.value} on {label}" + ) + raise ValueError(msg) + + if self.mate_type is MateType.PARALLEL: + for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]: + if ref.geometry_type is GeometryType.POINT: + msg = f"PARALLEL mate requires directional geometry, got POINT on {label}" + raise ValueError(msg) + + if self.mate_type is MateType.TANGENT: + _surface = frozenset({GeometryType.FACE, GeometryType.EDGE}) + for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]: + if ref.geometry_type not in _surface: + msg = ( + f"TANGENT mate requires surface geometry " + f"(FACE or EDGE), got {ref.geometry_type.value} " + f"on {label}" + ) + raise ValueError(msg) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "mate_id": self.mate_id, + "mate_type": self.mate_type.name, + "ref_a": self.ref_a.to_dict(), + "ref_b": self.ref_b.to_dict(), + "value": self.value, + "tolerance": self.tolerance, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Mate: + """Construct from a dict produced by :meth:`to_dict`.""" + return cls( + mate_id=data["mate_id"], + mate_type=MateType[data["mate_type"]], + ref_a=GeometryRef.from_dict(data["ref_a"]), + ref_b=GeometryRef.from_dict(data["ref_b"]), + value=data.get("value", 0.0), + tolerance=data.get("tolerance", 1e-6), + ) + + +# --------------------------------------------------------------------------- +# Context-dependent DOF removal +# --------------------------------------------------------------------------- + +# Lookup table: (MateType, ref_a GeometryType, ref_b GeometryType) -> DOF removed. +# Entries with None match any geometry type for that position. +_DOF_TABLE: dict[tuple[MateType, GeometryType | None, GeometryType | None], int] = { + # COINCIDENT — context-dependent + (MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE): 3, + (MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT): 3, + (MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE): 3, + (MateType.COINCIDENT, GeometryType.EDGE, GeometryType.EDGE): 2, + (MateType.COINCIDENT, GeometryType.FACE, GeometryType.POINT): 1, + (MateType.COINCIDENT, GeometryType.POINT, GeometryType.FACE): 1, + # CONCENTRIC + (MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS): 2, + # PARALLEL + (MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS): 2, + (MateType.PARALLEL, GeometryType.FACE, GeometryType.FACE): 2, + (MateType.PARALLEL, GeometryType.PLANE, GeometryType.PLANE): 2, + # TANGENT + (MateType.TANGENT, GeometryType.FACE, GeometryType.FACE): 1, + (MateType.TANGENT, GeometryType.FACE, GeometryType.EDGE): 1, + (MateType.TANGENT, GeometryType.EDGE, GeometryType.FACE): 1, + # Types where DOF is always the same regardless of geometry + (MateType.PERPENDICULAR, None, None): 1, + (MateType.DISTANCE, None, None): 1, + (MateType.ANGLE, None, None): 1, + (MateType.LOCK, None, None): 6, +} + + +def dof_removed( + mate_type: MateType, + ref_a: GeometryRef, + ref_b: GeometryRef, +) -> int: + """Return the number of DOF removed by a mate given its geometry context. + + Looks up the exact ``(mate_type, ref_a.geometry_type, ref_b.geometry_type)`` + combination first, then falls back to a wildcard ``(mate_type, None, None)`` + entry, and finally to :attr:`MateType.default_dof`. + + Args: + mate_type: The mate constraint type. + ref_a: Geometry reference on the first body. + ref_b: Geometry reference on the second body. + + Returns: + Number of scalar DOF removed by this mate. + """ + key = (mate_type, ref_a.geometry_type, ref_b.geometry_type) + if key in _DOF_TABLE: + return _DOF_TABLE[key] + + wildcard = (mate_type, None, None) + if wildcard in _DOF_TABLE: + return _DOF_TABLE[wildcard] + + return mate_type.default_dof diff --git a/tests/mates/__init__.py b/tests/mates/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/mates/test_primitives.py b/tests/mates/test_primitives.py new file mode 100644 index 0000000..c28fff8 --- /dev/null +++ b/tests/mates/test_primitives.py @@ -0,0 +1,329 @@ +"""Tests for solver.mates.primitives -- mate type definitions.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np +import pytest + +from solver.mates.primitives import ( + GeometryRef, + GeometryType, + Mate, + MateType, + dof_removed, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ref( + body_id: int, + geom_type: GeometryType, + *, + geometry_id: str = "Geom001", + origin: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> GeometryRef: + """Factory for GeometryRef with sensible defaults.""" + if origin is None: + origin = np.zeros(3) + if direction is None and geom_type in { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + }: + direction = np.array([0.0, 0.0, 1.0]) + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id=geometry_id, + origin=origin, + direction=direction, + ) + + +# --------------------------------------------------------------------------- +# MateType +# --------------------------------------------------------------------------- + + +class TestMateType: + """MateType enum construction and DOF values.""" + + EXPECTED_DOF: ClassVar[dict[str, int]] = { + "COINCIDENT": 3, + "CONCENTRIC": 2, + "PARALLEL": 2, + "PERPENDICULAR": 1, + "TANGENT": 1, + "DISTANCE": 1, + "ANGLE": 1, + "LOCK": 6, + } + + def test_member_count(self) -> None: + assert len(MateType) == 8 + + @pytest.mark.parametrize("name,dof", EXPECTED_DOF.items()) + def test_default_dof_values(self, name: str, dof: int) -> None: + assert MateType[name].default_dof == dof + + def test_value_is_tuple(self) -> None: + assert MateType.COINCIDENT.value == (0, 3) + assert MateType.COINCIDENT.default_dof == 3 + + def test_access_by_name(self) -> None: + assert MateType["LOCK"] is MateType.LOCK + + def test_no_alias_collision(self) -> None: + ordinals = [m.value[0] for m in MateType] + assert len(ordinals) == len(set(ordinals)) + + +# --------------------------------------------------------------------------- +# GeometryType +# --------------------------------------------------------------------------- + + +class TestGeometryType: + """GeometryType enum.""" + + def test_member_count(self) -> None: + assert len(GeometryType) == 5 + + def test_string_values(self) -> None: + for gt in GeometryType: + assert isinstance(gt.value, str) + assert gt.value == gt.name.lower() + + def test_access_by_name(self) -> None: + assert GeometryType["FACE"] is GeometryType.FACE + + +# --------------------------------------------------------------------------- +# GeometryRef +# --------------------------------------------------------------------------- + + +class TestGeometryRef: + """GeometryRef dataclass.""" + + def test_construction(self) -> None: + ref = _make_ref(0, GeometryType.AXIS, geometry_id="Axis001") + assert ref.body_id == 0 + assert ref.geometry_type is GeometryType.AXIS + assert ref.geometry_id == "Axis001" + np.testing.assert_array_equal(ref.origin, np.zeros(3)) + assert ref.direction is not None + + def test_default_direction_none(self) -> None: + ref = GeometryRef( + body_id=0, + geometry_type=GeometryType.POINT, + geometry_id="Point001", + ) + assert ref.direction is None + + def test_to_dict_round_trip(self) -> None: + ref = _make_ref( + 1, + GeometryType.FACE, + origin=np.array([1.0, 2.0, 3.0]), + direction=np.array([0.0, 1.0, 0.0]), + ) + d = ref.to_dict() + restored = GeometryRef.from_dict(d) + assert restored.body_id == ref.body_id + assert restored.geometry_type is ref.geometry_type + assert restored.geometry_id == ref.geometry_id + np.testing.assert_array_almost_equal(restored.origin, ref.origin) + assert restored.direction is not None + np.testing.assert_array_almost_equal(restored.direction, ref.direction) + + def test_to_dict_with_none_direction(self) -> None: + ref = GeometryRef( + body_id=2, + geometry_type=GeometryType.POINT, + geometry_id="Point002", + origin=np.array([5.0, 6.0, 7.0]), + ) + d = ref.to_dict() + assert d["direction"] is None + restored = GeometryRef.from_dict(d) + assert restored.direction is None + + +# --------------------------------------------------------------------------- +# Mate +# --------------------------------------------------------------------------- + + +class TestMate: + """Mate dataclass.""" + + def test_construction(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.FACE) + m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b) + assert m.mate_id == 0 + assert m.mate_type is MateType.COINCIDENT + + def test_value_default_zero(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.FACE) + m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b) + assert m.value == 0.0 + + def test_tolerance_default(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.FACE) + m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b) + assert m.tolerance == 1e-6 + + def test_to_dict_round_trip(self) -> None: + ref_a = _make_ref(0, GeometryType.AXIS, origin=np.array([1.0, 0.0, 0.0])) + ref_b = _make_ref(1, GeometryType.AXIS, origin=np.array([2.0, 0.0, 0.0])) + m = Mate( + mate_id=5, + mate_type=MateType.CONCENTRIC, + ref_a=ref_a, + ref_b=ref_b, + value=0.0, + tolerance=1e-8, + ) + d = m.to_dict() + restored = Mate.from_dict(d) + assert restored.mate_id == m.mate_id + assert restored.mate_type is m.mate_type + assert restored.ref_a.body_id == m.ref_a.body_id + assert restored.ref_b.body_id == m.ref_b.body_id + assert restored.value == m.value + assert restored.tolerance == m.tolerance + + def test_from_dict_missing_optional(self) -> None: + d = { + "mate_id": 1, + "mate_type": "DISTANCE", + "ref_a": _make_ref(0, GeometryType.POINT).to_dict(), + "ref_b": _make_ref(1, GeometryType.POINT).to_dict(), + } + m = Mate.from_dict(d) + assert m.value == 0.0 + assert m.tolerance == 1e-6 + + +# --------------------------------------------------------------------------- +# dof_removed +# --------------------------------------------------------------------------- + + +class TestDofRemoved: + """Context-dependent DOF removal counts.""" + + def test_coincident_face_face(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.FACE) + assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3 + + def test_coincident_point_point(self) -> None: + ref_a = _make_ref(0, GeometryType.POINT) + ref_b = _make_ref(1, GeometryType.POINT) + assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3 + + def test_coincident_edge_edge(self) -> None: + ref_a = _make_ref(0, GeometryType.EDGE) + ref_b = _make_ref(1, GeometryType.EDGE) + assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 2 + + def test_coincident_face_point(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.POINT) + assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 1 + + def test_concentric_axis_axis(self) -> None: + ref_a = _make_ref(0, GeometryType.AXIS) + ref_b = _make_ref(1, GeometryType.AXIS) + assert dof_removed(MateType.CONCENTRIC, ref_a, ref_b) == 2 + + def test_lock_any(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.POINT) + assert dof_removed(MateType.LOCK, ref_a, ref_b) == 6 + + def test_distance_any(self) -> None: + ref_a = _make_ref(0, GeometryType.POINT) + ref_b = _make_ref(1, GeometryType.EDGE) + assert dof_removed(MateType.DISTANCE, ref_a, ref_b) == 1 + + def test_unknown_combo_uses_default(self) -> None: + """Unlisted geometry combos fall back to default_dof.""" + ref_a = _make_ref(0, GeometryType.EDGE) + ref_b = _make_ref(1, GeometryType.POINT) + result = dof_removed(MateType.COINCIDENT, ref_a, ref_b) + assert result == MateType.COINCIDENT.default_dof + + +# --------------------------------------------------------------------------- +# Mate.validate +# --------------------------------------------------------------------------- + + +class TestMateValidation: + """Mate.validate() compatibility checks.""" + + def test_valid_concentric(self) -> None: + ref_a = _make_ref(0, GeometryType.AXIS) + ref_b = _make_ref(1, GeometryType.AXIS) + m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b) + m.validate() # should not raise + + def test_invalid_concentric_face(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.AXIS) + m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b) + with pytest.raises(ValueError, match="CONCENTRIC"): + m.validate() + + def test_valid_coincident_face_face(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(1, GeometryType.FACE) + m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b) + m.validate() # should not raise + + def test_invalid_self_mate(self) -> None: + ref_a = _make_ref(0, GeometryType.FACE) + ref_b = _make_ref(0, GeometryType.FACE, geometry_id="Face002") + m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b) + with pytest.raises(ValueError, match="Self-mate"): + m.validate() + + def test_invalid_parallel_point(self) -> None: + ref_a = _make_ref(0, GeometryType.POINT) + ref_b = _make_ref(1, GeometryType.AXIS) + m = Mate(mate_id=0, mate_type=MateType.PARALLEL, ref_a=ref_a, ref_b=ref_b) + with pytest.raises(ValueError, match="PARALLEL"): + m.validate() + + def test_invalid_tangent_axis(self) -> None: + ref_a = _make_ref(0, GeometryType.AXIS) + ref_b = _make_ref(1, GeometryType.FACE) + m = Mate(mate_id=0, mate_type=MateType.TANGENT, ref_a=ref_a, ref_b=ref_b) + with pytest.raises(ValueError, match="TANGENT"): + m.validate() + + def test_missing_direction_for_axis(self) -> None: + ref_a = GeometryRef( + body_id=0, + geometry_type=GeometryType.AXIS, + geometry_id="Axis001", + origin=np.zeros(3), + direction=None, # missing! + ) + ref_b = _make_ref(1, GeometryType.AXIS) + m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b) + with pytest.raises(ValueError, match="direction"): + m.validate() From e8143cf64c587a161d162746f4951f6029d9763e Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Tue, 3 Feb 2026 12:59:53 -0600 Subject: [PATCH 2/5] feat(mates): add joint pattern recognition JointPattern enum (9 patterns), PatternMatch dataclass, and recognize_patterns() function with data-driven pattern rules. Supports canonical, partial, and ambiguous pattern matching. Closes #12 --- solver/mates/__init__.py | 8 + solver/mates/patterns.py | 284 ++++++++++++++++++++++++++++++++++ tests/mates/test_patterns.py | 285 +++++++++++++++++++++++++++++++++++ 3 files changed, 577 insertions(+) create mode 100644 solver/mates/patterns.py create mode 100644 tests/mates/test_patterns.py diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py index 994e51d..e2b66ad 100644 --- a/solver/mates/__init__.py +++ b/solver/mates/__init__.py @@ -1,5 +1,10 @@ """Mate-level constraint types for assembly analysis.""" +from solver.mates.patterns import ( + JointPattern, + PatternMatch, + recognize_patterns, +) from solver.mates.primitives import ( GeometryRef, GeometryType, @@ -11,7 +16,10 @@ from solver.mates.primitives import ( __all__ = [ "GeometryRef", "GeometryType", + "JointPattern", "Mate", "MateType", + "PatternMatch", "dof_removed", + "recognize_patterns", ] diff --git a/solver/mates/patterns.py b/solver/mates/patterns.py new file mode 100644 index 0000000..6844365 --- /dev/null +++ b/solver/mates/patterns.py @@ -0,0 +1,284 @@ +"""Joint pattern recognition from mate combinations. + +Groups mates by body pair and matches them against canonical joint +patterns (hinge, slider, ball, etc.). Each pattern is a known +combination of mate types that together constrain motion equivalently +to a single mechanical joint. +""" + +from __future__ import annotations + +import enum +from collections import defaultdict +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from solver.datagen.types import JointType +from solver.mates.primitives import GeometryType, Mate, MateType + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "JointPattern", + "PatternMatch", + "recognize_patterns", +] + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class JointPattern(enum.Enum): + """Canonical joint patterns formed by mate combinations.""" + + HINGE = "hinge" + SLIDER = "slider" + CYLINDER = "cylinder" + BALL = "ball" + PLANAR = "planar" + FIXED = "fixed" + GEAR = "gear" + RACK_PINION = "rack_pinion" + UNKNOWN = "unknown" + + +# --------------------------------------------------------------------------- +# Pattern match result +# --------------------------------------------------------------------------- + + +@dataclass +class PatternMatch: + """Result of matching a group of mates to a joint pattern. + + Attributes: + pattern: The identified joint pattern. + mates: The mates that form this pattern. + body_a: First body in the pair. + body_b: Second body in the pair. + confidence: How well the mates match the canonical pattern (0-1). + equivalent_joint_type: The JointType this pattern maps to. + missing_mates: Descriptions of mates absent for a full match. + """ + + pattern: JointPattern + mates: list[Mate] + body_a: int + body_b: int + confidence: float + equivalent_joint_type: JointType + missing_mates: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "pattern": self.pattern.value, + "body_a": self.body_a, + "body_b": self.body_b, + "confidence": self.confidence, + "equivalent_joint_type": self.equivalent_joint_type.name, + "mate_ids": [m.mate_id for m in self.mates], + "missing_mates": self.missing_mates, + } + + +# --------------------------------------------------------------------------- +# Pattern rules (data-driven) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _MateRequirement: + """A single mate requirement within a pattern rule.""" + + mate_type: MateType + geometry_a: GeometryType | None = None + geometry_b: GeometryType | None = None + + +@dataclass(frozen=True) +class _PatternRule: + """Defines a canonical pattern as a set of required mates.""" + + pattern: JointPattern + joint_type: JointType + required: tuple[_MateRequirement, ...] + description: str = "" + + +_PATTERN_RULES: list[_PatternRule] = [ + _PatternRule( + pattern=JointPattern.HINGE, + joint_type=JointType.REVOLUTE, + required=( + _MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS), + _MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE), + ), + description="Concentric axes + coincident plane", + ), + _PatternRule( + pattern=JointPattern.SLIDER, + joint_type=JointType.SLIDER, + required=( + _MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE), + _MateRequirement(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS), + ), + description="Coincident plane + parallel axis", + ), + _PatternRule( + pattern=JointPattern.CYLINDER, + joint_type=JointType.CYLINDRICAL, + required=(_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),), + description="Concentric axes only", + ), + _PatternRule( + pattern=JointPattern.BALL, + joint_type=JointType.BALL, + required=(_MateRequirement(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),), + description="Coincident points", + ), + _PatternRule( + pattern=JointPattern.PLANAR, + joint_type=JointType.PLANAR, + required=(_MateRequirement(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),), + description="Coincident faces", + ), + _PatternRule( + pattern=JointPattern.PLANAR, + joint_type=JointType.PLANAR, + required=(_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),), + description="Coincident planes (alternate planar)", + ), + _PatternRule( + pattern=JointPattern.FIXED, + joint_type=JointType.FIXED, + required=(_MateRequirement(MateType.LOCK),), + description="Lock mate", + ), +] + + +# --------------------------------------------------------------------------- +# Matching logic +# --------------------------------------------------------------------------- + + +def _mate_matches_requirement(mate: Mate, req: _MateRequirement) -> bool: + """Check if a mate satisfies a requirement.""" + if mate.mate_type is not req.mate_type: + return False + if req.geometry_a is not None and mate.ref_a.geometry_type is not req.geometry_a: + return False + return not (req.geometry_b is not None and mate.ref_b.geometry_type is not req.geometry_b) + + +def _try_match_rule( + rule: _PatternRule, + mates: list[Mate], +) -> tuple[float, list[Mate], list[str]]: + """Try to match a rule against a group of mates. + + Returns: + (confidence, matched_mates, missing_descriptions) + """ + matched: list[Mate] = [] + missing: list[str] = [] + + for req in rule.required: + found = False + for mate in mates: + if mate in matched: + continue + if _mate_matches_requirement(mate, req): + matched.append(mate) + found = True + break + if not found: + geom_desc = "" + if req.geometry_a is not None: + geom_b = req.geometry_b.value if req.geometry_b else "*" + geom_desc = f" ({req.geometry_a.value}-{geom_b})" + missing.append(f"{req.mate_type.name}{geom_desc}") + + total_required = len(rule.required) + if total_required == 0: + return 0.0, [], [] + + matched_count = len(matched) + confidence = matched_count / total_required + + return confidence, matched, missing + + +def _normalize_body_pair(body_a: int, body_b: int) -> tuple[int, int]: + """Normalize a body pair so the smaller ID comes first.""" + return (min(body_a, body_b), max(body_a, body_b)) + + +def recognize_patterns(mates: list[Mate]) -> list[PatternMatch]: + """Identify joint patterns from a list of mates. + + Groups mates by body pair, then checks each group against + canonical pattern rules. Returns matches sorted by confidence + descending. + + Args: + mates: List of mate constraints to analyze. + + Returns: + List of PatternMatch results, highest confidence first. + """ + if not mates: + return [] + + # Group mates by normalized body pair + groups: dict[tuple[int, int], list[Mate]] = defaultdict(list) + for mate in mates: + pair = _normalize_body_pair(mate.ref_a.body_id, mate.ref_b.body_id) + groups[pair].append(mate) + + results: list[PatternMatch] = [] + + for (body_a, body_b), group_mates in groups.items(): + group_matches: list[PatternMatch] = [] + + for rule in _PATTERN_RULES: + confidence, matched, missing = _try_match_rule(rule, group_mates) + + if confidence > 0: + group_matches.append( + PatternMatch( + pattern=rule.pattern, + mates=matched if matched else group_mates, + body_a=body_a, + body_b=body_b, + confidence=confidence, + equivalent_joint_type=rule.joint_type, + missing_mates=missing, + ) + ) + + if group_matches: + # Sort by confidence descending, prefer more-specific patterns + group_matches.sort(key=lambda m: (-m.confidence, -len(m.mates))) + results.extend(group_matches) + else: + # No pattern matched at all + results.append( + PatternMatch( + pattern=JointPattern.UNKNOWN, + mates=group_mates, + body_a=body_a, + body_b=body_b, + confidence=0.0, + equivalent_joint_type=JointType.DISTANCE, + missing_mates=[], + ) + ) + + # Global sort by confidence descending + results.sort(key=lambda m: -m.confidence) + return results diff --git a/tests/mates/test_patterns.py b/tests/mates/test_patterns.py new file mode 100644 index 0000000..2cb4958 --- /dev/null +++ b/tests/mates/test_patterns.py @@ -0,0 +1,285 @@ +"""Tests for solver.mates.patterns -- joint pattern recognition.""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.types import JointType +from solver.mates.patterns import JointPattern, PatternMatch, recognize_patterns +from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ref( + body_id: int, + geom_type: GeometryType, + *, + geometry_id: str = "Geom001", + origin: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> GeometryRef: + """Factory for GeometryRef with sensible defaults.""" + if origin is None: + origin = np.zeros(3) + if direction is None and geom_type in { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + }: + direction = np.array([0.0, 0.0, 1.0]) + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id=geometry_id, + origin=origin, + direction=direction, + ) + + +def _make_mate( + mate_id: int, + mate_type: MateType, + body_a: int, + body_b: int, + geom_a: GeometryType = GeometryType.FACE, + geom_b: GeometryType = GeometryType.FACE, +) -> Mate: + """Factory for Mate with body pair and geometry types.""" + return Mate( + mate_id=mate_id, + mate_type=mate_type, + ref_a=_make_ref(body_a, geom_a), + ref_b=_make_ref(body_b, geom_b), + ) + + +# --------------------------------------------------------------------------- +# JointPattern enum +# --------------------------------------------------------------------------- + + +class TestJointPattern: + """JointPattern enum.""" + + def test_member_count(self) -> None: + assert len(JointPattern) == 9 + + def test_string_values(self) -> None: + for jp in JointPattern: + assert isinstance(jp.value, str) + + def test_access_by_name(self) -> None: + assert JointPattern["HINGE"] is JointPattern.HINGE + + +# --------------------------------------------------------------------------- +# PatternMatch +# --------------------------------------------------------------------------- + + +class TestPatternMatch: + """PatternMatch dataclass.""" + + def test_construction(self) -> None: + mate = _make_mate(0, MateType.LOCK, 0, 1) + pm = PatternMatch( + pattern=JointPattern.FIXED, + mates=[mate], + body_a=0, + body_b=1, + confidence=1.0, + equivalent_joint_type=JointType.FIXED, + ) + assert pm.pattern is JointPattern.FIXED + assert pm.confidence == 1.0 + assert pm.missing_mates == [] + + def test_to_dict(self) -> None: + mate = _make_mate(5, MateType.LOCK, 0, 1) + pm = PatternMatch( + pattern=JointPattern.FIXED, + mates=[mate], + body_a=0, + body_b=1, + confidence=1.0, + equivalent_joint_type=JointType.FIXED, + ) + d = pm.to_dict() + assert d["pattern"] == "fixed" + assert d["mate_ids"] == [5] + assert d["equivalent_joint_type"] == "FIXED" + + +# --------------------------------------------------------------------------- +# recognize_patterns — canonical patterns +# --------------------------------------------------------------------------- + + +class TestRecognizeCanonical: + """Full-confidence canonical pattern recognition.""" + + def test_empty_input(self) -> None: + assert recognize_patterns([]) == [] + + def test_hinge(self) -> None: + """Concentric(axis) + Coincident(plane) -> Hinge.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + _make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.HINGE + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.REVOLUTE + assert top.missing_mates == [] + + def test_slider(self) -> None: + """Coincident(plane) + Parallel(axis) -> Slider.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + _make_mate(1, MateType.PARALLEL, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.SLIDER + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.SLIDER + + def test_cylinder(self) -> None: + """Concentric(axis) only -> Cylinder.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + # Should match cylinder at confidence 1.0 + cylinder = [r for r in results if r.pattern is JointPattern.CYLINDER] + assert len(cylinder) >= 1 + assert cylinder[0].confidence == 1.0 + assert cylinder[0].equivalent_joint_type is JointType.CYLINDRICAL + + def test_ball(self) -> None: + """Coincident(point) -> Ball.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.POINT, GeometryType.POINT), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.BALL + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.BALL + + def test_planar_face(self) -> None: + """Coincident(face) -> Planar.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.FACE, GeometryType.FACE), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.PLANAR + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.PLANAR + + def test_fixed(self) -> None: + """Lock -> Fixed.""" + mates = [ + _make_mate(0, MateType.LOCK, 0, 1, GeometryType.FACE, GeometryType.FACE), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.FIXED + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.FIXED + + +# --------------------------------------------------------------------------- +# recognize_patterns — partial matches +# --------------------------------------------------------------------------- + + +class TestRecognizePartial: + """Partial pattern matches and hints.""" + + def test_concentric_without_plane_hints_hinge(self) -> None: + """Concentric alone matches hinge at 0.5 confidence with missing hint.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE] + assert len(hinge_matches) >= 1 + hinge = hinge_matches[0] + assert hinge.confidence == 0.5 + assert len(hinge.missing_mates) > 0 + + def test_coincident_plane_without_parallel_hints_slider(self) -> None: + """Coincident(plane) alone matches slider at 0.5 confidence.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + ] + results = recognize_patterns(mates) + slider_matches = [r for r in results if r.pattern is JointPattern.SLIDER] + assert len(slider_matches) >= 1 + assert slider_matches[0].confidence == 0.5 + + +# --------------------------------------------------------------------------- +# recognize_patterns — ambiguous / multi-body +# --------------------------------------------------------------------------- + + +class TestRecognizeAmbiguous: + """Ambiguous patterns and multi-body-pair assemblies.""" + + def test_concentric_matches_both_hinge_and_cylinder(self) -> None: + """A single concentric mate produces both hinge (partial) and cylinder matches.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + patterns = {r.pattern for r in results} + assert JointPattern.HINGE in patterns + assert JointPattern.CYLINDER in patterns + + def test_multiple_body_pairs(self) -> None: + """Mates across different body pairs produce separate pattern matches.""" + mates = [ + _make_mate(0, MateType.LOCK, 0, 1), + _make_mate(1, MateType.COINCIDENT, 2, 3, GeometryType.POINT, GeometryType.POINT), + ] + results = recognize_patterns(mates) + pairs = {(r.body_a, r.body_b) for r in results} + assert (0, 1) in pairs + assert (2, 3) in pairs + + def test_results_sorted_by_confidence(self) -> None: + """All results should be sorted by confidence descending.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + _make_mate(1, MateType.LOCK, 2, 3), + ] + results = recognize_patterns(mates) + confidences = [r.confidence for r in results] + assert confidences == sorted(confidences, reverse=True) + + def test_unknown_pattern(self) -> None: + """A mate type that matches no rule returns UNKNOWN.""" + mates = [ + _make_mate(0, MateType.ANGLE, 0, 1, GeometryType.FACE, GeometryType.FACE), + ] + results = recognize_patterns(mates) + assert any(r.pattern is JointPattern.UNKNOWN for r in results) + + def test_body_pair_normalization(self) -> None: + """Mates with reversed body order should be grouped together.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 1, 0, GeometryType.AXIS, GeometryType.AXIS), + _make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + ] + results = recognize_patterns(mates) + hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE] + assert len(hinge_matches) >= 1 + assert hinge_matches[0].confidence == 1.0 From 118474f89212c3d1055bb3281884f0f8713c8e09 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Tue, 3 Feb 2026 13:03:13 -0600 Subject: [PATCH 3/5] feat(mates): add mate-to-joint conversion and assembly analysis convert_mates_to_joints() bridges mate-level constraints to the existing joint-based analysis pipeline. analyze_mate_assembly() orchestrates the full pipeline with bidirectional mate-joint traceability. Closes #13 --- solver/mates/__init__.py | 8 + solver/mates/conversion.py | 276 +++++++++++++++++++++++++++++++ tests/mates/test_conversion.py | 287 +++++++++++++++++++++++++++++++++ 3 files changed, 571 insertions(+) create mode 100644 solver/mates/conversion.py create mode 100644 tests/mates/test_conversion.py diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py index e2b66ad..fa8a5d5 100644 --- a/solver/mates/__init__.py +++ b/solver/mates/__init__.py @@ -1,5 +1,10 @@ """Mate-level constraint types for assembly analysis.""" +from solver.mates.conversion import ( + MateAnalysisResult, + analyze_mate_assembly, + convert_mates_to_joints, +) from solver.mates.patterns import ( JointPattern, PatternMatch, @@ -18,8 +23,11 @@ __all__ = [ "GeometryType", "JointPattern", "Mate", + "MateAnalysisResult", "MateType", "PatternMatch", + "analyze_mate_assembly", + "convert_mates_to_joints", "dof_removed", "recognize_patterns", ] diff --git a/solver/mates/conversion.py b/solver/mates/conversion.py new file mode 100644 index 0000000..c8016ba --- /dev/null +++ b/solver/mates/conversion.py @@ -0,0 +1,276 @@ +"""Mate-to-joint conversion and assembly analysis. + +Bridges the mate-level constraint representation to the existing +joint-based analysis pipeline. Converts recognized mate patterns +to Joint objects, then runs the pebble game and Jacobian analysis, +maintaining bidirectional traceability between mates and joints. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.labeling import AssemblyLabels, label_assembly +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) +from solver.mates.patterns import PatternMatch, recognize_patterns + +if TYPE_CHECKING: + from typing import Any + + from solver.mates.primitives import Mate + +__all__ = [ + "MateAnalysisResult", + "analyze_mate_assembly", + "convert_mates_to_joints", +] + + +# --------------------------------------------------------------------------- +# Result dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class MateAnalysisResult: + """Combined result of mate-based assembly analysis. + + Attributes: + patterns: Recognized joint patterns from mate grouping. + joints: Joint objects produced by conversion. + mate_to_joint: Mapping from mate_id to list of joint_ids. + joint_to_mates: Mapping from joint_id to list of mate_ids. + analysis: Constraint analysis from pebble game + Jacobian. + labels: Full ground truth labels from label_assembly. + """ + + patterns: list[PatternMatch] + joints: list[Joint] + mate_to_joint: dict[int, list[int]] = field(default_factory=dict) + joint_to_mates: dict[int, list[int]] = field(default_factory=dict) + analysis: ConstraintAnalysis | None = None + labels: AssemblyLabels | None = None + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "patterns": [p.to_dict() for p in self.patterns], + "joints": [ + { + "joint_id": j.joint_id, + "body_a": j.body_a, + "body_b": j.body_b, + "joint_type": j.joint_type.name, + } + for j in self.joints + ], + "mate_to_joint": self.mate_to_joint, + "joint_to_mates": self.joint_to_mates, + "labels": self.labels.to_dict() if self.labels else None, + } + + +# --------------------------------------------------------------------------- +# Pattern-to-JointType mapping +# --------------------------------------------------------------------------- + +# Maps (JointPattern value) to JointType for known patterns. +# Used by convert_mates_to_joints when a full pattern is recognized. +_PATTERN_JOINT_MAP: dict[str, JointType] = { + "hinge": JointType.REVOLUTE, + "slider": JointType.SLIDER, + "cylinder": JointType.CYLINDRICAL, + "ball": JointType.BALL, + "planar": JointType.PLANAR, + "fixed": JointType.FIXED, +} + +# Fallback mapping for individual mate types when no pattern is recognized. +_MATE_JOINT_FALLBACK: dict[str, JointType] = { + "COINCIDENT": JointType.PLANAR, + "CONCENTRIC": JointType.CYLINDRICAL, + "PARALLEL": JointType.PARALLEL, + "PERPENDICULAR": JointType.PERPENDICULAR, + "TANGENT": JointType.DISTANCE, + "DISTANCE": JointType.DISTANCE, + "ANGLE": JointType.PERPENDICULAR, + "LOCK": JointType.FIXED, +} + + +# --------------------------------------------------------------------------- +# Conversion +# --------------------------------------------------------------------------- + + +def _compute_joint_params( + pattern: PatternMatch, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extract anchor and axis from pattern mates. + + Returns: + (anchor_a, anchor_b, axis) + """ + anchor_a = np.zeros(3) + anchor_b = np.zeros(3) + axis = np.array([0.0, 0.0, 1.0]) + + for mate in pattern.mates: + ref_a = mate.ref_a + ref_b = mate.ref_b + anchor_a = ref_a.origin.copy() + anchor_b = ref_b.origin.copy() + if ref_a.direction is not None: + axis = ref_a.direction.copy() + break + + return anchor_a, anchor_b, axis + + +def _convert_single_mate( + mate: Mate, + joint_id: int, +) -> Joint: + """Convert a single unmatched mate to a Joint.""" + joint_type = _MATE_JOINT_FALLBACK.get(mate.mate_type.name, JointType.DISTANCE) + + anchor_a = mate.ref_a.origin.copy() + anchor_b = mate.ref_b.origin.copy() + axis = np.array([0.0, 0.0, 1.0]) + if mate.ref_a.direction is not None: + axis = mate.ref_a.direction.copy() + + return Joint( + joint_id=joint_id, + body_a=mate.ref_a.body_id, + body_b=mate.ref_b.body_id, + joint_type=joint_type, + anchor_a=anchor_a, + anchor_b=anchor_b, + axis=axis, + ) + + +def convert_mates_to_joints( + mates: list[Mate], + bodies: list[RigidBody] | None = None, +) -> tuple[list[Joint], dict[int, list[int]], dict[int, list[int]]]: + """Convert mates to Joint objects via pattern recognition. + + For each body pair: + - If mates form a recognized pattern, emit the equivalent joint. + - Otherwise, emit individual joints for each unmatched mate. + + Args: + mates: Mate constraints to convert. + bodies: Optional body list (unused currently, reserved for + future geometry lookups). + + Returns: + (joints, mate_to_joint, joint_to_mates) tuple. + """ + if not mates: + return [], {}, {} + + patterns = recognize_patterns(mates) + joints: list[Joint] = [] + mate_to_joint: dict[int, list[int]] = {} + joint_to_mates: dict[int, list[int]] = {} + + # Track which mates have been consumed by full-confidence patterns + consumed_mate_ids: set[int] = set() + next_joint_id = 0 + + # First pass: emit joints for full-confidence patterns + for pattern in patterns: + if pattern.confidence < 1.0: + continue + if pattern.pattern.value not in _PATTERN_JOINT_MAP: + continue + + # Check if any of these mates were already consumed + mate_ids = [m.mate_id for m in pattern.mates] + if any(mid in consumed_mate_ids for mid in mate_ids): + continue + + joint_type = _PATTERN_JOINT_MAP[pattern.pattern.value] + anchor_a, anchor_b, axis = _compute_joint_params(pattern) + + joint = Joint( + joint_id=next_joint_id, + body_a=pattern.body_a, + body_b=pattern.body_b, + joint_type=joint_type, + anchor_a=anchor_a, + anchor_b=anchor_b, + axis=axis, + ) + joints.append(joint) + + joint_to_mates[next_joint_id] = mate_ids + for mid in mate_ids: + mate_to_joint.setdefault(mid, []).append(next_joint_id) + consumed_mate_ids.add(mid) + + next_joint_id += 1 + + # Second pass: emit individual joints for unconsumed mates + for mate in mates: + if mate.mate_id in consumed_mate_ids: + continue + + joint = _convert_single_mate(mate, next_joint_id) + joints.append(joint) + + joint_to_mates[next_joint_id] = [mate.mate_id] + mate_to_joint.setdefault(mate.mate_id, []).append(next_joint_id) + + next_joint_id += 1 + + return joints, mate_to_joint, joint_to_mates + + +# --------------------------------------------------------------------------- +# Full analysis pipeline +# --------------------------------------------------------------------------- + + +def analyze_mate_assembly( + bodies: list[RigidBody], + mates: list[Mate], + ground_body: int | None = None, +) -> MateAnalysisResult: + """Run the full analysis pipeline on a mate-based assembly. + + Orchestrates: recognize_patterns -> convert_mates_to_joints -> + label_assembly, returning a combined result with full traceability. + + Args: + bodies: Rigid bodies in the assembly. + mates: Mate constraints between the bodies. + ground_body: If set, this body is fixed to the world. + + Returns: + MateAnalysisResult with patterns, joints, mappings, and labels. + """ + patterns = recognize_patterns(mates) + joints, mate_to_joint, joint_to_mates = convert_mates_to_joints(mates, bodies) + + labels = label_assembly(bodies, joints, ground_body) + + return MateAnalysisResult( + patterns=patterns, + joints=joints, + mate_to_joint=mate_to_joint, + joint_to_mates=joint_to_mates, + analysis=labels.analysis, + labels=labels, + ) diff --git a/tests/mates/test_conversion.py b/tests/mates/test_conversion.py new file mode 100644 index 0000000..b2740ca --- /dev/null +++ b/tests/mates/test_conversion.py @@ -0,0 +1,287 @@ +"""Tests for solver.mates.conversion -- mate-to-joint conversion.""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.types import JointType, RigidBody +from solver.mates.conversion import ( + MateAnalysisResult, + analyze_mate_assembly, + convert_mates_to_joints, +) +from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ref( + body_id: int, + geom_type: GeometryType, + *, + origin: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> GeometryRef: + """Factory for GeometryRef with sensible defaults.""" + if origin is None: + origin = np.zeros(3) + if direction is None and geom_type in { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + }: + direction = np.array([0.0, 0.0, 1.0]) + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id="Geom001", + origin=origin, + direction=direction, + ) + + +def _make_bodies(n: int) -> list[RigidBody]: + """Create n bodies at distinct positions.""" + return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)] + + +# --------------------------------------------------------------------------- +# convert_mates_to_joints +# --------------------------------------------------------------------------- + + +class TestConvertMatesToJoints: + """convert_mates_to_joints function.""" + + def test_empty_input(self) -> None: + joints, m2j, j2m = convert_mates_to_joints([]) + assert joints == [] + assert m2j == {} + assert j2m == {} + + def test_hinge_pattern(self) -> None: + """Concentric + Coincident(plane) -> single REVOLUTE joint.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + joints, m2j, j2m = convert_mates_to_joints(mates) + assert len(joints) == 1 + assert joints[0].joint_type is JointType.REVOLUTE + assert joints[0].body_a == 0 + assert joints[0].body_b == 1 + # Both mates map to the single joint + assert 0 in m2j + assert 1 in m2j + assert j2m[joints[0].joint_id] == [0, 1] + + def test_lock_pattern(self) -> None: + """Lock -> FIXED joint.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + joints, _m2j, _j2m = convert_mates_to_joints(mates) + assert len(joints) == 1 + assert joints[0].joint_type is JointType.FIXED + + def test_unmatched_mate_fallback(self) -> None: + """A single ANGLE mate with no pattern -> individual joint.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.ANGLE, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + joints, _m2j, _j2m = convert_mates_to_joints(mates) + assert len(joints) == 1 + assert joints[0].joint_type is JointType.PERPENDICULAR + + def test_mapping_consistency(self) -> None: + """mate_to_joint and joint_to_mates are consistent.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + Mate( + mate_id=2, + mate_type=MateType.DISTANCE, + ref_a=_make_ref(2, GeometryType.POINT), + ref_b=_make_ref(3, GeometryType.POINT), + ), + ] + joints, m2j, j2m = convert_mates_to_joints(mates) + # Every mate should be in m2j + for mate in mates: + assert mate.mate_id in m2j + # Every joint should be in j2m + for joint in joints: + assert joint.joint_id in j2m + + def test_joint_axis_from_geometry(self) -> None: + """Joint axis should come from mate geometry direction.""" + axis_dir = np.array([1.0, 0.0, 0.0]) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS, direction=axis_dir), + ref_b=_make_ref(1, GeometryType.AXIS, direction=axis_dir), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + joints, _, _ = convert_mates_to_joints(mates) + np.testing.assert_array_almost_equal(joints[0].axis, axis_dir) + + +# --------------------------------------------------------------------------- +# MateAnalysisResult +# --------------------------------------------------------------------------- + + +class TestMateAnalysisResult: + """MateAnalysisResult dataclass.""" + + def test_to_dict(self) -> None: + result = MateAnalysisResult( + patterns=[], + joints=[], + ) + d = result.to_dict() + assert d["patterns"] == [] + assert d["joints"] == [] + assert d["labels"] is None + + +# --------------------------------------------------------------------------- +# analyze_mate_assembly +# --------------------------------------------------------------------------- + + +class TestAnalyzeMateAssembly: + """Full pipeline: mates -> joints -> analysis.""" + + def test_two_bodies_hinge(self) -> None: + """Two bodies connected by hinge mates -> underconstrained (1 DOF).""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + result = analyze_mate_assembly(bodies, mates) + assert result.analysis is not None + assert result.labels is not None + # A revolute joint removes 5 DOF, leaving 1 internal DOF + assert result.analysis.combinatorial_internal_dof == 1 + assert len(result.joints) == 1 + assert result.joints[0].joint_type is JointType.REVOLUTE + + def test_two_bodies_fixed(self) -> None: + """Two bodies with lock mate -> well-constrained.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = analyze_mate_assembly(bodies, mates) + assert result.analysis is not None + assert result.analysis.combinatorial_internal_dof == 0 + assert result.analysis.is_rigid + + def test_grounded_assembly(self) -> None: + """Grounded assembly analysis works.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = analyze_mate_assembly(bodies, mates, ground_body=0) + assert result.analysis is not None + assert result.analysis.is_rigid + + def test_no_mates(self) -> None: + """Assembly with no mates should be fully underconstrained.""" + bodies = _make_bodies(2) + result = analyze_mate_assembly(bodies, []) + assert result.analysis is not None + assert result.analysis.combinatorial_internal_dof == 6 + assert len(result.joints) == 0 + + def test_single_body(self) -> None: + """Single body, no mates.""" + bodies = _make_bodies(1) + result = analyze_mate_assembly(bodies, []) + assert result.analysis is not None + assert len(result.joints) == 0 + + def test_result_traceability(self) -> None: + """mate_to_joint and joint_to_mates populated in result.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + result = analyze_mate_assembly(bodies, mates) + assert 0 in result.mate_to_joint + assert 1 in result.mate_to_joint + assert len(result.joint_to_mates) > 0 From 239e45c7f90077d40b74e49d0848b6a7338c3a2e Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Tue, 3 Feb 2026 13:05:58 -0600 Subject: [PATCH 4/5] feat(mates): add mate-based synthetic assembly generator SyntheticMateGenerator wraps existing joint generator with reverse mapping (joint->mates) and configurable noise injection (redundant, missing, incompatible mates). Batch generation via generate_mate_training_batch(). Closes #14 --- solver/mates/__init__.py | 6 + solver/mates/generator.py | 315 ++++++++++++++++++++++++++++++++++ tests/mates/test_generator.py | 155 +++++++++++++++++ 3 files changed, 476 insertions(+) create mode 100644 solver/mates/generator.py create mode 100644 tests/mates/test_generator.py diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py index fa8a5d5..0496bf4 100644 --- a/solver/mates/__init__.py +++ b/solver/mates/__init__.py @@ -5,6 +5,10 @@ from solver.mates.conversion import ( analyze_mate_assembly, convert_mates_to_joints, ) +from solver.mates.generator import ( + SyntheticMateGenerator, + generate_mate_training_batch, +) from solver.mates.patterns import ( JointPattern, PatternMatch, @@ -26,8 +30,10 @@ __all__ = [ "MateAnalysisResult", "MateType", "PatternMatch", + "SyntheticMateGenerator", "analyze_mate_assembly", "convert_mates_to_joints", "dof_removed", + "generate_mate_training_batch", "recognize_patterns", ] diff --git a/solver/mates/generator.py b/solver/mates/generator.py new file mode 100644 index 0000000..26d3834 --- /dev/null +++ b/solver/mates/generator.py @@ -0,0 +1,315 @@ +"""Mate-based synthetic assembly generator. + +Wraps SyntheticAssemblyGenerator to produce mate-level training data. +Generates joint-based assemblies via the existing generator, then +reverse-maps joints to plausible mate combinations. Supports noise +injection (redundant, missing, incompatible mates) for robust training. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.generator import SyntheticAssemblyGenerator +from solver.datagen.types import Joint, JointType, RigidBody +from solver.mates.conversion import MateAnalysisResult, analyze_mate_assembly +from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "SyntheticMateGenerator", + "generate_mate_training_batch", +] + + +# --------------------------------------------------------------------------- +# Reverse mapping: JointType -> list of (MateType, geom_a, geom_b) combos +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _MateSpec: + """Specification for a mate to generate from a joint.""" + + mate_type: MateType + geom_a: GeometryType + geom_b: GeometryType + + +_JOINT_TO_MATES: dict[JointType, list[_MateSpec]] = { + JointType.REVOLUTE: [ + _MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS), + _MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE), + ], + JointType.CYLINDRICAL: [ + _MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS), + ], + JointType.BALL: [ + _MateSpec(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT), + ], + JointType.FIXED: [ + _MateSpec(MateType.LOCK, GeometryType.FACE, GeometryType.FACE), + ], + JointType.SLIDER: [ + _MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE), + _MateSpec(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS), + ], + JointType.PLANAR: [ + _MateSpec(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE), + ], +} + + +# --------------------------------------------------------------------------- +# Generator +# --------------------------------------------------------------------------- + + +class SyntheticMateGenerator: + """Generates mate-based assemblies for training data. + + Wraps SyntheticAssemblyGenerator to produce joint-based assemblies, + then reverse-maps each joint to a plausible set of mate constraints. + + Args: + seed: Random seed for reproducibility. + redundant_prob: Probability of injecting a redundant mate per joint. + missing_prob: Probability of dropping a mate from a multi-mate pattern. + incompatible_prob: Probability of injecting a mate with wrong geometry. + """ + + def __init__( + self, + seed: int = 42, + *, + redundant_prob: float = 0.0, + missing_prob: float = 0.0, + incompatible_prob: float = 0.0, + ) -> None: + self._joint_gen = SyntheticAssemblyGenerator(seed=seed) + self._rng = np.random.default_rng(seed) + self.redundant_prob = redundant_prob + self.missing_prob = missing_prob + self.incompatible_prob = incompatible_prob + + def _make_geometry_ref( + self, + body_id: int, + geom_type: GeometryType, + joint: Joint, + *, + is_ref_a: bool = True, + ) -> GeometryRef: + """Create a GeometryRef from joint geometry. + + Uses joint anchor, axis, and body_id to produce a ref + with realistic geometry for the given type. + """ + origin = joint.anchor_a if is_ref_a else joint.anchor_b + + direction: np.ndarray | None = None + if geom_type in {GeometryType.AXIS, GeometryType.PLANE, GeometryType.FACE}: + direction = joint.axis.copy() + + geom_id = f"{geom_type.value.capitalize()}001" + + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id=geom_id, + origin=origin.copy(), + direction=direction, + ) + + def _reverse_map_joint( + self, + joint: Joint, + next_mate_id: int, + ) -> list[Mate]: + """Convert a joint to its mate representation.""" + specs = _JOINT_TO_MATES.get(joint.joint_type, []) + if not specs: + # Fallback: emit a single DISTANCE mate + specs = [_MateSpec(MateType.DISTANCE, GeometryType.POINT, GeometryType.POINT)] + + mates: list[Mate] = [] + for spec in specs: + ref_a = self._make_geometry_ref(joint.body_a, spec.geom_a, joint, is_ref_a=True) + ref_b = self._make_geometry_ref(joint.body_b, spec.geom_b, joint, is_ref_a=False) + mates.append( + Mate( + mate_id=next_mate_id + len(mates), + mate_type=spec.mate_type, + ref_a=ref_a, + ref_b=ref_b, + ) + ) + return mates + + def _inject_noise( + self, + mates: list[Mate], + next_mate_id: int, + ) -> list[Mate]: + """Apply noise injection to the mate list. + + Modifies the list in-place and may add new mates. + Returns the (possibly extended) list. + """ + result = list(mates) + extra: list[Mate] = [] + + for mate in mates: + # Redundant: duplicate a mate + if self._rng.random() < self.redundant_prob: + dup = Mate( + mate_id=next_mate_id + len(extra), + mate_type=mate.mate_type, + ref_a=mate.ref_a, + ref_b=mate.ref_b, + value=mate.value, + tolerance=mate.tolerance, + ) + extra.append(dup) + + # Incompatible: wrong geometry type + if self._rng.random() < self.incompatible_prob: + bad_geom = GeometryType.POINT + bad_ref = GeometryRef( + body_id=mate.ref_a.body_id, + geometry_type=bad_geom, + geometry_id="BadGeom001", + origin=mate.ref_a.origin.copy(), + direction=None, + ) + extra.append( + Mate( + mate_id=next_mate_id + len(extra), + mate_type=MateType.CONCENTRIC, + ref_a=bad_ref, + ref_b=mate.ref_b, + ) + ) + + result.extend(extra) + + # Missing: drop mates from multi-mate patterns (only if > 1 mate + # for same body pair) + if self.missing_prob > 0: + filtered: list[Mate] = [] + for mate in result: + if self._rng.random() < self.missing_prob: + continue + filtered.append(mate) + # Ensure at least one mate remains + if not filtered and result: + filtered = [result[0]] + result = filtered + + return result + + def generate( + self, + n_bodies: int = 4, + *, + grounded: bool = False, + ) -> tuple[list[RigidBody], list[Mate], MateAnalysisResult]: + """Generate a mate-based assembly. + + Args: + n_bodies: Number of rigid bodies. + grounded: Whether to ground the first body. + + Returns: + (bodies, mates, analysis_result) tuple. + """ + bodies, joints, _analysis = self._joint_gen.generate_chain_assembly( + n_bodies, + joint_type=JointType.REVOLUTE, + grounded=grounded, + ) + + mates: list[Mate] = [] + next_id = 0 + for joint in joints: + joint_mates = self._reverse_map_joint(joint, next_id) + mates.extend(joint_mates) + next_id += len(joint_mates) + + # Apply noise + mates = self._inject_noise(mates, next_id) + + ground_body = bodies[0].body_id if grounded else None + result = analyze_mate_assembly(bodies, mates, ground_body) + + return bodies, mates, result + + +# --------------------------------------------------------------------------- +# Batch generation +# --------------------------------------------------------------------------- + + +def generate_mate_training_batch( + batch_size: int = 100, + n_bodies_range: tuple[int, int] = (3, 8), + seed: int = 42, + *, + redundant_prob: float = 0.0, + missing_prob: float = 0.0, + incompatible_prob: float = 0.0, + grounded_ratio: float = 1.0, +) -> list[dict[str, Any]]: + """Produce a batch of mate-level training examples. + + Args: + batch_size: Number of assemblies to generate. + n_bodies_range: (min, max_exclusive) body count. + seed: Random seed. + redundant_prob: Probability of redundant mate injection. + missing_prob: Probability of missing mate injection. + incompatible_prob: Probability of incompatible mate injection. + grounded_ratio: Fraction of assemblies that are grounded. + + Returns: + List of dicts with bodies, mates, patterns, and labels. + """ + rng = np.random.default_rng(seed) + examples: list[dict[str, Any]] = [] + + for i in range(batch_size): + gen = SyntheticMateGenerator( + seed=seed + i, + redundant_prob=redundant_prob, + missing_prob=missing_prob, + incompatible_prob=incompatible_prob, + ) + n = int(rng.integers(*n_bodies_range)) + grounded = bool(rng.random() < grounded_ratio) + + bodies, mates, result = gen.generate(n, grounded=grounded) + + examples.append( + { + "bodies": [ + { + "body_id": b.body_id, + "position": b.position.tolist(), + } + for b in bodies + ], + "mates": [m.to_dict() for m in mates], + "patterns": [p.to_dict() for p in result.patterns], + "labels": result.labels.to_dict() if result.labels else None, + "n_bodies": len(bodies), + "n_mates": len(mates), + "n_joints": len(result.joints), + } + ) + + return examples diff --git a/tests/mates/test_generator.py b/tests/mates/test_generator.py new file mode 100644 index 0000000..bb8dd71 --- /dev/null +++ b/tests/mates/test_generator.py @@ -0,0 +1,155 @@ +"""Tests for solver.mates.generator -- synthetic mate generator.""" + +from __future__ import annotations + +from solver.mates.generator import SyntheticMateGenerator, generate_mate_training_batch +from solver.mates.primitives import MateType + +# --------------------------------------------------------------------------- +# SyntheticMateGenerator +# --------------------------------------------------------------------------- + + +class TestSyntheticMateGenerator: + """SyntheticMateGenerator core functionality.""" + + def test_generate_basic(self) -> None: + """Generate a simple assembly with mates.""" + gen = SyntheticMateGenerator(seed=42) + bodies, mates, result = gen.generate(3) + assert len(bodies) == 3 + assert len(mates) > 0 + assert result.analysis is not None + + def test_deterministic_with_seed(self) -> None: + """Same seed produces same output.""" + gen1 = SyntheticMateGenerator(seed=123) + _, mates1, _ = gen1.generate(3) + + gen2 = SyntheticMateGenerator(seed=123) + _, mates2, _ = gen2.generate(3) + + assert len(mates1) == len(mates2) + for m1, m2 in zip(mates1, mates2, strict=True): + assert m1.mate_type == m2.mate_type + assert m1.ref_a.body_id == m2.ref_a.body_id + + def test_grounded(self) -> None: + """Grounded assembly should work.""" + gen = SyntheticMateGenerator(seed=42) + bodies, _mates, result = gen.generate(3, grounded=True) + assert len(bodies) == 3 + assert result.analysis is not None + + def test_revolute_produces_two_mates(self) -> None: + """A revolute joint should reverse-map to 2 mates.""" + gen = SyntheticMateGenerator(seed=42) + _bodies, mates, _result = gen.generate(2) + # 2 bodies -> 1 revolute joint -> 2 mates (concentric + coincident) + assert len(mates) == 2 + mate_types = {m.mate_type for m in mates} + assert MateType.CONCENTRIC in mate_types + assert MateType.COINCIDENT in mate_types + + +class TestReverseMapping: + """Reverse mapping from joints to mates.""" + + def test_revolute_mapping(self) -> None: + """REVOLUTE -> Concentric + Coincident.""" + gen = SyntheticMateGenerator(seed=42) + _bodies, mates, _result = gen.generate(2) + types = [m.mate_type for m in mates] + assert MateType.CONCENTRIC in types + assert MateType.COINCIDENT in types + + def test_round_trip_analysis(self) -> None: + """Generated mates round-trip through analysis successfully.""" + gen = SyntheticMateGenerator(seed=42) + _bodies, _mates, result = gen.generate(4) + assert result.analysis is not None + assert result.labels is not None + # Should produce joints from the mates + assert len(result.joints) > 0 + + +class TestNoiseInjection: + """Noise injection mechanisms.""" + + def test_redundant_injection(self) -> None: + """Redundant prob > 0 produces more mates than clean version.""" + gen_clean = SyntheticMateGenerator(seed=42, redundant_prob=0.0) + _, mates_clean, _ = gen_clean.generate(4) + + gen_noisy = SyntheticMateGenerator(seed=42, redundant_prob=1.0) + _, mates_noisy, _ = gen_noisy.generate(4) + + assert len(mates_noisy) > len(mates_clean) + + def test_missing_injection(self) -> None: + """Missing prob > 0 produces fewer mates than clean version.""" + gen_clean = SyntheticMateGenerator(seed=42, missing_prob=0.0) + _, mates_clean, _ = gen_clean.generate(4) + + gen_noisy = SyntheticMateGenerator(seed=42, missing_prob=0.5) + _, mates_noisy, _ = gen_noisy.generate(4) + + # With 50% drop rate on 6 mates, very likely to drop at least one + assert len(mates_noisy) <= len(mates_clean) + + def test_incompatible_injection(self) -> None: + """Incompatible prob > 0 adds mates with wrong geometry.""" + gen = SyntheticMateGenerator(seed=42, incompatible_prob=1.0) + _, mates, _ = gen.generate(3) + # Should have extra mates beyond the clean count + gen_clean = SyntheticMateGenerator(seed=42) + _, mates_clean, _ = gen_clean.generate(3) + assert len(mates) > len(mates_clean) + + +# --------------------------------------------------------------------------- +# generate_mate_training_batch +# --------------------------------------------------------------------------- + + +class TestGenerateMateTrainingBatch: + """Batch generation function.""" + + def test_batch_structure(self) -> None: + """Each example has required keys.""" + examples = generate_mate_training_batch(batch_size=3, seed=42) + assert len(examples) == 3 + for ex in examples: + assert "bodies" in ex + assert "mates" in ex + assert "patterns" in ex + assert "labels" in ex + assert "n_bodies" in ex + assert "n_mates" in ex + assert "n_joints" in ex + + def test_batch_deterministic(self) -> None: + """Same seed produces same batch.""" + batch1 = generate_mate_training_batch(batch_size=5, seed=99) + batch2 = generate_mate_training_batch(batch_size=5, seed=99) + for ex1, ex2 in zip(batch1, batch2, strict=True): + assert ex1["n_bodies"] == ex2["n_bodies"] + assert ex1["n_mates"] == ex2["n_mates"] + + def test_batch_grounded_ratio(self) -> None: + """Batch respects grounded_ratio parameter.""" + # All grounded + examples = generate_mate_training_batch(batch_size=5, seed=42, grounded_ratio=1.0) + assert len(examples) == 5 + + def test_batch_with_noise(self) -> None: + """Batch with noise injection runs without error.""" + examples = generate_mate_training_batch( + batch_size=3, + seed=42, + redundant_prob=0.3, + missing_prob=0.1, + ) + assert len(examples) == 3 + for ex in examples: + assert ex["n_mates"] >= 0 From 93bda28f67161d66ebbf6fd51af0595c99276d31 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Tue, 3 Feb 2026 13:08:23 -0600 Subject: [PATCH 5/5] feat(mates): add mate-level ground truth labels MateLabel and MateAssemblyLabels dataclasses with label_mate_assembly() that back-attributes joint-level independence to originating mates. Detects redundant and degenerate mates with pattern membership tracking. Closes #15 --- solver/mates/__init__.py | 8 ++ solver/mates/labeling.py | 224 +++++++++++++++++++++++++++++++++++ tests/mates/test_labeling.py | 224 +++++++++++++++++++++++++++++++++++ 3 files changed, 456 insertions(+) create mode 100644 solver/mates/labeling.py create mode 100644 tests/mates/test_labeling.py diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py index 0496bf4..884e09d 100644 --- a/solver/mates/__init__.py +++ b/solver/mates/__init__.py @@ -9,6 +9,11 @@ from solver.mates.generator import ( SyntheticMateGenerator, generate_mate_training_batch, ) +from solver.mates.labeling import ( + MateAssemblyLabels, + MateLabel, + label_mate_assembly, +) from solver.mates.patterns import ( JointPattern, PatternMatch, @@ -28,6 +33,8 @@ __all__ = [ "JointPattern", "Mate", "MateAnalysisResult", + "MateAssemblyLabels", + "MateLabel", "MateType", "PatternMatch", "SyntheticMateGenerator", @@ -35,5 +42,6 @@ __all__ = [ "convert_mates_to_joints", "dof_removed", "generate_mate_training_batch", + "label_mate_assembly", "recognize_patterns", ] diff --git a/solver/mates/labeling.py b/solver/mates/labeling.py new file mode 100644 index 0000000..deb7507 --- /dev/null +++ b/solver/mates/labeling.py @@ -0,0 +1,224 @@ +"""Mate-level ground truth labels for assembly analysis. + +Back-attributes joint-level independence results to originating mates +via the mate-to-joint mapping from conversion.py. Produces per-mate +labels indicating whether each mate is independent, redundant, or +degenerate. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from solver.mates.conversion import analyze_mate_assembly + +if TYPE_CHECKING: + from typing import Any + + from solver.datagen.labeling import AssemblyLabel + from solver.datagen.types import ConstraintAnalysis, RigidBody + from solver.mates.conversion import MateAnalysisResult + from solver.mates.patterns import JointPattern, PatternMatch + from solver.mates.primitives import Mate + +__all__ = [ + "MateAssemblyLabels", + "MateLabel", + "label_mate_assembly", +] + + +# --------------------------------------------------------------------------- +# Label dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class MateLabel: + """Per-mate ground truth label. + + Attributes: + mate_id: The mate this label refers to. + is_independent: Contributes non-redundant DOF removal. + is_redundant: Fully redundant (removable without DOF change). + is_degenerate: Combinatorially independent but geometrically dependent. + pattern: Which joint pattern this mate belongs to, if any. + issue: Detected issue type, if any. + """ + + mate_id: int + is_independent: bool = True + is_redundant: bool = False + is_degenerate: bool = False + pattern: JointPattern | None = None + issue: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "mate_id": self.mate_id, + "is_independent": self.is_independent, + "is_redundant": self.is_redundant, + "is_degenerate": self.is_degenerate, + "pattern": self.pattern.value if self.pattern else None, + "issue": self.issue, + } + + +@dataclass +class MateAssemblyLabels: + """Complete mate-level ground truth labels for an assembly. + + Attributes: + per_mate: Per-mate labels. + patterns: Recognized joint patterns. + assembly: Assembly-wide summary label. + analysis: Constraint analysis from pebble game + Jacobian. + """ + + per_mate: list[MateLabel] + patterns: list[PatternMatch] + assembly: AssemblyLabel + analysis: ConstraintAnalysis + mate_analysis: MateAnalysisResult | None = None + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "per_mate": [ml.to_dict() for ml in self.per_mate], + "patterns": [p.to_dict() for p in self.patterns], + "assembly": { + "classification": self.assembly.classification, + "total_dof": self.assembly.total_dof, + "redundant_count": self.assembly.redundant_count, + "is_rigid": self.assembly.is_rigid, + "is_minimally_rigid": self.assembly.is_minimally_rigid, + "has_degeneracy": self.assembly.has_degeneracy, + }, + } + + +# --------------------------------------------------------------------------- +# Labeling logic +# --------------------------------------------------------------------------- + + +def _build_mate_pattern_map( + patterns: list[PatternMatch], +) -> dict[int, JointPattern]: + """Map mate_ids to the pattern they belong to (best match).""" + result: dict[int, JointPattern] = {} + # Sort by confidence descending so best matches win + sorted_patterns = sorted(patterns, key=lambda p: -p.confidence) + for pm in sorted_patterns: + if pm.confidence < 1.0: + continue + for mate in pm.mates: + if mate.mate_id not in result: + result[mate.mate_id] = pm.pattern + return result + + +def label_mate_assembly( + bodies: list[RigidBody], + mates: list[Mate], + ground_body: int | None = None, +) -> MateAssemblyLabels: + """Produce mate-level ground truth labels for an assembly. + + Runs analyze_mate_assembly() internally, then back-attributes + joint-level independence to originating mates via the mate_to_joint + mapping. + + A mate is: + - **redundant** if ALL joints it contributes to are fully redundant + - **degenerate** if any joint it contributes to is geometrically + dependent but combinatorially independent + - **independent** otherwise + + Args: + bodies: Rigid bodies in the assembly. + mates: Mate constraints between the bodies. + ground_body: If set, this body is fixed to the world. + + Returns: + MateAssemblyLabels with per-mate labels and assembly summary. + """ + mate_result = analyze_mate_assembly(bodies, mates, ground_body) + + # Build per-joint redundancy from labels + joint_redundant: dict[int, bool] = {} + joint_degenerate: dict[int, bool] = {} + + if mate_result.labels is not None: + for jl in mate_result.labels.per_joint: + # A joint is fully redundant if all its constraints are redundant + joint_redundant[jl.joint_id] = jl.redundant_count == jl.total and jl.total > 0 + # Joint is degenerate if it has more independent constraints + # than Jacobian rank would suggest (geometric degeneracy) + joint_degenerate[jl.joint_id] = False + + # Check for geometric degeneracy via per-constraint labels + for cl in mate_result.labels.per_constraint: + if cl.pebble_independent and not cl.jacobian_independent: + joint_degenerate[cl.joint_id] = True + + # Build pattern membership map + pattern_map = _build_mate_pattern_map(mate_result.patterns) + + # Back-attribute to mates + per_mate: list[MateLabel] = [] + for mate in mates: + mate_joint_ids = mate_result.mate_to_joint.get(mate.mate_id, []) + + if not mate_joint_ids: + # Mate wasn't converted to any joint (shouldn't happen, but safe) + per_mate.append( + MateLabel( + mate_id=mate.mate_id, + is_independent=False, + is_redundant=True, + issue="unmapped", + ) + ) + continue + + # Redundant if ALL contributed joints are redundant + all_redundant = all(joint_redundant.get(jid, False) for jid in mate_joint_ids) + + # Degenerate if ANY contributed joint is degenerate + any_degenerate = any(joint_degenerate.get(jid, False) for jid in mate_joint_ids) + + is_independent = not all_redundant + pattern = pattern_map.get(mate.mate_id) + + # Determine issue string + issue: str | None = None + if all_redundant: + issue = "redundant" + elif any_degenerate: + issue = "degenerate" + + per_mate.append( + MateLabel( + mate_id=mate.mate_id, + is_independent=is_independent, + is_redundant=all_redundant, + is_degenerate=any_degenerate, + pattern=pattern, + issue=issue, + ) + ) + + # Assembly label + assert mate_result.labels is not None + assembly_label = mate_result.labels.assembly + + return MateAssemblyLabels( + per_mate=per_mate, + patterns=mate_result.patterns, + assembly=assembly_label, + analysis=mate_result.labels.analysis, + mate_analysis=mate_result, + ) diff --git a/tests/mates/test_labeling.py b/tests/mates/test_labeling.py new file mode 100644 index 0000000..1ed9f63 --- /dev/null +++ b/tests/mates/test_labeling.py @@ -0,0 +1,224 @@ +"""Tests for solver.mates.labeling -- mate-level ground truth labels.""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.types import RigidBody +from solver.mates.labeling import MateAssemblyLabels, MateLabel, label_mate_assembly +from solver.mates.patterns import JointPattern +from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ref( + body_id: int, + geom_type: GeometryType, + *, + origin: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> GeometryRef: + """Factory for GeometryRef with sensible defaults.""" + if origin is None: + origin = np.zeros(3) + if direction is None and geom_type in { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + }: + direction = np.array([0.0, 0.0, 1.0]) + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id="Geom001", + origin=origin, + direction=direction, + ) + + +def _make_bodies(n: int) -> list[RigidBody]: + """Create n bodies at distinct positions.""" + return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)] + + +# --------------------------------------------------------------------------- +# MateLabel +# --------------------------------------------------------------------------- + + +class TestMateLabel: + """MateLabel dataclass.""" + + def test_defaults(self) -> None: + ml = MateLabel(mate_id=0) + assert ml.is_independent is True + assert ml.is_redundant is False + assert ml.is_degenerate is False + assert ml.pattern is None + assert ml.issue is None + + def test_to_dict(self) -> None: + ml = MateLabel( + mate_id=5, + is_independent=False, + is_redundant=True, + pattern=JointPattern.HINGE, + issue="redundant", + ) + d = ml.to_dict() + assert d["mate_id"] == 5 + assert d["is_redundant"] is True + assert d["pattern"] == "hinge" + assert d["issue"] == "redundant" + + def test_to_dict_none_pattern(self) -> None: + ml = MateLabel(mate_id=0) + d = ml.to_dict() + assert d["pattern"] is None + + +# --------------------------------------------------------------------------- +# MateAssemblyLabels +# --------------------------------------------------------------------------- + + +class TestMateAssemblyLabels: + """MateAssemblyLabels dataclass.""" + + def test_to_dict_structure(self) -> None: + """to_dict produces expected keys.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates) + d = result.to_dict() + assert "per_mate" in d + assert "patterns" in d + assert "assembly" in d + assert isinstance(d["per_mate"], list) + + +# --------------------------------------------------------------------------- +# label_mate_assembly +# --------------------------------------------------------------------------- + + +class TestLabelMateAssembly: + """Full labeling pipeline.""" + + def test_clean_assembly_no_redundancy(self) -> None: + """Two bodies with lock mate -> clean, no redundancy.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates) + assert isinstance(result, MateAssemblyLabels) + assert len(result.per_mate) == 1 + ml = result.per_mate[0] + assert ml.mate_id == 0 + assert ml.is_independent is True + assert ml.is_redundant is False + assert ml.issue is None + + def test_redundant_assembly(self) -> None: + """Two lock mates on same body pair -> one is redundant.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + Mate( + mate_id=1, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])), + ref_b=_make_ref(1, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])), + ), + ] + result = label_mate_assembly(bodies, mates) + assert len(result.per_mate) == 2 + redundant_count = sum(1 for ml in result.per_mate if ml.is_redundant) + # At least one should be redundant + assert redundant_count >= 1 + assert result.assembly.redundant_count > 0 + + def test_hinge_pattern_labeling(self) -> None: + """Hinge mates get pattern membership.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + result = label_mate_assembly(bodies, mates) + assert len(result.per_mate) == 2 + # Both mates should be part of the hinge pattern + for ml in result.per_mate: + assert ml.pattern is JointPattern.HINGE + assert ml.is_independent is True + + def test_grounded_assembly(self) -> None: + """Grounded assembly labeling works.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates, ground_body=0) + assert result.assembly.is_rigid + + def test_empty_mates(self) -> None: + """No mates -> no per_mate labels, underconstrained.""" + bodies = _make_bodies(2) + result = label_mate_assembly(bodies, []) + assert len(result.per_mate) == 0 + assert result.assembly.classification == "underconstrained" + + def test_assembly_classification(self) -> None: + """Assembly classification is present.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates) + assert result.assembly.classification in { + "well-constrained", + "overconstrained", + "underconstrained", + "mixed", + }