feat: port shared types to solver/datagen/types.py
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

Port JointType, RigidBody, Joint, PebbleState, and ConstraintAnalysis
from data/synthetic/pebble-game.py into the solver package.

- Add __all__ export list
- Put typing.Any behind TYPE_CHECKING (ruff TCH003)
- Parameterize list[dict] as list[dict[str, Any]] (mypy strict)
- Re-export all types from solver.datagen.__init__

Closes #1
This commit is contained in:
2026-02-02 13:43:19 -06:00
parent 363b49281b
commit 1b6135129e
2 changed files with 154 additions and 0 deletions

View File

@@ -0,0 +1,17 @@
"""Data generation utilities for assembly constraint training data."""
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
PebbleState,
RigidBody,
)
__all__ = [
"ConstraintAnalysis",
"Joint",
"JointType",
"PebbleState",
"RigidBody",
]

137
solver/datagen/types.py Normal file
View File

@@ -0,0 +1,137 @@
"""Shared data types for assembly constraint analysis.
Types ported from the pebble-game synthetic data generator for reuse
across the solver package (data generation, training, inference).
"""
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from typing import Any
__all__ = [
"ConstraintAnalysis",
"Joint",
"JointType",
"PebbleState",
"RigidBody",
]
# ---------------------------------------------------------------------------
# Joint definitions: each joint type removes a known number of DOF
# ---------------------------------------------------------------------------
class JointType(enum.Enum):
"""Standard CAD joint types with their DOF-removal counts.
Each joint between two 6-DOF rigid bodies removes a specific number
of relative degrees of freedom. In the body-bar-hinge multigraph
representation, each joint maps to a number of edges equal to the
DOF it removes.
DOF removed = number of scalar constraint equations the joint imposes.
"""
FIXED = 6 # Locks all relative motion
REVOLUTE = 5 # Allows rotation about one axis
CYLINDRICAL = 4 # Allows rotation + translation along one axis
SLIDER = 5 # Allows translation along one axis (prismatic)
BALL = 3 # Allows rotation about a point (spherical)
PLANAR = 3 # Allows 2D translation + rotation normal to plane
SCREW = 5 # Coupled rotation-translation (helical)
UNIVERSAL = 4 # Two rotational DOF (Cardan/U-joint)
PARALLEL = 3 # Forces parallel orientation (3 rotation constraints)
PERPENDICULAR = 1 # Single angular constraint
DISTANCE = 1 # Single scalar distance constraint
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class RigidBody:
"""A rigid body in the assembly with pose and geometry info."""
body_id: int
position: np.ndarray = field(default_factory=lambda: np.zeros(3))
orientation: np.ndarray = field(default_factory=lambda: np.eye(3))
# Anchor points for joints, in local frame
# Populated when joints reference specific geometry
local_anchors: dict[str, np.ndarray] = field(default_factory=dict)
@dataclass
class Joint:
"""A joint connecting two rigid bodies."""
joint_id: int
body_a: int # Index of first body
body_b: int # Index of second body
joint_type: JointType
# Joint parameters in world frame
anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3))
anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3))
axis: np.ndarray = field(
default_factory=lambda: np.array([0.0, 0.0, 1.0]),
)
# For screw joints
pitch: float = 0.0
@dataclass
class PebbleState:
"""Tracks the state of the pebble game on the multigraph."""
# Number of free pebbles per body (vertex). Starts at 6.
free_pebbles: dict[int, int] = field(default_factory=dict)
# Directed edges: edge_id -> (source_body, target_body)
# Edge is directed away from the body that "spent" a pebble.
directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict)
# Track which edges are independent vs redundant
independent_edges: set[int] = field(default_factory=set)
redundant_edges: set[int] = field(default_factory=set)
# Adjacency: body_id -> set of (edge_id, neighbor_body_id)
# Following directed edges *towards* a body (incoming edges)
incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
# Outgoing edges from a body
outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
@dataclass
class ConstraintAnalysis:
"""Results of analyzing an assembly's constraint system."""
# Pebble game (combinatorial) results
combinatorial_dof: int
combinatorial_internal_dof: int
combinatorial_redundant: int
combinatorial_classification: str
per_edge_results: list[dict[str, Any]]
# Numerical (Jacobian) results
jacobian_rank: int
jacobian_nullity: int # = 6n - rank = total DOF
jacobian_internal_dof: int # = nullity - 6
numerically_dependent: list[int]
# Combined
geometric_degeneracies: int # = combinatorial_independent - jacobian_rank
is_rigid: bool
is_minimally_rigid: bool