diff --git a/solver/mates/__init__.py b/solver/mates/__init__.py index 0496bf4..884e09d 100644 --- a/solver/mates/__init__.py +++ b/solver/mates/__init__.py @@ -9,6 +9,11 @@ from solver.mates.generator import ( SyntheticMateGenerator, generate_mate_training_batch, ) +from solver.mates.labeling import ( + MateAssemblyLabels, + MateLabel, + label_mate_assembly, +) from solver.mates.patterns import ( JointPattern, PatternMatch, @@ -28,6 +33,8 @@ __all__ = [ "JointPattern", "Mate", "MateAnalysisResult", + "MateAssemblyLabels", + "MateLabel", "MateType", "PatternMatch", "SyntheticMateGenerator", @@ -35,5 +42,6 @@ __all__ = [ "convert_mates_to_joints", "dof_removed", "generate_mate_training_batch", + "label_mate_assembly", "recognize_patterns", ] diff --git a/solver/mates/labeling.py b/solver/mates/labeling.py new file mode 100644 index 0000000..deb7507 --- /dev/null +++ b/solver/mates/labeling.py @@ -0,0 +1,224 @@ +"""Mate-level ground truth labels for assembly analysis. + +Back-attributes joint-level independence results to originating mates +via the mate-to-joint mapping from conversion.py. Produces per-mate +labels indicating whether each mate is independent, redundant, or +degenerate. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from solver.mates.conversion import analyze_mate_assembly + +if TYPE_CHECKING: + from typing import Any + + from solver.datagen.labeling import AssemblyLabel + from solver.datagen.types import ConstraintAnalysis, RigidBody + from solver.mates.conversion import MateAnalysisResult + from solver.mates.patterns import JointPattern, PatternMatch + from solver.mates.primitives import Mate + +__all__ = [ + "MateAssemblyLabels", + "MateLabel", + "label_mate_assembly", +] + + +# --------------------------------------------------------------------------- +# Label dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class MateLabel: + """Per-mate ground truth label. + + Attributes: + mate_id: The mate this label refers to. + is_independent: Contributes non-redundant DOF removal. + is_redundant: Fully redundant (removable without DOF change). + is_degenerate: Combinatorially independent but geometrically dependent. + pattern: Which joint pattern this mate belongs to, if any. + issue: Detected issue type, if any. + """ + + mate_id: int + is_independent: bool = True + is_redundant: bool = False + is_degenerate: bool = False + pattern: JointPattern | None = None + issue: str | None = None + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "mate_id": self.mate_id, + "is_independent": self.is_independent, + "is_redundant": self.is_redundant, + "is_degenerate": self.is_degenerate, + "pattern": self.pattern.value if self.pattern else None, + "issue": self.issue, + } + + +@dataclass +class MateAssemblyLabels: + """Complete mate-level ground truth labels for an assembly. + + Attributes: + per_mate: Per-mate labels. + patterns: Recognized joint patterns. + assembly: Assembly-wide summary label. + analysis: Constraint analysis from pebble game + Jacobian. + """ + + per_mate: list[MateLabel] + patterns: list[PatternMatch] + assembly: AssemblyLabel + analysis: ConstraintAnalysis + mate_analysis: MateAnalysisResult | None = None + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "per_mate": [ml.to_dict() for ml in self.per_mate], + "patterns": [p.to_dict() for p in self.patterns], + "assembly": { + "classification": self.assembly.classification, + "total_dof": self.assembly.total_dof, + "redundant_count": self.assembly.redundant_count, + "is_rigid": self.assembly.is_rigid, + "is_minimally_rigid": self.assembly.is_minimally_rigid, + "has_degeneracy": self.assembly.has_degeneracy, + }, + } + + +# --------------------------------------------------------------------------- +# Labeling logic +# --------------------------------------------------------------------------- + + +def _build_mate_pattern_map( + patterns: list[PatternMatch], +) -> dict[int, JointPattern]: + """Map mate_ids to the pattern they belong to (best match).""" + result: dict[int, JointPattern] = {} + # Sort by confidence descending so best matches win + sorted_patterns = sorted(patterns, key=lambda p: -p.confidence) + for pm in sorted_patterns: + if pm.confidence < 1.0: + continue + for mate in pm.mates: + if mate.mate_id not in result: + result[mate.mate_id] = pm.pattern + return result + + +def label_mate_assembly( + bodies: list[RigidBody], + mates: list[Mate], + ground_body: int | None = None, +) -> MateAssemblyLabels: + """Produce mate-level ground truth labels for an assembly. + + Runs analyze_mate_assembly() internally, then back-attributes + joint-level independence to originating mates via the mate_to_joint + mapping. + + A mate is: + - **redundant** if ALL joints it contributes to are fully redundant + - **degenerate** if any joint it contributes to is geometrically + dependent but combinatorially independent + - **independent** otherwise + + Args: + bodies: Rigid bodies in the assembly. + mates: Mate constraints between the bodies. + ground_body: If set, this body is fixed to the world. + + Returns: + MateAssemblyLabels with per-mate labels and assembly summary. + """ + mate_result = analyze_mate_assembly(bodies, mates, ground_body) + + # Build per-joint redundancy from labels + joint_redundant: dict[int, bool] = {} + joint_degenerate: dict[int, bool] = {} + + if mate_result.labels is not None: + for jl in mate_result.labels.per_joint: + # A joint is fully redundant if all its constraints are redundant + joint_redundant[jl.joint_id] = jl.redundant_count == jl.total and jl.total > 0 + # Joint is degenerate if it has more independent constraints + # than Jacobian rank would suggest (geometric degeneracy) + joint_degenerate[jl.joint_id] = False + + # Check for geometric degeneracy via per-constraint labels + for cl in mate_result.labels.per_constraint: + if cl.pebble_independent and not cl.jacobian_independent: + joint_degenerate[cl.joint_id] = True + + # Build pattern membership map + pattern_map = _build_mate_pattern_map(mate_result.patterns) + + # Back-attribute to mates + per_mate: list[MateLabel] = [] + for mate in mates: + mate_joint_ids = mate_result.mate_to_joint.get(mate.mate_id, []) + + if not mate_joint_ids: + # Mate wasn't converted to any joint (shouldn't happen, but safe) + per_mate.append( + MateLabel( + mate_id=mate.mate_id, + is_independent=False, + is_redundant=True, + issue="unmapped", + ) + ) + continue + + # Redundant if ALL contributed joints are redundant + all_redundant = all(joint_redundant.get(jid, False) for jid in mate_joint_ids) + + # Degenerate if ANY contributed joint is degenerate + any_degenerate = any(joint_degenerate.get(jid, False) for jid in mate_joint_ids) + + is_independent = not all_redundant + pattern = pattern_map.get(mate.mate_id) + + # Determine issue string + issue: str | None = None + if all_redundant: + issue = "redundant" + elif any_degenerate: + issue = "degenerate" + + per_mate.append( + MateLabel( + mate_id=mate.mate_id, + is_independent=is_independent, + is_redundant=all_redundant, + is_degenerate=any_degenerate, + pattern=pattern, + issue=issue, + ) + ) + + # Assembly label + assert mate_result.labels is not None + assembly_label = mate_result.labels.assembly + + return MateAssemblyLabels( + per_mate=per_mate, + patterns=mate_result.patterns, + assembly=assembly_label, + analysis=mate_result.labels.analysis, + mate_analysis=mate_result, + ) diff --git a/tests/mates/test_labeling.py b/tests/mates/test_labeling.py new file mode 100644 index 0000000..1ed9f63 --- /dev/null +++ b/tests/mates/test_labeling.py @@ -0,0 +1,224 @@ +"""Tests for solver.mates.labeling -- mate-level ground truth labels.""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.types import RigidBody +from solver.mates.labeling import MateAssemblyLabels, MateLabel, label_mate_assembly +from solver.mates.patterns import JointPattern +from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_ref( + body_id: int, + geom_type: GeometryType, + *, + origin: np.ndarray | None = None, + direction: np.ndarray | None = None, +) -> GeometryRef: + """Factory for GeometryRef with sensible defaults.""" + if origin is None: + origin = np.zeros(3) + if direction is None and geom_type in { + GeometryType.FACE, + GeometryType.AXIS, + GeometryType.PLANE, + }: + direction = np.array([0.0, 0.0, 1.0]) + return GeometryRef( + body_id=body_id, + geometry_type=geom_type, + geometry_id="Geom001", + origin=origin, + direction=direction, + ) + + +def _make_bodies(n: int) -> list[RigidBody]: + """Create n bodies at distinct positions.""" + return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)] + + +# --------------------------------------------------------------------------- +# MateLabel +# --------------------------------------------------------------------------- + + +class TestMateLabel: + """MateLabel dataclass.""" + + def test_defaults(self) -> None: + ml = MateLabel(mate_id=0) + assert ml.is_independent is True + assert ml.is_redundant is False + assert ml.is_degenerate is False + assert ml.pattern is None + assert ml.issue is None + + def test_to_dict(self) -> None: + ml = MateLabel( + mate_id=5, + is_independent=False, + is_redundant=True, + pattern=JointPattern.HINGE, + issue="redundant", + ) + d = ml.to_dict() + assert d["mate_id"] == 5 + assert d["is_redundant"] is True + assert d["pattern"] == "hinge" + assert d["issue"] == "redundant" + + def test_to_dict_none_pattern(self) -> None: + ml = MateLabel(mate_id=0) + d = ml.to_dict() + assert d["pattern"] is None + + +# --------------------------------------------------------------------------- +# MateAssemblyLabels +# --------------------------------------------------------------------------- + + +class TestMateAssemblyLabels: + """MateAssemblyLabels dataclass.""" + + def test_to_dict_structure(self) -> None: + """to_dict produces expected keys.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates) + d = result.to_dict() + assert "per_mate" in d + assert "patterns" in d + assert "assembly" in d + assert isinstance(d["per_mate"], list) + + +# --------------------------------------------------------------------------- +# label_mate_assembly +# --------------------------------------------------------------------------- + + +class TestLabelMateAssembly: + """Full labeling pipeline.""" + + def test_clean_assembly_no_redundancy(self) -> None: + """Two bodies with lock mate -> clean, no redundancy.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates) + assert isinstance(result, MateAssemblyLabels) + assert len(result.per_mate) == 1 + ml = result.per_mate[0] + assert ml.mate_id == 0 + assert ml.is_independent is True + assert ml.is_redundant is False + assert ml.issue is None + + def test_redundant_assembly(self) -> None: + """Two lock mates on same body pair -> one is redundant.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + Mate( + mate_id=1, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])), + ref_b=_make_ref(1, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])), + ), + ] + result = label_mate_assembly(bodies, mates) + assert len(result.per_mate) == 2 + redundant_count = sum(1 for ml in result.per_mate if ml.is_redundant) + # At least one should be redundant + assert redundant_count >= 1 + assert result.assembly.redundant_count > 0 + + def test_hinge_pattern_labeling(self) -> None: + """Hinge mates get pattern membership.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.CONCENTRIC, + ref_a=_make_ref(0, GeometryType.AXIS), + ref_b=_make_ref(1, GeometryType.AXIS), + ), + Mate( + mate_id=1, + mate_type=MateType.COINCIDENT, + ref_a=_make_ref(0, GeometryType.PLANE), + ref_b=_make_ref(1, GeometryType.PLANE), + ), + ] + result = label_mate_assembly(bodies, mates) + assert len(result.per_mate) == 2 + # Both mates should be part of the hinge pattern + for ml in result.per_mate: + assert ml.pattern is JointPattern.HINGE + assert ml.is_independent is True + + def test_grounded_assembly(self) -> None: + """Grounded assembly labeling works.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates, ground_body=0) + assert result.assembly.is_rigid + + def test_empty_mates(self) -> None: + """No mates -> no per_mate labels, underconstrained.""" + bodies = _make_bodies(2) + result = label_mate_assembly(bodies, []) + assert len(result.per_mate) == 0 + assert result.assembly.classification == "underconstrained" + + def test_assembly_classification(self) -> None: + """Assembly classification is present.""" + bodies = _make_bodies(2) + mates = [ + Mate( + mate_id=0, + mate_type=MateType.LOCK, + ref_a=_make_ref(0, GeometryType.FACE), + ref_b=_make_ref(1, GeometryType.FACE), + ), + ] + result = label_mate_assembly(bodies, mates) + assert result.assembly.classification in { + "well-constrained", + "overconstrained", + "underconstrained", + "mixed", + }