feat: geometric diversity for synthetic assembly generation
- Add AxisStrategy type (cardinal, random, near_parallel) - Add random body orientations via scipy.spatial.transform.Rotation - Add parallel axis injection with configurable probability - Add grounded parameter on all 7 generators (grounded/floating) - Add axis sampling strategies: cardinal, random, near-parallel - Update _create_joint with orientation-aware anchor offsets - Add _resolve_axis helper for parallel axis propagation - Update generate_training_batch with axis_strategy, parallel_axis_prob, grounded_ratio parameters - Add body_orientations and grounded fields to batch output - Export AxisStrategy from datagen package - Add 28 new tests (72 total generator tests, 158 total) Closes #8
This commit is contained in:
@@ -1,7 +1,11 @@
|
||||
"""Data generation utilities for assembly constraint training data."""
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
|
||||
from solver.datagen.generator import (
|
||||
COMPLEXITY_RANGES,
|
||||
AxisStrategy,
|
||||
SyntheticAssemblyGenerator,
|
||||
)
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
@@ -14,6 +18,7 @@ from solver.datagen.types import (
|
||||
|
||||
__all__ = [
|
||||
"COMPLEXITY_RANGES",
|
||||
"AxisStrategy",
|
||||
"ConstraintAnalysis",
|
||||
"JacobianVerifier",
|
||||
"Joint",
|
||||
|
||||
@@ -10,6 +10,7 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.types import (
|
||||
@@ -22,7 +23,12 @@ from solver.datagen.types import (
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"]
|
||||
__all__ = [
|
||||
"COMPLEXITY_RANGES",
|
||||
"AxisStrategy",
|
||||
"ComplexityTier",
|
||||
"SyntheticAssemblyGenerator",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
|
||||
@@ -36,6 +42,12 @@ COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
|
||||
"complex": (16, 51),
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Axis sampling strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AxisStrategy = Literal["cardinal", "random", "near_parallel"]
|
||||
|
||||
|
||||
class SyntheticAssemblyGenerator:
|
||||
"""Generates assembly graphs with known minimal constraint sets.
|
||||
@@ -67,6 +79,63 @@ class SyntheticAssemblyGenerator:
|
||||
axis /= np.linalg.norm(axis)
|
||||
return axis
|
||||
|
||||
def _random_orientation(self) -> np.ndarray:
|
||||
"""Generate a random 3x3 rotation matrix."""
|
||||
mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix()
|
||||
return mat
|
||||
|
||||
def _cardinal_axis(self) -> np.ndarray:
|
||||
"""Pick uniformly from the six signed cardinal directions."""
|
||||
axes = np.array(
|
||||
[
|
||||
[1, 0, 0],
|
||||
[-1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, -1, 0],
|
||||
[0, 0, 1],
|
||||
[0, 0, -1],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
result: np.ndarray = axes[int(self.rng.integers(6))]
|
||||
return result
|
||||
|
||||
def _near_parallel_axis(
|
||||
self,
|
||||
base_axis: np.ndarray,
|
||||
perturbation_scale: float = 0.05,
|
||||
) -> np.ndarray:
|
||||
"""Return *base_axis* with a small random perturbation, re-normalized."""
|
||||
perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale
|
||||
return perturbed / np.linalg.norm(perturbed)
|
||||
|
||||
def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray:
|
||||
"""Sample a joint axis using the specified strategy."""
|
||||
if strategy == "cardinal":
|
||||
return self._cardinal_axis()
|
||||
if strategy == "near_parallel":
|
||||
return self._near_parallel_axis(np.array([0.0, 0.0, 1.0]))
|
||||
return self._random_axis()
|
||||
|
||||
def _resolve_axis(
|
||||
self,
|
||||
strategy: AxisStrategy,
|
||||
parallel_axis_prob: float,
|
||||
shared_axis: np.ndarray | None,
|
||||
) -> tuple[np.ndarray, np.ndarray | None]:
|
||||
"""Return (axis_for_this_joint, shared_axis_to_propagate).
|
||||
|
||||
On the first call where *shared_axis* is ``None`` and parallel
|
||||
injection triggers, a base axis is chosen and returned as
|
||||
*shared_axis* for subsequent calls.
|
||||
"""
|
||||
if shared_axis is not None:
|
||||
return self._near_parallel_axis(shared_axis), shared_axis
|
||||
if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob:
|
||||
base = self._sample_axis(strategy)
|
||||
return base.copy(), base
|
||||
return self._sample_axis(strategy), None
|
||||
|
||||
def _select_joint_type(
|
||||
self,
|
||||
joint_types: JointType | list[JointType],
|
||||
@@ -85,17 +154,37 @@ class SyntheticAssemblyGenerator:
|
||||
pos_a: np.ndarray,
|
||||
pos_b: np.ndarray,
|
||||
joint_type: JointType,
|
||||
*,
|
||||
axis: np.ndarray | None = None,
|
||||
orient_a: np.ndarray | None = None,
|
||||
orient_b: np.ndarray | None = None,
|
||||
) -> Joint:
|
||||
"""Create a joint between two bodies with random axis at midpoint."""
|
||||
anchor = (pos_a + pos_b) / 2.0
|
||||
"""Create a joint between two bodies.
|
||||
|
||||
When orientations are provided, anchor points are offset from
|
||||
each body's center along a random local direction rotated into
|
||||
world frame, rather than placed at the midpoint.
|
||||
"""
|
||||
if orient_a is not None and orient_b is not None:
|
||||
dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1)
|
||||
offset_scale = dist * 0.2
|
||||
local_a = self.rng.standard_normal(3) * offset_scale
|
||||
local_b = self.rng.standard_normal(3) * offset_scale
|
||||
anchor_a = pos_a + orient_a @ local_a
|
||||
anchor_b = pos_b + orient_b @ local_b
|
||||
else:
|
||||
anchor = (pos_a + pos_b) / 2.0
|
||||
anchor_a = anchor
|
||||
anchor_b = anchor
|
||||
|
||||
return Joint(
|
||||
joint_id=joint_id,
|
||||
body_a=body_a_id,
|
||||
body_b=body_b_id,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor,
|
||||
anchor_b=anchor,
|
||||
axis=self._random_axis(),
|
||||
anchor_a=anchor_a,
|
||||
anchor_b=anchor_b,
|
||||
axis=axis if axis is not None else self._random_axis(),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -106,6 +195,10 @@ class SyntheticAssemblyGenerator:
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_type: JointType = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a serial kinematic chain.
|
||||
|
||||
@@ -118,27 +211,49 @@ class SyntheticAssemblyGenerator:
|
||||
|
||||
for i in range(n_bodies):
|
||||
pos = np.array([i * 2.0, 0.0, 0.0])
|
||||
bodies.append(RigidBody(body_id=i, position=pos))
|
||||
|
||||
for i in range(n_bodies - 1):
|
||||
anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0])
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=i,
|
||||
body_a=i,
|
||||
body_b=i + 1,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor,
|
||||
anchor_b=anchor,
|
||||
axis=self._random_axis(),
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=pos,
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(n_bodies - 1):
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i,
|
||||
i,
|
||||
i + 1,
|
||||
bodies[i].position,
|
||||
bodies[i + 1].position,
|
||||
joint_type,
|
||||
axis=axis,
|
||||
orient_a=bodies[i].orientation,
|
||||
orient_b=bodies[i + 1].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_rigid_assembly(
|
||||
self, n_bodies: int
|
||||
self,
|
||||
n_bodies: int,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a minimally rigid assembly by adding joints until rigid.
|
||||
|
||||
@@ -148,22 +263,35 @@ class SyntheticAssemblyGenerator:
|
||||
"""
|
||||
bodies = []
|
||||
for i in range(n_bodies):
|
||||
bodies.append(RigidBody(body_id=i, position=self._random_position()))
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=self._random_position(),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
# Build spanning tree with fixed joints (overconstrained)
|
||||
joints: list[Joint] = []
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
parent = int(self.rng.integers(0, i))
|
||||
mid = (bodies[i].position + bodies[parent].position) / 2
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=i - 1,
|
||||
body_a=parent,
|
||||
body_b=i,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=mid,
|
||||
anchor_b=mid,
|
||||
axis=self._random_axis(),
|
||||
self._create_joint(
|
||||
i - 1,
|
||||
parent,
|
||||
i,
|
||||
bodies[parent].position,
|
||||
bodies[i].position,
|
||||
JointType.FIXED,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent].orientation,
|
||||
orient_b=bodies[i].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -174,56 +302,73 @@ class SyntheticAssemblyGenerator:
|
||||
JointType.BALL,
|
||||
]
|
||||
|
||||
ground = 0 if grounded else None
|
||||
for idx in self.rng.permutation(len(joints)):
|
||||
original_type = joints[idx].joint_type
|
||||
for candidate in weaker_types:
|
||||
joints[idx].joint_type = candidate
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
if analysis.is_rigid:
|
||||
break # Keep the weaker type
|
||||
else:
|
||||
joints[idx].joint_type = original_type
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_overconstrained_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
extra_joints: int = 2,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate an assembly with known redundant constraints.
|
||||
|
||||
Starts with a rigid assembly, then adds extra joints that
|
||||
the pebble game will flag as redundant.
|
||||
"""
|
||||
bodies, joints, _ = self.generate_rigid_assembly(n_bodies)
|
||||
bodies, joints, _ = self.generate_rigid_assembly(
|
||||
n_bodies,
|
||||
grounded=grounded,
|
||||
axis_strategy=axis_strategy,
|
||||
parallel_axis_prob=parallel_axis_prob,
|
||||
)
|
||||
|
||||
joint_id = len(joints)
|
||||
shared_axis: np.ndarray | None = None
|
||||
for _ in range(extra_joints):
|
||||
a, b = self.rng.choice(n_bodies, size=2, replace=False)
|
||||
mid = (bodies[a].position + bodies[b].position) / 2
|
||||
|
||||
_overcon_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.FIXED,
|
||||
JointType.BALL,
|
||||
]
|
||||
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=joint_id,
|
||||
body_a=int(a),
|
||||
body_b=int(b),
|
||||
joint_type=jtype,
|
||||
anchor_a=mid,
|
||||
anchor_b=mid,
|
||||
axis=self._random_axis(),
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
int(a),
|
||||
int(b),
|
||||
bodies[int(a)].position,
|
||||
bodies[int(b)].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[int(a)].orientation,
|
||||
orient_b=bodies[int(b)].orientation,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
ground = 0 if grounded else None
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -235,6 +380,10 @@ class SyntheticAssemblyGenerator:
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
branching_factor: int = 3,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a random tree topology with configurable branching.
|
||||
|
||||
@@ -247,12 +396,19 @@ class SyntheticAssemblyGenerator:
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
branching_factor: Max children per parent (1-5 recommended).
|
||||
"""
|
||||
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
|
||||
bodies: list[RigidBody] = [
|
||||
RigidBody(
|
||||
body_id=0,
|
||||
position=np.zeros(3),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
]
|
||||
joints: list[Joint] = []
|
||||
|
||||
available_parents = [0]
|
||||
next_id = 1
|
||||
joint_id = 0
|
||||
shared_axis: np.ndarray | None = None
|
||||
|
||||
while next_id < n_bodies and available_parents:
|
||||
pidx = int(self.rng.integers(0, len(available_parents)))
|
||||
@@ -266,11 +422,33 @@ class SyntheticAssemblyGenerator:
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(1.5, 3.0)
|
||||
child_pos = parent_pos + direction * distance
|
||||
child_orient = self._random_orientation()
|
||||
|
||||
bodies.append(RigidBody(body_id=next_id, position=child_pos))
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=next_id,
|
||||
position=child_pos,
|
||||
orientation=child_orient,
|
||||
)
|
||||
)
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype)
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
parent_id,
|
||||
next_id,
|
||||
parent_pos,
|
||||
child_pos,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent_id].orientation,
|
||||
orient_b=child_orient,
|
||||
)
|
||||
)
|
||||
|
||||
available_parents.append(next_id)
|
||||
@@ -283,13 +461,21 @@ class SyntheticAssemblyGenerator:
|
||||
if n_children >= branching_factor or self.rng.random() < 0.3:
|
||||
available_parents.pop(pidx)
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_loop_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a single closed loop (ring) of bodies.
|
||||
|
||||
@@ -317,22 +503,52 @@ class SyntheticAssemblyGenerator:
|
||||
x = radius * np.cos(angle)
|
||||
y = radius * np.sin(angle)
|
||||
z = float(self.rng.uniform(-0.2, 0.2))
|
||||
bodies.append(RigidBody(body_id=i, position=np.array([x, y, z])))
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=np.array([x, y, z]),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(n_bodies):
|
||||
next_i = (i + 1) % n_bodies
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype)
|
||||
self._create_joint(
|
||||
i,
|
||||
i,
|
||||
next_i,
|
||||
bodies[i].position,
|
||||
bodies[next_i].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[i].orientation,
|
||||
orient_b=bodies[next_i].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_star_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a star topology with central hub and satellites.
|
||||
|
||||
@@ -350,19 +566,49 @@ class SyntheticAssemblyGenerator:
|
||||
msg = "Star assembly requires at least 2 bodies"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))]
|
||||
hub_orient = self._random_orientation()
|
||||
bodies: list[RigidBody] = [
|
||||
RigidBody(
|
||||
body_id=0,
|
||||
position=np.zeros(3),
|
||||
orientation=hub_orient,
|
||||
)
|
||||
]
|
||||
joints: list[Joint] = []
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(2.0, 5.0)
|
||||
pos = direction * distance
|
||||
bodies.append(RigidBody(body_id=i, position=pos))
|
||||
sat_orient = self._random_orientation()
|
||||
bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient))
|
||||
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype))
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i - 1,
|
||||
0,
|
||||
i,
|
||||
np.zeros(3),
|
||||
pos,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=hub_orient,
|
||||
orient_b=sat_orient,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_mixed_assembly(
|
||||
@@ -370,6 +616,10 @@ class SyntheticAssemblyGenerator:
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
edge_density: float = 0.3,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a mixed topology combining tree and loop elements.
|
||||
|
||||
@@ -392,14 +642,26 @@ class SyntheticAssemblyGenerator:
|
||||
joints: list[Joint] = []
|
||||
|
||||
for i in range(n_bodies):
|
||||
bodies.append(RigidBody(body_id=i, position=self._random_position()))
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=self._random_position(),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
# Phase 1: spanning tree
|
||||
joint_id = 0
|
||||
existing_edges: set[frozenset[int]] = set()
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
parent = int(self.rng.integers(0, i))
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
@@ -408,6 +670,9 @@ class SyntheticAssemblyGenerator:
|
||||
bodies[parent].position,
|
||||
bodies[i].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent].orientation,
|
||||
orient_b=bodies[i].orientation,
|
||||
)
|
||||
)
|
||||
existing_edges.add(frozenset([parent, i]))
|
||||
@@ -425,6 +690,11 @@ class SyntheticAssemblyGenerator:
|
||||
|
||||
for a, b in candidates[:n_extra]:
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
@@ -433,11 +703,18 @@ class SyntheticAssemblyGenerator:
|
||||
bodies[a].position,
|
||||
bodies[b].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[a].orientation,
|
||||
orient_b=bodies[b].orientation,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
@@ -449,6 +726,10 @@ class SyntheticAssemblyGenerator:
|
||||
batch_size: int = 100,
|
||||
n_bodies_range: tuple[int, int] | None = None,
|
||||
complexity_tier: ComplexityTier | None = None,
|
||||
*,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
grounded_ratio: float = 1.0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate a batch of labeled training examples.
|
||||
|
||||
@@ -461,6 +742,9 @@ class SyntheticAssemblyGenerator:
|
||||
Overridden by *complexity_tier* when both are given.
|
||||
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
|
||||
/ ``"complex"``). Overrides *n_bodies_range*.
|
||||
axis_strategy: Axis sampling strategy for joint axes.
|
||||
parallel_axis_prob: Probability of parallel axis injection.
|
||||
grounded_ratio: Fraction of examples that are grounded.
|
||||
"""
|
||||
if complexity_tier is not None:
|
||||
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
|
||||
@@ -474,11 +758,17 @@ class SyntheticAssemblyGenerator:
|
||||
JointType.FIXED,
|
||||
]
|
||||
|
||||
geo_kw: dict[str, Any] = {
|
||||
"axis_strategy": axis_strategy,
|
||||
"parallel_axis_prob": parallel_axis_prob,
|
||||
}
|
||||
|
||||
examples: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
n = int(self.rng.integers(*n_bodies_range))
|
||||
gen_idx = int(self.rng.integers(7))
|
||||
grounded = bool(self.rng.random() < grounded_ratio)
|
||||
|
||||
if gen_idx == 0:
|
||||
_chain_types = [
|
||||
@@ -487,30 +777,66 @@ class SyntheticAssemblyGenerator:
|
||||
JointType.CYLINDRICAL,
|
||||
]
|
||||
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
|
||||
bodies, joints, analysis = self.generate_chain_assembly(n, jtype)
|
||||
bodies, joints, analysis = self.generate_chain_assembly(
|
||||
n,
|
||||
jtype,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "chain"
|
||||
elif gen_idx == 1:
|
||||
bodies, joints, analysis = self.generate_rigid_assembly(n)
|
||||
bodies, joints, analysis = self.generate_rigid_assembly(
|
||||
n,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "rigid"
|
||||
elif gen_idx == 2:
|
||||
extra = int(self.rng.integers(1, 4))
|
||||
bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra)
|
||||
bodies, joints, analysis = self.generate_overconstrained_assembly(
|
||||
n,
|
||||
extra,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "overconstrained"
|
||||
elif gen_idx == 3:
|
||||
branching = int(self.rng.integers(2, 5))
|
||||
bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching)
|
||||
bodies, joints, analysis = self.generate_tree_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
branching,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "tree"
|
||||
elif gen_idx == 4:
|
||||
n = max(n, 3)
|
||||
bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool)
|
||||
bodies, joints, analysis = self.generate_loop_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "loop"
|
||||
elif gen_idx == 5:
|
||||
n = max(n, 2)
|
||||
bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool)
|
||||
bodies, joints, analysis = self.generate_star_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "star"
|
||||
else:
|
||||
density = float(self.rng.uniform(0.2, 0.5))
|
||||
bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density)
|
||||
bodies, joints, analysis = self.generate_mixed_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
density,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "mixed"
|
||||
|
||||
# Build per-joint labels from edge results
|
||||
@@ -533,9 +859,11 @@ class SyntheticAssemblyGenerator:
|
||||
{
|
||||
"example_id": i,
|
||||
"generator_type": gen_name,
|
||||
"grounded": grounded,
|
||||
"n_bodies": len(bodies),
|
||||
"n_joints": len(joints),
|
||||
"body_positions": [b.position.tolist() for b in bodies],
|
||||
"body_orientations": [b.orientation.tolist() for b in bodies],
|
||||
"joints": [
|
||||
{
|
||||
"joint_id": j.joint_id,
|
||||
@@ -547,11 +875,11 @@ class SyntheticAssemblyGenerator:
|
||||
for j in joints
|
||||
],
|
||||
"joint_labels": joint_labels,
|
||||
"assembly_classification": analysis.combinatorial_classification,
|
||||
"assembly_classification": (analysis.combinatorial_classification),
|
||||
"is_rigid": analysis.is_rigid,
|
||||
"is_minimally_rigid": analysis.is_minimally_rigid,
|
||||
"internal_dof": analysis.jacobian_internal_dof,
|
||||
"geometric_degeneracies": analysis.geometric_degeneracies,
|
||||
"geometric_degeneracies": (analysis.geometric_degeneracies),
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -44,7 +44,10 @@ class TestChainAssembly:
|
||||
|
||||
def test_chain_custom_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, joints, _ = gen.generate_chain_assembly(3, joint_type=JointType.BALL)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
3,
|
||||
joint_type=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
|
||||
@@ -77,7 +80,10 @@ class TestOverconstrainedAssembly:
|
||||
|
||||
def test_has_redundant(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_overconstrained_assembly(4, extra_joints=2)
|
||||
_, _, analysis = gen.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=2,
|
||||
)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_extra_joints_added(self) -> None:
|
||||
@@ -85,7 +91,10 @@ class TestOverconstrainedAssembly:
|
||||
_, joints_base, _ = gen.generate_rigid_assembly(4)
|
||||
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints_over, _ = gen2.generate_overconstrained_assembly(4, extra_joints=3)
|
||||
_, joints_over, _ = gen2.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=3,
|
||||
)
|
||||
# Overconstrained has base joints + extra
|
||||
assert len(joints_over) > len(joints_base)
|
||||
|
||||
@@ -111,7 +120,10 @@ class TestTreeAssembly:
|
||||
|
||||
def test_branching_factor(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(
|
||||
10,
|
||||
branching_factor=2,
|
||||
)
|
||||
assert len(bodies) == 10
|
||||
assert len(joints) == 9
|
||||
|
||||
@@ -125,7 +137,10 @@ class TestTreeAssembly:
|
||||
|
||||
def test_single_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
|
||||
_, joints, _ = gen.generate_tree_assembly(
|
||||
5,
|
||||
joint_types=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
def test_sequential_body_ids(self) -> None:
|
||||
@@ -206,12 +221,18 @@ class TestMixedAssembly:
|
||||
|
||||
def test_more_joints_than_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
8,
|
||||
edge_density=0.3,
|
||||
)
|
||||
assert len(joints) > len(bodies) - 1
|
||||
|
||||
def test_density_zero_is_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=0.0,
|
||||
)
|
||||
assert len(joints) == 4 # spanning tree only
|
||||
|
||||
def test_density_validation(self) -> None:
|
||||
@@ -229,11 +250,210 @@ class TestMixedAssembly:
|
||||
|
||||
def test_high_density(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=1.0,
|
||||
)
|
||||
# Fully connected: 5*(5-1)/2 = 10 edges
|
||||
assert len(joints) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Axis sampling strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAxisStrategy:
|
||||
"""Axis sampling strategies produce valid unit vectors."""
|
||||
|
||||
def test_cardinal_axis_from_six(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
|
||||
expected = {
|
||||
(1, 0, 0),
|
||||
(-1, 0, 0),
|
||||
(0, 1, 0),
|
||||
(0, -1, 0),
|
||||
(0, 0, 1),
|
||||
(0, 0, -1),
|
||||
}
|
||||
assert axes == expected
|
||||
|
||||
def test_random_axis_unit_norm(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
for _ in range(50):
|
||||
axis = gen._sample_axis("random")
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
|
||||
def test_near_parallel_close_to_base(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
base = np.array([0.0, 0.0, 1.0])
|
||||
for _ in range(50):
|
||||
axis = gen._near_parallel_axis(base)
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
assert np.dot(axis, base) > 0.95
|
||||
|
||||
def test_sample_axis_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("cardinal")
|
||||
cardinals = [
|
||||
np.array(v, dtype=float)
|
||||
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
|
||||
]
|
||||
assert any(np.allclose(axis, c) for c in cardinals)
|
||||
|
||||
def test_sample_axis_near_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("near_parallel")
|
||||
z = np.array([0.0, 0.0, 1.0])
|
||||
assert np.dot(axis, z) > 0.95
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: orientations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRandomOrientations:
|
||||
"""Bodies should have non-identity orientations."""
|
||||
|
||||
def test_bodies_have_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(5)
|
||||
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
|
||||
assert non_identity >= 3
|
||||
|
||||
def test_orientations_are_valid_rotations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(6)
|
||||
for b in bodies:
|
||||
r = b.orientation
|
||||
# R^T R == I
|
||||
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
|
||||
# det(R) == 1
|
||||
assert abs(np.linalg.det(r) - 1.0) < 1e-10
|
||||
|
||||
def test_all_generators_set_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
# Chain
|
||||
bodies, _, _ = gen.generate_chain_assembly(3)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Loop
|
||||
bodies, _, _ = gen.generate_loop_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Mixed
|
||||
bodies, _, _ = gen.generate_mixed_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: grounded parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGroundedParameter:
|
||||
"""Grounded parameter controls ground_body in analysis."""
|
||||
|
||||
def test_chain_grounded_default(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(4)
|
||||
assert analysis.combinatorial_dof >= 0
|
||||
|
||||
def test_chain_floating(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(
|
||||
4,
|
||||
grounded=False,
|
||||
)
|
||||
# Floating: 6 trivial DOF not subtracted by ground
|
||||
assert analysis.combinatorial_dof >= 6
|
||||
|
||||
def test_floating_vs_grounded_dof_difference(self) -> None:
|
||||
gen1 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
|
||||
# Floating should have higher DOF due to missing ground constraint
|
||||
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gen_method",
|
||||
[
|
||||
"generate_chain_assembly",
|
||||
"generate_rigid_assembly",
|
||||
"generate_tree_assembly",
|
||||
"generate_loop_assembly",
|
||||
"generate_star_assembly",
|
||||
"generate_mixed_assembly",
|
||||
],
|
||||
)
|
||||
def test_all_generators_accept_grounded(
|
||||
self,
|
||||
gen_method: str,
|
||||
) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
method = getattr(gen, gen_method)
|
||||
n = 4
|
||||
# Should not raise
|
||||
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
|
||||
method(n, grounded=False)
|
||||
else:
|
||||
method(n, grounded=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: parallel axis injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelAxisInjection:
|
||||
"""parallel_axis_prob causes shared axis direction."""
|
||||
|
||||
def test_parallel_axes_similar(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
# Near-parallel: |dot| close to 1
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_zero_prob_no_forced_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=0.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
|
||||
# With 5 random axes, extremely unlikely all are parallel
|
||||
assert min(dots) < 0.95
|
||||
|
||||
def test_parallel_on_loop(self) -> None:
|
||||
"""Parallel axes on a loop assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_parallel_on_star(self) -> None:
|
||||
"""Parallel axes on a star assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -265,7 +485,11 @@ class TestComplexityTiers:
|
||||
|
||||
def test_tier_overrides_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
n_bodies_range=(2, 3),
|
||||
complexity_tier="medium",
|
||||
)
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
@@ -282,9 +506,11 @@ class TestTrainingBatch:
|
||||
EXPECTED_KEYS: ClassVar[set[str]] = {
|
||||
"example_id",
|
||||
"generator_type",
|
||||
"grounded",
|
||||
"n_bodies",
|
||||
"n_joints",
|
||||
"body_positions",
|
||||
"body_orientations",
|
||||
"joints",
|
||||
"joint_labels",
|
||||
"assembly_classification",
|
||||
@@ -337,7 +563,8 @@ class TestTrainingBatch:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
|
||||
# default (3, 8), but loop/star may clamp
|
||||
assert 2 <= ex["n_bodies"] <= 7
|
||||
|
||||
def test_joint_label_consistency(self) -> None:
|
||||
"""independent + redundant == total for every joint."""
|
||||
@@ -348,6 +575,70 @@ class TestTrainingBatch:
|
||||
total = label["independent_constraints"] + label["redundant_constraints"]
|
||||
assert total == label["total_constraints"]
|
||||
|
||||
def test_body_orientations_present(self) -> None:
|
||||
"""Each example includes body_orientations as 3x3 lists."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
orients = ex["body_orientations"]
|
||||
assert len(orients) == ex["n_bodies"]
|
||||
for o in orients:
|
||||
assert len(o) == 3
|
||||
assert len(o[0]) == 3
|
||||
|
||||
def test_grounded_field_present(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert isinstance(ex["grounded"], bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch grounded ratio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchGroundedRatio:
|
||||
"""grounded_ratio controls the mix in batch generation."""
|
||||
|
||||
def test_all_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
|
||||
assert all(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_none_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
|
||||
assert not any(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_mixed_ratio(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
|
||||
grounded_count = sum(1 for ex in batch if ex["grounded"])
|
||||
# With 100 samples and p=0.5, should be roughly 50 +/- 20
|
||||
assert 20 < grounded_count < 80
|
||||
|
||||
def test_batch_axis_strategy_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
axis_strategy="cardinal",
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
def test_batch_parallel_axis_prob(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
parallel_axis_prob=0.5,
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed reproducibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSeedReproducibility:
|
||||
"""Different seeds produce different results."""
|
||||
@@ -355,8 +646,14 @@ class TestSeedReproducibility:
|
||||
def test_different_seeds_differ(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=1)
|
||||
g2 = SyntheticAssemblyGenerator(seed=2)
|
||||
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
b1 = g1.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
b2 = g2.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
c1 = [ex["assembly_classification"] for ex in b1]
|
||||
c2 = [ex["assembly_classification"] for ex in b2]
|
||||
r1 = [ex["is_rigid"] for ex in b1]
|
||||
|
||||
Reference in New Issue
Block a user