Files
solver/tests/models/test_assembly_gnn.py
forbes fe41fa3b00 feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific
prediction heads, uncertainty-weighted multi-task loss (Kendall et al.
2018), and config-driven factory functions.

New modules:
- graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features)
- encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head)
- heads.py: edge classification, graph classification, joint type,
  DOF regression, per-body DOF tracking heads
- assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads
- losses.py: MultiTaskLoss with learnable log-variance per task
- factory.py: build_model() and build_loss() from YAML configs

Supporting changes:
- generator.py: serialize anchor_a, anchor_b, pitch in joint dicts
- configs: fix joint_type num_classes 12 -> 11 (matches JointType enum)

92 tests covering shapes, gradients, edge cases, and end-to-end
datagen-to-model pipeline.
2026-02-07 10:14:19 -06:00

189 lines
6.6 KiB
Python

"""Tests for solver.models.assembly_gnn -- main model wiring."""
from __future__ import annotations
import pytest
import torch
from solver.models.assembly_gnn import AssemblyGNN
def _default_heads_config(dof_tracking: bool = False) -> dict:
return {
"edge_classification": {"enabled": True, "hidden_dim": 64},
"graph_classification": {"enabled": True, "num_classes": 4},
"joint_type": {"enabled": True, "num_classes": 11},
"dof_regression": {"enabled": True},
"dof_tracking": {"enabled": dof_tracking},
}
def _random_graph(
n_nodes: int = 8,
n_edges: int = 16,
batch_size: int = 2,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
x = torch.randn(n_nodes, 22)
edge_index = torch.randint(0, n_nodes, (2, n_edges))
edge_attr = torch.randn(n_edges, 22)
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
if len(batch) < n_nodes:
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
return x, edge_index, edge_attr, batch
class TestAssemblyGNNGIN:
"""AssemblyGNN with GIN encoder."""
def test_forward_all_heads(self) -> None:
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 64, "num_layers": 2},
heads_config=_default_heads_config(),
)
x, ei, ea, batch = _random_graph()
preds = model(x, ei, ea, batch)
assert "edge_pred" in preds
assert "graph_pred" in preds
assert "joint_type_pred" in preds
assert "dof_pred" in preds
def test_output_shapes(self) -> None:
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 64, "num_layers": 2},
heads_config=_default_heads_config(),
)
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20, batch_size=3)
preds = model(x, ei, ea, batch)
assert preds["edge_pred"].shape == (20, 1)
assert preds["graph_pred"].shape == (3, 4)
assert preds["joint_type_pred"].shape == (20, 11)
assert preds["dof_pred"].shape == (3, 1)
def test_gradients_flow(self) -> None:
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 32, "num_layers": 2},
heads_config=_default_heads_config(),
)
x, ei, ea, batch = _random_graph()
x.requires_grad_(True)
preds = model(x, ei, ea, batch)
total = sum(p.sum() for p in preds.values())
total.backward()
assert x.grad is not None
def test_no_heads_returns_empty(self) -> None:
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 32, "num_layers": 2},
heads_config={},
)
x, ei, ea, batch = _random_graph()
preds = model(x, ei, ea, batch)
assert len(preds) == 0
class TestAssemblyGNNGAT:
"""AssemblyGNN with GAT encoder."""
def test_forward_all_heads(self) -> None:
model = AssemblyGNN(
encoder_type="gat",
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
heads_config=_default_heads_config(dof_tracking=True),
)
x, ei, ea, batch = _random_graph()
preds = model(x, ei, ea, batch)
assert "edge_pred" in preds
assert "graph_pred" in preds
assert "joint_type_pred" in preds
assert "dof_pred" in preds
assert "body_dof_pred" in preds
def test_body_dof_shape(self) -> None:
model = AssemblyGNN(
encoder_type="gat",
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
heads_config=_default_heads_config(dof_tracking=True),
)
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20)
preds = model(x, ei, ea, batch)
assert preds["body_dof_pred"].shape == (10, 2)
class TestAssemblyGNNEdgeCases:
"""Edge cases and error handling."""
def test_unknown_encoder_raises(self) -> None:
with pytest.raises(ValueError, match="Unknown encoder"):
AssemblyGNN(encoder_type="transformer")
def test_selective_heads(self) -> None:
"""Only enabled heads produce output."""
config = {
"edge_classification": {"enabled": True},
"graph_classification": {"enabled": False},
"joint_type": {"enabled": True, "num_classes": 11},
}
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 32, "num_layers": 2},
heads_config=config,
)
x, ei, ea, batch = _random_graph()
preds = model(x, ei, ea, batch)
assert "edge_pred" in preds
assert "joint_type_pred" in preds
assert "graph_pred" not in preds
assert "dof_pred" not in preds
def test_no_batch_single_graph(self) -> None:
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 32, "num_layers": 2},
heads_config=_default_heads_config(),
)
x = torch.randn(6, 22)
ei = torch.randint(0, 6, (2, 10))
ea = torch.randn(10, 22)
preds = model(x, ei, ea)
assert preds["graph_pred"].shape == (1, 4)
assert preds["dof_pred"].shape == (1, 1)
def test_parameter_count_reasonable(self) -> None:
"""Sanity check that model has learnable parameters."""
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 64, "num_layers": 2},
heads_config=_default_heads_config(),
)
n_params = sum(p.numel() for p in model.parameters())
assert n_params > 1000 # non-trivial model
class TestAssemblyGNNEndToEnd:
"""End-to-end test with datagen pipeline."""
def test_datagen_to_model(self) -> None:
from solver.datagen.generator import SyntheticAssemblyGenerator
from solver.models.graph_conv import assembly_to_pyg
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(batch_size=2, complexity_tier="simple")
model = AssemblyGNN(
encoder_type="gin",
encoder_config={"hidden_dim": 32, "num_layers": 2},
heads_config=_default_heads_config(),
)
model.eval()
for ex in batch:
data = assembly_to_pyg(ex)
with torch.no_grad():
preds = model(data.x, data.edge_index, data.edge_attr)
assert "edge_pred" in preds
assert preds["edge_pred"].shape[0] == data.edge_index.shape[1]