feat: port shared types to solver/datagen/types.py
Port JointType, RigidBody, Joint, PebbleState, and ConstraintAnalysis from data/synthetic/pebble-game.py into the solver package. - Add __all__ export list - Put typing.Any behind TYPE_CHECKING (ruff TCH003) - Parameterize list[dict] as list[dict[str, Any]] (mypy strict) - Re-export all types from solver.datagen.__init__ Closes #1
This commit is contained in:
@@ -0,0 +1,17 @@
|
||||
"""Data generation utilities for assembly constraint training data."""
|
||||
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
PebbleState,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ConstraintAnalysis",
|
||||
"Joint",
|
||||
"JointType",
|
||||
"PebbleState",
|
||||
"RigidBody",
|
||||
]
|
||||
|
||||
137
solver/datagen/types.py
Normal file
137
solver/datagen/types.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Shared data types for assembly constraint analysis.
|
||||
|
||||
Types ported from the pebble-game synthetic data generator for reuse
|
||||
across the solver package (data generation, training, inference).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"ConstraintAnalysis",
|
||||
"Joint",
|
||||
"JointType",
|
||||
"PebbleState",
|
||||
"RigidBody",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Joint definitions: each joint type removes a known number of DOF
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JointType(enum.Enum):
|
||||
"""Standard CAD joint types with their DOF-removal counts.
|
||||
|
||||
Each joint between two 6-DOF rigid bodies removes a specific number
|
||||
of relative degrees of freedom. In the body-bar-hinge multigraph
|
||||
representation, each joint maps to a number of edges equal to the
|
||||
DOF it removes.
|
||||
|
||||
DOF removed = number of scalar constraint equations the joint imposes.
|
||||
"""
|
||||
|
||||
FIXED = 6 # Locks all relative motion
|
||||
REVOLUTE = 5 # Allows rotation about one axis
|
||||
CYLINDRICAL = 4 # Allows rotation + translation along one axis
|
||||
SLIDER = 5 # Allows translation along one axis (prismatic)
|
||||
BALL = 3 # Allows rotation about a point (spherical)
|
||||
PLANAR = 3 # Allows 2D translation + rotation normal to plane
|
||||
SCREW = 5 # Coupled rotation-translation (helical)
|
||||
UNIVERSAL = 4 # Two rotational DOF (Cardan/U-joint)
|
||||
PARALLEL = 3 # Forces parallel orientation (3 rotation constraints)
|
||||
PERPENDICULAR = 1 # Single angular constraint
|
||||
DISTANCE = 1 # Single scalar distance constraint
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RigidBody:
|
||||
"""A rigid body in the assembly with pose and geometry info."""
|
||||
|
||||
body_id: int
|
||||
position: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
orientation: np.ndarray = field(default_factory=lambda: np.eye(3))
|
||||
|
||||
# Anchor points for joints, in local frame
|
||||
# Populated when joints reference specific geometry
|
||||
local_anchors: dict[str, np.ndarray] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Joint:
|
||||
"""A joint connecting two rigid bodies."""
|
||||
|
||||
joint_id: int
|
||||
body_a: int # Index of first body
|
||||
body_b: int # Index of second body
|
||||
joint_type: JointType
|
||||
|
||||
# Joint parameters in world frame
|
||||
anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
axis: np.ndarray = field(
|
||||
default_factory=lambda: np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
|
||||
# For screw joints
|
||||
pitch: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PebbleState:
|
||||
"""Tracks the state of the pebble game on the multigraph."""
|
||||
|
||||
# Number of free pebbles per body (vertex). Starts at 6.
|
||||
free_pebbles: dict[int, int] = field(default_factory=dict)
|
||||
|
||||
# Directed edges: edge_id -> (source_body, target_body)
|
||||
# Edge is directed away from the body that "spent" a pebble.
|
||||
directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict)
|
||||
|
||||
# Track which edges are independent vs redundant
|
||||
independent_edges: set[int] = field(default_factory=set)
|
||||
redundant_edges: set[int] = field(default_factory=set)
|
||||
|
||||
# Adjacency: body_id -> set of (edge_id, neighbor_body_id)
|
||||
# Following directed edges *towards* a body (incoming edges)
|
||||
incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||
|
||||
# Outgoing edges from a body
|
||||
outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintAnalysis:
|
||||
"""Results of analyzing an assembly's constraint system."""
|
||||
|
||||
# Pebble game (combinatorial) results
|
||||
combinatorial_dof: int
|
||||
combinatorial_internal_dof: int
|
||||
combinatorial_redundant: int
|
||||
combinatorial_classification: str
|
||||
per_edge_results: list[dict[str, Any]]
|
||||
|
||||
# Numerical (Jacobian) results
|
||||
jacobian_rank: int
|
||||
jacobian_nullity: int # = 6n - rank = total DOF
|
||||
jacobian_internal_dof: int # = nullity - 6
|
||||
numerically_dependent: list[int]
|
||||
|
||||
# Combined
|
||||
geometric_degeneracies: int # = combinatorial_independent - jacobian_rank
|
||||
is_rigid: bool
|
||||
is_minimally_rigid: bool
|
||||
Reference in New Issue
Block a user