Files
solver/tests/mates/test_generator.py
forbes-0023 239e45c7f9
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
feat(mates): add mate-based synthetic assembly generator
SyntheticMateGenerator wraps existing joint generator with reverse
mapping (joint->mates) and configurable noise injection (redundant,
missing, incompatible mates). Batch generation via
generate_mate_training_batch().

Closes #14
2026-02-03 13:05:58 -06:00

156 lines
5.8 KiB
Python

"""Tests for solver.mates.generator -- synthetic mate generator."""
from __future__ import annotations
from solver.mates.generator import SyntheticMateGenerator, generate_mate_training_batch
from solver.mates.primitives import MateType
# ---------------------------------------------------------------------------
# SyntheticMateGenerator
# ---------------------------------------------------------------------------
class TestSyntheticMateGenerator:
"""SyntheticMateGenerator core functionality."""
def test_generate_basic(self) -> None:
"""Generate a simple assembly with mates."""
gen = SyntheticMateGenerator(seed=42)
bodies, mates, result = gen.generate(3)
assert len(bodies) == 3
assert len(mates) > 0
assert result.analysis is not None
def test_deterministic_with_seed(self) -> None:
"""Same seed produces same output."""
gen1 = SyntheticMateGenerator(seed=123)
_, mates1, _ = gen1.generate(3)
gen2 = SyntheticMateGenerator(seed=123)
_, mates2, _ = gen2.generate(3)
assert len(mates1) == len(mates2)
for m1, m2 in zip(mates1, mates2, strict=True):
assert m1.mate_type == m2.mate_type
assert m1.ref_a.body_id == m2.ref_a.body_id
def test_grounded(self) -> None:
"""Grounded assembly should work."""
gen = SyntheticMateGenerator(seed=42)
bodies, _mates, result = gen.generate(3, grounded=True)
assert len(bodies) == 3
assert result.analysis is not None
def test_revolute_produces_two_mates(self) -> None:
"""A revolute joint should reverse-map to 2 mates."""
gen = SyntheticMateGenerator(seed=42)
_bodies, mates, _result = gen.generate(2)
# 2 bodies -> 1 revolute joint -> 2 mates (concentric + coincident)
assert len(mates) == 2
mate_types = {m.mate_type for m in mates}
assert MateType.CONCENTRIC in mate_types
assert MateType.COINCIDENT in mate_types
class TestReverseMapping:
"""Reverse mapping from joints to mates."""
def test_revolute_mapping(self) -> None:
"""REVOLUTE -> Concentric + Coincident."""
gen = SyntheticMateGenerator(seed=42)
_bodies, mates, _result = gen.generate(2)
types = [m.mate_type for m in mates]
assert MateType.CONCENTRIC in types
assert MateType.COINCIDENT in types
def test_round_trip_analysis(self) -> None:
"""Generated mates round-trip through analysis successfully."""
gen = SyntheticMateGenerator(seed=42)
_bodies, _mates, result = gen.generate(4)
assert result.analysis is not None
assert result.labels is not None
# Should produce joints from the mates
assert len(result.joints) > 0
class TestNoiseInjection:
"""Noise injection mechanisms."""
def test_redundant_injection(self) -> None:
"""Redundant prob > 0 produces more mates than clean version."""
gen_clean = SyntheticMateGenerator(seed=42, redundant_prob=0.0)
_, mates_clean, _ = gen_clean.generate(4)
gen_noisy = SyntheticMateGenerator(seed=42, redundant_prob=1.0)
_, mates_noisy, _ = gen_noisy.generate(4)
assert len(mates_noisy) > len(mates_clean)
def test_missing_injection(self) -> None:
"""Missing prob > 0 produces fewer mates than clean version."""
gen_clean = SyntheticMateGenerator(seed=42, missing_prob=0.0)
_, mates_clean, _ = gen_clean.generate(4)
gen_noisy = SyntheticMateGenerator(seed=42, missing_prob=0.5)
_, mates_noisy, _ = gen_noisy.generate(4)
# With 50% drop rate on 6 mates, very likely to drop at least one
assert len(mates_noisy) <= len(mates_clean)
def test_incompatible_injection(self) -> None:
"""Incompatible prob > 0 adds mates with wrong geometry."""
gen = SyntheticMateGenerator(seed=42, incompatible_prob=1.0)
_, mates, _ = gen.generate(3)
# Should have extra mates beyond the clean count
gen_clean = SyntheticMateGenerator(seed=42)
_, mates_clean, _ = gen_clean.generate(3)
assert len(mates) > len(mates_clean)
# ---------------------------------------------------------------------------
# generate_mate_training_batch
# ---------------------------------------------------------------------------
class TestGenerateMateTrainingBatch:
"""Batch generation function."""
def test_batch_structure(self) -> None:
"""Each example has required keys."""
examples = generate_mate_training_batch(batch_size=3, seed=42)
assert len(examples) == 3
for ex in examples:
assert "bodies" in ex
assert "mates" in ex
assert "patterns" in ex
assert "labels" in ex
assert "n_bodies" in ex
assert "n_mates" in ex
assert "n_joints" in ex
def test_batch_deterministic(self) -> None:
"""Same seed produces same batch."""
batch1 = generate_mate_training_batch(batch_size=5, seed=99)
batch2 = generate_mate_training_batch(batch_size=5, seed=99)
for ex1, ex2 in zip(batch1, batch2, strict=True):
assert ex1["n_bodies"] == ex2["n_bodies"]
assert ex1["n_mates"] == ex2["n_mates"]
def test_batch_grounded_ratio(self) -> None:
"""Batch respects grounded_ratio parameter."""
# All grounded
examples = generate_mate_training_batch(batch_size=5, seed=42, grounded_ratio=1.0)
assert len(examples) == 5
def test_batch_with_noise(self) -> None:
"""Batch with noise injection runs without error."""
examples = generate_mate_training_batch(
batch_size=3,
seed=42,
redundant_prob=0.3,
missing_prob=0.1,
)
assert len(examples) == 3
for ex in examples:
assert ex["n_mates"] >= 0