feat: port SyntheticAssemblyGenerator to solver/datagen/generator.py
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

Port chain, rigid, and overconstrained assembly generators plus
the training batch generation from data/synthetic/pebble-game.py.

- Refactored rng.choice on enums/callables to integer indexing (mypy)
- Typed n_bodies_range as tuple[int, int]
- Typed batch return as list[dict[str, Any]]
- Full type annotations (mypy strict)
- Re-exported from solver.datagen.__init__

Closes #5
This commit is contained in:
2026-02-02 13:54:32 -06:00
parent 9a31df4988
commit 831a10cdb4
2 changed files with 254 additions and 0 deletions

View File

@@ -1,6 +1,7 @@
"""Data generation utilities for assembly constraint training data."""
from solver.datagen.analysis import analyze_assembly
from solver.datagen.generator import SyntheticAssemblyGenerator
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
@@ -19,5 +20,6 @@ __all__ = [
"PebbleGame3D",
"PebbleState",
"RigidBody",
"SyntheticAssemblyGenerator",
"analyze_assembly",
]

252
solver/datagen/generator.py Normal file
View File

@@ -0,0 +1,252 @@
"""Synthetic assembly graph generator for training data production.
Generates assembly graphs with known constraint classifications using
the pebble game and Jacobian verification. Each assembly is fully labeled
with per-constraint independence flags and assembly-level classification.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from solver.datagen.analysis import analyze_assembly
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
if TYPE_CHECKING:
from typing import Any
__all__ = ["SyntheticAssemblyGenerator"]
class SyntheticAssemblyGenerator:
"""Generates assembly graphs with known minimal constraint sets.
Uses the pebble game to incrementally build assemblies, tracking
exactly which constraints are independent at each step. This produces
labeled training data: (assembly_graph, constraint_set, labels).
Labels per constraint:
- independent: bool (does this constraint remove a DOF?)
- redundant: bool (is this constraint overconstrained?)
- minimal_set: bool (part of a minimal rigidity basis?)
"""
def __init__(self, seed: int = 42) -> None:
self.rng = np.random.default_rng(seed)
def generate_chain_assembly(
self,
n_bodies: int,
joint_type: JointType = JointType.REVOLUTE,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a serial kinematic chain.
Simple but useful: each body connects to the next with the
specified joint type. Results in an underconstrained assembly
(serial chain is never rigid without closing loops).
"""
bodies = []
joints = []
for i in range(n_bodies):
pos = np.array([i * 2.0, 0.0, 0.0])
bodies.append(RigidBody(body_id=i, position=pos))
for i in range(n_bodies - 1):
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0])
joints.append(
Joint(
joint_id=i,
body_a=i,
body_b=i + 1,
joint_type=joint_type,
anchor_a=anchor,
anchor_b=anchor,
axis=axis,
)
)
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
def generate_rigid_assembly(
self, n_bodies: int
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a minimally rigid assembly by adding joints until rigid.
Strategy: start with fixed joints on a spanning tree (guarantees
rigidity), then randomly relax some to weaker joint types while
maintaining rigidity via the pebble game check.
"""
bodies = []
for i in range(n_bodies):
pos = self.rng.uniform(-5, 5, size=3)
bodies.append(RigidBody(body_id=i, position=pos))
# Build spanning tree with fixed joints (overconstrained)
joints: list[Joint] = []
for i in range(1, n_bodies):
parent = self.rng.integers(0, i)
mid = (bodies[i].position + bodies[parent].position) / 2
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
joints.append(
Joint(
joint_id=i - 1,
body_a=int(parent),
body_b=i,
joint_type=JointType.FIXED,
anchor_a=mid,
anchor_b=mid,
axis=axis,
)
)
# Try relaxing joints to weaker types while maintaining rigidity
weaker_types = [
JointType.REVOLUTE,
JointType.CYLINDRICAL,
JointType.BALL,
]
for idx in self.rng.permutation(len(joints)):
original_type = joints[idx].joint_type
for candidate in weaker_types:
joints[idx].joint_type = candidate
analysis = analyze_assembly(bodies, joints, ground_body=0)
if analysis.is_rigid:
break # Keep the weaker type
else:
joints[idx].joint_type = original_type
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
def generate_overconstrained_assembly(
self,
n_bodies: int,
extra_joints: int = 2,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate an assembly with known redundant constraints.
Starts with a rigid assembly, then adds extra joints that
the pebble game will flag as redundant.
"""
bodies, joints, _ = self.generate_rigid_assembly(n_bodies)
joint_id = len(joints)
for _ in range(extra_joints):
a, b = self.rng.choice(n_bodies, size=2, replace=False)
mid = (bodies[a].position + bodies[b].position) / 2
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
_overcon_types = [
JointType.REVOLUTE,
JointType.FIXED,
JointType.BALL,
]
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
joints.append(
Joint(
joint_id=joint_id,
body_a=int(a),
body_b=int(b),
joint_type=jtype,
anchor_a=mid,
anchor_b=mid,
axis=axis,
)
)
joint_id += 1
analysis = analyze_assembly(bodies, joints, ground_body=0)
return bodies, joints, analysis
def generate_training_batch(
self,
batch_size: int = 100,
n_bodies_range: tuple[int, int] = (3, 8),
) -> list[dict[str, Any]]:
"""Generate a batch of labeled training examples.
Each example contains:
- bodies: list of body positions/orientations
- joints: list of joints with types and parameters
- labels: per-joint independence/redundancy flags
- assembly_label: overall classification
"""
examples: list[dict[str, Any]] = []
for i in range(batch_size):
n = int(self.rng.integers(*n_bodies_range))
gen_idx = int(self.rng.integers(3))
if gen_idx == 0:
_chain_types = [
JointType.REVOLUTE,
JointType.BALL,
JointType.CYLINDRICAL,
]
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
bodies, joints, analysis = self.generate_chain_assembly(n, jtype)
elif gen_idx == 1:
bodies, joints, analysis = self.generate_rigid_assembly(n)
else:
extra = int(self.rng.integers(1, 4))
bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra)
# Build per-joint labels from edge results
joint_labels: dict[int, dict[str, int]] = {}
for result in analysis.per_edge_results:
jid = result["joint_id"]
if jid not in joint_labels:
joint_labels[jid] = {
"independent_constraints": 0,
"redundant_constraints": 0,
"total_constraints": 0,
}
joint_labels[jid]["total_constraints"] += 1
if result["independent"]:
joint_labels[jid]["independent_constraints"] += 1
else:
joint_labels[jid]["redundant_constraints"] += 1
examples.append(
{
"example_id": i,
"n_bodies": len(bodies),
"n_joints": len(joints),
"body_positions": [b.position.tolist() for b in bodies],
"joints": [
{
"joint_id": j.joint_id,
"body_a": j.body_a,
"body_b": j.body_b,
"type": j.joint_type.name,
"axis": j.axis.tolist(),
}
for j in joints
],
"joint_labels": joint_labels,
"assembly_classification": analysis.combinatorial_classification,
"is_rigid": analysis.is_rigid,
"is_minimally_rigid": analysis.is_minimally_rigid,
"internal_dof": analysis.jacobian_internal_dof,
"geometric_degeneracies": analysis.geometric_degeneracies,
}
)
return examples