feat: port SyntheticAssemblyGenerator to solver/datagen/generator.py
Port chain, rigid, and overconstrained assembly generators plus the training batch generation from data/synthetic/pebble-game.py. - Refactored rng.choice on enums/callables to integer indexing (mypy) - Typed n_bodies_range as tuple[int, int] - Typed batch return as list[dict[str, Any]] - Full type annotations (mypy strict) - Re-exported from solver.datagen.__init__ Closes #5
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
"""Data generation utilities for assembly constraint training data."""
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
@@ -19,5 +20,6 @@ __all__ = [
|
||||
"PebbleGame3D",
|
||||
"PebbleState",
|
||||
"RigidBody",
|
||||
"SyntheticAssemblyGenerator",
|
||||
"analyze_assembly",
|
||||
]
|
||||
|
||||
252
solver/datagen/generator.py
Normal file
252
solver/datagen/generator.py
Normal file
@@ -0,0 +1,252 @@
|
||||
"""Synthetic assembly graph generator for training data production.
|
||||
|
||||
Generates assembly graphs with known constraint classifications using
|
||||
the pebble game and Jacobian verification. Each assembly is fully labeled
|
||||
with per-constraint independence flags and assembly-level classification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["SyntheticAssemblyGenerator"]
|
||||
|
||||
|
||||
class SyntheticAssemblyGenerator:
|
||||
"""Generates assembly graphs with known minimal constraint sets.
|
||||
|
||||
Uses the pebble game to incrementally build assemblies, tracking
|
||||
exactly which constraints are independent at each step. This produces
|
||||
labeled training data: (assembly_graph, constraint_set, labels).
|
||||
|
||||
Labels per constraint:
|
||||
- independent: bool (does this constraint remove a DOF?)
|
||||
- redundant: bool (is this constraint overconstrained?)
|
||||
- minimal_set: bool (part of a minimal rigidity basis?)
|
||||
"""
|
||||
|
||||
def __init__(self, seed: int = 42) -> None:
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
def generate_chain_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_type: JointType = JointType.REVOLUTE,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a serial kinematic chain.
|
||||
|
||||
Simple but useful: each body connects to the next with the
|
||||
specified joint type. Results in an underconstrained assembly
|
||||
(serial chain is never rigid without closing loops).
|
||||
"""
|
||||
bodies = []
|
||||
joints = []
|
||||
|
||||
for i in range(n_bodies):
|
||||
pos = np.array([i * 2.0, 0.0, 0.0])
|
||||
bodies.append(RigidBody(body_id=i, position=pos))
|
||||
|
||||
for i in range(n_bodies - 1):
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
|
||||
anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0])
|
||||
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=i,
|
||||
body_a=i,
|
||||
body_b=i + 1,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor,
|
||||
anchor_b=anchor,
|
||||
axis=axis,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_rigid_assembly(
|
||||
self, n_bodies: int
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a minimally rigid assembly by adding joints until rigid.
|
||||
|
||||
Strategy: start with fixed joints on a spanning tree (guarantees
|
||||
rigidity), then randomly relax some to weaker joint types while
|
||||
maintaining rigidity via the pebble game check.
|
||||
"""
|
||||
bodies = []
|
||||
for i in range(n_bodies):
|
||||
pos = self.rng.uniform(-5, 5, size=3)
|
||||
bodies.append(RigidBody(body_id=i, position=pos))
|
||||
|
||||
# Build spanning tree with fixed joints (overconstrained)
|
||||
joints: list[Joint] = []
|
||||
for i in range(1, n_bodies):
|
||||
parent = self.rng.integers(0, i)
|
||||
mid = (bodies[i].position + bodies[parent].position) / 2
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=i - 1,
|
||||
body_a=int(parent),
|
||||
body_b=i,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=mid,
|
||||
anchor_b=mid,
|
||||
axis=axis,
|
||||
)
|
||||
)
|
||||
|
||||
# Try relaxing joints to weaker types while maintaining rigidity
|
||||
weaker_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.CYLINDRICAL,
|
||||
JointType.BALL,
|
||||
]
|
||||
|
||||
for idx in self.rng.permutation(len(joints)):
|
||||
original_type = joints[idx].joint_type
|
||||
for candidate in weaker_types:
|
||||
joints[idx].joint_type = candidate
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
if analysis.is_rigid:
|
||||
break # Keep the weaker type
|
||||
else:
|
||||
joints[idx].joint_type = original_type
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_overconstrained_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
extra_joints: int = 2,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate an assembly with known redundant constraints.
|
||||
|
||||
Starts with a rigid assembly, then adds extra joints that
|
||||
the pebble game will flag as redundant.
|
||||
"""
|
||||
bodies, joints, _ = self.generate_rigid_assembly(n_bodies)
|
||||
|
||||
joint_id = len(joints)
|
||||
for _ in range(extra_joints):
|
||||
a, b = self.rng.choice(n_bodies, size=2, replace=False)
|
||||
mid = (bodies[a].position + bodies[b].position) / 2
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
|
||||
_overcon_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.FIXED,
|
||||
JointType.BALL,
|
||||
]
|
||||
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
|
||||
joints.append(
|
||||
Joint(
|
||||
joint_id=joint_id,
|
||||
body_a=int(a),
|
||||
body_b=int(b),
|
||||
joint_type=jtype,
|
||||
anchor_a=mid,
|
||||
anchor_b=mid,
|
||||
axis=axis,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=0)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_training_batch(
|
||||
self,
|
||||
batch_size: int = 100,
|
||||
n_bodies_range: tuple[int, int] = (3, 8),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate a batch of labeled training examples.
|
||||
|
||||
Each example contains:
|
||||
- bodies: list of body positions/orientations
|
||||
- joints: list of joints with types and parameters
|
||||
- labels: per-joint independence/redundancy flags
|
||||
- assembly_label: overall classification
|
||||
"""
|
||||
examples: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
n = int(self.rng.integers(*n_bodies_range))
|
||||
gen_idx = int(self.rng.integers(3))
|
||||
|
||||
if gen_idx == 0:
|
||||
_chain_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.BALL,
|
||||
JointType.CYLINDRICAL,
|
||||
]
|
||||
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
|
||||
bodies, joints, analysis = self.generate_chain_assembly(n, jtype)
|
||||
elif gen_idx == 1:
|
||||
bodies, joints, analysis = self.generate_rigid_assembly(n)
|
||||
else:
|
||||
extra = int(self.rng.integers(1, 4))
|
||||
bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra)
|
||||
|
||||
# Build per-joint labels from edge results
|
||||
joint_labels: dict[int, dict[str, int]] = {}
|
||||
for result in analysis.per_edge_results:
|
||||
jid = result["joint_id"]
|
||||
if jid not in joint_labels:
|
||||
joint_labels[jid] = {
|
||||
"independent_constraints": 0,
|
||||
"redundant_constraints": 0,
|
||||
"total_constraints": 0,
|
||||
}
|
||||
joint_labels[jid]["total_constraints"] += 1
|
||||
if result["independent"]:
|
||||
joint_labels[jid]["independent_constraints"] += 1
|
||||
else:
|
||||
joint_labels[jid]["redundant_constraints"] += 1
|
||||
|
||||
examples.append(
|
||||
{
|
||||
"example_id": i,
|
||||
"n_bodies": len(bodies),
|
||||
"n_joints": len(joints),
|
||||
"body_positions": [b.position.tolist() for b in bodies],
|
||||
"joints": [
|
||||
{
|
||||
"joint_id": j.joint_id,
|
||||
"body_a": j.body_a,
|
||||
"body_b": j.body_b,
|
||||
"type": j.joint_type.name,
|
||||
"axis": j.axis.tolist(),
|
||||
}
|
||||
for j in joints
|
||||
],
|
||||
"joint_labels": joint_labels,
|
||||
"assembly_classification": analysis.combinatorial_classification,
|
||||
"is_rigid": analysis.is_rigid,
|
||||
"is_minimally_rigid": analysis.is_minimally_rigid,
|
||||
"internal_dof": analysis.jacobian_internal_dof,
|
||||
"geometric_degeneracies": analysis.geometric_degeneracies,
|
||||
}
|
||||
)
|
||||
|
||||
return examples
|
||||
Reference in New Issue
Block a user