feat(datasets): PyG dataset adapter for assembly shard files #17

Closed
opened 2026-02-04 20:22:10 +00:00 by forbes · 0 comments
Owner

Summary

Implement solver/datasets/assembly.py with an AssemblyDataset(torch_geometric.data.Dataset) that loads shard files from data/synthetic/shards/ and converts each example to a torch_geometric.data.Data object using feature encoding from #16.

Depends on #16.

Data object fields

Field Shape Description
x [n_bodies, 22] Node features
edge_index [2, 2*n_joints] Bidirectional edges from joint body_a/body_b
edge_attr [2*n_joints, 22] Edge features (duplicated for both directions)
y_edge [2*n_joints] Per-edge independence labels
y_graph [1] Graph classification label
y_joint_type [2*n_joints] Per-edge joint type labels
y_dof [1] Scalar DOF target
y_body_dof [n_bodies, 2] Per-body translational + rotational DOF

Requirements

  • Load both .pt and .json shard formats
  • Support train/val/test splits via index files
  • Compatible with torch_geometric.loader.DataLoader for batching

Files

  • solver/datasets/assembly.py
  • solver/datasets/__init__.py
  • tests/datasets/__init__.py
  • tests/datasets/test_assembly.py

Acceptance criteria

  • Loads a shard file and produces valid Data objects
  • Bidirectional edges with correct feature duplication
  • Compatible with DataLoader (batching works)
  • Passes ruff, mypy, pytest
## Summary Implement `solver/datasets/assembly.py` with an `AssemblyDataset(torch_geometric.data.Dataset)` that loads shard files from `data/synthetic/shards/` and converts each example to a `torch_geometric.data.Data` object using feature encoding from #16. Depends on #16. ## Data object fields | Field | Shape | Description | |-------|-------|-------------| | `x` | `[n_bodies, 22]` | Node features | | `edge_index` | `[2, 2*n_joints]` | Bidirectional edges from joint body_a/body_b | | `edge_attr` | `[2*n_joints, 22]` | Edge features (duplicated for both directions) | | `y_edge` | `[2*n_joints]` | Per-edge independence labels | | `y_graph` | `[1]` | Graph classification label | | `y_joint_type` | `[2*n_joints]` | Per-edge joint type labels | | `y_dof` | `[1]` | Scalar DOF target | | `y_body_dof` | `[n_bodies, 2]` | Per-body translational + rotational DOF | ## Requirements - Load both `.pt` and `.json` shard formats - Support train/val/test splits via index files - Compatible with `torch_geometric.loader.DataLoader` for batching ## Files - `solver/datasets/assembly.py` - `solver/datasets/__init__.py` - `tests/datasets/__init__.py` - `tests/datasets/test_assembly.py` ## Acceptance criteria - [ ] Loads a shard file and produces valid `Data` objects - [ ] Bidirectional edges with correct feature duplication - [ ] Compatible with `DataLoader` (batching works) - [ ] Passes ruff, mypy, pytest
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: kindred/solver#17