feat(mates): add mate-level ground truth labels
MateLabel and MateAssemblyLabels dataclasses with label_mate_assembly() that back-attributes joint-level independence to originating mates. Detects redundant and degenerate mates with pattern membership tracking. Closes #15
This commit is contained in:
@@ -9,6 +9,11 @@ from solver.mates.generator import (
|
||||
SyntheticMateGenerator,
|
||||
generate_mate_training_batch,
|
||||
)
|
||||
from solver.mates.labeling import (
|
||||
MateAssemblyLabels,
|
||||
MateLabel,
|
||||
label_mate_assembly,
|
||||
)
|
||||
from solver.mates.patterns import (
|
||||
JointPattern,
|
||||
PatternMatch,
|
||||
@@ -28,6 +33,8 @@ __all__ = [
|
||||
"JointPattern",
|
||||
"Mate",
|
||||
"MateAnalysisResult",
|
||||
"MateAssemblyLabels",
|
||||
"MateLabel",
|
||||
"MateType",
|
||||
"PatternMatch",
|
||||
"SyntheticMateGenerator",
|
||||
@@ -35,5 +42,6 @@ __all__ = [
|
||||
"convert_mates_to_joints",
|
||||
"dof_removed",
|
||||
"generate_mate_training_batch",
|
||||
"label_mate_assembly",
|
||||
"recognize_patterns",
|
||||
]
|
||||
|
||||
224
solver/mates/labeling.py
Normal file
224
solver/mates/labeling.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Mate-level ground truth labels for assembly analysis.
|
||||
|
||||
Back-attributes joint-level independence results to originating mates
|
||||
via the mate-to-joint mapping from conversion.py. Produces per-mate
|
||||
labels indicating whether each mate is independent, redundant, or
|
||||
degenerate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.mates.conversion import analyze_mate_assembly
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from solver.datagen.labeling import AssemblyLabel
|
||||
from solver.datagen.types import ConstraintAnalysis, RigidBody
|
||||
from solver.mates.conversion import MateAnalysisResult
|
||||
from solver.mates.patterns import JointPattern, PatternMatch
|
||||
from solver.mates.primitives import Mate
|
||||
|
||||
__all__ = [
|
||||
"MateAssemblyLabels",
|
||||
"MateLabel",
|
||||
"label_mate_assembly",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Label dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MateLabel:
|
||||
"""Per-mate ground truth label.
|
||||
|
||||
Attributes:
|
||||
mate_id: The mate this label refers to.
|
||||
is_independent: Contributes non-redundant DOF removal.
|
||||
is_redundant: Fully redundant (removable without DOF change).
|
||||
is_degenerate: Combinatorially independent but geometrically dependent.
|
||||
pattern: Which joint pattern this mate belongs to, if any.
|
||||
issue: Detected issue type, if any.
|
||||
"""
|
||||
|
||||
mate_id: int
|
||||
is_independent: bool = True
|
||||
is_redundant: bool = False
|
||||
is_degenerate: bool = False
|
||||
pattern: JointPattern | None = None
|
||||
issue: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"mate_id": self.mate_id,
|
||||
"is_independent": self.is_independent,
|
||||
"is_redundant": self.is_redundant,
|
||||
"is_degenerate": self.is_degenerate,
|
||||
"pattern": self.pattern.value if self.pattern else None,
|
||||
"issue": self.issue,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MateAssemblyLabels:
|
||||
"""Complete mate-level ground truth labels for an assembly.
|
||||
|
||||
Attributes:
|
||||
per_mate: Per-mate labels.
|
||||
patterns: Recognized joint patterns.
|
||||
assembly: Assembly-wide summary label.
|
||||
analysis: Constraint analysis from pebble game + Jacobian.
|
||||
"""
|
||||
|
||||
per_mate: list[MateLabel]
|
||||
patterns: list[PatternMatch]
|
||||
assembly: AssemblyLabel
|
||||
analysis: ConstraintAnalysis
|
||||
mate_analysis: MateAnalysisResult | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"per_mate": [ml.to_dict() for ml in self.per_mate],
|
||||
"patterns": [p.to_dict() for p in self.patterns],
|
||||
"assembly": {
|
||||
"classification": self.assembly.classification,
|
||||
"total_dof": self.assembly.total_dof,
|
||||
"redundant_count": self.assembly.redundant_count,
|
||||
"is_rigid": self.assembly.is_rigid,
|
||||
"is_minimally_rigid": self.assembly.is_minimally_rigid,
|
||||
"has_degeneracy": self.assembly.has_degeneracy,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Labeling logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_mate_pattern_map(
|
||||
patterns: list[PatternMatch],
|
||||
) -> dict[int, JointPattern]:
|
||||
"""Map mate_ids to the pattern they belong to (best match)."""
|
||||
result: dict[int, JointPattern] = {}
|
||||
# Sort by confidence descending so best matches win
|
||||
sorted_patterns = sorted(patterns, key=lambda p: -p.confidence)
|
||||
for pm in sorted_patterns:
|
||||
if pm.confidence < 1.0:
|
||||
continue
|
||||
for mate in pm.mates:
|
||||
if mate.mate_id not in result:
|
||||
result[mate.mate_id] = pm.pattern
|
||||
return result
|
||||
|
||||
|
||||
def label_mate_assembly(
|
||||
bodies: list[RigidBody],
|
||||
mates: list[Mate],
|
||||
ground_body: int | None = None,
|
||||
) -> MateAssemblyLabels:
|
||||
"""Produce mate-level ground truth labels for an assembly.
|
||||
|
||||
Runs analyze_mate_assembly() internally, then back-attributes
|
||||
joint-level independence to originating mates via the mate_to_joint
|
||||
mapping.
|
||||
|
||||
A mate is:
|
||||
- **redundant** if ALL joints it contributes to are fully redundant
|
||||
- **degenerate** if any joint it contributes to is geometrically
|
||||
dependent but combinatorially independent
|
||||
- **independent** otherwise
|
||||
|
||||
Args:
|
||||
bodies: Rigid bodies in the assembly.
|
||||
mates: Mate constraints between the bodies.
|
||||
ground_body: If set, this body is fixed to the world.
|
||||
|
||||
Returns:
|
||||
MateAssemblyLabels with per-mate labels and assembly summary.
|
||||
"""
|
||||
mate_result = analyze_mate_assembly(bodies, mates, ground_body)
|
||||
|
||||
# Build per-joint redundancy from labels
|
||||
joint_redundant: dict[int, bool] = {}
|
||||
joint_degenerate: dict[int, bool] = {}
|
||||
|
||||
if mate_result.labels is not None:
|
||||
for jl in mate_result.labels.per_joint:
|
||||
# A joint is fully redundant if all its constraints are redundant
|
||||
joint_redundant[jl.joint_id] = jl.redundant_count == jl.total and jl.total > 0
|
||||
# Joint is degenerate if it has more independent constraints
|
||||
# than Jacobian rank would suggest (geometric degeneracy)
|
||||
joint_degenerate[jl.joint_id] = False
|
||||
|
||||
# Check for geometric degeneracy via per-constraint labels
|
||||
for cl in mate_result.labels.per_constraint:
|
||||
if cl.pebble_independent and not cl.jacobian_independent:
|
||||
joint_degenerate[cl.joint_id] = True
|
||||
|
||||
# Build pattern membership map
|
||||
pattern_map = _build_mate_pattern_map(mate_result.patterns)
|
||||
|
||||
# Back-attribute to mates
|
||||
per_mate: list[MateLabel] = []
|
||||
for mate in mates:
|
||||
mate_joint_ids = mate_result.mate_to_joint.get(mate.mate_id, [])
|
||||
|
||||
if not mate_joint_ids:
|
||||
# Mate wasn't converted to any joint (shouldn't happen, but safe)
|
||||
per_mate.append(
|
||||
MateLabel(
|
||||
mate_id=mate.mate_id,
|
||||
is_independent=False,
|
||||
is_redundant=True,
|
||||
issue="unmapped",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Redundant if ALL contributed joints are redundant
|
||||
all_redundant = all(joint_redundant.get(jid, False) for jid in mate_joint_ids)
|
||||
|
||||
# Degenerate if ANY contributed joint is degenerate
|
||||
any_degenerate = any(joint_degenerate.get(jid, False) for jid in mate_joint_ids)
|
||||
|
||||
is_independent = not all_redundant
|
||||
pattern = pattern_map.get(mate.mate_id)
|
||||
|
||||
# Determine issue string
|
||||
issue: str | None = None
|
||||
if all_redundant:
|
||||
issue = "redundant"
|
||||
elif any_degenerate:
|
||||
issue = "degenerate"
|
||||
|
||||
per_mate.append(
|
||||
MateLabel(
|
||||
mate_id=mate.mate_id,
|
||||
is_independent=is_independent,
|
||||
is_redundant=all_redundant,
|
||||
is_degenerate=any_degenerate,
|
||||
pattern=pattern,
|
||||
issue=issue,
|
||||
)
|
||||
)
|
||||
|
||||
# Assembly label
|
||||
assert mate_result.labels is not None
|
||||
assembly_label = mate_result.labels.assembly
|
||||
|
||||
return MateAssemblyLabels(
|
||||
per_mate=per_mate,
|
||||
patterns=mate_result.patterns,
|
||||
assembly=assembly_label,
|
||||
analysis=mate_result.labels.analysis,
|
||||
mate_analysis=mate_result,
|
||||
)
|
||||
224
tests/mates/test_labeling.py
Normal file
224
tests/mates/test_labeling.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Tests for solver.mates.labeling -- mate-level ground truth labels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.types import RigidBody
|
||||
from solver.mates.labeling import MateAssemblyLabels, MateLabel, label_mate_assembly
|
||||
from solver.mates.patterns import JointPattern
|
||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ref(
|
||||
body_id: int,
|
||||
geom_type: GeometryType,
|
||||
*,
|
||||
origin: np.ndarray | None = None,
|
||||
direction: np.ndarray | None = None,
|
||||
) -> GeometryRef:
|
||||
"""Factory for GeometryRef with sensible defaults."""
|
||||
if origin is None:
|
||||
origin = np.zeros(3)
|
||||
if direction is None and geom_type in {
|
||||
GeometryType.FACE,
|
||||
GeometryType.AXIS,
|
||||
GeometryType.PLANE,
|
||||
}:
|
||||
direction = np.array([0.0, 0.0, 1.0])
|
||||
return GeometryRef(
|
||||
body_id=body_id,
|
||||
geometry_type=geom_type,
|
||||
geometry_id="Geom001",
|
||||
origin=origin,
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
def _make_bodies(n: int) -> list[RigidBody]:
|
||||
"""Create n bodies at distinct positions."""
|
||||
return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MateLabel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMateLabel:
|
||||
"""MateLabel dataclass."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
ml = MateLabel(mate_id=0)
|
||||
assert ml.is_independent is True
|
||||
assert ml.is_redundant is False
|
||||
assert ml.is_degenerate is False
|
||||
assert ml.pattern is None
|
||||
assert ml.issue is None
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
ml = MateLabel(
|
||||
mate_id=5,
|
||||
is_independent=False,
|
||||
is_redundant=True,
|
||||
pattern=JointPattern.HINGE,
|
||||
issue="redundant",
|
||||
)
|
||||
d = ml.to_dict()
|
||||
assert d["mate_id"] == 5
|
||||
assert d["is_redundant"] is True
|
||||
assert d["pattern"] == "hinge"
|
||||
assert d["issue"] == "redundant"
|
||||
|
||||
def test_to_dict_none_pattern(self) -> None:
|
||||
ml = MateLabel(mate_id=0)
|
||||
d = ml.to_dict()
|
||||
assert d["pattern"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MateAssemblyLabels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMateAssemblyLabels:
|
||||
"""MateAssemblyLabels dataclass."""
|
||||
|
||||
def test_to_dict_structure(self) -> None:
|
||||
"""to_dict produces expected keys."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
d = result.to_dict()
|
||||
assert "per_mate" in d
|
||||
assert "patterns" in d
|
||||
assert "assembly" in d
|
||||
assert isinstance(d["per_mate"], list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# label_mate_assembly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLabelMateAssembly:
|
||||
"""Full labeling pipeline."""
|
||||
|
||||
def test_clean_assembly_no_redundancy(self) -> None:
|
||||
"""Two bodies with lock mate -> clean, no redundancy."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert isinstance(result, MateAssemblyLabels)
|
||||
assert len(result.per_mate) == 1
|
||||
ml = result.per_mate[0]
|
||||
assert ml.mate_id == 0
|
||||
assert ml.is_independent is True
|
||||
assert ml.is_redundant is False
|
||||
assert ml.issue is None
|
||||
|
||||
def test_redundant_assembly(self) -> None:
|
||||
"""Two lock mates on same body pair -> one is redundant."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
|
||||
ref_b=_make_ref(1, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert len(result.per_mate) == 2
|
||||
redundant_count = sum(1 for ml in result.per_mate if ml.is_redundant)
|
||||
# At least one should be redundant
|
||||
assert redundant_count >= 1
|
||||
assert result.assembly.redundant_count > 0
|
||||
|
||||
def test_hinge_pattern_labeling(self) -> None:
|
||||
"""Hinge mates get pattern membership."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.COINCIDENT,
|
||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert len(result.per_mate) == 2
|
||||
# Both mates should be part of the hinge pattern
|
||||
for ml in result.per_mate:
|
||||
assert ml.pattern is JointPattern.HINGE
|
||||
assert ml.is_independent is True
|
||||
|
||||
def test_grounded_assembly(self) -> None:
|
||||
"""Grounded assembly labeling works."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates, ground_body=0)
|
||||
assert result.assembly.is_rigid
|
||||
|
||||
def test_empty_mates(self) -> None:
|
||||
"""No mates -> no per_mate labels, underconstrained."""
|
||||
bodies = _make_bodies(2)
|
||||
result = label_mate_assembly(bodies, [])
|
||||
assert len(result.per_mate) == 0
|
||||
assert result.assembly.classification == "underconstrained"
|
||||
|
||||
def test_assembly_classification(self) -> None:
|
||||
"""Assembly classification is present."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert result.assembly.classification in {
|
||||
"well-constrained",
|
||||
"overconstrained",
|
||||
"underconstrained",
|
||||
"mixed",
|
||||
}
|
||||
Reference in New Issue
Block a user