feat: parameterized assembly templates and complexity tiers
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

Add 4 new topology generators to SyntheticAssemblyGenerator:
- generate_tree_assembly: random spanning tree with configurable branching
- generate_loop_assembly: closed ring producing overconstrained data
- generate_star_assembly: hub-and-spoke topology
- generate_mixed_assembly: tree + loops with configurable edge density

Each accepts joint_types as JointType | list[JointType] for per-joint
type sampling.

Add complexity tiers (simple/medium/complex) with predefined body count
ranges via COMPLEXITY_RANGES dict and ComplexityTier type alias.

Update generate_training_batch with 7-way generator selection,
complexity_tier parameter, and generator_type field in output dicts.

Extract private helpers (_random_position, _random_axis,
_select_joint_type, _create_joint) to reduce duplication.

44 generator tests, 130 total — all passing.

Closes #7
This commit is contained in:
2026-02-02 14:38:05 -06:00
parent dc742bfc82
commit 0b5813b5a9
3 changed files with 590 additions and 76 deletions

View File

@@ -1,7 +1,7 @@
"""Data generation utilities for assembly constraint training data."""
from solver.datagen.analysis import analyze_assembly
from solver.datagen.generator import SyntheticAssemblyGenerator
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
@@ -13,6 +13,7 @@ from solver.datagen.types import (
)
__all__ = [
"COMPLEXITY_RANGES",
"ConstraintAnalysis",
"JacobianVerifier",
"Joint",

View File

@@ -7,7 +7,7 @@ with per-constraint independence flags and assembly-level classification.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Literal
import numpy as np
@@ -22,7 +22,19 @@ from solver.datagen.types import (
if TYPE_CHECKING:
from typing import Any
__all__ = ["SyntheticAssemblyGenerator"]
__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"]
# ---------------------------------------------------------------------------
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
# ---------------------------------------------------------------------------
ComplexityTier = Literal["simple", "medium", "complex"]
COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
"simple": (2, 6),
"medium": (6, 16),
"complex": (16, 51),
}
class SyntheticAssemblyGenerator:
@@ -41,6 +53,55 @@ class SyntheticAssemblyGenerator:
def __init__(self, seed: int = 42) -> None:
self.rng = np.random.default_rng(seed)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _random_position(self, scale: float = 5.0) -> np.ndarray:
"""Generate random 3D position within [-scale, scale] cube."""
return self.rng.uniform(-scale, scale, size=3)
def _random_axis(self) -> np.ndarray:
"""Generate random normalized 3D axis."""
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
return axis
def _select_joint_type(
self,
joint_types: JointType | list[JointType],
) -> JointType:
"""Select a joint type from a single type or list."""
if isinstance(joint_types, list):
idx = int(self.rng.integers(0, len(joint_types)))
return joint_types[idx]
return joint_types
def _create_joint(
self,
joint_id: int,
body_a_id: int,
body_b_id: int,
pos_a: np.ndarray,
pos_b: np.ndarray,
joint_type: JointType,
) -> Joint:
"""Create a joint between two bodies with random axis at midpoint."""
anchor = (pos_a + pos_b) / 2.0
return Joint(
joint_id=joint_id,
body_a=body_a_id,
body_b=body_b_id,
joint_type=joint_type,
anchor_a=anchor,
anchor_b=anchor,
axis=self._random_axis(),
)
# ------------------------------------------------------------------
# Original generators (chain / rigid / overconstrained)
# ------------------------------------------------------------------
def generate_chain_assembly(
self,
n_bodies: int,
@@ -60,11 +121,7 @@ class SyntheticAssemblyGenerator:
bodies.append(RigidBody(body_id=i, position=pos))
for i in range(n_bodies - 1):
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0])
joints.append(
Joint(
joint_id=i,
@@ -73,7 +130,7 @@ class SyntheticAssemblyGenerator:
joint_type=joint_type,
anchor_a=anchor,
anchor_b=anchor,
axis=axis,
axis=self._random_axis(),
)
)
@@ -91,26 +148,22 @@ class SyntheticAssemblyGenerator:
"""
bodies = []
for i in range(n_bodies):
pos = self.rng.uniform(-5, 5, size=3)
bodies.append(RigidBody(body_id=i, position=pos))
bodies.append(RigidBody(body_id=i, position=self._random_position()))
# Build spanning tree with fixed joints (overconstrained)
joints: list[Joint] = []
for i in range(1, n_bodies):
parent = self.rng.integers(0, i)
parent = int(self.rng.integers(0, i))
mid = (bodies[i].position + bodies[parent].position) / 2
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
joints.append(
Joint(
joint_id=i - 1,
body_a=int(parent),
body_a=parent,
body_b=i,
joint_type=JointType.FIXED,
anchor_a=mid,
anchor_b=mid,
axis=axis,
axis=self._random_axis(),
)
)
@@ -150,8 +203,6 @@ class SyntheticAssemblyGenerator:
for _ in range(extra_joints):
a, b = self.rng.choice(n_bodies, size=2, replace=False)
mid = (bodies[a].position + bodies[b].position) / 2
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
_overcon_types = [
JointType.REVOLUTE,
@@ -167,7 +218,7 @@ class SyntheticAssemblyGenerator:
joint_type=jtype,
anchor_a=mid,
anchor_b=mid,
axis=axis,
axis=self._random_axis(),
)
)
joint_id += 1
@@ -175,24 +226,259 @@ class SyntheticAssemblyGenerator:
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
# ------------------------------------------------------------------
# New topology generators
# ------------------------------------------------------------------
def generate_tree_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
branching_factor: int = 3,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a random tree topology with configurable branching.
Creates a tree where each body can have up to *branching_factor*
children. Different branches can use different joint types if a
list is provided. Always underconstrained (no closed loops).
Args:
n_bodies: Total bodies (root + children).
joint_types: Single type or list to sample from per joint.
branching_factor: Max children per parent (1-5 recommended).
"""
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
joints: list[Joint] = []
available_parents = [0]
next_id = 1
joint_id = 0
while next_id < n_bodies and available_parents:
pidx = int(self.rng.integers(0, len(available_parents)))
parent_id = available_parents[pidx]
parent_pos = bodies[parent_id].position
max_children = min(branching_factor, n_bodies - next_id)
n_children = int(self.rng.integers(1, max_children + 1))
for _ in range(n_children):
direction = self._random_axis()
distance = self.rng.uniform(1.5, 3.0)
child_pos = parent_pos + direction * distance
bodies.append(RigidBody(body_id=next_id, position=child_pos))
jtype = self._select_joint_type(joint_types)
joints.append(
self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype)
)
available_parents.append(next_id)
next_id += 1
joint_id += 1
if next_id >= n_bodies:
break
# Retire parent if it reached branching limit or randomly
if n_children >= branching_factor or self.rng.random() < 0.3:
available_parents.pop(pidx)
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
def generate_loop_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a single closed loop (ring) of bodies.
The closing constraint introduces redundancy, making this
useful for generating overconstrained training data.
Args:
n_bodies: Bodies in the ring (>= 3).
joint_types: Single type or list to sample from per joint.
Raises:
ValueError: If *n_bodies* < 3.
"""
if n_bodies < 3:
msg = "Loop assembly requires at least 3 bodies"
raise ValueError(msg)
bodies: list[RigidBody] = []
joints: list[Joint] = []
base_radius = max(2.0, n_bodies * 0.4)
for i in range(n_bodies):
angle = 2 * np.pi * i / n_bodies
radius = base_radius + self.rng.uniform(-0.5, 0.5)
x = radius * np.cos(angle)
y = radius * np.sin(angle)
z = float(self.rng.uniform(-0.2, 0.2))
bodies.append(RigidBody(body_id=i, position=np.array([x, y, z])))
for i in range(n_bodies):
next_i = (i + 1) % n_bodies
jtype = self._select_joint_type(joint_types)
joints.append(
self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype)
)
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
def generate_star_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a star topology with central hub and satellites.
Body 0 is the hub; all other bodies connect directly to it.
Underconstrained because there are no inter-satellite connections.
Args:
n_bodies: Total bodies including hub (>= 2).
joint_types: Single type or list to sample from per joint.
Raises:
ValueError: If *n_bodies* < 2.
"""
if n_bodies < 2:
msg = "Star assembly requires at least 2 bodies"
raise ValueError(msg)
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
joints: list[Joint] = []
for i in range(1, n_bodies):
direction = self._random_axis()
distance = self.rng.uniform(2.0, 5.0)
pos = direction * distance
bodies.append(RigidBody(body_id=i, position=pos))
jtype = self._select_joint_type(joint_types)
joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype))
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
def generate_mixed_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
edge_density: float = 0.3,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a mixed topology combining tree and loop elements.
Builds a spanning tree for connectivity, then adds extra edges
based on *edge_density* to create loops and redundancy.
Args:
n_bodies: Number of bodies.
joint_types: Single type or list to sample from per joint.
edge_density: Fraction of non-tree edges to add (0.0-1.0).
Raises:
ValueError: If *edge_density* not in [0.0, 1.0].
"""
if not 0.0 <= edge_density <= 1.0:
msg = "edge_density must be in [0.0, 1.0]"
raise ValueError(msg)
bodies: list[RigidBody] = []
joints: list[Joint] = []
for i in range(n_bodies):
bodies.append(RigidBody(body_id=i, position=self._random_position()))
# Phase 1: spanning tree
joint_id = 0
existing_edges: set[frozenset[int]] = set()
for i in range(1, n_bodies):
parent = int(self.rng.integers(0, i))
jtype = self._select_joint_type(joint_types)
joints.append(
self._create_joint(
joint_id,
parent,
i,
bodies[parent].position,
bodies[i].position,
jtype,
)
)
existing_edges.add(frozenset([parent, i]))
joint_id += 1
# Phase 2: extra edges based on density
candidates: list[tuple[int, int]] = []
for i in range(n_bodies):
for j in range(i + 1, n_bodies):
if frozenset([i, j]) not in existing_edges:
candidates.append((i, j))
n_extra = int(edge_density * len(candidates))
self.rng.shuffle(candidates)
for a, b in candidates[:n_extra]:
jtype = self._select_joint_type(joint_types)
joints.append(
self._create_joint(
joint_id,
a,
b,
bodies[a].position,
bodies[b].position,
jtype,
)
)
joint_id += 1
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
# ------------------------------------------------------------------
# Batch generation
# ------------------------------------------------------------------
def generate_training_batch(
self,
batch_size: int = 100,
n_bodies_range: tuple[int, int] = (3, 8),
n_bodies_range: tuple[int, int] | None = None,
complexity_tier: ComplexityTier | None = None,
) -> list[dict[str, Any]]:
"""Generate a batch of labeled training examples.
Each example contains:
- bodies: list of body positions/orientations
- joints: list of joints with types and parameters
- labels: per-joint independence/redundancy flags
- assembly_label: overall classification
Each example contains body positions, joint descriptions,
per-joint independence labels, and assembly-level classification.
Args:
batch_size: Number of assemblies to generate.
n_bodies_range: ``(min, max_exclusive)`` body count.
Overridden by *complexity_tier* when both are given.
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
/ ``"complex"``). Overrides *n_bodies_range*.
"""
if complexity_tier is not None:
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
elif n_bodies_range is None:
n_bodies_range = (3, 8)
_joint_pool = [
JointType.REVOLUTE,
JointType.BALL,
JointType.CYLINDRICAL,
JointType.FIXED,
]
examples: list[dict[str, Any]] = []
for i in range(batch_size):
n = int(self.rng.integers(*n_bodies_range))
gen_idx = int(self.rng.integers(3))
gen_idx = int(self.rng.integers(7))
if gen_idx == 0:
_chain_types = [
@@ -202,11 +488,30 @@ class SyntheticAssemblyGenerator:
]
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
bodies, joints, analysis = self.generate_chain_assembly(n, jtype)
gen_name = "chain"
elif gen_idx == 1:
bodies, joints, analysis = self.generate_rigid_assembly(n)
else:
gen_name = "rigid"
elif gen_idx == 2:
extra = int(self.rng.integers(1, 4))
bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra)
gen_name = "overconstrained"
elif gen_idx == 3:
branching = int(self.rng.integers(2, 5))
bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching)
gen_name = "tree"
elif gen_idx == 4:
n = max(n, 3)
bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool)
gen_name = "loop"
elif gen_idx == 5:
n = max(n, 2)
bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool)
gen_name = "star"
else:
density = float(self.rng.uniform(0.2, 0.5))
bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density)
gen_name = "mixed"
# Build per-joint labels from edge results
joint_labels: dict[int, dict[str, int]] = {}
@@ -227,6 +532,7 @@ class SyntheticAssemblyGenerator:
examples.append(
{
"example_id": i,
"generator_type": gen_name,
"n_bodies": len(bodies),
"n_joints": len(joints),
"body_positions": [b.position.tolist() for b in bodies],

View File

@@ -2,11 +2,18 @@
from __future__ import annotations
from typing import ClassVar
import numpy as np
import pytest
from solver.datagen.generator import SyntheticAssemblyGenerator
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
from solver.datagen.types import JointType
# ---------------------------------------------------------------------------
# Original generators (chain / rigid / overconstrained)
# ---------------------------------------------------------------------------
class TestChainAssembly:
"""generate_chain_assembly produces valid underconstrained chains."""
@@ -83,66 +90,259 @@ class TestOverconstrainedAssembly:
assert len(joints_over) > len(joints_base)
# ---------------------------------------------------------------------------
# New topology generators
# ---------------------------------------------------------------------------
class TestTreeAssembly:
"""generate_tree_assembly produces tree-structured assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_tree_assembly(6)
assert analysis.combinatorial_classification == "underconstrained"
def test_branching_factor(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
assert len(bodies) == 10
assert len(joints) == 9
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
used = {j.joint_type for j in joints}
# With 9 joints and 3 types, very likely to use at least 2
assert len(used) >= 2
def test_single_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
assert all(j.joint_type is JointType.BALL for j in joints)
def test_sequential_body_ids(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(7)
assert [b.body_id for b in bodies] == list(range(7))
class TestLoopAssembly:
"""generate_loop_assembly produces closed-loop assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_loop_assembly(5)
assert len(bodies) == 5
assert len(joints) == 5 # n joints for n bodies
def test_has_redundancy(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_loop_assembly(5)
assert analysis.combinatorial_redundant > 0
def test_wrap_around_connectivity(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(4)
edges = {(j.body_a, j.body_b) for j in joints}
assert (0, 1) in edges
assert (1, 2) in edges
assert (2, 3) in edges
assert (3, 0) in edges # wrap-around
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 3"):
gen.generate_loop_assembly(2)
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.FIXED]
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
used = {j.joint_type for j in joints}
assert len(used) >= 2
class TestStarAssembly:
"""generate_star_assembly produces hub-and-spoke assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_star_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_all_joints_connect_to_hub(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(6)
for j in joints:
assert j.body_a == 0 or j.body_b == 0
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_star_assembly(5)
assert analysis.combinatorial_classification == "underconstrained"
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 2"):
gen.generate_star_assembly(1)
def test_hub_at_origin(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(4)
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
class TestMixedAssembly:
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
def test_more_joints_than_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
assert len(joints) > len(bodies) - 1
def test_density_zero_is_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
assert len(joints) == 4 # spanning tree only
def test_density_validation(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=1.5)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=-0.1)
def test_no_duplicate_edges(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
assert len(edges) == len(set(edges))
def test_high_density(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
# Fully connected: 5*(5-1)/2 = 10 edges
assert len(joints) == 10
# ---------------------------------------------------------------------------
# Complexity tiers
# ---------------------------------------------------------------------------
class TestComplexityTiers:
"""Complexity tier parameter on batch generation."""
def test_simple_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="simple")
lo, hi = COMPLEXITY_RANGES["simple"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_medium_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="medium")
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_complex_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(3, complexity_tier="complex")
lo, hi = COMPLEXITY_RANGES["complex"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_tier_overrides_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
# ---------------------------------------------------------------------------
# Training batch
# ---------------------------------------------------------------------------
class TestTrainingBatch:
"""generate_training_batch produces well-structured examples."""
@pytest.fixture()
def batch(self) -> list[dict]:
gen = SyntheticAssemblyGenerator(seed=42)
return gen.generate_training_batch(batch_size=20, n_bodies_range=(3, 6))
EXPECTED_KEYS: ClassVar[set[str]] = {
"example_id",
"generator_type",
"n_bodies",
"n_joints",
"body_positions",
"joints",
"joint_labels",
"assembly_classification",
"is_rigid",
"is_minimally_rigid",
"internal_dof",
"geometric_degeneracies",
}
def test_batch_size(self, batch: list[dict]) -> None:
VALID_GEN_TYPES: ClassVar[set[str]] = {
"chain",
"rigid",
"overconstrained",
"tree",
"loop",
"star",
"mixed",
}
def test_batch_size(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20)
assert len(batch) == 20
def test_example_keys(self, batch: list[dict]) -> None:
expected = {
"example_id",
"n_bodies",
"n_joints",
"body_positions",
"joints",
"joint_labels",
"assembly_classification",
"is_rigid",
"is_minimally_rigid",
"internal_dof",
"geometric_degeneracies",
}
def test_example_keys(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert set(ex.keys()) == expected
assert set(ex.keys()) == self.EXPECTED_KEYS
def test_example_ids_sequential(self, batch: list[dict]) -> None:
ids = [ex["example_id"] for ex in batch]
assert ids == list(range(20))
def test_example_ids_sequential(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(15)
assert [ex["example_id"] for ex in batch] == list(range(15))
def test_classification_distribution(self, batch: list[dict]) -> None:
"""Batch should contain multiple classification types."""
classes = {ex["assembly_classification"] for ex in batch}
# With the 3-way generator split we expect at least 2 types
assert len(classes) >= 2
def test_body_count_in_range(self, batch: list[dict]) -> None:
def test_generator_type_valid(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(50)
for ex in batch:
assert 3 <= ex["n_bodies"] <= 5 # range is [3, 6)
assert ex["generator_type"] in self.VALID_GEN_TYPES
def test_joint_labels_match_joints(self, batch: list[dict]) -> None:
def test_generator_type_diversity(self) -> None:
"""100-sample batch should use at least 5 of 7 generator types."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100)
types = {ex["generator_type"] for ex in batch}
assert len(types) >= 5
def test_default_body_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
label_jids = set(ex["joint_labels"].keys())
joint_jids = {j["joint_id"] for j in ex["joints"]}
assert label_jids == joint_jids
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
def test_joint_label_fields(self, batch: list[dict]) -> None:
expected_fields = {
"independent_constraints",
"redundant_constraints",
"total_constraints",
}
for ex in batch:
for label in ex["joint_labels"].values():
assert set(label.keys()) == expected_fields
def test_joint_label_consistency(self, batch: list[dict]) -> None:
def test_joint_label_consistency(self) -> None:
"""independent + redundant == total for every joint."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
for label in ex["joint_labels"].values():
total = label["independent_constraints"] + label["redundant_constraints"]
@@ -157,10 +357,17 @@ class TestSeedReproducibility:
g2 = SyntheticAssemblyGenerator(seed=2)
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
# Very unlikely to be identical with different seeds
c1 = [ex["assembly_classification"] for ex in b1]
c2 = [ex["assembly_classification"] for ex in b2]
r1 = [ex["is_rigid"] for ex in b1]
r2 = [ex["is_rigid"] for ex in b2]
# At least one of these should differ (probabilistically certain)
assert c1 != c2 or r1 != r2
def test_same_seed_identical(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=123)
g2 = SyntheticAssemblyGenerator(seed=123)
b1, j1, _ = g1.generate_tree_assembly(5)
b2, j2, _ = g2.generate_tree_assembly(5)
for a, b in zip(b1, b2, strict=True):
np.testing.assert_array_almost_equal(a.position, b.position)
assert len(j1) == len(j2)