diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index e69de29..6e53b70 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -0,0 +1,17 @@ +"""Data generation utilities for assembly constraint training data.""" + +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + PebbleState, + RigidBody, +) + +__all__ = [ + "ConstraintAnalysis", + "Joint", + "JointType", + "PebbleState", + "RigidBody", +] diff --git a/solver/datagen/types.py b/solver/datagen/types.py new file mode 100644 index 0000000..754e5dc --- /dev/null +++ b/solver/datagen/types.py @@ -0,0 +1,137 @@ +"""Shared data types for assembly constraint analysis. + +Types ported from the pebble-game synthetic data generator for reuse +across the solver package (data generation, training, inference). +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "ConstraintAnalysis", + "Joint", + "JointType", + "PebbleState", + "RigidBody", +] + + +# --------------------------------------------------------------------------- +# Joint definitions: each joint type removes a known number of DOF +# --------------------------------------------------------------------------- + + +class JointType(enum.Enum): + """Standard CAD joint types with their DOF-removal counts. + + Each joint between two 6-DOF rigid bodies removes a specific number + of relative degrees of freedom. In the body-bar-hinge multigraph + representation, each joint maps to a number of edges equal to the + DOF it removes. + + DOF removed = number of scalar constraint equations the joint imposes. + """ + + FIXED = 6 # Locks all relative motion + REVOLUTE = 5 # Allows rotation about one axis + CYLINDRICAL = 4 # Allows rotation + translation along one axis + SLIDER = 5 # Allows translation along one axis (prismatic) + BALL = 3 # Allows rotation about a point (spherical) + PLANAR = 3 # Allows 2D translation + rotation normal to plane + SCREW = 5 # Coupled rotation-translation (helical) + UNIVERSAL = 4 # Two rotational DOF (Cardan/U-joint) + PARALLEL = 3 # Forces parallel orientation (3 rotation constraints) + PERPENDICULAR = 1 # Single angular constraint + DISTANCE = 1 # Single scalar distance constraint + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class RigidBody: + """A rigid body in the assembly with pose and geometry info.""" + + body_id: int + position: np.ndarray = field(default_factory=lambda: np.zeros(3)) + orientation: np.ndarray = field(default_factory=lambda: np.eye(3)) + + # Anchor points for joints, in local frame + # Populated when joints reference specific geometry + local_anchors: dict[str, np.ndarray] = field(default_factory=dict) + + +@dataclass +class Joint: + """A joint connecting two rigid bodies.""" + + joint_id: int + body_a: int # Index of first body + body_b: int # Index of second body + joint_type: JointType + + # Joint parameters in world frame + anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3)) + anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3)) + axis: np.ndarray = field( + default_factory=lambda: np.array([0.0, 0.0, 1.0]), + ) + + # For screw joints + pitch: float = 0.0 + + +@dataclass +class PebbleState: + """Tracks the state of the pebble game on the multigraph.""" + + # Number of free pebbles per body (vertex). Starts at 6. + free_pebbles: dict[int, int] = field(default_factory=dict) + + # Directed edges: edge_id -> (source_body, target_body) + # Edge is directed away from the body that "spent" a pebble. + directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict) + + # Track which edges are independent vs redundant + independent_edges: set[int] = field(default_factory=set) + redundant_edges: set[int] = field(default_factory=set) + + # Adjacency: body_id -> set of (edge_id, neighbor_body_id) + # Following directed edges *towards* a body (incoming edges) + incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict) + + # Outgoing edges from a body + outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict) + + +@dataclass +class ConstraintAnalysis: + """Results of analyzing an assembly's constraint system.""" + + # Pebble game (combinatorial) results + combinatorial_dof: int + combinatorial_internal_dof: int + combinatorial_redundant: int + combinatorial_classification: str + per_edge_results: list[dict[str, Any]] + + # Numerical (Jacobian) results + jacobian_rank: int + jacobian_nullity: int # = 6n - rank = total DOF + jacobian_internal_dof: int # = nullity - 6 + numerically_dependent: list[int] + + # Combined + geometric_degeneracies: int # = combinatorial_independent - jacobian_rank + is_rigid: bool + is_minimally_rigid: bool