Files
solver/tests/models/test_heads.py
forbes fe41fa3b00 feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific
prediction heads, uncertainty-weighted multi-task loss (Kendall et al.
2018), and config-driven factory functions.

New modules:
- graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features)
- encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head)
- heads.py: edge classification, graph classification, joint type,
  DOF regression, per-body DOF tracking heads
- assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads
- losses.py: MultiTaskLoss with learnable log-variance per task
- factory.py: build_model() and build_loss() from YAML configs

Supporting changes:
- generator.py: serialize anchor_a, anchor_b, pitch in joint dicts
- configs: fix joint_type num_classes 12 -> 11 (matches JointType enum)

92 tests covering shapes, gradients, edge cases, and end-to-end
datagen-to-model pipeline.
2026-02-07 10:14:19 -06:00

183 lines
5.9 KiB
Python

"""Tests for solver.models.heads -- task-specific prediction heads."""
from __future__ import annotations
import torch
from solver.models.heads import (
DOFRegressionHead,
DOFTrackingHead,
EdgeClassificationHead,
GraphClassificationHead,
JointTypeHead,
)
class TestEdgeClassificationHead:
"""EdgeClassificationHead produces correct shape and gradients."""
def test_output_shape(self) -> None:
head = EdgeClassificationHead(hidden_dim=128)
edge_emb = torch.randn(20, 128)
out = head(edge_emb)
assert out.shape == (20, 1)
def test_output_shape_small(self) -> None:
head = EdgeClassificationHead(hidden_dim=32)
edge_emb = torch.randn(5, 32)
out = head(edge_emb)
assert out.shape == (5, 1)
def test_gradients_flow(self) -> None:
head = EdgeClassificationHead(hidden_dim=64)
edge_emb = torch.randn(10, 64, requires_grad=True)
out = head(edge_emb)
out.sum().backward()
assert edge_emb.grad is not None
assert edge_emb.grad.abs().sum() > 0
def test_zero_edges(self) -> None:
head = EdgeClassificationHead(hidden_dim=64)
edge_emb = torch.zeros(0, 64)
out = head(edge_emb)
assert out.shape == (0, 1)
def test_output_is_logits(self) -> None:
"""Output should be unbounded logits (not probabilities)."""
head = EdgeClassificationHead(hidden_dim=64)
torch.manual_seed(42)
edge_emb = torch.randn(100, 64)
out = head(edge_emb)
# Logits can be negative.
assert out.min().item() < 0 or out.max().item() > 1
class TestGraphClassificationHead:
"""GraphClassificationHead produces correct shape and gradients."""
def test_output_shape(self) -> None:
head = GraphClassificationHead(hidden_dim=128, num_classes=4)
graph_emb = torch.randn(3, 128)
out = head(graph_emb)
assert out.shape == (3, 4)
def test_custom_num_classes(self) -> None:
head = GraphClassificationHead(hidden_dim=64, num_classes=8)
graph_emb = torch.randn(2, 64)
out = head(graph_emb)
assert out.shape == (2, 8)
def test_gradients_flow(self) -> None:
head = GraphClassificationHead(hidden_dim=64)
graph_emb = torch.randn(2, 64, requires_grad=True)
out = head(graph_emb)
out.sum().backward()
assert graph_emb.grad is not None
def test_single_graph(self) -> None:
head = GraphClassificationHead(hidden_dim=128)
graph_emb = torch.randn(1, 128)
out = head(graph_emb)
assert out.shape == (1, 4)
class TestJointTypeHead:
"""JointTypeHead produces correct shape and gradients."""
def test_output_shape(self) -> None:
head = JointTypeHead(hidden_dim=128, num_classes=11)
edge_emb = torch.randn(20, 128)
out = head(edge_emb)
assert out.shape == (20, 11)
def test_custom_classes(self) -> None:
head = JointTypeHead(hidden_dim=64, num_classes=7)
edge_emb = torch.randn(10, 64)
out = head(edge_emb)
assert out.shape == (10, 7)
def test_gradients_flow(self) -> None:
head = JointTypeHead(hidden_dim=64)
edge_emb = torch.randn(10, 64, requires_grad=True)
out = head(edge_emb)
out.sum().backward()
assert edge_emb.grad is not None
def test_zero_edges(self) -> None:
head = JointTypeHead(hidden_dim=64)
edge_emb = torch.zeros(0, 64)
out = head(edge_emb)
assert out.shape == (0, 11)
class TestDOFRegressionHead:
"""DOFRegressionHead produces correct shape and non-negative output."""
def test_output_shape(self) -> None:
head = DOFRegressionHead(hidden_dim=128)
graph_emb = torch.randn(3, 128)
out = head(graph_emb)
assert out.shape == (3, 1)
def test_output_non_negative(self) -> None:
"""Softplus ensures non-negative output."""
head = DOFRegressionHead(hidden_dim=64)
torch.manual_seed(0)
graph_emb = torch.randn(50, 64)
out = head(graph_emb)
assert (out >= 0).all()
def test_gradients_flow(self) -> None:
head = DOFRegressionHead(hidden_dim=64)
graph_emb = torch.randn(2, 64, requires_grad=True)
out = head(graph_emb)
out.sum().backward()
assert graph_emb.grad is not None
def test_single_graph(self) -> None:
head = DOFRegressionHead(hidden_dim=32)
graph_emb = torch.randn(1, 32)
out = head(graph_emb)
assert out.shape == (1, 1)
assert out.item() >= 0
class TestDOFTrackingHead:
"""DOFTrackingHead produces correct shape and non-negative output."""
def test_output_shape(self) -> None:
head = DOFTrackingHead(hidden_dim=128)
node_emb = torch.randn(10, 128)
out = head(node_emb)
assert out.shape == (10, 2)
def test_output_non_negative(self) -> None:
"""Softplus ensures non-negative output."""
head = DOFTrackingHead(hidden_dim=64)
torch.manual_seed(0)
node_emb = torch.randn(50, 64)
out = head(node_emb)
assert (out >= 0).all()
def test_gradients_flow(self) -> None:
head = DOFTrackingHead(hidden_dim=64)
node_emb = torch.randn(10, 64, requires_grad=True)
out = head(node_emb)
out.sum().backward()
assert node_emb.grad is not None
def test_single_node(self) -> None:
head = DOFTrackingHead(hidden_dim=32)
node_emb = torch.randn(1, 32)
out = head(node_emb)
assert out.shape == (1, 2)
assert (out >= 0).all()
def test_two_columns_independent(self) -> None:
"""Translational and rotational DOF are independently predicted."""
head = DOFTrackingHead(hidden_dim=64)
node_emb = torch.randn(20, 64)
out = head(node_emb)
# The two columns should generally differ.
assert not torch.allclose(out[:, 0], out[:, 1], atol=1e-6)