feat(mates): add mate-based synthetic assembly generator
SyntheticMateGenerator wraps existing joint generator with reverse mapping (joint->mates) and configurable noise injection (redundant, missing, incompatible mates). Batch generation via generate_mate_training_batch(). Closes #14
This commit is contained in:
@@ -5,6 +5,10 @@ from solver.mates.conversion import (
|
||||
analyze_mate_assembly,
|
||||
convert_mates_to_joints,
|
||||
)
|
||||
from solver.mates.generator import (
|
||||
SyntheticMateGenerator,
|
||||
generate_mate_training_batch,
|
||||
)
|
||||
from solver.mates.patterns import (
|
||||
JointPattern,
|
||||
PatternMatch,
|
||||
@@ -26,8 +30,10 @@ __all__ = [
|
||||
"MateAnalysisResult",
|
||||
"MateType",
|
||||
"PatternMatch",
|
||||
"SyntheticMateGenerator",
|
||||
"analyze_mate_assembly",
|
||||
"convert_mates_to_joints",
|
||||
"dof_removed",
|
||||
"generate_mate_training_batch",
|
||||
"recognize_patterns",
|
||||
]
|
||||
|
||||
315
solver/mates/generator.py
Normal file
315
solver/mates/generator.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Mate-based synthetic assembly generator.
|
||||
|
||||
Wraps SyntheticAssemblyGenerator to produce mate-level training data.
|
||||
Generates joint-based assemblies via the existing generator, then
|
||||
reverse-maps joints to plausible mate combinations. Supports noise
|
||||
injection (redundant, missing, incompatible mates) for robust training.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
from solver.mates.conversion import MateAnalysisResult, analyze_mate_assembly
|
||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"SyntheticMateGenerator",
|
||||
"generate_mate_training_batch",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reverse mapping: JointType -> list of (MateType, geom_a, geom_b) combos
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _MateSpec:
|
||||
"""Specification for a mate to generate from a joint."""
|
||||
|
||||
mate_type: MateType
|
||||
geom_a: GeometryType
|
||||
geom_b: GeometryType
|
||||
|
||||
|
||||
_JOINT_TO_MATES: dict[JointType, list[_MateSpec]] = {
|
||||
JointType.REVOLUTE: [
|
||||
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
||||
],
|
||||
JointType.CYLINDRICAL: [
|
||||
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
||||
],
|
||||
JointType.BALL: [
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),
|
||||
],
|
||||
JointType.FIXED: [
|
||||
_MateSpec(MateType.LOCK, GeometryType.FACE, GeometryType.FACE),
|
||||
],
|
||||
JointType.SLIDER: [
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
||||
_MateSpec(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS),
|
||||
],
|
||||
JointType.PLANAR: [
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SyntheticMateGenerator:
|
||||
"""Generates mate-based assemblies for training data.
|
||||
|
||||
Wraps SyntheticAssemblyGenerator to produce joint-based assemblies,
|
||||
then reverse-maps each joint to a plausible set of mate constraints.
|
||||
|
||||
Args:
|
||||
seed: Random seed for reproducibility.
|
||||
redundant_prob: Probability of injecting a redundant mate per joint.
|
||||
missing_prob: Probability of dropping a mate from a multi-mate pattern.
|
||||
incompatible_prob: Probability of injecting a mate with wrong geometry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seed: int = 42,
|
||||
*,
|
||||
redundant_prob: float = 0.0,
|
||||
missing_prob: float = 0.0,
|
||||
incompatible_prob: float = 0.0,
|
||||
) -> None:
|
||||
self._joint_gen = SyntheticAssemblyGenerator(seed=seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
self.redundant_prob = redundant_prob
|
||||
self.missing_prob = missing_prob
|
||||
self.incompatible_prob = incompatible_prob
|
||||
|
||||
def _make_geometry_ref(
|
||||
self,
|
||||
body_id: int,
|
||||
geom_type: GeometryType,
|
||||
joint: Joint,
|
||||
*,
|
||||
is_ref_a: bool = True,
|
||||
) -> GeometryRef:
|
||||
"""Create a GeometryRef from joint geometry.
|
||||
|
||||
Uses joint anchor, axis, and body_id to produce a ref
|
||||
with realistic geometry for the given type.
|
||||
"""
|
||||
origin = joint.anchor_a if is_ref_a else joint.anchor_b
|
||||
|
||||
direction: np.ndarray | None = None
|
||||
if geom_type in {GeometryType.AXIS, GeometryType.PLANE, GeometryType.FACE}:
|
||||
direction = joint.axis.copy()
|
||||
|
||||
geom_id = f"{geom_type.value.capitalize()}001"
|
||||
|
||||
return GeometryRef(
|
||||
body_id=body_id,
|
||||
geometry_type=geom_type,
|
||||
geometry_id=geom_id,
|
||||
origin=origin.copy(),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
def _reverse_map_joint(
|
||||
self,
|
||||
joint: Joint,
|
||||
next_mate_id: int,
|
||||
) -> list[Mate]:
|
||||
"""Convert a joint to its mate representation."""
|
||||
specs = _JOINT_TO_MATES.get(joint.joint_type, [])
|
||||
if not specs:
|
||||
# Fallback: emit a single DISTANCE mate
|
||||
specs = [_MateSpec(MateType.DISTANCE, GeometryType.POINT, GeometryType.POINT)]
|
||||
|
||||
mates: list[Mate] = []
|
||||
for spec in specs:
|
||||
ref_a = self._make_geometry_ref(joint.body_a, spec.geom_a, joint, is_ref_a=True)
|
||||
ref_b = self._make_geometry_ref(joint.body_b, spec.geom_b, joint, is_ref_a=False)
|
||||
mates.append(
|
||||
Mate(
|
||||
mate_id=next_mate_id + len(mates),
|
||||
mate_type=spec.mate_type,
|
||||
ref_a=ref_a,
|
||||
ref_b=ref_b,
|
||||
)
|
||||
)
|
||||
return mates
|
||||
|
||||
def _inject_noise(
|
||||
self,
|
||||
mates: list[Mate],
|
||||
next_mate_id: int,
|
||||
) -> list[Mate]:
|
||||
"""Apply noise injection to the mate list.
|
||||
|
||||
Modifies the list in-place and may add new mates.
|
||||
Returns the (possibly extended) list.
|
||||
"""
|
||||
result = list(mates)
|
||||
extra: list[Mate] = []
|
||||
|
||||
for mate in mates:
|
||||
# Redundant: duplicate a mate
|
||||
if self._rng.random() < self.redundant_prob:
|
||||
dup = Mate(
|
||||
mate_id=next_mate_id + len(extra),
|
||||
mate_type=mate.mate_type,
|
||||
ref_a=mate.ref_a,
|
||||
ref_b=mate.ref_b,
|
||||
value=mate.value,
|
||||
tolerance=mate.tolerance,
|
||||
)
|
||||
extra.append(dup)
|
||||
|
||||
# Incompatible: wrong geometry type
|
||||
if self._rng.random() < self.incompatible_prob:
|
||||
bad_geom = GeometryType.POINT
|
||||
bad_ref = GeometryRef(
|
||||
body_id=mate.ref_a.body_id,
|
||||
geometry_type=bad_geom,
|
||||
geometry_id="BadGeom001",
|
||||
origin=mate.ref_a.origin.copy(),
|
||||
direction=None,
|
||||
)
|
||||
extra.append(
|
||||
Mate(
|
||||
mate_id=next_mate_id + len(extra),
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=bad_ref,
|
||||
ref_b=mate.ref_b,
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(extra)
|
||||
|
||||
# Missing: drop mates from multi-mate patterns (only if > 1 mate
|
||||
# for same body pair)
|
||||
if self.missing_prob > 0:
|
||||
filtered: list[Mate] = []
|
||||
for mate in result:
|
||||
if self._rng.random() < self.missing_prob:
|
||||
continue
|
||||
filtered.append(mate)
|
||||
# Ensure at least one mate remains
|
||||
if not filtered and result:
|
||||
filtered = [result[0]]
|
||||
result = filtered
|
||||
|
||||
return result
|
||||
|
||||
def generate(
|
||||
self,
|
||||
n_bodies: int = 4,
|
||||
*,
|
||||
grounded: bool = False,
|
||||
) -> tuple[list[RigidBody], list[Mate], MateAnalysisResult]:
|
||||
"""Generate a mate-based assembly.
|
||||
|
||||
Args:
|
||||
n_bodies: Number of rigid bodies.
|
||||
grounded: Whether to ground the first body.
|
||||
|
||||
Returns:
|
||||
(bodies, mates, analysis_result) tuple.
|
||||
"""
|
||||
bodies, joints, _analysis = self._joint_gen.generate_chain_assembly(
|
||||
n_bodies,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
grounded=grounded,
|
||||
)
|
||||
|
||||
mates: list[Mate] = []
|
||||
next_id = 0
|
||||
for joint in joints:
|
||||
joint_mates = self._reverse_map_joint(joint, next_id)
|
||||
mates.extend(joint_mates)
|
||||
next_id += len(joint_mates)
|
||||
|
||||
# Apply noise
|
||||
mates = self._inject_noise(mates, next_id)
|
||||
|
||||
ground_body = bodies[0].body_id if grounded else None
|
||||
result = analyze_mate_assembly(bodies, mates, ground_body)
|
||||
|
||||
return bodies, mates, result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_mate_training_batch(
|
||||
batch_size: int = 100,
|
||||
n_bodies_range: tuple[int, int] = (3, 8),
|
||||
seed: int = 42,
|
||||
*,
|
||||
redundant_prob: float = 0.0,
|
||||
missing_prob: float = 0.0,
|
||||
incompatible_prob: float = 0.0,
|
||||
grounded_ratio: float = 1.0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Produce a batch of mate-level training examples.
|
||||
|
||||
Args:
|
||||
batch_size: Number of assemblies to generate.
|
||||
n_bodies_range: (min, max_exclusive) body count.
|
||||
seed: Random seed.
|
||||
redundant_prob: Probability of redundant mate injection.
|
||||
missing_prob: Probability of missing mate injection.
|
||||
incompatible_prob: Probability of incompatible mate injection.
|
||||
grounded_ratio: Fraction of assemblies that are grounded.
|
||||
|
||||
Returns:
|
||||
List of dicts with bodies, mates, patterns, and labels.
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
examples: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
gen = SyntheticMateGenerator(
|
||||
seed=seed + i,
|
||||
redundant_prob=redundant_prob,
|
||||
missing_prob=missing_prob,
|
||||
incompatible_prob=incompatible_prob,
|
||||
)
|
||||
n = int(rng.integers(*n_bodies_range))
|
||||
grounded = bool(rng.random() < grounded_ratio)
|
||||
|
||||
bodies, mates, result = gen.generate(n, grounded=grounded)
|
||||
|
||||
examples.append(
|
||||
{
|
||||
"bodies": [
|
||||
{
|
||||
"body_id": b.body_id,
|
||||
"position": b.position.tolist(),
|
||||
}
|
||||
for b in bodies
|
||||
],
|
||||
"mates": [m.to_dict() for m in mates],
|
||||
"patterns": [p.to_dict() for p in result.patterns],
|
||||
"labels": result.labels.to_dict() if result.labels else None,
|
||||
"n_bodies": len(bodies),
|
||||
"n_mates": len(mates),
|
||||
"n_joints": len(result.joints),
|
||||
}
|
||||
)
|
||||
|
||||
return examples
|
||||
155
tests/mates/test_generator.py
Normal file
155
tests/mates/test_generator.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Tests for solver.mates.generator -- synthetic mate generator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from solver.mates.generator import SyntheticMateGenerator, generate_mate_training_batch
|
||||
from solver.mates.primitives import MateType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SyntheticMateGenerator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyntheticMateGenerator:
|
||||
"""SyntheticMateGenerator core functionality."""
|
||||
|
||||
def test_generate_basic(self) -> None:
|
||||
"""Generate a simple assembly with mates."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
bodies, mates, result = gen.generate(3)
|
||||
assert len(bodies) == 3
|
||||
assert len(mates) > 0
|
||||
assert result.analysis is not None
|
||||
|
||||
def test_deterministic_with_seed(self) -> None:
|
||||
"""Same seed produces same output."""
|
||||
gen1 = SyntheticMateGenerator(seed=123)
|
||||
_, mates1, _ = gen1.generate(3)
|
||||
|
||||
gen2 = SyntheticMateGenerator(seed=123)
|
||||
_, mates2, _ = gen2.generate(3)
|
||||
|
||||
assert len(mates1) == len(mates2)
|
||||
for m1, m2 in zip(mates1, mates2, strict=True):
|
||||
assert m1.mate_type == m2.mate_type
|
||||
assert m1.ref_a.body_id == m2.ref_a.body_id
|
||||
|
||||
def test_grounded(self) -> None:
|
||||
"""Grounded assembly should work."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
bodies, _mates, result = gen.generate(3, grounded=True)
|
||||
assert len(bodies) == 3
|
||||
assert result.analysis is not None
|
||||
|
||||
def test_revolute_produces_two_mates(self) -> None:
|
||||
"""A revolute joint should reverse-map to 2 mates."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
_bodies, mates, _result = gen.generate(2)
|
||||
# 2 bodies -> 1 revolute joint -> 2 mates (concentric + coincident)
|
||||
assert len(mates) == 2
|
||||
mate_types = {m.mate_type for m in mates}
|
||||
assert MateType.CONCENTRIC in mate_types
|
||||
assert MateType.COINCIDENT in mate_types
|
||||
|
||||
|
||||
class TestReverseMapping:
|
||||
"""Reverse mapping from joints to mates."""
|
||||
|
||||
def test_revolute_mapping(self) -> None:
|
||||
"""REVOLUTE -> Concentric + Coincident."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
_bodies, mates, _result = gen.generate(2)
|
||||
types = [m.mate_type for m in mates]
|
||||
assert MateType.CONCENTRIC in types
|
||||
assert MateType.COINCIDENT in types
|
||||
|
||||
def test_round_trip_analysis(self) -> None:
|
||||
"""Generated mates round-trip through analysis successfully."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
_bodies, _mates, result = gen.generate(4)
|
||||
assert result.analysis is not None
|
||||
assert result.labels is not None
|
||||
# Should produce joints from the mates
|
||||
assert len(result.joints) > 0
|
||||
|
||||
|
||||
class TestNoiseInjection:
|
||||
"""Noise injection mechanisms."""
|
||||
|
||||
def test_redundant_injection(self) -> None:
|
||||
"""Redundant prob > 0 produces more mates than clean version."""
|
||||
gen_clean = SyntheticMateGenerator(seed=42, redundant_prob=0.0)
|
||||
_, mates_clean, _ = gen_clean.generate(4)
|
||||
|
||||
gen_noisy = SyntheticMateGenerator(seed=42, redundant_prob=1.0)
|
||||
_, mates_noisy, _ = gen_noisy.generate(4)
|
||||
|
||||
assert len(mates_noisy) > len(mates_clean)
|
||||
|
||||
def test_missing_injection(self) -> None:
|
||||
"""Missing prob > 0 produces fewer mates than clean version."""
|
||||
gen_clean = SyntheticMateGenerator(seed=42, missing_prob=0.0)
|
||||
_, mates_clean, _ = gen_clean.generate(4)
|
||||
|
||||
gen_noisy = SyntheticMateGenerator(seed=42, missing_prob=0.5)
|
||||
_, mates_noisy, _ = gen_noisy.generate(4)
|
||||
|
||||
# With 50% drop rate on 6 mates, very likely to drop at least one
|
||||
assert len(mates_noisy) <= len(mates_clean)
|
||||
|
||||
def test_incompatible_injection(self) -> None:
|
||||
"""Incompatible prob > 0 adds mates with wrong geometry."""
|
||||
gen = SyntheticMateGenerator(seed=42, incompatible_prob=1.0)
|
||||
_, mates, _ = gen.generate(3)
|
||||
# Should have extra mates beyond the clean count
|
||||
gen_clean = SyntheticMateGenerator(seed=42)
|
||||
_, mates_clean, _ = gen_clean.generate(3)
|
||||
assert len(mates) > len(mates_clean)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_mate_training_batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGenerateMateTrainingBatch:
|
||||
"""Batch generation function."""
|
||||
|
||||
def test_batch_structure(self) -> None:
|
||||
"""Each example has required keys."""
|
||||
examples = generate_mate_training_batch(batch_size=3, seed=42)
|
||||
assert len(examples) == 3
|
||||
for ex in examples:
|
||||
assert "bodies" in ex
|
||||
assert "mates" in ex
|
||||
assert "patterns" in ex
|
||||
assert "labels" in ex
|
||||
assert "n_bodies" in ex
|
||||
assert "n_mates" in ex
|
||||
assert "n_joints" in ex
|
||||
|
||||
def test_batch_deterministic(self) -> None:
|
||||
"""Same seed produces same batch."""
|
||||
batch1 = generate_mate_training_batch(batch_size=5, seed=99)
|
||||
batch2 = generate_mate_training_batch(batch_size=5, seed=99)
|
||||
for ex1, ex2 in zip(batch1, batch2, strict=True):
|
||||
assert ex1["n_bodies"] == ex2["n_bodies"]
|
||||
assert ex1["n_mates"] == ex2["n_mates"]
|
||||
|
||||
def test_batch_grounded_ratio(self) -> None:
|
||||
"""Batch respects grounded_ratio parameter."""
|
||||
# All grounded
|
||||
examples = generate_mate_training_batch(batch_size=5, seed=42, grounded_ratio=1.0)
|
||||
assert len(examples) == 5
|
||||
|
||||
def test_batch_with_noise(self) -> None:
|
||||
"""Batch with noise injection runs without error."""
|
||||
examples = generate_mate_training_batch(
|
||||
batch_size=3,
|
||||
seed=42,
|
||||
redundant_prob=0.3,
|
||||
missing_prob=0.1,
|
||||
)
|
||||
assert len(examples) == 3
|
||||
for ex in examples:
|
||||
assert ex["n_mates"] >= 0
|
||||
Reference in New Issue
Block a user