diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py index 994e51d..e2b66ad 100644 --- a/solver/mates/__init__.py +++ b/solver/mates/__init__.py @@ -1,5 +1,10 @@ """Mate-level constraint types for assembly analysis.""" +from solver.mates.patterns import ( + JointPattern, + PatternMatch, + recognize_patterns, +) from solver.mates.primitives import ( GeometryRef, GeometryType, @@ -11,7 +16,10 @@ from solver.mates.primitives import ( __all__ = [ "GeometryRef", "GeometryType", + "JointPattern", "Mate", "MateType", + "PatternMatch", "dof_removed", + "recognize_patterns", ] diff --git a/solver/mates/patterns.py b/solver/mates/patterns.py new file mode 100644 index 0000000..6844365 --- /dev/null +++ b/solver/mates/patterns.py @@ -0,0 +1,284 @@ +"""Joint pattern recognition from mate combinations. + +Groups mates by body pair and matches them against canonical joint +patterns (hinge, slider, ball, etc.). Each pattern is a known +combination of mate types that together constrain motion equivalently +to a single mechanical joint. +""" + +from __future__ import annotations + +import enum +from collections import defaultdict +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +from solver.datagen.types import JointType +from solver.mates.primitives import GeometryType, Mate, MateType + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "JointPattern", + "PatternMatch", + "recognize_patterns", +] + + +# --------------------------------------------------------------------------- +# Enums +# --------------------------------------------------------------------------- + + +class JointPattern(enum.Enum): + """Canonical joint patterns formed by mate combinations.""" + + HINGE = "hinge" + SLIDER = "slider" + CYLINDER = "cylinder" + BALL = "ball" + PLANAR = "planar" + FIXED = "fixed" + GEAR = "gear" + RACK_PINION = "rack_pinion" + UNKNOWN = "unknown" + + +# --------------------------------------------------------------------------- +# Pattern match result +# --------------------------------------------------------------------------- + + +@dataclass +class PatternMatch: + """Result of matching a group of mates to a joint pattern. + + Attributes: + pattern: The identified joint pattern. + mates: The mates that form this pattern. + body_a: First body in the pair. + body_b: Second body in the pair. + confidence: How well the mates match the canonical pattern (0-1). + equivalent_joint_type: The JointType this pattern maps to. + missing_mates: Descriptions of mates absent for a full match. + """ + + pattern: JointPattern + mates: list[Mate] + body_a: int + body_b: int + confidence: float + equivalent_joint_type: JointType + missing_mates: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "pattern": self.pattern.value, + "body_a": self.body_a, + "body_b": self.body_b, + "confidence": self.confidence, + "equivalent_joint_type": self.equivalent_joint_type.name, + "mate_ids": [m.mate_id for m in self.mates], + "missing_mates": self.missing_mates, + } + + +# --------------------------------------------------------------------------- +# Pattern rules (data-driven) +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _MateRequirement: + """A single mate requirement within a pattern rule.""" + + mate_type: MateType + geometry_a: GeometryType | None = None + geometry_b: GeometryType | None = None + + +@dataclass(frozen=True) +class _PatternRule: + """Defines a canonical pattern as a set of required mates.""" + + pattern: JointPattern + joint_type: JointType + required: tuple[_MateRequirement, ...] + description: str = "" + + +_PATTERN_RULES: list[_PatternRule] = [ + _PatternRule( + pattern=JointPattern.HINGE, + joint_type=JointType.REVOLUTE, + required=( + _MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS), + _MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE), + ), + description="Concentric axes + coincident plane", + ), + _PatternRule( + pattern=JointPattern.SLIDER, + joint_type=JointType.SLIDER, + required=( + _MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE), + _MateRequirement(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS), + ), + description="Coincident plane + parallel axis", + ), + _PatternRule( + pattern=JointPattern.CYLINDER, + joint_type=JointType.CYLINDRICAL, + required=(_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),), + description="Concentric axes only", + ), + _PatternRule( + pattern=JointPattern.BALL, + joint_type=JointType.BALL, + required=(_MateRequirement(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),), + description="Coincident points", + ), + _PatternRule( + pattern=JointPattern.PLANAR, + joint_type=JointType.PLANAR, + required=(_MateRequirement(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),), + description="Coincident faces", + ), + _PatternRule( + pattern=JointPattern.PLANAR, + joint_type=JointType.PLANAR, + required=(_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),), + description="Coincident planes (alternate planar)", + ), + _PatternRule( + pattern=JointPattern.FIXED, + joint_type=JointType.FIXED, + required=(_MateRequirement(MateType.LOCK),), + description="Lock mate", + ), +] + + +# --------------------------------------------------------------------------- +# Matching logic +# --------------------------------------------------------------------------- + + +def _mate_matches_requirement(mate: Mate, req: _MateRequirement) -> bool: + """Check if a mate satisfies a requirement.""" + if mate.mate_type is not req.mate_type: + return False + if req.geometry_a is not None and mate.ref_a.geometry_type is not req.geometry_a: + return False + return not (req.geometry_b is not None and mate.ref_b.geometry_type is not req.geometry_b) + + +def _try_match_rule( + rule: _PatternRule, + mates: list[Mate], +) -> tuple[float, list[Mate], list[str]]: + """Try to match a rule against a group of mates. + + Returns: + (confidence, matched_mates, missing_descriptions) + """ + matched: list[Mate] = [] + missing: list[str] = [] + + for req in rule.required: + found = False + for mate in mates: + if mate in matched: + continue + if _mate_matches_requirement(mate, req): + matched.append(mate) + found = True + break + if not found: + geom_desc = "" + if req.geometry_a is not None: + geom_b = req.geometry_b.value if req.geometry_b else "*" + geom_desc = f" ({req.geometry_a.value}-{geom_b})" + missing.append(f"{req.mate_type.name}{geom_desc}") + + total_required = len(rule.required) + if total_required == 0: + return 0.0, [], [] + + matched_count = len(matched) + confidence = matched_count / total_required + + return confidence, matched, missing + + +def _normalize_body_pair(body_a: int, body_b: int) -> tuple[int, int]: + """Normalize a body pair so the smaller ID comes first.""" + return (min(body_a, body_b), max(body_a, body_b)) + + +def recognize_patterns(mates: list[Mate]) -> list[PatternMatch]: + """Identify joint patterns from a list of mates. + + Groups mates by body pair, then checks each group against + canonical pattern rules. Returns matches sorted by confidence + descending. + + Args: + mates: List of mate constraints to analyze. + + Returns: + List of PatternMatch results, highest confidence first. + """ + if not mates: + return [] + + # Group mates by normalized body pair + groups: dict[tuple[int, int], list[Mate]] = defaultdict(list) + for mate in mates: + pair = _normalize_body_pair(mate.ref_a.body_id, mate.ref_b.body_id) + groups[pair].append(mate) + + results: list[PatternMatch] = [] + + for (body_a, body_b), group_mates in groups.items(): + group_matches: list[PatternMatch] = [] + + for rule in _PATTERN_RULES: + confidence, matched, missing = _try_match_rule(rule, group_mates) + + if confidence > 0: + group_matches.append( + PatternMatch( + pattern=rule.pattern, + mates=matched if matched else group_mates, + body_a=body_a, + body_b=body_b, + confidence=confidence, + equivalent_joint_type=rule.joint_type, + missing_mates=missing, + ) + ) + + if group_matches: + # Sort by confidence descending, prefer more-specific patterns + group_matches.sort(key=lambda m: (-m.confidence, -len(m.mates))) + results.extend(group_matches) + else: + # No pattern matched at all + results.append( + PatternMatch( + pattern=JointPattern.UNKNOWN, + mates=group_mates, + body_a=body_a, + body_b=body_b, + confidence=0.0, + equivalent_joint_type=JointType.DISTANCE, + missing_mates=[], + ) + ) + + # Global sort by confidence descending + results.sort(key=lambda m: -m.confidence) + return results diff --git a/tests/mates/test_patterns.py b/tests/mates/test_patterns.py new file mode 100644 index 0000000..2cb4958 --- /dev/null +++ b/tests/mates/test_patterns.py @@ -0,0 +1,285 @@ +"""Tests for solver.mates.patterns -- joint pattern recognition.""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.types import JointType +from solver.mates.patterns import JointPattern, PatternMatch, recognize_patterns +from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ref( + body_id: int, + geom_type: GeometryType, + *, + geometry_id: str = "Geom001", + origin: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> GeometryRef: + """Factory for GeometryRef with sensible defaults.""" + if origin is None: + origin = np.zeros(3) + if direction is None and geom_type in { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + }: + direction = np.array([0.0, 0.0, 1.0]) + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id=geometry_id, + origin=origin, + direction=direction, + ) + + +def _make_mate( + mate_id: int, + mate_type: MateType, + body_a: int, + body_b: int, + geom_a: GeometryType = GeometryType.FACE, + geom_b: GeometryType = GeometryType.FACE, +) -> Mate: + """Factory for Mate with body pair and geometry types.""" + return Mate( + mate_id=mate_id, + mate_type=mate_type, + ref_a=_make_ref(body_a, geom_a), + ref_b=_make_ref(body_b, geom_b), + ) + + +# --------------------------------------------------------------------------- +# JointPattern enum +# --------------------------------------------------------------------------- + + +class TestJointPattern: + """JointPattern enum.""" + + def test_member_count(self) -> None: + assert len(JointPattern) == 9 + + def test_string_values(self) -> None: + for jp in JointPattern: + assert isinstance(jp.value, str) + + def test_access_by_name(self) -> None: + assert JointPattern["HINGE"] is JointPattern.HINGE + + +# --------------------------------------------------------------------------- +# PatternMatch +# --------------------------------------------------------------------------- + + +class TestPatternMatch: + """PatternMatch dataclass.""" + + def test_construction(self) -> None: + mate = _make_mate(0, MateType.LOCK, 0, 1) + pm = PatternMatch( + pattern=JointPattern.FIXED, + mates=[mate], + body_a=0, + body_b=1, + confidence=1.0, + equivalent_joint_type=JointType.FIXED, + ) + assert pm.pattern is JointPattern.FIXED + assert pm.confidence == 1.0 + assert pm.missing_mates == [] + + def test_to_dict(self) -> None: + mate = _make_mate(5, MateType.LOCK, 0, 1) + pm = PatternMatch( + pattern=JointPattern.FIXED, + mates=[mate], + body_a=0, + body_b=1, + confidence=1.0, + equivalent_joint_type=JointType.FIXED, + ) + d = pm.to_dict() + assert d["pattern"] == "fixed" + assert d["mate_ids"] == [5] + assert d["equivalent_joint_type"] == "FIXED" + + +# --------------------------------------------------------------------------- +# recognize_patterns — canonical patterns +# --------------------------------------------------------------------------- + + +class TestRecognizeCanonical: + """Full-confidence canonical pattern recognition.""" + + def test_empty_input(self) -> None: + assert recognize_patterns([]) == [] + + def test_hinge(self) -> None: + """Concentric(axis) + Coincident(plane) -> Hinge.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + _make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.HINGE + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.REVOLUTE + assert top.missing_mates == [] + + def test_slider(self) -> None: + """Coincident(plane) + Parallel(axis) -> Slider.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + _make_mate(1, MateType.PARALLEL, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.SLIDER + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.SLIDER + + def test_cylinder(self) -> None: + """Concentric(axis) only -> Cylinder.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + # Should match cylinder at confidence 1.0 + cylinder = [r for r in results if r.pattern is JointPattern.CYLINDER] + assert len(cylinder) >= 1 + assert cylinder[0].confidence == 1.0 + assert cylinder[0].equivalent_joint_type is JointType.CYLINDRICAL + + def test_ball(self) -> None: + """Coincident(point) -> Ball.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.POINT, GeometryType.POINT), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.BALL + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.BALL + + def test_planar_face(self) -> None: + """Coincident(face) -> Planar.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.FACE, GeometryType.FACE), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.PLANAR + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.PLANAR + + def test_fixed(self) -> None: + """Lock -> Fixed.""" + mates = [ + _make_mate(0, MateType.LOCK, 0, 1, GeometryType.FACE, GeometryType.FACE), + ] + results = recognize_patterns(mates) + top = results[0] + assert top.pattern is JointPattern.FIXED + assert top.confidence == 1.0 + assert top.equivalent_joint_type is JointType.FIXED + + +# --------------------------------------------------------------------------- +# recognize_patterns — partial matches +# --------------------------------------------------------------------------- + + +class TestRecognizePartial: + """Partial pattern matches and hints.""" + + def test_concentric_without_plane_hints_hinge(self) -> None: + """Concentric alone matches hinge at 0.5 confidence with missing hint.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE] + assert len(hinge_matches) >= 1 + hinge = hinge_matches[0] + assert hinge.confidence == 0.5 + assert len(hinge.missing_mates) > 0 + + def test_coincident_plane_without_parallel_hints_slider(self) -> None: + """Coincident(plane) alone matches slider at 0.5 confidence.""" + mates = [ + _make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + ] + results = recognize_patterns(mates) + slider_matches = [r for r in results if r.pattern is JointPattern.SLIDER] + assert len(slider_matches) >= 1 + assert slider_matches[0].confidence == 0.5 + + +# --------------------------------------------------------------------------- +# recognize_patterns — ambiguous / multi-body +# --------------------------------------------------------------------------- + + +class TestRecognizeAmbiguous: + """Ambiguous patterns and multi-body-pair assemblies.""" + + def test_concentric_matches_both_hinge_and_cylinder(self) -> None: + """A single concentric mate produces both hinge (partial) and cylinder matches.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + ] + results = recognize_patterns(mates) + patterns = {r.pattern for r in results} + assert JointPattern.HINGE in patterns + assert JointPattern.CYLINDER in patterns + + def test_multiple_body_pairs(self) -> None: + """Mates across different body pairs produce separate pattern matches.""" + mates = [ + _make_mate(0, MateType.LOCK, 0, 1), + _make_mate(1, MateType.COINCIDENT, 2, 3, GeometryType.POINT, GeometryType.POINT), + ] + results = recognize_patterns(mates) + pairs = {(r.body_a, r.body_b) for r in results} + assert (0, 1) in pairs + assert (2, 3) in pairs + + def test_results_sorted_by_confidence(self) -> None: + """All results should be sorted by confidence descending.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS), + _make_mate(1, MateType.LOCK, 2, 3), + ] + results = recognize_patterns(mates) + confidences = [r.confidence for r in results] + assert confidences == sorted(confidences, reverse=True) + + def test_unknown_pattern(self) -> None: + """A mate type that matches no rule returns UNKNOWN.""" + mates = [ + _make_mate(0, MateType.ANGLE, 0, 1, GeometryType.FACE, GeometryType.FACE), + ] + results = recognize_patterns(mates) + assert any(r.pattern is JointPattern.UNKNOWN for r in results) + + def test_body_pair_normalization(self) -> None: + """Mates with reversed body order should be grouped together.""" + mates = [ + _make_mate(0, MateType.CONCENTRIC, 1, 0, GeometryType.AXIS, GeometryType.AXIS), + _make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE), + ] + results = recognize_patterns(mates) + hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE] + assert len(hinge_matches) >= 1 + assert hinge_matches[0].confidence == 1.0