feat: parameterized assembly templates and complexity tiers
Add 4 new topology generators to SyntheticAssemblyGenerator: - generate_tree_assembly: random spanning tree with configurable branching - generate_loop_assembly: closed ring producing overconstrained data - generate_star_assembly: hub-and-spoke topology - generate_mixed_assembly: tree + loops with configurable edge density Each accepts joint_types as JointType | list[JointType] for per-joint type sampling. Add complexity tiers (simple/medium/complex) with predefined body count ranges via COMPLEXITY_RANGES dict and ComplexityTier type alias. Update generate_training_batch with 7-way generator selection, complexity_tier parameter, and generator_type field in output dicts. Extract private helpers (_random_position, _random_axis, _select_joint_type, _create_joint) to reduce duplication. 44 generator tests, 130 total — all passing. Closes #7
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
"""Data generation utilities for assembly constraint training data."""
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
@@ -13,6 +13,7 @@ from solver.datagen.types import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"COMPLEXITY_RANGES",
|
||||
"ConstraintAnalysis",
|
||||
"JacobianVerifier",
|
||||
"Joint",
|
||||
|
||||
@@ -7,7 +7,7 @@ with per-constraint independence flags and assembly-level classification.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
|
||||
@@ -22,7 +22,19 @@ from solver.datagen.types import (
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["SyntheticAssemblyGenerator"]
|
||||
__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ComplexityTier = Literal["simple", "medium", "complex"]
|
||||
|
||||
COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
|
||||
"simple": (2, 6),
|
||||
"medium": (6, 16),
|
||||
"complex": (16, 51),
|
||||
}
|
||||
|
||||
|
||||
class SyntheticAssemblyGenerator:
|
||||
@@ -41,6 +53,55 @@ class SyntheticAssemblyGenerator:
|
||||
def __init__(self, seed: int = 42) -> None:
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _random_position(self, scale: float = 5.0) -> np.ndarray:
|
||||
"""Generate random 3D position within [-scale, scale] cube."""
|
||||
return self.rng.uniform(-scale, scale, size=3)
|
||||
|
||||
def _random_axis(self) -> np.ndarray:
|
||||
"""Generate random normalized 3D axis."""
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
return axis
|
||||
|
||||
def _select_joint_type(
|
||||
self,
|
||||
joint_types: JointType | list[JointType],
|
||||
) -> JointType:
|
||||
"""Select a joint type from a single type or list."""
|
||||
if isinstance(joint_types, list):
|
||||
idx = int(self.rng.integers(0, len(joint_types)))
|
||||
return joint_types[idx]
|
||||
return joint_types
|
||||
|
||||
def _create_joint(
|
||||
self,
|
||||
joint_id: int,
|
||||
body_a_id: int,
|
||||
body_b_id: int,
|
||||
pos_a: np.ndarray,
|
||||
pos_b: np.ndarray,
|
||||
joint_type: JointType,
|
||||
) -> Joint:
|
||||
"""Create a joint between two bodies with random axis at midpoint."""
|
||||
anchor = (pos_a + pos_b) / 2.0
|
||||
return Joint(
|
||||
joint_id=joint_id,
|
||||
body_a=body_a_id,
|
||||
body_b=body_b_id,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor,
|
||||
anchor_b=anchor,
|
||||
axis=self._random_axis(),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Original generators (chain / rigid / overconstrained)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_chain_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
@@ -60,11 +121,7 @@ class SyntheticAssemblyGenerator:
|
||||
bodies.append(RigidBody(body_id=i, position=pos))
|
||||
|
||||
for i in range(n_bodies - 1):
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
|
||||
anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0])
|
||||
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=i,
|
||||
@@ -73,7 +130,7 @@ class SyntheticAssemblyGenerator:
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor,
|
||||
anchor_b=anchor,
|
||||
axis=axis,
|
||||
axis=self._random_axis(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -91,26 +148,22 @@ class SyntheticAssemblyGenerator:
|
||||
"""
|
||||
bodies = []
|
||||
for i in range(n_bodies):
|
||||
pos = self.rng.uniform(-5, 5, size=3)
|
||||
bodies.append(RigidBody(body_id=i, position=pos))
|
||||
bodies.append(RigidBody(body_id=i, position=self._random_position()))
|
||||
|
||||
# Build spanning tree with fixed joints (overconstrained)
|
||||
joints: list[Joint] = []
|
||||
for i in range(1, n_bodies):
|
||||
parent = self.rng.integers(0, i)
|
||||
parent = int(self.rng.integers(0, i))
|
||||
mid = (bodies[i].position + bodies[parent].position) / 2
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=i - 1,
|
||||
body_a=int(parent),
|
||||
body_a=parent,
|
||||
body_b=i,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=mid,
|
||||
anchor_b=mid,
|
||||
axis=axis,
|
||||
axis=self._random_axis(),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -150,8 +203,6 @@ class SyntheticAssemblyGenerator:
|
||||
for _ in range(extra_joints):
|
||||
a, b = self.rng.choice(n_bodies, size=2, replace=False)
|
||||
mid = (bodies[a].position + bodies[b].position) / 2
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
|
||||
_overcon_types = [
|
||||
JointType.REVOLUTE,
|
||||
@@ -167,7 +218,7 @@ class SyntheticAssemblyGenerator:
|
||||
joint_type=jtype,
|
||||
anchor_a=mid,
|
||||
anchor_b=mid,
|
||||
axis=axis,
|
||||
axis=self._random_axis(),
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
@@ -175,24 +226,259 @@ class SyntheticAssemblyGenerator:
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# New topology generators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_tree_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
branching_factor: int = 3,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a random tree topology with configurable branching.
|
||||
|
||||
Creates a tree where each body can have up to *branching_factor*
|
||||
children. Different branches can use different joint types if a
|
||||
list is provided. Always underconstrained (no closed loops).
|
||||
|
||||
Args:
|
||||
n_bodies: Total bodies (root + children).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
branching_factor: Max children per parent (1-5 recommended).
|
||||
"""
|
||||
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
|
||||
joints: list[Joint] = []
|
||||
|
||||
available_parents = [0]
|
||||
next_id = 1
|
||||
joint_id = 0
|
||||
|
||||
while next_id < n_bodies and available_parents:
|
||||
pidx = int(self.rng.integers(0, len(available_parents)))
|
||||
parent_id = available_parents[pidx]
|
||||
parent_pos = bodies[parent_id].position
|
||||
|
||||
max_children = min(branching_factor, n_bodies - next_id)
|
||||
n_children = int(self.rng.integers(1, max_children + 1))
|
||||
|
||||
for _ in range(n_children):
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(1.5, 3.0)
|
||||
child_pos = parent_pos + direction * distance
|
||||
|
||||
bodies.append(RigidBody(body_id=next_id, position=child_pos))
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
joints.append(
|
||||
self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype)
|
||||
)
|
||||
|
||||
available_parents.append(next_id)
|
||||
next_id += 1
|
||||
joint_id += 1
|
||||
if next_id >= n_bodies:
|
||||
break
|
||||
|
||||
# Retire parent if it reached branching limit or randomly
|
||||
if n_children >= branching_factor or self.rng.random() < 0.3:
|
||||
available_parents.pop(pidx)
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_loop_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a single closed loop (ring) of bodies.
|
||||
|
||||
The closing constraint introduces redundancy, making this
|
||||
useful for generating overconstrained training data.
|
||||
|
||||
Args:
|
||||
n_bodies: Bodies in the ring (>= 3).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
|
||||
Raises:
|
||||
ValueError: If *n_bodies* < 3.
|
||||
"""
|
||||
if n_bodies < 3:
|
||||
msg = "Loop assembly requires at least 3 bodies"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = []
|
||||
joints: list[Joint] = []
|
||||
|
||||
base_radius = max(2.0, n_bodies * 0.4)
|
||||
for i in range(n_bodies):
|
||||
angle = 2 * np.pi * i / n_bodies
|
||||
radius = base_radius + self.rng.uniform(-0.5, 0.5)
|
||||
x = radius * np.cos(angle)
|
||||
y = radius * np.sin(angle)
|
||||
z = float(self.rng.uniform(-0.2, 0.2))
|
||||
bodies.append(RigidBody(body_id=i, position=np.array([x, y, z])))
|
||||
|
||||
for i in range(n_bodies):
|
||||
next_i = (i + 1) % n_bodies
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
joints.append(
|
||||
self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_star_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a star topology with central hub and satellites.
|
||||
|
||||
Body 0 is the hub; all other bodies connect directly to it.
|
||||
Underconstrained because there are no inter-satellite connections.
|
||||
|
||||
Args:
|
||||
n_bodies: Total bodies including hub (>= 2).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
|
||||
Raises:
|
||||
ValueError: If *n_bodies* < 2.
|
||||
"""
|
||||
if n_bodies < 2:
|
||||
msg = "Star assembly requires at least 2 bodies"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
|
||||
joints: list[Joint] = []
|
||||
|
||||
for i in range(1, n_bodies):
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(2.0, 5.0)
|
||||
pos = direction * distance
|
||||
bodies.append(RigidBody(body_id=i, position=pos))
|
||||
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype))
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_mixed_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
edge_density: float = 0.3,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a mixed topology combining tree and loop elements.
|
||||
|
||||
Builds a spanning tree for connectivity, then adds extra edges
|
||||
based on *edge_density* to create loops and redundancy.
|
||||
|
||||
Args:
|
||||
n_bodies: Number of bodies.
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
edge_density: Fraction of non-tree edges to add (0.0-1.0).
|
||||
|
||||
Raises:
|
||||
ValueError: If *edge_density* not in [0.0, 1.0].
|
||||
"""
|
||||
if not 0.0 <= edge_density <= 1.0:
|
||||
msg = "edge_density must be in [0.0, 1.0]"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = []
|
||||
joints: list[Joint] = []
|
||||
|
||||
for i in range(n_bodies):
|
||||
bodies.append(RigidBody(body_id=i, position=self._random_position()))
|
||||
|
||||
# Phase 1: spanning tree
|
||||
joint_id = 0
|
||||
existing_edges: set[frozenset[int]] = set()
|
||||
for i in range(1, n_bodies):
|
||||
parent = int(self.rng.integers(0, i))
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
parent,
|
||||
i,
|
||||
bodies[parent].position,
|
||||
bodies[i].position,
|
||||
jtype,
|
||||
)
|
||||
)
|
||||
existing_edges.add(frozenset([parent, i]))
|
||||
joint_id += 1
|
||||
|
||||
# Phase 2: extra edges based on density
|
||||
candidates: list[tuple[int, int]] = []
|
||||
for i in range(n_bodies):
|
||||
for j in range(i + 1, n_bodies):
|
||||
if frozenset([i, j]) not in existing_edges:
|
||||
candidates.append((i, j))
|
||||
|
||||
n_extra = int(edge_density * len(candidates))
|
||||
self.rng.shuffle(candidates)
|
||||
|
||||
for a, b in candidates[:n_extra]:
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
a,
|
||||
b,
|
||||
bodies[a].position,
|
||||
bodies[b].position,
|
||||
jtype,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Batch generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_training_batch(
|
||||
self,
|
||||
batch_size: int = 100,
|
||||
n_bodies_range: tuple[int, int] = (3, 8),
|
||||
n_bodies_range: tuple[int, int] | None = None,
|
||||
complexity_tier: ComplexityTier | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate a batch of labeled training examples.
|
||||
|
||||
Each example contains:
|
||||
- bodies: list of body positions/orientations
|
||||
- joints: list of joints with types and parameters
|
||||
- labels: per-joint independence/redundancy flags
|
||||
- assembly_label: overall classification
|
||||
Each example contains body positions, joint descriptions,
|
||||
per-joint independence labels, and assembly-level classification.
|
||||
|
||||
Args:
|
||||
batch_size: Number of assemblies to generate.
|
||||
n_bodies_range: ``(min, max_exclusive)`` body count.
|
||||
Overridden by *complexity_tier* when both are given.
|
||||
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
|
||||
/ ``"complex"``). Overrides *n_bodies_range*.
|
||||
"""
|
||||
if complexity_tier is not None:
|
||||
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
|
||||
elif n_bodies_range is None:
|
||||
n_bodies_range = (3, 8)
|
||||
|
||||
_joint_pool = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.BALL,
|
||||
JointType.CYLINDRICAL,
|
||||
JointType.FIXED,
|
||||
]
|
||||
|
||||
examples: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
n = int(self.rng.integers(*n_bodies_range))
|
||||
gen_idx = int(self.rng.integers(3))
|
||||
gen_idx = int(self.rng.integers(7))
|
||||
|
||||
if gen_idx == 0:
|
||||
_chain_types = [
|
||||
@@ -202,11 +488,30 @@ class SyntheticAssemblyGenerator:
|
||||
]
|
||||
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
|
||||
bodies, joints, analysis = self.generate_chain_assembly(n, jtype)
|
||||
gen_name = "chain"
|
||||
elif gen_idx == 1:
|
||||
bodies, joints, analysis = self.generate_rigid_assembly(n)
|
||||
else:
|
||||
gen_name = "rigid"
|
||||
elif gen_idx == 2:
|
||||
extra = int(self.rng.integers(1, 4))
|
||||
bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra)
|
||||
gen_name = "overconstrained"
|
||||
elif gen_idx == 3:
|
||||
branching = int(self.rng.integers(2, 5))
|
||||
bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching)
|
||||
gen_name = "tree"
|
||||
elif gen_idx == 4:
|
||||
n = max(n, 3)
|
||||
bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool)
|
||||
gen_name = "loop"
|
||||
elif gen_idx == 5:
|
||||
n = max(n, 2)
|
||||
bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool)
|
||||
gen_name = "star"
|
||||
else:
|
||||
density = float(self.rng.uniform(0.2, 0.5))
|
||||
bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density)
|
||||
gen_name = "mixed"
|
||||
|
||||
# Build per-joint labels from edge results
|
||||
joint_labels: dict[int, dict[str, int]] = {}
|
||||
@@ -227,6 +532,7 @@ class SyntheticAssemblyGenerator:
|
||||
examples.append(
|
||||
{
|
||||
"example_id": i,
|
||||
"generator_type": gen_name,
|
||||
"n_bodies": len(bodies),
|
||||
"n_joints": len(joints),
|
||||
"body_positions": [b.position.tolist() for b in bodies],
|
||||
|
||||
@@ -2,11 +2,18 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
|
||||
from solver.datagen.types import JointType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Original generators (chain / rigid / overconstrained)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChainAssembly:
|
||||
"""generate_chain_assembly produces valid underconstrained chains."""
|
||||
@@ -83,20 +90,198 @@ class TestOverconstrainedAssembly:
|
||||
assert len(joints_over) > len(joints_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New topology generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTreeAssembly:
|
||||
"""generate_tree_assembly produces tree-structured assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_tree_assembly(6)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_branching_factor(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
|
||||
assert len(bodies) == 10
|
||||
assert len(joints) == 9
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
# With 9 joints and 3 types, very likely to use at least 2
|
||||
assert len(used) >= 2
|
||||
|
||||
def test_single_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
def test_sequential_body_ids(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(7)
|
||||
assert [b.body_id for b in bodies] == list(range(7))
|
||||
|
||||
|
||||
class TestLoopAssembly:
|
||||
"""generate_loop_assembly produces closed-loop assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_loop_assembly(5)
|
||||
assert len(bodies) == 5
|
||||
assert len(joints) == 5 # n joints for n bodies
|
||||
|
||||
def test_has_redundancy(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_loop_assembly(5)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_wrap_around_connectivity(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(4)
|
||||
edges = {(j.body_a, j.body_b) for j in joints}
|
||||
assert (0, 1) in edges
|
||||
assert (1, 2) in edges
|
||||
assert (2, 3) in edges
|
||||
assert (3, 0) in edges # wrap-around
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
gen.generate_loop_assembly(2)
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
assert len(used) >= 2
|
||||
|
||||
|
||||
class TestStarAssembly:
|
||||
"""generate_star_assembly produces hub-and-spoke assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_star_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_all_joints_connect_to_hub(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(6)
|
||||
for j in joints:
|
||||
assert j.body_a == 0 or j.body_b == 0
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_star_assembly(5)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 2"):
|
||||
gen.generate_star_assembly(1)
|
||||
|
||||
def test_hub_at_origin(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(4)
|
||||
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
|
||||
|
||||
|
||||
class TestMixedAssembly:
|
||||
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
|
||||
|
||||
def test_more_joints_than_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
|
||||
assert len(joints) > len(bodies) - 1
|
||||
|
||||
def test_density_zero_is_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
|
||||
assert len(joints) == 4 # spanning tree only
|
||||
|
||||
def test_density_validation(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=1.5)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=-0.1)
|
||||
|
||||
def test_no_duplicate_edges(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
|
||||
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
|
||||
assert len(edges) == len(set(edges))
|
||||
|
||||
def test_high_density(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
|
||||
# Fully connected: 5*(5-1)/2 = 10 edges
|
||||
assert len(joints) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComplexityTiers:
|
||||
"""Complexity tier parameter on batch generation."""
|
||||
|
||||
def test_simple_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="simple")
|
||||
lo, hi = COMPLEXITY_RANGES["simple"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_medium_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="medium")
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_complex_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(3, complexity_tier="complex")
|
||||
lo, hi = COMPLEXITY_RANGES["complex"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_tier_overrides_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrainingBatch:
|
||||
"""generate_training_batch produces well-structured examples."""
|
||||
|
||||
@pytest.fixture()
|
||||
def batch(self) -> list[dict]:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
return gen.generate_training_batch(batch_size=20, n_bodies_range=(3, 6))
|
||||
|
||||
def test_batch_size(self, batch: list[dict]) -> None:
|
||||
assert len(batch) == 20
|
||||
|
||||
def test_example_keys(self, batch: list[dict]) -> None:
|
||||
expected = {
|
||||
EXPECTED_KEYS: ClassVar[set[str]] = {
|
||||
"example_id",
|
||||
"generator_type",
|
||||
"n_bodies",
|
||||
"n_joints",
|
||||
"body_positions",
|
||||
@@ -108,41 +293,56 @@ class TestTrainingBatch:
|
||||
"internal_dof",
|
||||
"geometric_degeneracies",
|
||||
}
|
||||
for ex in batch:
|
||||
assert set(ex.keys()) == expected
|
||||
|
||||
def test_example_ids_sequential(self, batch: list[dict]) -> None:
|
||||
ids = [ex["example_id"] for ex in batch]
|
||||
assert ids == list(range(20))
|
||||
|
||||
def test_classification_distribution(self, batch: list[dict]) -> None:
|
||||
"""Batch should contain multiple classification types."""
|
||||
classes = {ex["assembly_classification"] for ex in batch}
|
||||
# With the 3-way generator split we expect at least 2 types
|
||||
assert len(classes) >= 2
|
||||
|
||||
def test_body_count_in_range(self, batch: list[dict]) -> None:
|
||||
for ex in batch:
|
||||
assert 3 <= ex["n_bodies"] <= 5 # range is [3, 6)
|
||||
|
||||
def test_joint_labels_match_joints(self, batch: list[dict]) -> None:
|
||||
for ex in batch:
|
||||
label_jids = set(ex["joint_labels"].keys())
|
||||
joint_jids = {j["joint_id"] for j in ex["joints"]}
|
||||
assert label_jids == joint_jids
|
||||
|
||||
def test_joint_label_fields(self, batch: list[dict]) -> None:
|
||||
expected_fields = {
|
||||
"independent_constraints",
|
||||
"redundant_constraints",
|
||||
"total_constraints",
|
||||
VALID_GEN_TYPES: ClassVar[set[str]] = {
|
||||
"chain",
|
||||
"rigid",
|
||||
"overconstrained",
|
||||
"tree",
|
||||
"loop",
|
||||
"star",
|
||||
"mixed",
|
||||
}
|
||||
for ex in batch:
|
||||
for label in ex["joint_labels"].values():
|
||||
assert set(label.keys()) == expected_fields
|
||||
|
||||
def test_joint_label_consistency(self, batch: list[dict]) -> None:
|
||||
def test_batch_size(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20)
|
||||
assert len(batch) == 20
|
||||
|
||||
def test_example_keys(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert set(ex.keys()) == self.EXPECTED_KEYS
|
||||
|
||||
def test_example_ids_sequential(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(15)
|
||||
assert [ex["example_id"] for ex in batch] == list(range(15))
|
||||
|
||||
def test_generator_type_valid(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(50)
|
||||
for ex in batch:
|
||||
assert ex["generator_type"] in self.VALID_GEN_TYPES
|
||||
|
||||
def test_generator_type_diversity(self) -> None:
|
||||
"""100-sample batch should use at least 5 of 7 generator types."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100)
|
||||
types = {ex["generator_type"] for ex in batch}
|
||||
assert len(types) >= 5
|
||||
|
||||
def test_default_body_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
|
||||
|
||||
def test_joint_label_consistency(self) -> None:
|
||||
"""independent + redundant == total for every joint."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
for label in ex["joint_labels"].values():
|
||||
total = label["independent_constraints"] + label["redundant_constraints"]
|
||||
@@ -157,10 +357,17 @@ class TestSeedReproducibility:
|
||||
g2 = SyntheticAssemblyGenerator(seed=2)
|
||||
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
# Very unlikely to be identical with different seeds
|
||||
c1 = [ex["assembly_classification"] for ex in b1]
|
||||
c2 = [ex["assembly_classification"] for ex in b2]
|
||||
r1 = [ex["is_rigid"] for ex in b1]
|
||||
r2 = [ex["is_rigid"] for ex in b2]
|
||||
# At least one of these should differ (probabilistically certain)
|
||||
assert c1 != c2 or r1 != r2
|
||||
|
||||
def test_same_seed_identical(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=123)
|
||||
g2 = SyntheticAssemblyGenerator(seed=123)
|
||||
b1, j1, _ = g1.generate_tree_assembly(5)
|
||||
b2, j2, _ = g2.generate_tree_assembly(5)
|
||||
for a, b in zip(b1, b2, strict=True):
|
||||
np.testing.assert_array_almost_equal(a.position, b.position)
|
||||
assert len(j1) == len(j2)
|
||||
|
||||
Reference in New Issue
Block a user