feat: geometric diversity for synthetic assembly generation
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

- Add AxisStrategy type (cardinal, random, near_parallel)
- Add random body orientations via scipy.spatial.transform.Rotation
- Add parallel axis injection with configurable probability
- Add grounded parameter on all 7 generators (grounded/floating)
- Add axis sampling strategies: cardinal, random, near-parallel
- Update _create_joint with orientation-aware anchor offsets
- Add _resolve_axis helper for parallel axis propagation
- Update generate_training_batch with axis_strategy, parallel_axis_prob,
  grounded_ratio parameters
- Add body_orientations and grounded fields to batch output
- Export AxisStrategy from datagen package
- Add 28 new tests (72 total generator tests, 158 total)

Closes #8
This commit is contained in:
2026-02-02 14:57:49 -06:00
parent 0b5813b5a9
commit 78289494e2
3 changed files with 710 additions and 80 deletions

View File

@@ -1,7 +1,11 @@
"""Data generation utilities for assembly constraint training data."""
from solver.datagen.analysis import analyze_assembly
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
from solver.datagen.generator import (
COMPLEXITY_RANGES,
AxisStrategy,
SyntheticAssemblyGenerator,
)
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
@@ -14,6 +18,7 @@ from solver.datagen.types import (
__all__ = [
"COMPLEXITY_RANGES",
"AxisStrategy",
"ConstraintAnalysis",
"JacobianVerifier",
"Joint",

View File

@@ -10,6 +10,7 @@ from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import numpy as np
from scipy.spatial.transform import Rotation
from solver.datagen.analysis import analyze_assembly
from solver.datagen.types import (
@@ -22,7 +23,12 @@ from solver.datagen.types import (
if TYPE_CHECKING:
from typing import Any
__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"]
__all__ = [
"COMPLEXITY_RANGES",
"AxisStrategy",
"ComplexityTier",
"SyntheticAssemblyGenerator",
]
# ---------------------------------------------------------------------------
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
@@ -36,6 +42,12 @@ COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
"complex": (16, 51),
}
# ---------------------------------------------------------------------------
# Axis sampling strategies
# ---------------------------------------------------------------------------
AxisStrategy = Literal["cardinal", "random", "near_parallel"]
class SyntheticAssemblyGenerator:
"""Generates assembly graphs with known minimal constraint sets.
@@ -67,6 +79,63 @@ class SyntheticAssemblyGenerator:
axis /= np.linalg.norm(axis)
return axis
def _random_orientation(self) -> np.ndarray:
"""Generate a random 3x3 rotation matrix."""
mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix()
return mat
def _cardinal_axis(self) -> np.ndarray:
"""Pick uniformly from the six signed cardinal directions."""
axes = np.array(
[
[1, 0, 0],
[-1, 0, 0],
[0, 1, 0],
[0, -1, 0],
[0, 0, 1],
[0, 0, -1],
],
dtype=float,
)
result: np.ndarray = axes[int(self.rng.integers(6))]
return result
def _near_parallel_axis(
self,
base_axis: np.ndarray,
perturbation_scale: float = 0.05,
) -> np.ndarray:
"""Return *base_axis* with a small random perturbation, re-normalized."""
perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale
return perturbed / np.linalg.norm(perturbed)
def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray:
"""Sample a joint axis using the specified strategy."""
if strategy == "cardinal":
return self._cardinal_axis()
if strategy == "near_parallel":
return self._near_parallel_axis(np.array([0.0, 0.0, 1.0]))
return self._random_axis()
def _resolve_axis(
self,
strategy: AxisStrategy,
parallel_axis_prob: float,
shared_axis: np.ndarray | None,
) -> tuple[np.ndarray, np.ndarray | None]:
"""Return (axis_for_this_joint, shared_axis_to_propagate).
On the first call where *shared_axis* is ``None`` and parallel
injection triggers, a base axis is chosen and returned as
*shared_axis* for subsequent calls.
"""
if shared_axis is not None:
return self._near_parallel_axis(shared_axis), shared_axis
if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob:
base = self._sample_axis(strategy)
return base.copy(), base
return self._sample_axis(strategy), None
def _select_joint_type(
self,
joint_types: JointType | list[JointType],
@@ -85,17 +154,37 @@ class SyntheticAssemblyGenerator:
pos_a: np.ndarray,
pos_b: np.ndarray,
joint_type: JointType,
*,
axis: np.ndarray | None = None,
orient_a: np.ndarray | None = None,
orient_b: np.ndarray | None = None,
) -> Joint:
"""Create a joint between two bodies with random axis at midpoint."""
anchor = (pos_a + pos_b) / 2.0
"""Create a joint between two bodies.
When orientations are provided, anchor points are offset from
each body's center along a random local direction rotated into
world frame, rather than placed at the midpoint.
"""
if orient_a is not None and orient_b is not None:
dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1)
offset_scale = dist * 0.2
local_a = self.rng.standard_normal(3) * offset_scale
local_b = self.rng.standard_normal(3) * offset_scale
anchor_a = pos_a + orient_a @ local_a
anchor_b = pos_b + orient_b @ local_b
else:
anchor = (pos_a + pos_b) / 2.0
anchor_a = anchor
anchor_b = anchor
return Joint(
joint_id=joint_id,
body_a=body_a_id,
body_b=body_b_id,
joint_type=joint_type,
anchor_a=anchor,
anchor_b=anchor,
axis=self._random_axis(),
anchor_a=anchor_a,
anchor_b=anchor_b,
axis=axis if axis is not None else self._random_axis(),
)
# ------------------------------------------------------------------
@@ -106,6 +195,10 @@ class SyntheticAssemblyGenerator:
self,
n_bodies: int,
joint_type: JointType = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a serial kinematic chain.
@@ -118,27 +211,49 @@ class SyntheticAssemblyGenerator:
for i in range(n_bodies):
pos = np.array([i * 2.0, 0.0, 0.0])
bodies.append(RigidBody(body_id=i, position=pos))
for i in range(n_bodies - 1):
anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0])
joints.append(
Joint(
joint_id=i,
body_a=i,
body_b=i + 1,
joint_type=joint_type,
anchor_a=anchor,
anchor_b=anchor,
axis=self._random_axis(),
bodies.append(
RigidBody(
body_id=i,
position=pos,
orientation=self._random_orientation(),
)
)
analysis = analyze_assembly(bodies, joints, ground_body=0)
shared_axis: np.ndarray | None = None
for i in range(n_bodies - 1):
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i,
i,
i + 1,
bodies[i].position,
bodies[i + 1].position,
joint_type,
axis=axis,
orient_a=bodies[i].orientation,
orient_b=bodies[i + 1].orientation,
)
)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_rigid_assembly(
self, n_bodies: int
self,
n_bodies: int,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a minimally rigid assembly by adding joints until rigid.
@@ -148,22 +263,35 @@ class SyntheticAssemblyGenerator:
"""
bodies = []
for i in range(n_bodies):
bodies.append(RigidBody(body_id=i, position=self._random_position()))
bodies.append(
RigidBody(
body_id=i,
position=self._random_position(),
orientation=self._random_orientation(),
)
)
# Build spanning tree with fixed joints (overconstrained)
joints: list[Joint] = []
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
parent = int(self.rng.integers(0, i))
mid = (bodies[i].position + bodies[parent].position) / 2
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
Joint(
joint_id=i - 1,
body_a=parent,
body_b=i,
joint_type=JointType.FIXED,
anchor_a=mid,
anchor_b=mid,
axis=self._random_axis(),
self._create_joint(
i - 1,
parent,
i,
bodies[parent].position,
bodies[i].position,
JointType.FIXED,
axis=axis,
orient_a=bodies[parent].orientation,
orient_b=bodies[i].orientation,
)
)
@@ -174,56 +302,73 @@ class SyntheticAssemblyGenerator:
JointType.BALL,
]
ground = 0 if grounded else None
for idx in self.rng.permutation(len(joints)):
original_type = joints[idx].joint_type
for candidate in weaker_types:
joints[idx].joint_type = candidate
analysis = analyze_assembly(bodies, joints, ground_body=0)
analysis = analyze_assembly(bodies, joints, ground_body=ground)
if analysis.is_rigid:
break # Keep the weaker type
else:
joints[idx].joint_type = original_type
analysis = analyze_assembly(bodies, joints, ground_body=0)
analysis = analyze_assembly(bodies, joints, ground_body=ground)
return bodies, joints, analysis
def generate_overconstrained_assembly(
self,
n_bodies: int,
extra_joints: int = 2,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate an assembly with known redundant constraints.
Starts with a rigid assembly, then adds extra joints that
the pebble game will flag as redundant.
"""
bodies, joints, _ = self.generate_rigid_assembly(n_bodies)
bodies, joints, _ = self.generate_rigid_assembly(
n_bodies,
grounded=grounded,
axis_strategy=axis_strategy,
parallel_axis_prob=parallel_axis_prob,
)
joint_id = len(joints)
shared_axis: np.ndarray | None = None
for _ in range(extra_joints):
a, b = self.rng.choice(n_bodies, size=2, replace=False)
mid = (bodies[a].position + bodies[b].position) / 2
_overcon_types = [
JointType.REVOLUTE,
JointType.FIXED,
JointType.BALL,
]
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
Joint(
joint_id=joint_id,
body_a=int(a),
body_b=int(b),
joint_type=jtype,
anchor_a=mid,
anchor_b=mid,
axis=self._random_axis(),
self._create_joint(
joint_id,
int(a),
int(b),
bodies[int(a)].position,
bodies[int(b)].position,
jtype,
axis=axis,
orient_a=bodies[int(a)].orientation,
orient_b=bodies[int(b)].orientation,
)
)
joint_id += 1
analysis = analyze_assembly(bodies, joints, ground_body=0)
ground = 0 if grounded else None
analysis = analyze_assembly(bodies, joints, ground_body=ground)
return bodies, joints, analysis
# ------------------------------------------------------------------
@@ -235,6 +380,10 @@ class SyntheticAssemblyGenerator:
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
branching_factor: int = 3,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a random tree topology with configurable branching.
@@ -247,12 +396,19 @@ class SyntheticAssemblyGenerator:
joint_types: Single type or list to sample from per joint.
branching_factor: Max children per parent (1-5 recommended).
"""
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
bodies: list[RigidBody] = [
RigidBody(
body_id=0,
position=np.zeros(3),
orientation=self._random_orientation(),
)
]
joints: list[Joint] = []
available_parents = [0]
next_id = 1
joint_id = 0
shared_axis: np.ndarray | None = None
while next_id < n_bodies and available_parents:
pidx = int(self.rng.integers(0, len(available_parents)))
@@ -266,11 +422,33 @@ class SyntheticAssemblyGenerator:
direction = self._random_axis()
distance = self.rng.uniform(1.5, 3.0)
child_pos = parent_pos + direction * distance
child_orient = self._random_orientation()
bodies.append(RigidBody(body_id=next_id, position=child_pos))
bodies.append(
RigidBody(
body_id=next_id,
position=child_pos,
orientation=child_orient,
)
)
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype)
self._create_joint(
joint_id,
parent_id,
next_id,
parent_pos,
child_pos,
jtype,
axis=axis,
orient_a=bodies[parent_id].orientation,
orient_b=child_orient,
)
)
available_parents.append(next_id)
@@ -283,13 +461,21 @@ class SyntheticAssemblyGenerator:
if n_children >= branching_factor or self.rng.random() < 0.3:
available_parents.pop(pidx)
analysis = analyze_assembly(bodies, joints, ground_body=0)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_loop_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a single closed loop (ring) of bodies.
@@ -317,22 +503,52 @@ class SyntheticAssemblyGenerator:
x = radius * np.cos(angle)
y = radius * np.sin(angle)
z = float(self.rng.uniform(-0.2, 0.2))
bodies.append(RigidBody(body_id=i, position=np.array([x, y, z])))
bodies.append(
RigidBody(
body_id=i,
position=np.array([x, y, z]),
orientation=self._random_orientation(),
)
)
shared_axis: np.ndarray | None = None
for i in range(n_bodies):
next_i = (i + 1) % n_bodies
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype)
self._create_joint(
i,
i,
next_i,
bodies[i].position,
bodies[next_i].position,
jtype,
axis=axis,
orient_a=bodies[i].orientation,
orient_b=bodies[next_i].orientation,
)
)
analysis = analyze_assembly(bodies, joints, ground_body=0)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_star_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a star topology with central hub and satellites.
@@ -350,19 +566,49 @@ class SyntheticAssemblyGenerator:
msg = "Star assembly requires at least 2 bodies"
raise ValueError(msg)
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
hub_orient = self._random_orientation()
bodies: list[RigidBody] = [
RigidBody(
body_id=0,
position=np.zeros(3),
orientation=hub_orient,
)
]
joints: list[Joint] = []
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
direction = self._random_axis()
distance = self.rng.uniform(2.0, 5.0)
pos = direction * distance
bodies.append(RigidBody(body_id=i, position=pos))
sat_orient = self._random_orientation()
bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient))
jtype = self._select_joint_type(joint_types)
joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype))
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i - 1,
0,
i,
np.zeros(3),
pos,
jtype,
axis=axis,
orient_a=hub_orient,
orient_b=sat_orient,
)
)
analysis = analyze_assembly(bodies, joints, ground_body=0)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_mixed_assembly(
@@ -370,6 +616,10 @@ class SyntheticAssemblyGenerator:
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
edge_density: float = 0.3,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a mixed topology combining tree and loop elements.
@@ -392,14 +642,26 @@ class SyntheticAssemblyGenerator:
joints: list[Joint] = []
for i in range(n_bodies):
bodies.append(RigidBody(body_id=i, position=self._random_position()))
bodies.append(
RigidBody(
body_id=i,
position=self._random_position(),
orientation=self._random_orientation(),
)
)
# Phase 1: spanning tree
joint_id = 0
existing_edges: set[frozenset[int]] = set()
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
parent = int(self.rng.integers(0, i))
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
@@ -408,6 +670,9 @@ class SyntheticAssemblyGenerator:
bodies[parent].position,
bodies[i].position,
jtype,
axis=axis,
orient_a=bodies[parent].orientation,
orient_b=bodies[i].orientation,
)
)
existing_edges.add(frozenset([parent, i]))
@@ -425,6 +690,11 @@ class SyntheticAssemblyGenerator:
for a, b in candidates[:n_extra]:
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
@@ -433,11 +703,18 @@ class SyntheticAssemblyGenerator:
bodies[a].position,
bodies[b].position,
jtype,
axis=axis,
orient_a=bodies[a].orientation,
orient_b=bodies[b].orientation,
)
)
joint_id += 1
analysis = analyze_assembly(bodies, joints, ground_body=0)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
# ------------------------------------------------------------------
@@ -449,6 +726,10 @@ class SyntheticAssemblyGenerator:
batch_size: int = 100,
n_bodies_range: tuple[int, int] | None = None,
complexity_tier: ComplexityTier | None = None,
*,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
grounded_ratio: float = 1.0,
) -> list[dict[str, Any]]:
"""Generate a batch of labeled training examples.
@@ -461,6 +742,9 @@ class SyntheticAssemblyGenerator:
Overridden by *complexity_tier* when both are given.
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
/ ``"complex"``). Overrides *n_bodies_range*.
axis_strategy: Axis sampling strategy for joint axes.
parallel_axis_prob: Probability of parallel axis injection.
grounded_ratio: Fraction of examples that are grounded.
"""
if complexity_tier is not None:
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
@@ -474,11 +758,17 @@ class SyntheticAssemblyGenerator:
JointType.FIXED,
]
geo_kw: dict[str, Any] = {
"axis_strategy": axis_strategy,
"parallel_axis_prob": parallel_axis_prob,
}
examples: list[dict[str, Any]] = []
for i in range(batch_size):
n = int(self.rng.integers(*n_bodies_range))
gen_idx = int(self.rng.integers(7))
grounded = bool(self.rng.random() < grounded_ratio)
if gen_idx == 0:
_chain_types = [
@@ -487,30 +777,66 @@ class SyntheticAssemblyGenerator:
JointType.CYLINDRICAL,
]
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
bodies, joints, analysis = self.generate_chain_assembly(n, jtype)
bodies, joints, analysis = self.generate_chain_assembly(
n,
jtype,
grounded=grounded,
**geo_kw,
)
gen_name = "chain"
elif gen_idx == 1:
bodies, joints, analysis = self.generate_rigid_assembly(n)
bodies, joints, analysis = self.generate_rigid_assembly(
n,
grounded=grounded,
**geo_kw,
)
gen_name = "rigid"
elif gen_idx == 2:
extra = int(self.rng.integers(1, 4))
bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra)
bodies, joints, analysis = self.generate_overconstrained_assembly(
n,
extra,
grounded=grounded,
**geo_kw,
)
gen_name = "overconstrained"
elif gen_idx == 3:
branching = int(self.rng.integers(2, 5))
bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching)
bodies, joints, analysis = self.generate_tree_assembly(
n,
_joint_pool,
branching,
grounded=grounded,
**geo_kw,
)
gen_name = "tree"
elif gen_idx == 4:
n = max(n, 3)
bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool)
bodies, joints, analysis = self.generate_loop_assembly(
n,
_joint_pool,
grounded=grounded,
**geo_kw,
)
gen_name = "loop"
elif gen_idx == 5:
n = max(n, 2)
bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool)
bodies, joints, analysis = self.generate_star_assembly(
n,
_joint_pool,
grounded=grounded,
**geo_kw,
)
gen_name = "star"
else:
density = float(self.rng.uniform(0.2, 0.5))
bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density)
bodies, joints, analysis = self.generate_mixed_assembly(
n,
_joint_pool,
density,
grounded=grounded,
**geo_kw,
)
gen_name = "mixed"
# Build per-joint labels from edge results
@@ -533,9 +859,11 @@ class SyntheticAssemblyGenerator:
{
"example_id": i,
"generator_type": gen_name,
"grounded": grounded,
"n_bodies": len(bodies),
"n_joints": len(joints),
"body_positions": [b.position.tolist() for b in bodies],
"body_orientations": [b.orientation.tolist() for b in bodies],
"joints": [
{
"joint_id": j.joint_id,
@@ -547,11 +875,11 @@ class SyntheticAssemblyGenerator:
for j in joints
],
"joint_labels": joint_labels,
"assembly_classification": analysis.combinatorial_classification,
"assembly_classification": (analysis.combinatorial_classification),
"is_rigid": analysis.is_rigid,
"is_minimally_rigid": analysis.is_minimally_rigid,
"internal_dof": analysis.jacobian_internal_dof,
"geometric_degeneracies": analysis.geometric_degeneracies,
"geometric_degeneracies": (analysis.geometric_degeneracies),
}
)

View File

@@ -44,7 +44,10 @@ class TestChainAssembly:
def test_chain_custom_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, joints, _ = gen.generate_chain_assembly(3, joint_type=JointType.BALL)
_, joints, _ = gen.generate_chain_assembly(
3,
joint_type=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
@@ -77,7 +80,10 @@ class TestOverconstrainedAssembly:
def test_has_redundant(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_overconstrained_assembly(4, extra_joints=2)
_, _, analysis = gen.generate_overconstrained_assembly(
4,
extra_joints=2,
)
assert analysis.combinatorial_redundant > 0
def test_extra_joints_added(self) -> None:
@@ -85,7 +91,10 @@ class TestOverconstrainedAssembly:
_, joints_base, _ = gen.generate_rigid_assembly(4)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, joints_over, _ = gen2.generate_overconstrained_assembly(4, extra_joints=3)
_, joints_over, _ = gen2.generate_overconstrained_assembly(
4,
extra_joints=3,
)
# Overconstrained has base joints + extra
assert len(joints_over) > len(joints_base)
@@ -111,7 +120,10 @@ class TestTreeAssembly:
def test_branching_factor(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
bodies, joints, _ = gen.generate_tree_assembly(
10,
branching_factor=2,
)
assert len(bodies) == 10
assert len(joints) == 9
@@ -125,7 +137,10 @@ class TestTreeAssembly:
def test_single_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
_, joints, _ = gen.generate_tree_assembly(
5,
joint_types=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
def test_sequential_body_ids(self) -> None:
@@ -206,12 +221,18 @@ class TestMixedAssembly:
def test_more_joints_than_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
bodies, joints, _ = gen.generate_mixed_assembly(
8,
edge_density=0.3,
)
assert len(joints) > len(bodies) - 1
def test_density_zero_is_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=0.0,
)
assert len(joints) == 4 # spanning tree only
def test_density_validation(self) -> None:
@@ -229,11 +250,210 @@ class TestMixedAssembly:
def test_high_density(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=1.0,
)
# Fully connected: 5*(5-1)/2 = 10 edges
assert len(joints) == 10
# ---------------------------------------------------------------------------
# Axis sampling strategies
# ---------------------------------------------------------------------------
class TestAxisStrategy:
"""Axis sampling strategies produce valid unit vectors."""
def test_cardinal_axis_from_six(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
expected = {
(1, 0, 0),
(-1, 0, 0),
(0, 1, 0),
(0, -1, 0),
(0, 0, 1),
(0, 0, -1),
}
assert axes == expected
def test_random_axis_unit_norm(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
for _ in range(50):
axis = gen._sample_axis("random")
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
def test_near_parallel_close_to_base(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
base = np.array([0.0, 0.0, 1.0])
for _ in range(50):
axis = gen._near_parallel_axis(base)
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
assert np.dot(axis, base) > 0.95
def test_sample_axis_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("cardinal")
cardinals = [
np.array(v, dtype=float)
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
]
assert any(np.allclose(axis, c) for c in cardinals)
def test_sample_axis_near_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("near_parallel")
z = np.array([0.0, 0.0, 1.0])
assert np.dot(axis, z) > 0.95
# ---------------------------------------------------------------------------
# Geometric diversity: orientations
# ---------------------------------------------------------------------------
class TestRandomOrientations:
"""Bodies should have non-identity orientations."""
def test_bodies_have_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(5)
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
assert non_identity >= 3
def test_orientations_are_valid_rotations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(6)
for b in bodies:
r = b.orientation
# R^T R == I
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
# det(R) == 1
assert abs(np.linalg.det(r) - 1.0) < 1e-10
def test_all_generators_set_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
# Chain
bodies, _, _ = gen.generate_chain_assembly(3)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Loop
bodies, _, _ = gen.generate_loop_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Mixed
bodies, _, _ = gen.generate_mixed_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# ---------------------------------------------------------------------------
# Geometric diversity: grounded parameter
# ---------------------------------------------------------------------------
class TestGroundedParameter:
"""Grounded parameter controls ground_body in analysis."""
def test_chain_grounded_default(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(4)
assert analysis.combinatorial_dof >= 0
def test_chain_floating(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(
4,
grounded=False,
)
# Floating: 6 trivial DOF not subtracted by ground
assert analysis.combinatorial_dof >= 6
def test_floating_vs_grounded_dof_difference(self) -> None:
gen1 = SyntheticAssemblyGenerator(seed=42)
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
# Floating should have higher DOF due to missing ground constraint
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
@pytest.mark.parametrize(
"gen_method",
[
"generate_chain_assembly",
"generate_rigid_assembly",
"generate_tree_assembly",
"generate_loop_assembly",
"generate_star_assembly",
"generate_mixed_assembly",
],
)
def test_all_generators_accept_grounded(
self,
gen_method: str,
) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
method = getattr(gen, gen_method)
n = 4
# Should not raise
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
method(n, grounded=False)
else:
method(n, grounded=False)
# ---------------------------------------------------------------------------
# Geometric diversity: parallel axis injection
# ---------------------------------------------------------------------------
class TestParallelAxisInjection:
"""parallel_axis_prob causes shared axis direction."""
def test_parallel_axes_similar(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
# Near-parallel: |dot| close to 1
assert abs(np.dot(j.axis, base)) > 0.9
def test_zero_prob_no_forced_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=0.0,
)
base = joints[0].axis
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
# With 5 random axes, extremely unlikely all are parallel
assert min(dots) < 0.95
def test_parallel_on_loop(self) -> None:
"""Parallel axes on a loop assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
def test_parallel_on_star(self) -> None:
"""Parallel axes on a star assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
# ---------------------------------------------------------------------------
# Complexity tiers
# ---------------------------------------------------------------------------
@@ -265,7 +485,11 @@ class TestComplexityTiers:
def test_tier_overrides_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
batch = gen.generate_training_batch(
10,
n_bodies_range=(2, 3),
complexity_tier="medium",
)
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
@@ -282,9 +506,11 @@ class TestTrainingBatch:
EXPECTED_KEYS: ClassVar[set[str]] = {
"example_id",
"generator_type",
"grounded",
"n_bodies",
"n_joints",
"body_positions",
"body_orientations",
"joints",
"joint_labels",
"assembly_classification",
@@ -337,7 +563,8 @@ class TestTrainingBatch:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
# default (3, 8), but loop/star may clamp
assert 2 <= ex["n_bodies"] <= 7
def test_joint_label_consistency(self) -> None:
"""independent + redundant == total for every joint."""
@@ -348,6 +575,70 @@ class TestTrainingBatch:
total = label["independent_constraints"] + label["redundant_constraints"]
assert total == label["total_constraints"]
def test_body_orientations_present(self) -> None:
"""Each example includes body_orientations as 3x3 lists."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
orients = ex["body_orientations"]
assert len(orients) == ex["n_bodies"]
for o in orients:
assert len(o) == 3
assert len(o[0]) == 3
def test_grounded_field_present(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert isinstance(ex["grounded"], bool)
# ---------------------------------------------------------------------------
# Batch grounded ratio
# ---------------------------------------------------------------------------
class TestBatchGroundedRatio:
"""grounded_ratio controls the mix in batch generation."""
def test_all_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
assert all(ex["grounded"] for ex in batch)
def test_none_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
assert not any(ex["grounded"] for ex in batch)
def test_mixed_ratio(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
grounded_count = sum(1 for ex in batch if ex["grounded"])
# With 100 samples and p=0.5, should be roughly 50 +/- 20
assert 20 < grounded_count < 80
def test_batch_axis_strategy_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
axis_strategy="cardinal",
)
assert len(batch) == 10
def test_batch_parallel_axis_prob(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
parallel_axis_prob=0.5,
)
assert len(batch) == 10
# ---------------------------------------------------------------------------
# Seed reproducibility
# ---------------------------------------------------------------------------
class TestSeedReproducibility:
"""Different seeds produce different results."""
@@ -355,8 +646,14 @@ class TestSeedReproducibility:
def test_different_seeds_differ(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=1)
g2 = SyntheticAssemblyGenerator(seed=2)
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
b1 = g1.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
b2 = g2.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
c1 = [ex["assembly_classification"] for ex in b1]
c2 = [ex["assembly_classification"] for ex in b2]
r1 = [ex["is_rigid"] for ex in b1]