From 78289494e274e3796a4bf9e2d5ef97feb3e14ec4 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 14:57:49 -0600 Subject: [PATCH] feat: geometric diversity for synthetic assembly generation - Add AxisStrategy type (cardinal, random, near_parallel) - Add random body orientations via scipy.spatial.transform.Rotation - Add parallel axis injection with configurable probability - Add grounded parameter on all 7 generators (grounded/floating) - Add axis sampling strategies: cardinal, random, near-parallel - Update _create_joint with orientation-aware anchor offsets - Add _resolve_axis helper for parallel axis propagation - Update generate_training_batch with axis_strategy, parallel_axis_prob, grounded_ratio parameters - Add body_orientations and grounded fields to batch output - Export AxisStrategy from datagen package - Add 28 new tests (72 total generator tests, 158 total) Closes #8 --- solver/datagen/__init__.py | 7 +- solver/datagen/generator.py | 462 +++++++++++++++++++++++++++----- tests/datagen/test_generator.py | 321 +++++++++++++++++++++- 3 files changed, 710 insertions(+), 80 deletions(-) diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 203adcc..a1d0c34 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,7 +1,11 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly -from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator +from solver.datagen.generator import ( + COMPLEXITY_RANGES, + AxisStrategy, + SyntheticAssemblyGenerator, +) from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( @@ -14,6 +18,7 @@ from solver.datagen.types import ( __all__ = [ "COMPLEXITY_RANGES", + "AxisStrategy", "ConstraintAnalysis", "JacobianVerifier", "Joint", diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py index cd90ade..15df68f 100644 --- a/solver/datagen/generator.py +++ b/solver/datagen/generator.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Literal import numpy as np +from scipy.spatial.transform import Rotation from solver.datagen.analysis import analyze_assembly from solver.datagen.types import ( @@ -22,7 +23,12 @@ from solver.datagen.types import ( if TYPE_CHECKING: from typing import Any -__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"] +__all__ = [ + "COMPLEXITY_RANGES", + "AxisStrategy", + "ComplexityTier", + "SyntheticAssemblyGenerator", +] # --------------------------------------------------------------------------- # Complexity tiers — ranges use exclusive upper bound for rng.integers() @@ -36,6 +42,12 @@ COMPLEXITY_RANGES: dict[str, tuple[int, int]] = { "complex": (16, 51), } +# --------------------------------------------------------------------------- +# Axis sampling strategies +# --------------------------------------------------------------------------- + +AxisStrategy = Literal["cardinal", "random", "near_parallel"] + class SyntheticAssemblyGenerator: """Generates assembly graphs with known minimal constraint sets. @@ -67,6 +79,63 @@ class SyntheticAssemblyGenerator: axis /= np.linalg.norm(axis) return axis + def _random_orientation(self) -> np.ndarray: + """Generate a random 3x3 rotation matrix.""" + mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix() + return mat + + def _cardinal_axis(self) -> np.ndarray: + """Pick uniformly from the six signed cardinal directions.""" + axes = np.array( + [ + [1, 0, 0], + [-1, 0, 0], + [0, 1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, -1], + ], + dtype=float, + ) + result: np.ndarray = axes[int(self.rng.integers(6))] + return result + + def _near_parallel_axis( + self, + base_axis: np.ndarray, + perturbation_scale: float = 0.05, + ) -> np.ndarray: + """Return *base_axis* with a small random perturbation, re-normalized.""" + perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale + return perturbed / np.linalg.norm(perturbed) + + def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray: + """Sample a joint axis using the specified strategy.""" + if strategy == "cardinal": + return self._cardinal_axis() + if strategy == "near_parallel": + return self._near_parallel_axis(np.array([0.0, 0.0, 1.0])) + return self._random_axis() + + def _resolve_axis( + self, + strategy: AxisStrategy, + parallel_axis_prob: float, + shared_axis: np.ndarray | None, + ) -> tuple[np.ndarray, np.ndarray | None]: + """Return (axis_for_this_joint, shared_axis_to_propagate). + + On the first call where *shared_axis* is ``None`` and parallel + injection triggers, a base axis is chosen and returned as + *shared_axis* for subsequent calls. + """ + if shared_axis is not None: + return self._near_parallel_axis(shared_axis), shared_axis + if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob: + base = self._sample_axis(strategy) + return base.copy(), base + return self._sample_axis(strategy), None + def _select_joint_type( self, joint_types: JointType | list[JointType], @@ -85,17 +154,37 @@ class SyntheticAssemblyGenerator: pos_a: np.ndarray, pos_b: np.ndarray, joint_type: JointType, + *, + axis: np.ndarray | None = None, + orient_a: np.ndarray | None = None, + orient_b: np.ndarray | None = None, ) -> Joint: - """Create a joint between two bodies with random axis at midpoint.""" - anchor = (pos_a + pos_b) / 2.0 + """Create a joint between two bodies. + + When orientations are provided, anchor points are offset from + each body's center along a random local direction rotated into + world frame, rather than placed at the midpoint. + """ + if orient_a is not None and orient_b is not None: + dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1) + offset_scale = dist * 0.2 + local_a = self.rng.standard_normal(3) * offset_scale + local_b = self.rng.standard_normal(3) * offset_scale + anchor_a = pos_a + orient_a @ local_a + anchor_b = pos_b + orient_b @ local_b + else: + anchor = (pos_a + pos_b) / 2.0 + anchor_a = anchor + anchor_b = anchor + return Joint( joint_id=joint_id, body_a=body_a_id, body_b=body_b_id, joint_type=joint_type, - anchor_a=anchor, - anchor_b=anchor, - axis=self._random_axis(), + anchor_a=anchor_a, + anchor_b=anchor_b, + axis=axis if axis is not None else self._random_axis(), ) # ------------------------------------------------------------------ @@ -106,6 +195,10 @@ class SyntheticAssemblyGenerator: self, n_bodies: int, joint_type: JointType = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a serial kinematic chain. @@ -118,27 +211,49 @@ class SyntheticAssemblyGenerator: for i in range(n_bodies): pos = np.array([i * 2.0, 0.0, 0.0]) - bodies.append(RigidBody(body_id=i, position=pos)) - - for i in range(n_bodies - 1): - anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0]) - joints.append( - Joint( - joint_id=i, - body_a=i, - body_b=i + 1, - joint_type=joint_type, - anchor_a=anchor, - anchor_b=anchor, - axis=self._random_axis(), + bodies.append( + RigidBody( + body_id=i, + position=pos, + orientation=self._random_orientation(), ) ) - analysis = analyze_assembly(bodies, joints, ground_body=0) + shared_axis: np.ndarray | None = None + for i in range(n_bodies - 1): + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i, + i, + i + 1, + bodies[i].position, + bodies[i + 1].position, + joint_type, + axis=axis, + orient_a=bodies[i].orientation, + orient_b=bodies[i + 1].orientation, + ) + ) + + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_rigid_assembly( - self, n_bodies: int + self, + n_bodies: int, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a minimally rigid assembly by adding joints until rigid. @@ -148,22 +263,35 @@ class SyntheticAssemblyGenerator: """ bodies = [] for i in range(n_bodies): - bodies.append(RigidBody(body_id=i, position=self._random_position())) + bodies.append( + RigidBody( + body_id=i, + position=self._random_position(), + orientation=self._random_orientation(), + ) + ) # Build spanning tree with fixed joints (overconstrained) joints: list[Joint] = [] + shared_axis: np.ndarray | None = None for i in range(1, n_bodies): parent = int(self.rng.integers(0, i)) - mid = (bodies[i].position + bodies[parent].position) / 2 + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - Joint( - joint_id=i - 1, - body_a=parent, - body_b=i, - joint_type=JointType.FIXED, - anchor_a=mid, - anchor_b=mid, - axis=self._random_axis(), + self._create_joint( + i - 1, + parent, + i, + bodies[parent].position, + bodies[i].position, + JointType.FIXED, + axis=axis, + orient_a=bodies[parent].orientation, + orient_b=bodies[i].orientation, ) ) @@ -174,56 +302,73 @@ class SyntheticAssemblyGenerator: JointType.BALL, ] + ground = 0 if grounded else None for idx in self.rng.permutation(len(joints)): original_type = joints[idx].joint_type for candidate in weaker_types: joints[idx].joint_type = candidate - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly(bodies, joints, ground_body=ground) if analysis.is_rigid: break # Keep the weaker type else: joints[idx].joint_type = original_type - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly(bodies, joints, ground_body=ground) return bodies, joints, analysis def generate_overconstrained_assembly( self, n_bodies: int, extra_joints: int = 2, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate an assembly with known redundant constraints. Starts with a rigid assembly, then adds extra joints that the pebble game will flag as redundant. """ - bodies, joints, _ = self.generate_rigid_assembly(n_bodies) + bodies, joints, _ = self.generate_rigid_assembly( + n_bodies, + grounded=grounded, + axis_strategy=axis_strategy, + parallel_axis_prob=parallel_axis_prob, + ) joint_id = len(joints) + shared_axis: np.ndarray | None = None for _ in range(extra_joints): a, b = self.rng.choice(n_bodies, size=2, replace=False) - mid = (bodies[a].position + bodies[b].position) / 2 - _overcon_types = [ JointType.REVOLUTE, JointType.FIXED, JointType.BALL, ] jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))] + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - Joint( - joint_id=joint_id, - body_a=int(a), - body_b=int(b), - joint_type=jtype, - anchor_a=mid, - anchor_b=mid, - axis=self._random_axis(), + self._create_joint( + joint_id, + int(a), + int(b), + bodies[int(a)].position, + bodies[int(b)].position, + jtype, + axis=axis, + orient_a=bodies[int(a)].orientation, + orient_b=bodies[int(b)].orientation, ) ) joint_id += 1 - analysis = analyze_assembly(bodies, joints, ground_body=0) + ground = 0 if grounded else None + analysis = analyze_assembly(bodies, joints, ground_body=ground) return bodies, joints, analysis # ------------------------------------------------------------------ @@ -235,6 +380,10 @@ class SyntheticAssemblyGenerator: n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, branching_factor: int = 3, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a random tree topology with configurable branching. @@ -247,12 +396,19 @@ class SyntheticAssemblyGenerator: joint_types: Single type or list to sample from per joint. branching_factor: Max children per parent (1-5 recommended). """ - bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + bodies: list[RigidBody] = [ + RigidBody( + body_id=0, + position=np.zeros(3), + orientation=self._random_orientation(), + ) + ] joints: list[Joint] = [] available_parents = [0] next_id = 1 joint_id = 0 + shared_axis: np.ndarray | None = None while next_id < n_bodies and available_parents: pidx = int(self.rng.integers(0, len(available_parents))) @@ -266,11 +422,33 @@ class SyntheticAssemblyGenerator: direction = self._random_axis() distance = self.rng.uniform(1.5, 3.0) child_pos = parent_pos + direction * distance + child_orient = self._random_orientation() - bodies.append(RigidBody(body_id=next_id, position=child_pos)) + bodies.append( + RigidBody( + body_id=next_id, + position=child_pos, + orientation=child_orient, + ) + ) jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype) + self._create_joint( + joint_id, + parent_id, + next_id, + parent_pos, + child_pos, + jtype, + axis=axis, + orient_a=bodies[parent_id].orientation, + orient_b=child_orient, + ) ) available_parents.append(next_id) @@ -283,13 +461,21 @@ class SyntheticAssemblyGenerator: if n_children >= branching_factor or self.rng.random() < 0.3: available_parents.pop(pidx) - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_loop_assembly( self, n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a single closed loop (ring) of bodies. @@ -317,22 +503,52 @@ class SyntheticAssemblyGenerator: x = radius * np.cos(angle) y = radius * np.sin(angle) z = float(self.rng.uniform(-0.2, 0.2)) - bodies.append(RigidBody(body_id=i, position=np.array([x, y, z]))) + bodies.append( + RigidBody( + body_id=i, + position=np.array([x, y, z]), + orientation=self._random_orientation(), + ) + ) + shared_axis: np.ndarray | None = None for i in range(n_bodies): next_i = (i + 1) % n_bodies jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype) + self._create_joint( + i, + i, + next_i, + bodies[i].position, + bodies[next_i].position, + jtype, + axis=axis, + orient_a=bodies[i].orientation, + orient_b=bodies[next_i].orientation, + ) ) - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_star_assembly( self, n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a star topology with central hub and satellites. @@ -350,19 +566,49 @@ class SyntheticAssemblyGenerator: msg = "Star assembly requires at least 2 bodies" raise ValueError(msg) - bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + hub_orient = self._random_orientation() + bodies: list[RigidBody] = [ + RigidBody( + body_id=0, + position=np.zeros(3), + orientation=hub_orient, + ) + ] joints: list[Joint] = [] + shared_axis: np.ndarray | None = None for i in range(1, n_bodies): direction = self._random_axis() distance = self.rng.uniform(2.0, 5.0) pos = direction * distance - bodies.append(RigidBody(body_id=i, position=pos)) + sat_orient = self._random_orientation() + bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient)) jtype = self._select_joint_type(joint_types) - joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype)) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i - 1, + 0, + i, + np.zeros(3), + pos, + jtype, + axis=axis, + orient_a=hub_orient, + orient_b=sat_orient, + ) + ) - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_mixed_assembly( @@ -370,6 +616,10 @@ class SyntheticAssemblyGenerator: n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, edge_density: float = 0.3, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a mixed topology combining tree and loop elements. @@ -392,14 +642,26 @@ class SyntheticAssemblyGenerator: joints: list[Joint] = [] for i in range(n_bodies): - bodies.append(RigidBody(body_id=i, position=self._random_position())) + bodies.append( + RigidBody( + body_id=i, + position=self._random_position(), + orientation=self._random_orientation(), + ) + ) # Phase 1: spanning tree joint_id = 0 existing_edges: set[frozenset[int]] = set() + shared_axis: np.ndarray | None = None for i in range(1, n_bodies): parent = int(self.rng.integers(0, i)) jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( self._create_joint( joint_id, @@ -408,6 +670,9 @@ class SyntheticAssemblyGenerator: bodies[parent].position, bodies[i].position, jtype, + axis=axis, + orient_a=bodies[parent].orientation, + orient_b=bodies[i].orientation, ) ) existing_edges.add(frozenset([parent, i])) @@ -425,6 +690,11 @@ class SyntheticAssemblyGenerator: for a, b in candidates[:n_extra]: jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( self._create_joint( joint_id, @@ -433,11 +703,18 @@ class SyntheticAssemblyGenerator: bodies[a].position, bodies[b].position, jtype, + axis=axis, + orient_a=bodies[a].orientation, + orient_b=bodies[b].orientation, ) ) joint_id += 1 - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis # ------------------------------------------------------------------ @@ -449,6 +726,10 @@ class SyntheticAssemblyGenerator: batch_size: int = 100, n_bodies_range: tuple[int, int] | None = None, complexity_tier: ComplexityTier | None = None, + *, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + grounded_ratio: float = 1.0, ) -> list[dict[str, Any]]: """Generate a batch of labeled training examples. @@ -461,6 +742,9 @@ class SyntheticAssemblyGenerator: Overridden by *complexity_tier* when both are given. complexity_tier: Predefined range (``"simple"`` / ``"medium"`` / ``"complex"``). Overrides *n_bodies_range*. + axis_strategy: Axis sampling strategy for joint axes. + parallel_axis_prob: Probability of parallel axis injection. + grounded_ratio: Fraction of examples that are grounded. """ if complexity_tier is not None: n_bodies_range = COMPLEXITY_RANGES[complexity_tier] @@ -474,11 +758,17 @@ class SyntheticAssemblyGenerator: JointType.FIXED, ] + geo_kw: dict[str, Any] = { + "axis_strategy": axis_strategy, + "parallel_axis_prob": parallel_axis_prob, + } + examples: list[dict[str, Any]] = [] for i in range(batch_size): n = int(self.rng.integers(*n_bodies_range)) gen_idx = int(self.rng.integers(7)) + grounded = bool(self.rng.random() < grounded_ratio) if gen_idx == 0: _chain_types = [ @@ -487,30 +777,66 @@ class SyntheticAssemblyGenerator: JointType.CYLINDRICAL, ] jtype = _chain_types[int(self.rng.integers(len(_chain_types)))] - bodies, joints, analysis = self.generate_chain_assembly(n, jtype) + bodies, joints, analysis = self.generate_chain_assembly( + n, + jtype, + grounded=grounded, + **geo_kw, + ) gen_name = "chain" elif gen_idx == 1: - bodies, joints, analysis = self.generate_rigid_assembly(n) + bodies, joints, analysis = self.generate_rigid_assembly( + n, + grounded=grounded, + **geo_kw, + ) gen_name = "rigid" elif gen_idx == 2: extra = int(self.rng.integers(1, 4)) - bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra) + bodies, joints, analysis = self.generate_overconstrained_assembly( + n, + extra, + grounded=grounded, + **geo_kw, + ) gen_name = "overconstrained" elif gen_idx == 3: branching = int(self.rng.integers(2, 5)) - bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching) + bodies, joints, analysis = self.generate_tree_assembly( + n, + _joint_pool, + branching, + grounded=grounded, + **geo_kw, + ) gen_name = "tree" elif gen_idx == 4: n = max(n, 3) - bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool) + bodies, joints, analysis = self.generate_loop_assembly( + n, + _joint_pool, + grounded=grounded, + **geo_kw, + ) gen_name = "loop" elif gen_idx == 5: n = max(n, 2) - bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool) + bodies, joints, analysis = self.generate_star_assembly( + n, + _joint_pool, + grounded=grounded, + **geo_kw, + ) gen_name = "star" else: density = float(self.rng.uniform(0.2, 0.5)) - bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density) + bodies, joints, analysis = self.generate_mixed_assembly( + n, + _joint_pool, + density, + grounded=grounded, + **geo_kw, + ) gen_name = "mixed" # Build per-joint labels from edge results @@ -533,9 +859,11 @@ class SyntheticAssemblyGenerator: { "example_id": i, "generator_type": gen_name, + "grounded": grounded, "n_bodies": len(bodies), "n_joints": len(joints), "body_positions": [b.position.tolist() for b in bodies], + "body_orientations": [b.orientation.tolist() for b in bodies], "joints": [ { "joint_id": j.joint_id, @@ -547,11 +875,11 @@ class SyntheticAssemblyGenerator: for j in joints ], "joint_labels": joint_labels, - "assembly_classification": analysis.combinatorial_classification, + "assembly_classification": (analysis.combinatorial_classification), "is_rigid": analysis.is_rigid, "is_minimally_rigid": analysis.is_minimally_rigid, "internal_dof": analysis.jacobian_internal_dof, - "geometric_degeneracies": analysis.geometric_degeneracies, + "geometric_degeneracies": (analysis.geometric_degeneracies), } ) diff --git a/tests/datagen/test_generator.py b/tests/datagen/test_generator.py index fa5d70f..dc70baa 100644 --- a/tests/datagen/test_generator.py +++ b/tests/datagen/test_generator.py @@ -44,7 +44,10 @@ class TestChainAssembly: def test_chain_custom_joint_type(self) -> None: gen = SyntheticAssemblyGenerator(seed=0) - _, joints, _ = gen.generate_chain_assembly(3, joint_type=JointType.BALL) + _, joints, _ = gen.generate_chain_assembly( + 3, + joint_type=JointType.BALL, + ) assert all(j.joint_type is JointType.BALL for j in joints) @@ -77,7 +80,10 @@ class TestOverconstrainedAssembly: def test_has_redundant(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _, _, analysis = gen.generate_overconstrained_assembly(4, extra_joints=2) + _, _, analysis = gen.generate_overconstrained_assembly( + 4, + extra_joints=2, + ) assert analysis.combinatorial_redundant > 0 def test_extra_joints_added(self) -> None: @@ -85,7 +91,10 @@ class TestOverconstrainedAssembly: _, joints_base, _ = gen.generate_rigid_assembly(4) gen2 = SyntheticAssemblyGenerator(seed=42) - _, joints_over, _ = gen2.generate_overconstrained_assembly(4, extra_joints=3) + _, joints_over, _ = gen2.generate_overconstrained_assembly( + 4, + extra_joints=3, + ) # Overconstrained has base joints + extra assert len(joints_over) > len(joints_base) @@ -111,7 +120,10 @@ class TestTreeAssembly: def test_branching_factor(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2) + bodies, joints, _ = gen.generate_tree_assembly( + 10, + branching_factor=2, + ) assert len(bodies) == 10 assert len(joints) == 9 @@ -125,7 +137,10 @@ class TestTreeAssembly: def test_single_joint_type(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL) + _, joints, _ = gen.generate_tree_assembly( + 5, + joint_types=JointType.BALL, + ) assert all(j.joint_type is JointType.BALL for j in joints) def test_sequential_body_ids(self) -> None: @@ -206,12 +221,18 @@ class TestMixedAssembly: def test_more_joints_than_tree(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3) + bodies, joints, _ = gen.generate_mixed_assembly( + 8, + edge_density=0.3, + ) assert len(joints) > len(bodies) - 1 def test_density_zero_is_tree(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0) + _bodies, joints, _ = gen.generate_mixed_assembly( + 5, + edge_density=0.0, + ) assert len(joints) == 4 # spanning tree only def test_density_validation(self) -> None: @@ -229,11 +250,210 @@ class TestMixedAssembly: def test_high_density(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0) + _bodies, joints, _ = gen.generate_mixed_assembly( + 5, + edge_density=1.0, + ) # Fully connected: 5*(5-1)/2 = 10 edges assert len(joints) == 10 +# --------------------------------------------------------------------------- +# Axis sampling strategies +# --------------------------------------------------------------------------- + + +class TestAxisStrategy: + """Axis sampling strategies produce valid unit vectors.""" + + def test_cardinal_axis_from_six(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axes = {tuple(gen._cardinal_axis()) for _ in range(200)} + expected = { + (1, 0, 0), + (-1, 0, 0), + (0, 1, 0), + (0, -1, 0), + (0, 0, 1), + (0, 0, -1), + } + assert axes == expected + + def test_random_axis_unit_norm(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + for _ in range(50): + axis = gen._sample_axis("random") + assert abs(np.linalg.norm(axis) - 1.0) < 1e-10 + + def test_near_parallel_close_to_base(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + base = np.array([0.0, 0.0, 1.0]) + for _ in range(50): + axis = gen._near_parallel_axis(base) + assert abs(np.linalg.norm(axis) - 1.0) < 1e-10 + assert np.dot(axis, base) > 0.95 + + def test_sample_axis_cardinal(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axis = gen._sample_axis("cardinal") + cardinals = [ + np.array(v, dtype=float) + for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)] + ] + assert any(np.allclose(axis, c) for c in cardinals) + + def test_sample_axis_near_parallel(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axis = gen._sample_axis("near_parallel") + z = np.array([0.0, 0.0, 1.0]) + assert np.dot(axis, z) > 0.95 + + +# --------------------------------------------------------------------------- +# Geometric diversity: orientations +# --------------------------------------------------------------------------- + + +class TestRandomOrientations: + """Bodies should have non-identity orientations.""" + + def test_bodies_have_orientations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_tree_assembly(5) + non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3))) + assert non_identity >= 3 + + def test_orientations_are_valid_rotations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_star_assembly(6) + for b in bodies: + r = b.orientation + # R^T R == I + np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10) + # det(R) == 1 + assert abs(np.linalg.det(r) - 1.0) < 1e-10 + + def test_all_generators_set_orientations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + # Chain + bodies, _, _ = gen.generate_chain_assembly(3) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + # Loop + bodies, _, _ = gen.generate_loop_assembly(4) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + # Mixed + bodies, _, _ = gen.generate_mixed_assembly(4) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + + +# --------------------------------------------------------------------------- +# Geometric diversity: grounded parameter +# --------------------------------------------------------------------------- + + +class TestGroundedParameter: + """Grounded parameter controls ground_body in analysis.""" + + def test_chain_grounded_default(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_chain_assembly(4) + assert analysis.combinatorial_dof >= 0 + + def test_chain_floating(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_chain_assembly( + 4, + grounded=False, + ) + # Floating: 6 trivial DOF not subtracted by ground + assert analysis.combinatorial_dof >= 6 + + def test_floating_vs_grounded_dof_difference(self) -> None: + gen1 = SyntheticAssemblyGenerator(seed=42) + _, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True) + gen2 = SyntheticAssemblyGenerator(seed=42) + _, _, a_floating = gen2.generate_chain_assembly(4, grounded=False) + # Floating should have higher DOF due to missing ground constraint + assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof + + @pytest.mark.parametrize( + "gen_method", + [ + "generate_chain_assembly", + "generate_rigid_assembly", + "generate_tree_assembly", + "generate_loop_assembly", + "generate_star_assembly", + "generate_mixed_assembly", + ], + ) + def test_all_generators_accept_grounded( + self, + gen_method: str, + ) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + method = getattr(gen, gen_method) + n = 4 + # Should not raise + if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"): + method(n, grounded=False) + else: + method(n, grounded=False) + + +# --------------------------------------------------------------------------- +# Geometric diversity: parallel axis injection +# --------------------------------------------------------------------------- + + +class TestParallelAxisInjection: + """parallel_axis_prob causes shared axis direction.""" + + def test_parallel_axes_similar(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_chain_assembly( + 6, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + # Near-parallel: |dot| close to 1 + assert abs(np.dot(j.axis, base)) > 0.9 + + def test_zero_prob_no_forced_parallel(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_chain_assembly( + 6, + parallel_axis_prob=0.0, + ) + base = joints[0].axis + dots = [abs(np.dot(j.axis, base)) for j in joints[1:]] + # With 5 random axes, extremely unlikely all are parallel + assert min(dots) < 0.95 + + def test_parallel_on_loop(self) -> None: + """Parallel axes on a loop assembly.""" + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_loop_assembly( + 5, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + assert abs(np.dot(j.axis, base)) > 0.9 + + def test_parallel_on_star(self) -> None: + """Parallel axes on a star assembly.""" + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_star_assembly( + 5, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + assert abs(np.dot(j.axis, base)) > 0.9 + + # --------------------------------------------------------------------------- # Complexity tiers # --------------------------------------------------------------------------- @@ -265,7 +485,11 @@ class TestComplexityTiers: def test_tier_overrides_range(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium") + batch = gen.generate_training_batch( + 10, + n_bodies_range=(2, 3), + complexity_tier="medium", + ) lo, hi = COMPLEXITY_RANGES["medium"] for ex in batch: assert lo <= ex["n_bodies"] < hi @@ -282,9 +506,11 @@ class TestTrainingBatch: EXPECTED_KEYS: ClassVar[set[str]] = { "example_id", "generator_type", + "grounded", "n_bodies", "n_joints", "body_positions", + "body_orientations", "joints", "joint_labels", "assembly_classification", @@ -337,7 +563,8 @@ class TestTrainingBatch: gen = SyntheticAssemblyGenerator(seed=42) batch = gen.generate_training_batch(30) for ex in batch: - assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp + # default (3, 8), but loop/star may clamp + assert 2 <= ex["n_bodies"] <= 7 def test_joint_label_consistency(self) -> None: """independent + redundant == total for every joint.""" @@ -348,6 +575,70 @@ class TestTrainingBatch: total = label["independent_constraints"] + label["redundant_constraints"] assert total == label["total_constraints"] + def test_body_orientations_present(self) -> None: + """Each example includes body_orientations as 3x3 lists.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + orients = ex["body_orientations"] + assert len(orients) == ex["n_bodies"] + for o in orients: + assert len(o) == 3 + assert len(o[0]) == 3 + + def test_grounded_field_present(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + assert isinstance(ex["grounded"], bool) + + +# --------------------------------------------------------------------------- +# Batch grounded ratio +# --------------------------------------------------------------------------- + + +class TestBatchGroundedRatio: + """grounded_ratio controls the mix in batch generation.""" + + def test_all_grounded(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, grounded_ratio=1.0) + assert all(ex["grounded"] for ex in batch) + + def test_none_grounded(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, grounded_ratio=0.0) + assert not any(ex["grounded"] for ex in batch) + + def test_mixed_ratio(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(100, grounded_ratio=0.5) + grounded_count = sum(1 for ex in batch if ex["grounded"]) + # With 100 samples and p=0.5, should be roughly 50 +/- 20 + assert 20 < grounded_count < 80 + + def test_batch_axis_strategy_cardinal(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch( + 10, + axis_strategy="cardinal", + ) + assert len(batch) == 10 + + def test_batch_parallel_axis_prob(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch( + 10, + parallel_axis_prob=0.5, + ) + assert len(batch) == 10 + + +# --------------------------------------------------------------------------- +# Seed reproducibility +# --------------------------------------------------------------------------- + class TestSeedReproducibility: """Different seeds produce different results.""" @@ -355,8 +646,14 @@ class TestSeedReproducibility: def test_different_seeds_differ(self) -> None: g1 = SyntheticAssemblyGenerator(seed=1) g2 = SyntheticAssemblyGenerator(seed=2) - b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) - b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) + b1 = g1.generate_training_batch( + batch_size=5, + n_bodies_range=(3, 6), + ) + b2 = g2.generate_training_batch( + batch_size=5, + n_bodies_range=(3, 6), + ) c1 = [ex["assembly_classification"] for ex in b1] c2 = [ex["assembly_classification"] for ex in b2] r1 = [ex["is_rigid"] for ex in b1]