diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 7d64bc0..bd6fc3c 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,6 +1,7 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly +from solver.datagen.generator import SyntheticAssemblyGenerator from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( @@ -19,5 +20,6 @@ __all__ = [ "PebbleGame3D", "PebbleState", "RigidBody", + "SyntheticAssemblyGenerator", "analyze_assembly", ] diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py new file mode 100644 index 0000000..6c6bb0f --- /dev/null +++ b/solver/datagen/generator.py @@ -0,0 +1,252 @@ +"""Synthetic assembly graph generator for training data production. + +Generates assembly graphs with known constraint classifications using +the pebble game and Jacobian verification. Each assembly is fully labeled +with per-constraint independence flags and assembly-level classification. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.analysis import analyze_assembly +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["SyntheticAssemblyGenerator"] + + +class SyntheticAssemblyGenerator: + """Generates assembly graphs with known minimal constraint sets. + + Uses the pebble game to incrementally build assemblies, tracking + exactly which constraints are independent at each step. This produces + labeled training data: (assembly_graph, constraint_set, labels). + + Labels per constraint: + - independent: bool (does this constraint remove a DOF?) + - redundant: bool (is this constraint overconstrained?) + - minimal_set: bool (part of a minimal rigidity basis?) + """ + + def __init__(self, seed: int = 42) -> None: + self.rng = np.random.default_rng(seed) + + def generate_chain_assembly( + self, + n_bodies: int, + joint_type: JointType = JointType.REVOLUTE, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a serial kinematic chain. + + Simple but useful: each body connects to the next with the + specified joint type. Results in an underconstrained assembly + (serial chain is never rigid without closing loops). + """ + bodies = [] + joints = [] + + for i in range(n_bodies): + pos = np.array([i * 2.0, 0.0, 0.0]) + bodies.append(RigidBody(body_id=i, position=pos)) + + for i in range(n_bodies - 1): + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + + anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0]) + + joints.append( + Joint( + joint_id=i, + body_a=i, + body_b=i + 1, + joint_type=joint_type, + anchor_a=anchor, + anchor_b=anchor, + axis=axis, + ) + ) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_rigid_assembly( + self, n_bodies: int + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a minimally rigid assembly by adding joints until rigid. + + Strategy: start with fixed joints on a spanning tree (guarantees + rigidity), then randomly relax some to weaker joint types while + maintaining rigidity via the pebble game check. + """ + bodies = [] + for i in range(n_bodies): + pos = self.rng.uniform(-5, 5, size=3) + bodies.append(RigidBody(body_id=i, position=pos)) + + # Build spanning tree with fixed joints (overconstrained) + joints: list[Joint] = [] + for i in range(1, n_bodies): + parent = self.rng.integers(0, i) + mid = (bodies[i].position + bodies[parent].position) / 2 + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + + joints.append( + Joint( + joint_id=i - 1, + body_a=int(parent), + body_b=i, + joint_type=JointType.FIXED, + anchor_a=mid, + anchor_b=mid, + axis=axis, + ) + ) + + # Try relaxing joints to weaker types while maintaining rigidity + weaker_types = [ + JointType.REVOLUTE, + JointType.CYLINDRICAL, + JointType.BALL, + ] + + for idx in self.rng.permutation(len(joints)): + original_type = joints[idx].joint_type + for candidate in weaker_types: + joints[idx].joint_type = candidate + analysis = analyze_assembly(bodies, joints, ground_body=0) + if analysis.is_rigid: + break # Keep the weaker type + else: + joints[idx].joint_type = original_type + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_overconstrained_assembly( + self, + n_bodies: int, + extra_joints: int = 2, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate an assembly with known redundant constraints. + + Starts with a rigid assembly, then adds extra joints that + the pebble game will flag as redundant. + """ + bodies, joints, _ = self.generate_rigid_assembly(n_bodies) + + joint_id = len(joints) + for _ in range(extra_joints): + a, b = self.rng.choice(n_bodies, size=2, replace=False) + mid = (bodies[a].position + bodies[b].position) / 2 + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + + _overcon_types = [ + JointType.REVOLUTE, + JointType.FIXED, + JointType.BALL, + ] + jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))] + joints.append( + Joint( + joint_id=joint_id, + body_a=int(a), + body_b=int(b), + joint_type=jtype, + anchor_a=mid, + anchor_b=mid, + axis=axis, + ) + ) + joint_id += 1 + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_training_batch( + self, + batch_size: int = 100, + n_bodies_range: tuple[int, int] = (3, 8), + ) -> list[dict[str, Any]]: + """Generate a batch of labeled training examples. + + Each example contains: + - bodies: list of body positions/orientations + - joints: list of joints with types and parameters + - labels: per-joint independence/redundancy flags + - assembly_label: overall classification + """ + examples: list[dict[str, Any]] = [] + + for i in range(batch_size): + n = int(self.rng.integers(*n_bodies_range)) + gen_idx = int(self.rng.integers(3)) + + if gen_idx == 0: + _chain_types = [ + JointType.REVOLUTE, + JointType.BALL, + JointType.CYLINDRICAL, + ] + jtype = _chain_types[int(self.rng.integers(len(_chain_types)))] + bodies, joints, analysis = self.generate_chain_assembly(n, jtype) + elif gen_idx == 1: + bodies, joints, analysis = self.generate_rigid_assembly(n) + else: + extra = int(self.rng.integers(1, 4)) + bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra) + + # Build per-joint labels from edge results + joint_labels: dict[int, dict[str, int]] = {} + for result in analysis.per_edge_results: + jid = result["joint_id"] + if jid not in joint_labels: + joint_labels[jid] = { + "independent_constraints": 0, + "redundant_constraints": 0, + "total_constraints": 0, + } + joint_labels[jid]["total_constraints"] += 1 + if result["independent"]: + joint_labels[jid]["independent_constraints"] += 1 + else: + joint_labels[jid]["redundant_constraints"] += 1 + + examples.append( + { + "example_id": i, + "n_bodies": len(bodies), + "n_joints": len(joints), + "body_positions": [b.position.tolist() for b in bodies], + "joints": [ + { + "joint_id": j.joint_id, + "body_a": j.body_a, + "body_b": j.body_b, + "type": j.joint_type.name, + "axis": j.axis.tolist(), + } + for j in joints + ], + "joint_labels": joint_labels, + "assembly_classification": analysis.combinatorial_classification, + "is_rigid": analysis.is_rigid, + "is_minimally_rigid": analysis.is_minimally_rigid, + "internal_dof": analysis.jacobian_internal_dof, + "geometric_degeneracies": analysis.geometric_degeneracies, + } + ) + + return examples