feat: parameterized assembly templates and complexity tiers
Add 4 new topology generators to SyntheticAssemblyGenerator: - generate_tree_assembly: random spanning tree with configurable branching - generate_loop_assembly: closed ring producing overconstrained data - generate_star_assembly: hub-and-spoke topology - generate_mixed_assembly: tree + loops with configurable edge density Each accepts joint_types as JointType | list[JointType] for per-joint type sampling. Add complexity tiers (simple/medium/complex) with predefined body count ranges via COMPLEXITY_RANGES dict and ComplexityTier type alias. Update generate_training_batch with 7-way generator selection, complexity_tier parameter, and generator_type field in output dicts. Extract private helpers (_random_position, _random_axis, _select_joint_type, _create_joint) to reduce duplication. 44 generator tests, 130 total — all passing. Closes #7
This commit is contained in:
@@ -2,11 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
|
||||
from solver.datagen.types import JointType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Original generators (chain / rigid / overconstrained)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChainAssembly:
|
||||
"""generate_chain_assembly produces valid underconstrained chains."""
|
||||
@@ -83,66 +90,259 @@ class TestOverconstrainedAssembly:
|
||||
assert len(joints_over) > len(joints_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New topology generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTreeAssembly:
|
||||
"""generate_tree_assembly produces tree-structured assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_tree_assembly(6)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_branching_factor(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
|
||||
assert len(bodies) == 10
|
||||
assert len(joints) == 9
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
# With 9 joints and 3 types, very likely to use at least 2
|
||||
assert len(used) >= 2
|
||||
|
||||
def test_single_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
def test_sequential_body_ids(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(7)
|
||||
assert [b.body_id for b in bodies] == list(range(7))
|
||||
|
||||
|
||||
class TestLoopAssembly:
|
||||
"""generate_loop_assembly produces closed-loop assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_loop_assembly(5)
|
||||
assert len(bodies) == 5
|
||||
assert len(joints) == 5 # n joints for n bodies
|
||||
|
||||
def test_has_redundancy(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_loop_assembly(5)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_wrap_around_connectivity(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(4)
|
||||
edges = {(j.body_a, j.body_b) for j in joints}
|
||||
assert (0, 1) in edges
|
||||
assert (1, 2) in edges
|
||||
assert (2, 3) in edges
|
||||
assert (3, 0) in edges # wrap-around
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
gen.generate_loop_assembly(2)
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
assert len(used) >= 2
|
||||
|
||||
|
||||
class TestStarAssembly:
|
||||
"""generate_star_assembly produces hub-and-spoke assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_star_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_all_joints_connect_to_hub(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(6)
|
||||
for j in joints:
|
||||
assert j.body_a == 0 or j.body_b == 0
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_star_assembly(5)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 2"):
|
||||
gen.generate_star_assembly(1)
|
||||
|
||||
def test_hub_at_origin(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(4)
|
||||
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
|
||||
|
||||
|
||||
class TestMixedAssembly:
|
||||
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
|
||||
|
||||
def test_more_joints_than_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
|
||||
assert len(joints) > len(bodies) - 1
|
||||
|
||||
def test_density_zero_is_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
|
||||
assert len(joints) == 4 # spanning tree only
|
||||
|
||||
def test_density_validation(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=1.5)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=-0.1)
|
||||
|
||||
def test_no_duplicate_edges(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
|
||||
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
|
||||
assert len(edges) == len(set(edges))
|
||||
|
||||
def test_high_density(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
|
||||
# Fully connected: 5*(5-1)/2 = 10 edges
|
||||
assert len(joints) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComplexityTiers:
|
||||
"""Complexity tier parameter on batch generation."""
|
||||
|
||||
def test_simple_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="simple")
|
||||
lo, hi = COMPLEXITY_RANGES["simple"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_medium_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="medium")
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_complex_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(3, complexity_tier="complex")
|
||||
lo, hi = COMPLEXITY_RANGES["complex"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_tier_overrides_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrainingBatch:
|
||||
"""generate_training_batch produces well-structured examples."""
|
||||
|
||||
@pytest.fixture()
|
||||
def batch(self) -> list[dict]:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
return gen.generate_training_batch(batch_size=20, n_bodies_range=(3, 6))
|
||||
EXPECTED_KEYS: ClassVar[set[str]] = {
|
||||
"example_id",
|
||||
"generator_type",
|
||||
"n_bodies",
|
||||
"n_joints",
|
||||
"body_positions",
|
||||
"joints",
|
||||
"joint_labels",
|
||||
"assembly_classification",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
"internal_dof",
|
||||
"geometric_degeneracies",
|
||||
}
|
||||
|
||||
def test_batch_size(self, batch: list[dict]) -> None:
|
||||
VALID_GEN_TYPES: ClassVar[set[str]] = {
|
||||
"chain",
|
||||
"rigid",
|
||||
"overconstrained",
|
||||
"tree",
|
||||
"loop",
|
||||
"star",
|
||||
"mixed",
|
||||
}
|
||||
|
||||
def test_batch_size(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20)
|
||||
assert len(batch) == 20
|
||||
|
||||
def test_example_keys(self, batch: list[dict]) -> None:
|
||||
expected = {
|
||||
"example_id",
|
||||
"n_bodies",
|
||||
"n_joints",
|
||||
"body_positions",
|
||||
"joints",
|
||||
"joint_labels",
|
||||
"assembly_classification",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
"internal_dof",
|
||||
"geometric_degeneracies",
|
||||
}
|
||||
def test_example_keys(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert set(ex.keys()) == expected
|
||||
assert set(ex.keys()) == self.EXPECTED_KEYS
|
||||
|
||||
def test_example_ids_sequential(self, batch: list[dict]) -> None:
|
||||
ids = [ex["example_id"] for ex in batch]
|
||||
assert ids == list(range(20))
|
||||
def test_example_ids_sequential(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(15)
|
||||
assert [ex["example_id"] for ex in batch] == list(range(15))
|
||||
|
||||
def test_classification_distribution(self, batch: list[dict]) -> None:
|
||||
"""Batch should contain multiple classification types."""
|
||||
classes = {ex["assembly_classification"] for ex in batch}
|
||||
# With the 3-way generator split we expect at least 2 types
|
||||
assert len(classes) >= 2
|
||||
|
||||
def test_body_count_in_range(self, batch: list[dict]) -> None:
|
||||
def test_generator_type_valid(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(50)
|
||||
for ex in batch:
|
||||
assert 3 <= ex["n_bodies"] <= 5 # range is [3, 6)
|
||||
assert ex["generator_type"] in self.VALID_GEN_TYPES
|
||||
|
||||
def test_joint_labels_match_joints(self, batch: list[dict]) -> None:
|
||||
def test_generator_type_diversity(self) -> None:
|
||||
"""100-sample batch should use at least 5 of 7 generator types."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100)
|
||||
types = {ex["generator_type"] for ex in batch}
|
||||
assert len(types) >= 5
|
||||
|
||||
def test_default_body_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
label_jids = set(ex["joint_labels"].keys())
|
||||
joint_jids = {j["joint_id"] for j in ex["joints"]}
|
||||
assert label_jids == joint_jids
|
||||
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
|
||||
|
||||
def test_joint_label_fields(self, batch: list[dict]) -> None:
|
||||
expected_fields = {
|
||||
"independent_constraints",
|
||||
"redundant_constraints",
|
||||
"total_constraints",
|
||||
}
|
||||
for ex in batch:
|
||||
for label in ex["joint_labels"].values():
|
||||
assert set(label.keys()) == expected_fields
|
||||
|
||||
def test_joint_label_consistency(self, batch: list[dict]) -> None:
|
||||
def test_joint_label_consistency(self) -> None:
|
||||
"""independent + redundant == total for every joint."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
for label in ex["joint_labels"].values():
|
||||
total = label["independent_constraints"] + label["redundant_constraints"]
|
||||
@@ -157,10 +357,17 @@ class TestSeedReproducibility:
|
||||
g2 = SyntheticAssemblyGenerator(seed=2)
|
||||
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
# Very unlikely to be identical with different seeds
|
||||
c1 = [ex["assembly_classification"] for ex in b1]
|
||||
c2 = [ex["assembly_classification"] for ex in b2]
|
||||
r1 = [ex["is_rigid"] for ex in b1]
|
||||
r2 = [ex["is_rigid"] for ex in b2]
|
||||
# At least one of these should differ (probabilistically certain)
|
||||
assert c1 != c2 or r1 != r2
|
||||
|
||||
def test_same_seed_identical(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=123)
|
||||
g2 = SyntheticAssemblyGenerator(seed=123)
|
||||
b1, j1, _ = g1.generate_tree_assembly(5)
|
||||
b2, j2, _ = g2.generate_tree_assembly(5)
|
||||
for a, b in zip(b1, b2, strict=True):
|
||||
np.testing.assert_array_almost_equal(a.position, b.position)
|
||||
assert len(j1) == len(j2)
|
||||
|
||||
Reference in New Issue
Block a user