feat: parameterized assembly templates and complexity tiers
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

Add 4 new topology generators to SyntheticAssemblyGenerator:
- generate_tree_assembly: random spanning tree with configurable branching
- generate_loop_assembly: closed ring producing overconstrained data
- generate_star_assembly: hub-and-spoke topology
- generate_mixed_assembly: tree + loops with configurable edge density

Each accepts joint_types as JointType | list[JointType] for per-joint
type sampling.

Add complexity tiers (simple/medium/complex) with predefined body count
ranges via COMPLEXITY_RANGES dict and ComplexityTier type alias.

Update generate_training_batch with 7-way generator selection,
complexity_tier parameter, and generator_type field in output dicts.

Extract private helpers (_random_position, _random_axis,
_select_joint_type, _create_joint) to reduce duplication.

44 generator tests, 130 total — all passing.

Closes #7
This commit is contained in:
2026-02-02 14:38:05 -06:00
parent dc742bfc82
commit 0b5813b5a9
3 changed files with 590 additions and 76 deletions

View File

@@ -2,11 +2,18 @@
from __future__ import annotations
from typing import ClassVar
import numpy as np
import pytest
from solver.datagen.generator import SyntheticAssemblyGenerator
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
from solver.datagen.types import JointType
# ---------------------------------------------------------------------------
# Original generators (chain / rigid / overconstrained)
# ---------------------------------------------------------------------------
class TestChainAssembly:
"""generate_chain_assembly produces valid underconstrained chains."""
@@ -83,66 +90,259 @@ class TestOverconstrainedAssembly:
assert len(joints_over) > len(joints_base)
# ---------------------------------------------------------------------------
# New topology generators
# ---------------------------------------------------------------------------
class TestTreeAssembly:
"""generate_tree_assembly produces tree-structured assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_tree_assembly(6)
assert analysis.combinatorial_classification == "underconstrained"
def test_branching_factor(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
assert len(bodies) == 10
assert len(joints) == 9
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
used = {j.joint_type for j in joints}
# With 9 joints and 3 types, very likely to use at least 2
assert len(used) >= 2
def test_single_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
assert all(j.joint_type is JointType.BALL for j in joints)
def test_sequential_body_ids(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(7)
assert [b.body_id for b in bodies] == list(range(7))
class TestLoopAssembly:
"""generate_loop_assembly produces closed-loop assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_loop_assembly(5)
assert len(bodies) == 5
assert len(joints) == 5 # n joints for n bodies
def test_has_redundancy(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_loop_assembly(5)
assert analysis.combinatorial_redundant > 0
def test_wrap_around_connectivity(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(4)
edges = {(j.body_a, j.body_b) for j in joints}
assert (0, 1) in edges
assert (1, 2) in edges
assert (2, 3) in edges
assert (3, 0) in edges # wrap-around
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 3"):
gen.generate_loop_assembly(2)
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.FIXED]
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
used = {j.joint_type for j in joints}
assert len(used) >= 2
class TestStarAssembly:
"""generate_star_assembly produces hub-and-spoke assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_star_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_all_joints_connect_to_hub(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(6)
for j in joints:
assert j.body_a == 0 or j.body_b == 0
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_star_assembly(5)
assert analysis.combinatorial_classification == "underconstrained"
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 2"):
gen.generate_star_assembly(1)
def test_hub_at_origin(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(4)
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
class TestMixedAssembly:
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
def test_more_joints_than_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
assert len(joints) > len(bodies) - 1
def test_density_zero_is_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
assert len(joints) == 4 # spanning tree only
def test_density_validation(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=1.5)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=-0.1)
def test_no_duplicate_edges(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
assert len(edges) == len(set(edges))
def test_high_density(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
# Fully connected: 5*(5-1)/2 = 10 edges
assert len(joints) == 10
# ---------------------------------------------------------------------------
# Complexity tiers
# ---------------------------------------------------------------------------
class TestComplexityTiers:
"""Complexity tier parameter on batch generation."""
def test_simple_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="simple")
lo, hi = COMPLEXITY_RANGES["simple"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_medium_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="medium")
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_complex_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(3, complexity_tier="complex")
lo, hi = COMPLEXITY_RANGES["complex"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_tier_overrides_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
# ---------------------------------------------------------------------------
# Training batch
# ---------------------------------------------------------------------------
class TestTrainingBatch:
"""generate_training_batch produces well-structured examples."""
@pytest.fixture()
def batch(self) -> list[dict]:
gen = SyntheticAssemblyGenerator(seed=42)
return gen.generate_training_batch(batch_size=20, n_bodies_range=(3, 6))
EXPECTED_KEYS: ClassVar[set[str]] = {
"example_id",
"generator_type",
"n_bodies",
"n_joints",
"body_positions",
"joints",
"joint_labels",
"assembly_classification",
"is_rigid",
"is_minimally_rigid",
"internal_dof",
"geometric_degeneracies",
}
def test_batch_size(self, batch: list[dict]) -> None:
VALID_GEN_TYPES: ClassVar[set[str]] = {
"chain",
"rigid",
"overconstrained",
"tree",
"loop",
"star",
"mixed",
}
def test_batch_size(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20)
assert len(batch) == 20
def test_example_keys(self, batch: list[dict]) -> None:
expected = {
"example_id",
"n_bodies",
"n_joints",
"body_positions",
"joints",
"joint_labels",
"assembly_classification",
"is_rigid",
"is_minimally_rigid",
"internal_dof",
"geometric_degeneracies",
}
def test_example_keys(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert set(ex.keys()) == expected
assert set(ex.keys()) == self.EXPECTED_KEYS
def test_example_ids_sequential(self, batch: list[dict]) -> None:
ids = [ex["example_id"] for ex in batch]
assert ids == list(range(20))
def test_example_ids_sequential(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(15)
assert [ex["example_id"] for ex in batch] == list(range(15))
def test_classification_distribution(self, batch: list[dict]) -> None:
"""Batch should contain multiple classification types."""
classes = {ex["assembly_classification"] for ex in batch}
# With the 3-way generator split we expect at least 2 types
assert len(classes) >= 2
def test_body_count_in_range(self, batch: list[dict]) -> None:
def test_generator_type_valid(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(50)
for ex in batch:
assert 3 <= ex["n_bodies"] <= 5 # range is [3, 6)
assert ex["generator_type"] in self.VALID_GEN_TYPES
def test_joint_labels_match_joints(self, batch: list[dict]) -> None:
def test_generator_type_diversity(self) -> None:
"""100-sample batch should use at least 5 of 7 generator types."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100)
types = {ex["generator_type"] for ex in batch}
assert len(types) >= 5
def test_default_body_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
label_jids = set(ex["joint_labels"].keys())
joint_jids = {j["joint_id"] for j in ex["joints"]}
assert label_jids == joint_jids
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
def test_joint_label_fields(self, batch: list[dict]) -> None:
expected_fields = {
"independent_constraints",
"redundant_constraints",
"total_constraints",
}
for ex in batch:
for label in ex["joint_labels"].values():
assert set(label.keys()) == expected_fields
def test_joint_label_consistency(self, batch: list[dict]) -> None:
def test_joint_label_consistency(self) -> None:
"""independent + redundant == total for every joint."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
for label in ex["joint_labels"].values():
total = label["independent_constraints"] + label["redundant_constraints"]
@@ -157,10 +357,17 @@ class TestSeedReproducibility:
g2 = SyntheticAssemblyGenerator(seed=2)
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
# Very unlikely to be identical with different seeds
c1 = [ex["assembly_classification"] for ex in b1]
c2 = [ex["assembly_classification"] for ex in b2]
r1 = [ex["is_rigid"] for ex in b1]
r2 = [ex["is_rigid"] for ex in b2]
# At least one of these should differ (probabilistically certain)
assert c1 != c2 or r1 != r2
def test_same_seed_identical(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=123)
g2 = SyntheticAssemblyGenerator(seed=123)
b1, j1, _ = g1.generate_tree_assembly(5)
b2, j2, _ = g2.generate_tree_assembly(5)
for a, b in zip(b1, b2, strict=True):
np.testing.assert_array_almost_equal(a.position, b.position)
assert len(j1) == len(j2)