feat(models): feature encoding for assembly graph nodes and edges #16

Closed
opened 2026-02-04 20:21:51 +00:00 by forbes · 0 comments
Owner

Summary

Implement solver/models/features.py with functions to convert raw datagen example dicts into fixed-dimension tensor features for GNN consumption.

Node features (22-dim per body)

Dims Feature Source
3 Position body_positions[i]
9 Orientation flattened 3x3 rotation matrix from body_orientations[i]
1 Degree number of joints connected to this body
6 Joint-type histogram counts of joints by DOF-removal bucket (1, 3, 4, 5, 6 DOF + other)
1 Is-grounded flag whether this body is the ground body
1 Normalized index body_id / n_bodies
1 Log body count log(n_bodies)

Edge features (22-dim per joint)

Dims Feature Source
11 Joint type one-hot JointType enum ordinals 0-10
3 Axis normalized joint axis vector
3 Relative position body_positions[body_b] - body_positions[body_a]
1 Relative distance norm of relative position
1 DOF removed JointType.dof value
1 Pitch screw joint pitch (0 for non-screw)
1 Same-axis indicator cosine similarity to z-axis
1 Grounded indicator whether either body is grounded

Functions

  • encode_node_features(example) -> Tensor[n_bodies, 22]
  • encode_edge_features(example) -> Tensor[n_joints, 22]
  • encode_targets(example) -> dict[str, Tensor] — extracts edge_labels, graph_label, joint_type_labels, dof_target, per_body_dof

Files

  • solver/models/features.py
  • tests/models/__init__.py
  • tests/models/test_features.py

Acceptance criteria

  • Round-trips a datagen example dict through encode functions
  • Output tensor shapes match (n_bodies, 22) and (n_joints, 22)
  • encode_targets produces correct label tensors
  • Passes ruff, mypy, pytest
## Summary Implement `solver/models/features.py` with functions to convert raw datagen example dicts into fixed-dimension tensor features for GNN consumption. ## Node features (22-dim per body) | Dims | Feature | Source | |------|---------|--------| | 3 | Position | `body_positions[i]` | | 9 | Orientation | flattened 3x3 rotation matrix from `body_orientations[i]` | | 1 | Degree | number of joints connected to this body | | 6 | Joint-type histogram | counts of joints by DOF-removal bucket (1, 3, 4, 5, 6 DOF + other) | | 1 | Is-grounded flag | whether this body is the ground body | | 1 | Normalized index | `body_id / n_bodies` | | 1 | Log body count | `log(n_bodies)` | ## Edge features (22-dim per joint) | Dims | Feature | Source | |------|---------|--------| | 11 | Joint type one-hot | `JointType` enum ordinals 0-10 | | 3 | Axis | normalized joint axis vector | | 3 | Relative position | `body_positions[body_b] - body_positions[body_a]` | | 1 | Relative distance | norm of relative position | | 1 | DOF removed | `JointType.dof` value | | 1 | Pitch | screw joint pitch (0 for non-screw) | | 1 | Same-axis indicator | cosine similarity to z-axis | | 1 | Grounded indicator | whether either body is grounded | ## Functions - `encode_node_features(example) -> Tensor[n_bodies, 22]` - `encode_edge_features(example) -> Tensor[n_joints, 22]` - `encode_targets(example) -> dict[str, Tensor]` — extracts edge_labels, graph_label, joint_type_labels, dof_target, per_body_dof ## Files - `solver/models/features.py` - `tests/models/__init__.py` - `tests/models/test_features.py` ## Acceptance criteria - [ ] Round-trips a datagen example dict through encode functions - [ ] Output tensor shapes match `(n_bodies, 22)` and `(n_joints, 22)` - [ ] `encode_targets` produces correct label tensors - [ ] Passes ruff, mypy, pytest
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: kindred/solver#16