Files
solver/tests/models/test_graph_conv.py
forbes fe41fa3b00 feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific
prediction heads, uncertainty-weighted multi-task loss (Kendall et al.
2018), and config-driven factory functions.

New modules:
- graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features)
- encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head)
- heads.py: edge classification, graph classification, joint type,
  DOF regression, per-body DOF tracking heads
- assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads
- losses.py: MultiTaskLoss with learnable log-variance per task
- factory.py: build_model() and build_loss() from YAML configs

Supporting changes:
- generator.py: serialize anchor_a, anchor_b, pitch in joint dicts
- configs: fix joint_type num_classes 12 -> 11 (matches JointType enum)

92 tests covering shapes, gradients, edge cases, and end-to-end
datagen-to-model pipeline.
2026-02-07 10:14:19 -06:00

145 lines
5.2 KiB
Python

"""Tests for solver.models.graph_conv -- assembly to PyG conversion."""
from __future__ import annotations
import torch
from solver.datagen.generator import SyntheticAssemblyGenerator
from solver.datagen.types import JointType
from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg
def _make_example(n_bodies: int = 4, grounded: bool = True, seed: int = 0) -> dict:
"""Generate a single training example via the datagen pipeline."""
gen = SyntheticAssemblyGenerator(seed=seed)
batch = gen.generate_training_batch(batch_size=1, n_bodies_range=(n_bodies, n_bodies + 1))
return batch[0]
class TestAssemblyToPyg:
"""assembly_to_pyg converts datagen dicts to PyG Data correctly."""
def test_node_feature_shape(self) -> None:
ex = _make_example(n_bodies=5)
data = assembly_to_pyg(ex)
assert data.x.shape == (5, 22)
def test_edge_feature_shape(self) -> None:
ex = _make_example(n_bodies=4)
data = assembly_to_pyg(ex)
n_joints = ex["n_joints"]
assert data.edge_attr.shape == (n_joints * 2, 22)
def test_edge_index_bidirectional(self) -> None:
ex = _make_example(n_bodies=4)
data = assembly_to_pyg(ex)
ei = data.edge_index
# Each joint produces 2 directed edges: a->b and b->a.
for i in range(0, ei.size(1), 2):
assert ei[0, i].item() == ei[1, i + 1].item()
assert ei[1, i].item() == ei[0, i + 1].item()
def test_edge_index_shape(self) -> None:
ex = _make_example(n_bodies=4)
data = assembly_to_pyg(ex)
n_joints = ex["n_joints"]
assert data.edge_index.shape == (2, n_joints * 2)
def test_node_features_centered(self) -> None:
ex = _make_example(n_bodies=5)
data = assembly_to_pyg(ex)
# Positions (dims 0-2) should be centered (mean ~0).
pos = data.x[:, :3]
assert pos.mean(dim=0).abs().max().item() < 1e-5
def test_grounded_flag_set(self) -> None:
ex = _make_example(n_bodies=4, grounded=True)
ex["grounded"] = True
data = assembly_to_pyg(ex)
assert data.x[0, 12].item() == 1.0
def test_ungrounded_flag_clear(self) -> None:
ex = _make_example(n_bodies=4)
ex["grounded"] = False
data = assembly_to_pyg(ex)
assert (data.x[:, 12] == 0.0).all()
def test_edge_type_one_hot_valid(self) -> None:
ex = _make_example(n_bodies=5)
data = assembly_to_pyg(ex)
if data.edge_attr.size(0) > 0:
onehot = data.edge_attr[:, :11]
# Each row should have exactly one 1.0.
assert (onehot.sum(dim=1) == 1.0).all()
def test_labels_present_when_requested(self) -> None:
ex = _make_example(n_bodies=4)
data = assembly_to_pyg(ex, include_labels=True)
assert hasattr(data, "y_edge")
assert hasattr(data, "y_graph")
assert hasattr(data, "y_joint_type")
assert hasattr(data, "y_dof")
assert hasattr(data, "y_body_dof")
def test_labels_absent_when_not_requested(self) -> None:
ex = _make_example(n_bodies=4)
data = assembly_to_pyg(ex, include_labels=False)
assert not hasattr(data, "y_edge")
assert not hasattr(data, "y_graph")
def test_graph_classification_label_mapping(self) -> None:
ex = _make_example(n_bodies=4)
data = assembly_to_pyg(ex)
cls = ex["assembly_classification"]
expected = ASSEMBLY_CLASSES[cls]
assert data.y_graph.item() == expected
def test_body_dof_shape(self) -> None:
ex = _make_example(n_bodies=5)
data = assembly_to_pyg(ex)
assert data.y_body_dof.shape == (5, 2)
def test_edge_labels_binary(self) -> None:
ex = _make_example(n_bodies=5)
data = assembly_to_pyg(ex)
if data.y_edge.numel() > 0:
assert ((data.y_edge == 0.0) | (data.y_edge == 1.0)).all()
def test_dof_removed_normalized(self) -> None:
ex = _make_example(n_bodies=4)
data = assembly_to_pyg(ex)
if data.edge_attr.size(0) > 0:
dof_norm = data.edge_attr[:, 21]
assert (dof_norm >= 0.0).all()
assert (dof_norm <= 1.0).all()
def test_roundtrip_with_generator(self) -> None:
"""Generate a real example and convert -- no crash."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(batch_size=5, complexity_tier="simple")
for ex in batch:
data = assembly_to_pyg(ex)
assert data.x.shape[1] == 22
assert data.edge_attr.shape[1] == 22
class TestAssemblyClasses:
"""ASSEMBLY_CLASSES covers all classifications."""
def test_four_classes(self) -> None:
assert len(ASSEMBLY_CLASSES) == 4
def test_values_are_0_to_3(self) -> None:
assert set(ASSEMBLY_CLASSES.values()) == {0, 1, 2, 3}
class TestJointTypeNames:
"""JOINT_TYPE_NAMES matches the JointType enum."""
def test_length_matches_enum(self) -> None:
assert len(JOINT_TYPE_NAMES) == len(JointType)
def test_order_matches_ordinal(self) -> None:
for i, name in enumerate(JOINT_TYPE_NAMES):
assert JointType[name].value[0] == i