Files
solver/tests/models/test_encoder.py
forbes fe41fa3b00 feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific
prediction heads, uncertainty-weighted multi-task loss (Kendall et al.
2018), and config-driven factory functions.

New modules:
- graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features)
- encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head)
- heads.py: edge classification, graph classification, joint type,
  DOF regression, per-body DOF tracking heads
- assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads
- losses.py: MultiTaskLoss with learnable log-variance per task
- factory.py: build_model() and build_loss() from YAML configs

Supporting changes:
- generator.py: serialize anchor_a, anchor_b, pitch in joint dicts
- configs: fix joint_type num_classes 12 -> 11 (matches JointType enum)

92 tests covering shapes, gradients, edge cases, and end-to-end
datagen-to-model pipeline.
2026-02-07 10:14:19 -06:00

184 lines
7.0 KiB
Python

"""Tests for solver.models.encoder -- GIN and GAT graph encoders."""
from __future__ import annotations
import pytest
import torch
from solver.models.encoder import GATEncoder, GINEncoder
def _random_graph(
n_nodes: int = 8,
n_edges: int = 20,
node_dim: int = 22,
edge_dim: int = 22,
batch_size: int = 2,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Create a random graph for testing."""
x = torch.randn(n_nodes, node_dim)
edge_index = torch.randint(0, n_nodes, (2, n_edges))
edge_attr = torch.randn(n_edges, edge_dim)
# Assign nodes to batches roughly evenly.
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
# Handle remainder nodes.
if len(batch) < n_nodes:
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
return x, edge_index, edge_attr, batch
class TestGINEncoder:
"""GINEncoder shape and gradient tests."""
def test_output_shapes(self) -> None:
enc = GINEncoder(node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2)
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
assert node_emb.shape == (10, 64)
assert edge_emb.shape == (16, 64)
assert graph_emb.shape == (3, 64)
def test_hidden_dim_property(self) -> None:
enc = GINEncoder(hidden_dim=128)
assert enc.hidden_dim == 128
def test_default_dimensions(self) -> None:
enc = GINEncoder()
x, ei, ea, batch = _random_graph()
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
assert node_emb.shape[1] == 128
assert edge_emb.shape[1] == 128
assert graph_emb.shape[1] == 128
def test_no_batch_defaults_to_single_graph(self) -> None:
enc = GINEncoder(hidden_dim=64, num_layers=2)
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
assert graph_emb.shape == (1, 64)
def test_gradients_flow(self) -> None:
enc = GINEncoder(hidden_dim=32, num_layers=2)
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
x.requires_grad_(True)
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
loss = graph_emb.sum()
loss.backward()
assert x.grad is not None
assert x.grad.abs().sum() > 0
def test_zero_edges(self) -> None:
enc = GINEncoder(hidden_dim=32, num_layers=2)
x = torch.randn(4, 22)
ei = torch.zeros(2, 0, dtype=torch.long)
ea = torch.zeros(0, 22)
batch = torch.tensor([0, 0, 1, 1])
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
assert node_emb.shape == (4, 32)
assert edge_emb.shape == (0, 32)
assert graph_emb.shape == (2, 32)
def test_single_node(self) -> None:
enc = GINEncoder(hidden_dim=32, num_layers=2)
# Train with a small batch first to populate BN running stats.
x_train = torch.randn(4, 22)
ei_train = torch.zeros(2, 0, dtype=torch.long)
ea_train = torch.zeros(0, 22)
batch_train = torch.tensor([0, 0, 1, 1])
enc.train()
enc(x_train, ei_train, ea_train, batch_train)
# Now test single node in eval mode (BN uses running stats).
enc.eval()
x = torch.randn(1, 22)
ei = torch.zeros(2, 0, dtype=torch.long)
ea = torch.zeros(0, 22)
with torch.no_grad():
node_emb, edge_emb, graph_emb = enc(x, ei, ea)
assert node_emb.shape == (1, 32)
assert graph_emb.shape == (1, 32)
def test_eval_mode(self) -> None:
"""Encoder works in eval mode (BatchNorm uses running stats)."""
enc = GINEncoder(hidden_dim=32, num_layers=2)
# Forward pass in train mode to populate BN stats.
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
enc.train()
enc(x, ei, ea, batch)
# Switch to eval.
enc.eval()
with torch.no_grad():
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
assert node_emb.shape[1] == 32
class TestGATEncoder:
"""GATEncoder shape and gradient tests."""
def test_output_shapes(self) -> None:
enc = GATEncoder(
node_features_dim=22,
edge_features_dim=22,
hidden_dim=64,
num_layers=2,
num_heads=4,
)
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
assert node_emb.shape == (10, 64)
assert edge_emb.shape == (16, 64)
assert graph_emb.shape == (3, 64)
def test_hidden_dim_property(self) -> None:
enc = GATEncoder(hidden_dim=256, num_heads=8)
assert enc.hidden_dim == 256
def test_default_dimensions(self) -> None:
enc = GATEncoder()
x, ei, ea, batch = _random_graph()
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
assert node_emb.shape[1] == 256
assert edge_emb.shape[1] == 256
assert graph_emb.shape[1] == 256
def test_no_batch_defaults_to_single_graph(self) -> None:
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
assert graph_emb.shape == (1, 64)
def test_gradients_flow(self) -> None:
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
x.requires_grad_(True)
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
loss = graph_emb.sum()
loss.backward()
assert x.grad is not None
assert x.grad.abs().sum() > 0
def test_residual_connection(self) -> None:
"""With residual=True, output should differ from residual=False."""
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
torch.manual_seed(0)
enc_res = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=True)
torch.manual_seed(0)
enc_no = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=False)
with torch.no_grad():
n1, _, _ = enc_res(x, ei, ea, batch)
n2, _, _ = enc_no(x, ei, ea, batch)
# Outputs should generally differ (unless by very unlikely coincidence).
assert not torch.allclose(n1, n2, atol=1e-4)
def test_hidden_dim_must_divide_heads(self) -> None:
with pytest.raises(ValueError, match="divisible"):
GATEncoder(hidden_dim=100, num_heads=8)
def test_eval_mode(self) -> None:
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
enc.train()
enc(x, ei, ea, batch)
enc.eval()
with torch.no_grad():
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
assert node_emb.shape[1] == 64