"""Tests for solver.mates.generator -- synthetic mate generator.""" from __future__ import annotations from solver.mates.generator import SyntheticMateGenerator, generate_mate_training_batch from solver.mates.primitives import MateType # --------------------------------------------------------------------------- # SyntheticMateGenerator # --------------------------------------------------------------------------- class TestSyntheticMateGenerator: """SyntheticMateGenerator core functionality.""" def test_generate_basic(self) -> None: """Generate a simple assembly with mates.""" gen = SyntheticMateGenerator(seed=42) bodies, mates, result = gen.generate(3) assert len(bodies) == 3 assert len(mates) > 0 assert result.analysis is not None def test_deterministic_with_seed(self) -> None: """Same seed produces same output.""" gen1 = SyntheticMateGenerator(seed=123) _, mates1, _ = gen1.generate(3) gen2 = SyntheticMateGenerator(seed=123) _, mates2, _ = gen2.generate(3) assert len(mates1) == len(mates2) for m1, m2 in zip(mates1, mates2, strict=True): assert m1.mate_type == m2.mate_type assert m1.ref_a.body_id == m2.ref_a.body_id def test_grounded(self) -> None: """Grounded assembly should work.""" gen = SyntheticMateGenerator(seed=42) bodies, _mates, result = gen.generate(3, grounded=True) assert len(bodies) == 3 assert result.analysis is not None def test_revolute_produces_two_mates(self) -> None: """A revolute joint should reverse-map to 2 mates.""" gen = SyntheticMateGenerator(seed=42) _bodies, mates, _result = gen.generate(2) # 2 bodies -> 1 revolute joint -> 2 mates (concentric + coincident) assert len(mates) == 2 mate_types = {m.mate_type for m in mates} assert MateType.CONCENTRIC in mate_types assert MateType.COINCIDENT in mate_types class TestReverseMapping: """Reverse mapping from joints to mates.""" def test_revolute_mapping(self) -> None: """REVOLUTE -> Concentric + Coincident.""" gen = SyntheticMateGenerator(seed=42) _bodies, mates, _result = gen.generate(2) types = [m.mate_type for m in mates] assert MateType.CONCENTRIC in types assert MateType.COINCIDENT in types def test_round_trip_analysis(self) -> None: """Generated mates round-trip through analysis successfully.""" gen = SyntheticMateGenerator(seed=42) _bodies, _mates, result = gen.generate(4) assert result.analysis is not None assert result.labels is not None # Should produce joints from the mates assert len(result.joints) > 0 class TestNoiseInjection: """Noise injection mechanisms.""" def test_redundant_injection(self) -> None: """Redundant prob > 0 produces more mates than clean version.""" gen_clean = SyntheticMateGenerator(seed=42, redundant_prob=0.0) _, mates_clean, _ = gen_clean.generate(4) gen_noisy = SyntheticMateGenerator(seed=42, redundant_prob=1.0) _, mates_noisy, _ = gen_noisy.generate(4) assert len(mates_noisy) > len(mates_clean) def test_missing_injection(self) -> None: """Missing prob > 0 produces fewer mates than clean version.""" gen_clean = SyntheticMateGenerator(seed=42, missing_prob=0.0) _, mates_clean, _ = gen_clean.generate(4) gen_noisy = SyntheticMateGenerator(seed=42, missing_prob=0.5) _, mates_noisy, _ = gen_noisy.generate(4) # With 50% drop rate on 6 mates, very likely to drop at least one assert len(mates_noisy) <= len(mates_clean) def test_incompatible_injection(self) -> None: """Incompatible prob > 0 adds mates with wrong geometry.""" gen = SyntheticMateGenerator(seed=42, incompatible_prob=1.0) _, mates, _ = gen.generate(3) # Should have extra mates beyond the clean count gen_clean = SyntheticMateGenerator(seed=42) _, mates_clean, _ = gen_clean.generate(3) assert len(mates) > len(mates_clean) # --------------------------------------------------------------------------- # generate_mate_training_batch # --------------------------------------------------------------------------- class TestGenerateMateTrainingBatch: """Batch generation function.""" def test_batch_structure(self) -> None: """Each example has required keys.""" examples = generate_mate_training_batch(batch_size=3, seed=42) assert len(examples) == 3 for ex in examples: assert "bodies" in ex assert "mates" in ex assert "patterns" in ex assert "labels" in ex assert "n_bodies" in ex assert "n_mates" in ex assert "n_joints" in ex def test_batch_deterministic(self) -> None: """Same seed produces same batch.""" batch1 = generate_mate_training_batch(batch_size=5, seed=99) batch2 = generate_mate_training_batch(batch_size=5, seed=99) for ex1, ex2 in zip(batch1, batch2, strict=True): assert ex1["n_bodies"] == ex2["n_bodies"] assert ex1["n_mates"] == ex2["n_mates"] def test_batch_grounded_ratio(self) -> None: """Batch respects grounded_ratio parameter.""" # All grounded examples = generate_mate_training_batch(batch_size=5, seed=42, grounded_ratio=1.0) assert len(examples) == 5 def test_batch_with_noise(self) -> None: """Batch with noise injection runs without error.""" examples = generate_mate_training_batch( batch_size=3, seed=42, redundant_prob=0.3, missing_prob=0.1, ) assert len(examples) == 3 for ex in examples: assert ex["n_mates"] >= 0