diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index bd6fc3c..203adcc 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,7 +1,7 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly -from solver.datagen.generator import SyntheticAssemblyGenerator +from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( @@ -13,6 +13,7 @@ from solver.datagen.types import ( ) __all__ = [ + "COMPLEXITY_RANGES", "ConstraintAnalysis", "JacobianVerifier", "Joint", diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py index 6c6bb0f..cd90ade 100644 --- a/solver/datagen/generator.py +++ b/solver/datagen/generator.py @@ -7,7 +7,7 @@ with per-constraint independence flags and assembly-level classification. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np @@ -22,7 +22,19 @@ from solver.datagen.types import ( if TYPE_CHECKING: from typing import Any -__all__ = ["SyntheticAssemblyGenerator"] +__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"] + +# --------------------------------------------------------------------------- +# Complexity tiers — ranges use exclusive upper bound for rng.integers() +# --------------------------------------------------------------------------- + +ComplexityTier = Literal["simple", "medium", "complex"] + +COMPLEXITY_RANGES: dict[str, tuple[int, int]] = { + "simple": (2, 6), + "medium": (6, 16), + "complex": (16, 51), +} class SyntheticAssemblyGenerator: @@ -41,6 +53,55 @@ class SyntheticAssemblyGenerator: def __init__(self, seed: int = 42) -> None: self.rng = np.random.default_rng(seed) + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _random_position(self, scale: float = 5.0) -> np.ndarray: + """Generate random 3D position within [-scale, scale] cube.""" + return self.rng.uniform(-scale, scale, size=3) + + def _random_axis(self) -> np.ndarray: + """Generate random normalized 3D axis.""" + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + return axis + + def _select_joint_type( + self, + joint_types: JointType | list[JointType], + ) -> JointType: + """Select a joint type from a single type or list.""" + if isinstance(joint_types, list): + idx = int(self.rng.integers(0, len(joint_types))) + return joint_types[idx] + return joint_types + + def _create_joint( + self, + joint_id: int, + body_a_id: int, + body_b_id: int, + pos_a: np.ndarray, + pos_b: np.ndarray, + joint_type: JointType, + ) -> Joint: + """Create a joint between two bodies with random axis at midpoint.""" + anchor = (pos_a + pos_b) / 2.0 + return Joint( + joint_id=joint_id, + body_a=body_a_id, + body_b=body_b_id, + joint_type=joint_type, + anchor_a=anchor, + anchor_b=anchor, + axis=self._random_axis(), + ) + + # ------------------------------------------------------------------ + # Original generators (chain / rigid / overconstrained) + # ------------------------------------------------------------------ + def generate_chain_assembly( self, n_bodies: int, @@ -60,11 +121,7 @@ class SyntheticAssemblyGenerator: bodies.append(RigidBody(body_id=i, position=pos)) for i in range(n_bodies - 1): - axis = self.rng.standard_normal(3) - axis /= np.linalg.norm(axis) - anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0]) - joints.append( Joint( joint_id=i, @@ -73,7 +130,7 @@ class SyntheticAssemblyGenerator: joint_type=joint_type, anchor_a=anchor, anchor_b=anchor, - axis=axis, + axis=self._random_axis(), ) ) @@ -91,26 +148,22 @@ class SyntheticAssemblyGenerator: """ bodies = [] for i in range(n_bodies): - pos = self.rng.uniform(-5, 5, size=3) - bodies.append(RigidBody(body_id=i, position=pos)) + bodies.append(RigidBody(body_id=i, position=self._random_position())) # Build spanning tree with fixed joints (overconstrained) joints: list[Joint] = [] for i in range(1, n_bodies): - parent = self.rng.integers(0, i) + parent = int(self.rng.integers(0, i)) mid = (bodies[i].position + bodies[parent].position) / 2 - axis = self.rng.standard_normal(3) - axis /= np.linalg.norm(axis) - joints.append( Joint( joint_id=i - 1, - body_a=int(parent), + body_a=parent, body_b=i, joint_type=JointType.FIXED, anchor_a=mid, anchor_b=mid, - axis=axis, + axis=self._random_axis(), ) ) @@ -150,8 +203,6 @@ class SyntheticAssemblyGenerator: for _ in range(extra_joints): a, b = self.rng.choice(n_bodies, size=2, replace=False) mid = (bodies[a].position + bodies[b].position) / 2 - axis = self.rng.standard_normal(3) - axis /= np.linalg.norm(axis) _overcon_types = [ JointType.REVOLUTE, @@ -167,7 +218,7 @@ class SyntheticAssemblyGenerator: joint_type=jtype, anchor_a=mid, anchor_b=mid, - axis=axis, + axis=self._random_axis(), ) ) joint_id += 1 @@ -175,24 +226,259 @@ class SyntheticAssemblyGenerator: analysis = analyze_assembly(bodies, joints, ground_body=0) return bodies, joints, analysis + # ------------------------------------------------------------------ + # New topology generators + # ------------------------------------------------------------------ + + def generate_tree_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + branching_factor: int = 3, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a random tree topology with configurable branching. + + Creates a tree where each body can have up to *branching_factor* + children. Different branches can use different joint types if a + list is provided. Always underconstrained (no closed loops). + + Args: + n_bodies: Total bodies (root + children). + joint_types: Single type or list to sample from per joint. + branching_factor: Max children per parent (1-5 recommended). + """ + bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + joints: list[Joint] = [] + + available_parents = [0] + next_id = 1 + joint_id = 0 + + while next_id < n_bodies and available_parents: + pidx = int(self.rng.integers(0, len(available_parents))) + parent_id = available_parents[pidx] + parent_pos = bodies[parent_id].position + + max_children = min(branching_factor, n_bodies - next_id) + n_children = int(self.rng.integers(1, max_children + 1)) + + for _ in range(n_children): + direction = self._random_axis() + distance = self.rng.uniform(1.5, 3.0) + child_pos = parent_pos + direction * distance + + bodies.append(RigidBody(body_id=next_id, position=child_pos)) + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype) + ) + + available_parents.append(next_id) + next_id += 1 + joint_id += 1 + if next_id >= n_bodies: + break + + # Retire parent if it reached branching limit or randomly + if n_children >= branching_factor or self.rng.random() < 0.3: + available_parents.pop(pidx) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_loop_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a single closed loop (ring) of bodies. + + The closing constraint introduces redundancy, making this + useful for generating overconstrained training data. + + Args: + n_bodies: Bodies in the ring (>= 3). + joint_types: Single type or list to sample from per joint. + + Raises: + ValueError: If *n_bodies* < 3. + """ + if n_bodies < 3: + msg = "Loop assembly requires at least 3 bodies" + raise ValueError(msg) + + bodies: list[RigidBody] = [] + joints: list[Joint] = [] + + base_radius = max(2.0, n_bodies * 0.4) + for i in range(n_bodies): + angle = 2 * np.pi * i / n_bodies + radius = base_radius + self.rng.uniform(-0.5, 0.5) + x = radius * np.cos(angle) + y = radius * np.sin(angle) + z = float(self.rng.uniform(-0.2, 0.2)) + bodies.append(RigidBody(body_id=i, position=np.array([x, y, z]))) + + for i in range(n_bodies): + next_i = (i + 1) % n_bodies + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype) + ) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_star_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a star topology with central hub and satellites. + + Body 0 is the hub; all other bodies connect directly to it. + Underconstrained because there are no inter-satellite connections. + + Args: + n_bodies: Total bodies including hub (>= 2). + joint_types: Single type or list to sample from per joint. + + Raises: + ValueError: If *n_bodies* < 2. + """ + if n_bodies < 2: + msg = "Star assembly requires at least 2 bodies" + raise ValueError(msg) + + bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + joints: list[Joint] = [] + + for i in range(1, n_bodies): + direction = self._random_axis() + distance = self.rng.uniform(2.0, 5.0) + pos = direction * distance + bodies.append(RigidBody(body_id=i, position=pos)) + + jtype = self._select_joint_type(joint_types) + joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype)) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_mixed_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + edge_density: float = 0.3, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a mixed topology combining tree and loop elements. + + Builds a spanning tree for connectivity, then adds extra edges + based on *edge_density* to create loops and redundancy. + + Args: + n_bodies: Number of bodies. + joint_types: Single type or list to sample from per joint. + edge_density: Fraction of non-tree edges to add (0.0-1.0). + + Raises: + ValueError: If *edge_density* not in [0.0, 1.0]. + """ + if not 0.0 <= edge_density <= 1.0: + msg = "edge_density must be in [0.0, 1.0]" + raise ValueError(msg) + + bodies: list[RigidBody] = [] + joints: list[Joint] = [] + + for i in range(n_bodies): + bodies.append(RigidBody(body_id=i, position=self._random_position())) + + # Phase 1: spanning tree + joint_id = 0 + existing_edges: set[frozenset[int]] = set() + for i in range(1, n_bodies): + parent = int(self.rng.integers(0, i)) + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint( + joint_id, + parent, + i, + bodies[parent].position, + bodies[i].position, + jtype, + ) + ) + existing_edges.add(frozenset([parent, i])) + joint_id += 1 + + # Phase 2: extra edges based on density + candidates: list[tuple[int, int]] = [] + for i in range(n_bodies): + for j in range(i + 1, n_bodies): + if frozenset([i, j]) not in existing_edges: + candidates.append((i, j)) + + n_extra = int(edge_density * len(candidates)) + self.rng.shuffle(candidates) + + for a, b in candidates[:n_extra]: + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint( + joint_id, + a, + b, + bodies[a].position, + bodies[b].position, + jtype, + ) + ) + joint_id += 1 + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + # ------------------------------------------------------------------ + # Batch generation + # ------------------------------------------------------------------ + def generate_training_batch( self, batch_size: int = 100, - n_bodies_range: tuple[int, int] = (3, 8), + n_bodies_range: tuple[int, int] | None = None, + complexity_tier: ComplexityTier | None = None, ) -> list[dict[str, Any]]: """Generate a batch of labeled training examples. - Each example contains: - - bodies: list of body positions/orientations - - joints: list of joints with types and parameters - - labels: per-joint independence/redundancy flags - - assembly_label: overall classification + Each example contains body positions, joint descriptions, + per-joint independence labels, and assembly-level classification. + + Args: + batch_size: Number of assemblies to generate. + n_bodies_range: ``(min, max_exclusive)`` body count. + Overridden by *complexity_tier* when both are given. + complexity_tier: Predefined range (``"simple"`` / ``"medium"`` + / ``"complex"``). Overrides *n_bodies_range*. """ + if complexity_tier is not None: + n_bodies_range = COMPLEXITY_RANGES[complexity_tier] + elif n_bodies_range is None: + n_bodies_range = (3, 8) + + _joint_pool = [ + JointType.REVOLUTE, + JointType.BALL, + JointType.CYLINDRICAL, + JointType.FIXED, + ] + examples: list[dict[str, Any]] = [] for i in range(batch_size): n = int(self.rng.integers(*n_bodies_range)) - gen_idx = int(self.rng.integers(3)) + gen_idx = int(self.rng.integers(7)) if gen_idx == 0: _chain_types = [ @@ -202,11 +488,30 @@ class SyntheticAssemblyGenerator: ] jtype = _chain_types[int(self.rng.integers(len(_chain_types)))] bodies, joints, analysis = self.generate_chain_assembly(n, jtype) + gen_name = "chain" elif gen_idx == 1: bodies, joints, analysis = self.generate_rigid_assembly(n) - else: + gen_name = "rigid" + elif gen_idx == 2: extra = int(self.rng.integers(1, 4)) bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra) + gen_name = "overconstrained" + elif gen_idx == 3: + branching = int(self.rng.integers(2, 5)) + bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching) + gen_name = "tree" + elif gen_idx == 4: + n = max(n, 3) + bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool) + gen_name = "loop" + elif gen_idx == 5: + n = max(n, 2) + bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool) + gen_name = "star" + else: + density = float(self.rng.uniform(0.2, 0.5)) + bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density) + gen_name = "mixed" # Build per-joint labels from edge results joint_labels: dict[int, dict[str, int]] = {} @@ -227,6 +532,7 @@ class SyntheticAssemblyGenerator: examples.append( { "example_id": i, + "generator_type": gen_name, "n_bodies": len(bodies), "n_joints": len(joints), "body_positions": [b.position.tolist() for b in bodies], diff --git a/tests/datagen/test_generator.py b/tests/datagen/test_generator.py index 06c10a5..fa5d70f 100644 --- a/tests/datagen/test_generator.py +++ b/tests/datagen/test_generator.py @@ -2,11 +2,18 @@ from __future__ import annotations +from typing import ClassVar + +import numpy as np import pytest -from solver.datagen.generator import SyntheticAssemblyGenerator +from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator from solver.datagen.types import JointType +# --------------------------------------------------------------------------- +# Original generators (chain / rigid / overconstrained) +# --------------------------------------------------------------------------- + class TestChainAssembly: """generate_chain_assembly produces valid underconstrained chains.""" @@ -83,66 +90,259 @@ class TestOverconstrainedAssembly: assert len(joints_over) > len(joints_base) +# --------------------------------------------------------------------------- +# New topology generators +# --------------------------------------------------------------------------- + + +class TestTreeAssembly: + """generate_tree_assembly produces tree-structured assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_tree_assembly(6) + assert len(bodies) == 6 + assert len(joints) == 5 # n - 1 + + def test_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_tree_assembly(6) + assert analysis.combinatorial_classification == "underconstrained" + + def test_branching_factor(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2) + assert len(bodies) == 10 + assert len(joints) == 9 + + def test_mixed_joint_types(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED] + _, joints, _ = gen.generate_tree_assembly(10, joint_types=types) + used = {j.joint_type for j in joints} + # With 9 joints and 3 types, very likely to use at least 2 + assert len(used) >= 2 + + def test_single_joint_type(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL) + assert all(j.joint_type is JointType.BALL for j in joints) + + def test_sequential_body_ids(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_tree_assembly(7) + assert [b.body_id for b in bodies] == list(range(7)) + + +class TestLoopAssembly: + """generate_loop_assembly produces closed-loop assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_loop_assembly(5) + assert len(bodies) == 5 + assert len(joints) == 5 # n joints for n bodies + + def test_has_redundancy(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_loop_assembly(5) + assert analysis.combinatorial_redundant > 0 + + def test_wrap_around_connectivity(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_loop_assembly(4) + edges = {(j.body_a, j.body_b) for j in joints} + assert (0, 1) in edges + assert (1, 2) in edges + assert (2, 3) in edges + assert (3, 0) in edges # wrap-around + + def test_minimum_bodies_error(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="at least 3"): + gen.generate_loop_assembly(2) + + def test_mixed_joint_types(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + types = [JointType.REVOLUTE, JointType.FIXED] + _, joints, _ = gen.generate_loop_assembly(8, joint_types=types) + used = {j.joint_type for j in joints} + assert len(used) >= 2 + + +class TestStarAssembly: + """generate_star_assembly produces hub-and-spoke assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_star_assembly(6) + assert len(bodies) == 6 + assert len(joints) == 5 # n - 1 + + def test_all_joints_connect_to_hub(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_star_assembly(6) + for j in joints: + assert j.body_a == 0 or j.body_b == 0 + + def test_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_star_assembly(5) + assert analysis.combinatorial_classification == "underconstrained" + + def test_minimum_bodies_error(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="at least 2"): + gen.generate_star_assembly(1) + + def test_hub_at_origin(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_star_assembly(4) + np.testing.assert_array_equal(bodies[0].position, np.zeros(3)) + + +class TestMixedAssembly: + """generate_mixed_assembly produces tree+loop hybrid assemblies.""" + + def test_more_joints_than_tree(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3) + assert len(joints) > len(bodies) - 1 + + def test_density_zero_is_tree(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0) + assert len(joints) == 4 # spanning tree only + + def test_density_validation(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="must be in"): + gen.generate_mixed_assembly(5, edge_density=1.5) + with pytest.raises(ValueError, match="must be in"): + gen.generate_mixed_assembly(5, edge_density=-0.1) + + def test_no_duplicate_edges(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5) + edges = [frozenset([j.body_a, j.body_b]) for j in joints] + assert len(edges) == len(set(edges)) + + def test_high_density(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0) + # Fully connected: 5*(5-1)/2 = 10 edges + assert len(joints) == 10 + + +# --------------------------------------------------------------------------- +# Complexity tiers +# --------------------------------------------------------------------------- + + +class TestComplexityTiers: + """Complexity tier parameter on batch generation.""" + + def test_simple_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, complexity_tier="simple") + lo, hi = COMPLEXITY_RANGES["simple"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_medium_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, complexity_tier="medium") + lo, hi = COMPLEXITY_RANGES["medium"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_complex_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(3, complexity_tier="complex") + lo, hi = COMPLEXITY_RANGES["complex"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_tier_overrides_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium") + lo, hi = COMPLEXITY_RANGES["medium"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + +# --------------------------------------------------------------------------- +# Training batch +# --------------------------------------------------------------------------- + + class TestTrainingBatch: """generate_training_batch produces well-structured examples.""" - @pytest.fixture() - def batch(self) -> list[dict]: - gen = SyntheticAssemblyGenerator(seed=42) - return gen.generate_training_batch(batch_size=20, n_bodies_range=(3, 6)) + EXPECTED_KEYS: ClassVar[set[str]] = { + "example_id", + "generator_type", + "n_bodies", + "n_joints", + "body_positions", + "joints", + "joint_labels", + "assembly_classification", + "is_rigid", + "is_minimally_rigid", + "internal_dof", + "geometric_degeneracies", + } - def test_batch_size(self, batch: list[dict]) -> None: + VALID_GEN_TYPES: ClassVar[set[str]] = { + "chain", + "rigid", + "overconstrained", + "tree", + "loop", + "star", + "mixed", + } + + def test_batch_size(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20) assert len(batch) == 20 - def test_example_keys(self, batch: list[dict]) -> None: - expected = { - "example_id", - "n_bodies", - "n_joints", - "body_positions", - "joints", - "joint_labels", - "assembly_classification", - "is_rigid", - "is_minimally_rigid", - "internal_dof", - "geometric_degeneracies", - } + def test_example_keys(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) for ex in batch: - assert set(ex.keys()) == expected + assert set(ex.keys()) == self.EXPECTED_KEYS - def test_example_ids_sequential(self, batch: list[dict]) -> None: - ids = [ex["example_id"] for ex in batch] - assert ids == list(range(20)) + def test_example_ids_sequential(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(15) + assert [ex["example_id"] for ex in batch] == list(range(15)) - def test_classification_distribution(self, batch: list[dict]) -> None: - """Batch should contain multiple classification types.""" - classes = {ex["assembly_classification"] for ex in batch} - # With the 3-way generator split we expect at least 2 types - assert len(classes) >= 2 - - def test_body_count_in_range(self, batch: list[dict]) -> None: + def test_generator_type_valid(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(50) for ex in batch: - assert 3 <= ex["n_bodies"] <= 5 # range is [3, 6) + assert ex["generator_type"] in self.VALID_GEN_TYPES - def test_joint_labels_match_joints(self, batch: list[dict]) -> None: + def test_generator_type_diversity(self) -> None: + """100-sample batch should use at least 5 of 7 generator types.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(100) + types = {ex["generator_type"] for ex in batch} + assert len(types) >= 5 + + def test_default_body_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(30) for ex in batch: - label_jids = set(ex["joint_labels"].keys()) - joint_jids = {j["joint_id"] for j in ex["joints"]} - assert label_jids == joint_jids + assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp - def test_joint_label_fields(self, batch: list[dict]) -> None: - expected_fields = { - "independent_constraints", - "redundant_constraints", - "total_constraints", - } - for ex in batch: - for label in ex["joint_labels"].values(): - assert set(label.keys()) == expected_fields - - def test_joint_label_consistency(self, batch: list[dict]) -> None: + def test_joint_label_consistency(self) -> None: """independent + redundant == total for every joint.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(30) for ex in batch: for label in ex["joint_labels"].values(): total = label["independent_constraints"] + label["redundant_constraints"] @@ -157,10 +357,17 @@ class TestSeedReproducibility: g2 = SyntheticAssemblyGenerator(seed=2) b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) - # Very unlikely to be identical with different seeds c1 = [ex["assembly_classification"] for ex in b1] c2 = [ex["assembly_classification"] for ex in b2] r1 = [ex["is_rigid"] for ex in b1] r2 = [ex["is_rigid"] for ex in b2] - # At least one of these should differ (probabilistically certain) assert c1 != c2 or r1 != r2 + + def test_same_seed_identical(self) -> None: + g1 = SyntheticAssemblyGenerator(seed=123) + g2 = SyntheticAssemblyGenerator(seed=123) + b1, j1, _ = g1.generate_tree_assembly(5) + b2, j2, _ = g2.generate_tree_assembly(5) + for a, b in zip(b1, b2, strict=True): + np.testing.assert_array_almost_equal(a.position, b.position) + assert len(j1) == len(j2)