diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py index e2b66ad..fa8a5d5 100644 --- a/solver/mates/__init__.py +++ b/solver/mates/__init__.py @@ -1,5 +1,10 @@ """Mate-level constraint types for assembly analysis.""" +from solver.mates.conversion import ( + MateAnalysisResult, + analyze_mate_assembly, + convert_mates_to_joints, +) from solver.mates.patterns import ( JointPattern, PatternMatch, @@ -18,8 +23,11 @@ __all__ = [ "GeometryType", "JointPattern", "Mate", + "MateAnalysisResult", "MateType", "PatternMatch", + "analyze_mate_assembly", + "convert_mates_to_joints", "dof_removed", "recognize_patterns", ] diff --git a/solver/mates/conversion.py b/solver/mates/conversion.py new file mode 100644 index 0000000..c8016ba --- /dev/null +++ b/solver/mates/conversion.py @@ -0,0 +1,276 @@ +"""Mate-to-joint conversion and assembly analysis. + +Bridges the mate-level constraint representation to the existing +joint-based analysis pipeline. Converts recognized mate patterns +to Joint objects, then runs the pebble game and Jacobian analysis, +maintaining bidirectional traceability between mates and joints. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.labeling import AssemblyLabels, label_assembly +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) +from solver.mates.patterns import PatternMatch, recognize_patterns + +if TYPE_CHECKING: + from typing import Any + + from solver.mates.primitives import Mate + +__all__ = [ + "MateAnalysisResult", + "analyze_mate_assembly", + "convert_mates_to_joints", +] + + +# --------------------------------------------------------------------------- +# Result dataclass +# --------------------------------------------------------------------------- + + +@dataclass +class MateAnalysisResult: + """Combined result of mate-based assembly analysis. + + Attributes: + patterns: Recognized joint patterns from mate grouping. + joints: Joint objects produced by conversion. + mate_to_joint: Mapping from mate_id to list of joint_ids. + joint_to_mates: Mapping from joint_id to list of mate_ids. + analysis: Constraint analysis from pebble game + Jacobian. + labels: Full ground truth labels from label_assembly. + """ + + patterns: list[PatternMatch] + joints: list[Joint] + mate_to_joint: dict[int, list[int]] = field(default_factory=dict) + joint_to_mates: dict[int, list[int]] = field(default_factory=dict) + analysis: ConstraintAnalysis | None = None + labels: AssemblyLabels | None = None + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "patterns": [p.to_dict() for p in self.patterns], + "joints": [ + { + "joint_id": j.joint_id, + "body_a": j.body_a, + "body_b": j.body_b, + "joint_type": j.joint_type.name, + } + for j in self.joints + ], + "mate_to_joint": self.mate_to_joint, + "joint_to_mates": self.joint_to_mates, + "labels": self.labels.to_dict() if self.labels else None, + } + + +# --------------------------------------------------------------------------- +# Pattern-to-JointType mapping +# --------------------------------------------------------------------------- + +# Maps (JointPattern value) to JointType for known patterns. +# Used by convert_mates_to_joints when a full pattern is recognized. +_PATTERN_JOINT_MAP: dict[str, JointType] = { + "hinge": JointType.REVOLUTE, + "slider": JointType.SLIDER, + "cylinder": JointType.CYLINDRICAL, + "ball": JointType.BALL, + "planar": JointType.PLANAR, + "fixed": JointType.FIXED, +} + +# Fallback mapping for individual mate types when no pattern is recognized. +_MATE_JOINT_FALLBACK: dict[str, JointType] = { + "COINCIDENT": JointType.PLANAR, + "CONCENTRIC": JointType.CYLINDRICAL, + "PARALLEL": JointType.PARALLEL, + "PERPENDICULAR": JointType.PERPENDICULAR, + "TANGENT": JointType.DISTANCE, + "DISTANCE": JointType.DISTANCE, + "ANGLE": JointType.PERPENDICULAR, + "LOCK": JointType.FIXED, +} + + +# --------------------------------------------------------------------------- +# Conversion +# --------------------------------------------------------------------------- + + +def _compute_joint_params( + pattern: PatternMatch, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Extract anchor and axis from pattern mates. + + Returns: + (anchor_a, anchor_b, axis) + """ + anchor_a = np.zeros(3) + anchor_b = np.zeros(3) + axis = np.array([0.0, 0.0, 1.0]) + + for mate in pattern.mates: + ref_a = mate.ref_a + ref_b = mate.ref_b + anchor_a = ref_a.origin.copy() + anchor_b = ref_b.origin.copy() + if ref_a.direction is not None: + axis = ref_a.direction.copy() + break + + return anchor_a, anchor_b, axis + + +def _convert_single_mate( + mate: Mate, + joint_id: int, +) -> Joint: + """Convert a single unmatched mate to a Joint.""" + joint_type = _MATE_JOINT_FALLBACK.get(mate.mate_type.name, JointType.DISTANCE) + + anchor_a = mate.ref_a.origin.copy() + anchor_b = mate.ref_b.origin.copy() + axis = np.array([0.0, 0.0, 1.0]) + if mate.ref_a.direction is not None: + axis = mate.ref_a.direction.copy() + + return Joint( + joint_id=joint_id, + body_a=mate.ref_a.body_id, + body_b=mate.ref_b.body_id, + joint_type=joint_type, + anchor_a=anchor_a, + anchor_b=anchor_b, + axis=axis, + ) + + +def convert_mates_to_joints( + mates: list[Mate], + bodies: list[RigidBody] | None = None, +) -> tuple[list[Joint], dict[int, list[int]], dict[int, list[int]]]: + """Convert mates to Joint objects via pattern recognition. + + For each body pair: + - If mates form a recognized pattern, emit the equivalent joint. + - Otherwise, emit individual joints for each unmatched mate. + + Args: + mates: Mate constraints to convert. + bodies: Optional body list (unused currently, reserved for + future geometry lookups). + + Returns: + (joints, mate_to_joint, joint_to_mates) tuple. + """ + if not mates: + return [], {}, {} + + patterns = recognize_patterns(mates) + joints: list[Joint] = [] + mate_to_joint: dict[int, list[int]] = {} + joint_to_mates: dict[int, list[int]] = {} + + # Track which mates have been consumed by full-confidence patterns + consumed_mate_ids: set[int] = set() + next_joint_id = 0 + + # First pass: emit joints for full-confidence patterns + for pattern in patterns: + if pattern.confidence < 1.0: + continue + if pattern.pattern.value not in _PATTERN_JOINT_MAP: + continue + + # Check if any of these mates were already consumed + mate_ids = [m.mate_id for m in pattern.mates] + if any(mid in consumed_mate_ids for mid in mate_ids): + continue + + joint_type = _PATTERN_JOINT_MAP[pattern.pattern.value] + anchor_a, anchor_b, axis = _compute_joint_params(pattern) + + joint = Joint( + joint_id=next_joint_id, + body_a=pattern.body_a, + body_b=pattern.body_b, + joint_type=joint_type, + anchor_a=anchor_a, + anchor_b=anchor_b, + axis=axis, + ) + joints.append(joint) + + joint_to_mates[next_joint_id] = mate_ids + for mid in mate_ids: + mate_to_joint.setdefault(mid, []).append(next_joint_id) + consumed_mate_ids.add(mid) + + next_joint_id += 1 + + # Second pass: emit individual joints for unconsumed mates + for mate in mates: + if mate.mate_id in consumed_mate_ids: + continue + + joint = _convert_single_mate(mate, next_joint_id) + joints.append(joint) + + joint_to_mates[next_joint_id] = [mate.mate_id] + mate_to_joint.setdefault(mate.mate_id, []).append(next_joint_id) + + next_joint_id += 1 + + return joints, mate_to_joint, joint_to_mates + + +# --------------------------------------------------------------------------- +# Full analysis pipeline +# --------------------------------------------------------------------------- + + +def analyze_mate_assembly( + bodies: list[RigidBody], + mates: list[Mate], + ground_body: int | None = None, +) -> MateAnalysisResult: + """Run the full analysis pipeline on a mate-based assembly. + + Orchestrates: recognize_patterns -> convert_mates_to_joints -> + label_assembly, returning a combined result with full traceability. + + Args: + bodies: Rigid bodies in the assembly. + mates: Mate constraints between the bodies. + ground_body: If set, this body is fixed to the world. + + Returns: + MateAnalysisResult with patterns, joints, mappings, and labels. + """ + patterns = recognize_patterns(mates) + joints, mate_to_joint, joint_to_mates = convert_mates_to_joints(mates, bodies) + + labels = label_assembly(bodies, joints, ground_body) + + return MateAnalysisResult( + patterns=patterns, + joints=joints, + mate_to_joint=mate_to_joint, + joint_to_mates=joint_to_mates, + analysis=labels.analysis, + labels=labels, + ) diff --git a/tests/mates/test_conversion.py b/tests/mates/test_conversion.py new file mode 100644 index 0000000..b2740ca --- /dev/null +++ b/tests/mates/test_conversion.py @@ -0,0 +1,287 @@ +"""Tests for solver.mates.conversion -- mate-to-joint conversion.""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.types import JointType, RigidBody +from solver.mates.conversion import ( + MateAnalysisResult, + analyze_mate_assembly, + convert_mates_to_joints, +) +from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ref( + body_id: int, + geom_type: GeometryType, + *, + origin: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> GeometryRef: + """Factory for GeometryRef with sensible defaults.""" + if origin is None: + origin = np.zeros(3) + if direction is None and geom_type in { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + }: + direction = np.array([0.0, 0.0, 1.0]) + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id="Geom001", + origin=origin, + direction=direction, + ) + + +def _make_bodies(n: int) -> list[RigidBody]: + """Create n bodies at distinct positions.""" + return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)] + + +# --------------------------------------------------------------------------- +# convert_mates_to_joints +# --------------------------------------------------------------------------- + + +class TestConvertMatesToJoints: + """convert_mates_to_joints function.""" + + def test_empty_input(self) -> None: + joints, m2j, j2m = convert_mates_to_joints([]) + assert joints == [] + assert m2j == {} + assert j2m == {} + + def test_hinge_pattern(self) -> None: + """Concentric + Coincident(plane) -> single REVOLUTE joint.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + joints, m2j, j2m = convert_mates_to_joints(mates) + assert len(joints) == 1 + assert joints[0].joint_type is JointType.REVOLUTE + assert joints[0].body_a == 0 + assert joints[0].body_b == 1 + # Both mates map to the single joint + assert 0 in m2j + assert 1 in m2j + assert j2m[joints[0].joint_id] == [0, 1] + + def test_lock_pattern(self) -> None: + """Lock -> FIXED joint.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + joints, _m2j, _j2m = convert_mates_to_joints(mates) + assert len(joints) == 1 + assert joints[0].joint_type is JointType.FIXED + + def test_unmatched_mate_fallback(self) -> None: + """A single ANGLE mate with no pattern -> individual joint.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.ANGLE, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + joints, _m2j, _j2m = convert_mates_to_joints(mates) + assert len(joints) == 1 + assert joints[0].joint_type is JointType.PERPENDICULAR + + def test_mapping_consistency(self) -> None: + """mate_to_joint and joint_to_mates are consistent.""" + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + Mate( + mate_id=2, + mate_type=MateType.DISTANCE, + ref_a=_make_ref(2, GeometryType.POINT), + ref_b=_make_ref(3, GeometryType.POINT), + ), + ] + joints, m2j, j2m = convert_mates_to_joints(mates) + # Every mate should be in m2j + for mate in mates: + assert mate.mate_id in m2j + # Every joint should be in j2m + for joint in joints: + assert joint.joint_id in j2m + + def test_joint_axis_from_geometry(self) -> None: + """Joint axis should come from mate geometry direction.""" + axis_dir = np.array([1.0, 0.0, 0.0]) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS, direction=axis_dir), + ref_b=_make_ref(1, GeometryType.AXIS, direction=axis_dir), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + joints, _, _ = convert_mates_to_joints(mates) + np.testing.assert_array_almost_equal(joints[0].axis, axis_dir) + + +# --------------------------------------------------------------------------- +# MateAnalysisResult +# --------------------------------------------------------------------------- + + +class TestMateAnalysisResult: + """MateAnalysisResult dataclass.""" + + def test_to_dict(self) -> None: + result = MateAnalysisResult( + patterns=[], + joints=[], + ) + d = result.to_dict() + assert d["patterns"] == [] + assert d["joints"] == [] + assert d["labels"] is None + + +# --------------------------------------------------------------------------- +# analyze_mate_assembly +# --------------------------------------------------------------------------- + + +class TestAnalyzeMateAssembly: + """Full pipeline: mates -> joints -> analysis.""" + + def test_two_bodies_hinge(self) -> None: + """Two bodies connected by hinge mates -> underconstrained (1 DOF).""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + result = analyze_mate_assembly(bodies, mates) + assert result.analysis is not None + assert result.labels is not None + # A revolute joint removes 5 DOF, leaving 1 internal DOF + assert result.analysis.combinatorial_internal_dof == 1 + assert len(result.joints) == 1 + assert result.joints[0].joint_type is JointType.REVOLUTE + + def test_two_bodies_fixed(self) -> None: + """Two bodies with lock mate -> well-constrained.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = analyze_mate_assembly(bodies, mates) + assert result.analysis is not None + assert result.analysis.combinatorial_internal_dof == 0 + assert result.analysis.is_rigid + + def test_grounded_assembly(self) -> None: + """Grounded assembly analysis works.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = analyze_mate_assembly(bodies, mates, ground_body=0) + assert result.analysis is not None + assert result.analysis.is_rigid + + def test_no_mates(self) -> None: + """Assembly with no mates should be fully underconstrained.""" + bodies = _make_bodies(2) + result = analyze_mate_assembly(bodies, []) + assert result.analysis is not None + assert result.analysis.combinatorial_internal_dof == 6 + assert len(result.joints) == 0 + + def test_single_body(self) -> None: + """Single body, no mates.""" + bodies = _make_bodies(1) + result = analyze_mate_assembly(bodies, []) + assert result.analysis is not None + assert len(result.joints) == 0 + + def test_result_traceability(self) -> None: + """mate_to_joint and joint_to_mates populated in result.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + result = analyze_mate_assembly(bodies, mates) + assert 0 in result.mate_to_joint + assert 1 in result.mate_to_joint + assert len(result.joint_to_mates) > 0