Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
183 lines
5.9 KiB
Python
183 lines
5.9 KiB
Python
"""Tests for solver.models.heads -- task-specific prediction heads."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
from solver.models.heads import (
|
|
DOFRegressionHead,
|
|
DOFTrackingHead,
|
|
EdgeClassificationHead,
|
|
GraphClassificationHead,
|
|
JointTypeHead,
|
|
)
|
|
|
|
|
|
class TestEdgeClassificationHead:
|
|
"""EdgeClassificationHead produces correct shape and gradients."""
|
|
|
|
def test_output_shape(self) -> None:
|
|
head = EdgeClassificationHead(hidden_dim=128)
|
|
edge_emb = torch.randn(20, 128)
|
|
out = head(edge_emb)
|
|
assert out.shape == (20, 1)
|
|
|
|
def test_output_shape_small(self) -> None:
|
|
head = EdgeClassificationHead(hidden_dim=32)
|
|
edge_emb = torch.randn(5, 32)
|
|
out = head(edge_emb)
|
|
assert out.shape == (5, 1)
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
head = EdgeClassificationHead(hidden_dim=64)
|
|
edge_emb = torch.randn(10, 64, requires_grad=True)
|
|
out = head(edge_emb)
|
|
out.sum().backward()
|
|
assert edge_emb.grad is not None
|
|
assert edge_emb.grad.abs().sum() > 0
|
|
|
|
def test_zero_edges(self) -> None:
|
|
head = EdgeClassificationHead(hidden_dim=64)
|
|
edge_emb = torch.zeros(0, 64)
|
|
out = head(edge_emb)
|
|
assert out.shape == (0, 1)
|
|
|
|
def test_output_is_logits(self) -> None:
|
|
"""Output should be unbounded logits (not probabilities)."""
|
|
head = EdgeClassificationHead(hidden_dim=64)
|
|
torch.manual_seed(42)
|
|
edge_emb = torch.randn(100, 64)
|
|
out = head(edge_emb)
|
|
# Logits can be negative.
|
|
assert out.min().item() < 0 or out.max().item() > 1
|
|
|
|
|
|
class TestGraphClassificationHead:
|
|
"""GraphClassificationHead produces correct shape and gradients."""
|
|
|
|
def test_output_shape(self) -> None:
|
|
head = GraphClassificationHead(hidden_dim=128, num_classes=4)
|
|
graph_emb = torch.randn(3, 128)
|
|
out = head(graph_emb)
|
|
assert out.shape == (3, 4)
|
|
|
|
def test_custom_num_classes(self) -> None:
|
|
head = GraphClassificationHead(hidden_dim=64, num_classes=8)
|
|
graph_emb = torch.randn(2, 64)
|
|
out = head(graph_emb)
|
|
assert out.shape == (2, 8)
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
head = GraphClassificationHead(hidden_dim=64)
|
|
graph_emb = torch.randn(2, 64, requires_grad=True)
|
|
out = head(graph_emb)
|
|
out.sum().backward()
|
|
assert graph_emb.grad is not None
|
|
|
|
def test_single_graph(self) -> None:
|
|
head = GraphClassificationHead(hidden_dim=128)
|
|
graph_emb = torch.randn(1, 128)
|
|
out = head(graph_emb)
|
|
assert out.shape == (1, 4)
|
|
|
|
|
|
class TestJointTypeHead:
|
|
"""JointTypeHead produces correct shape and gradients."""
|
|
|
|
def test_output_shape(self) -> None:
|
|
head = JointTypeHead(hidden_dim=128, num_classes=11)
|
|
edge_emb = torch.randn(20, 128)
|
|
out = head(edge_emb)
|
|
assert out.shape == (20, 11)
|
|
|
|
def test_custom_classes(self) -> None:
|
|
head = JointTypeHead(hidden_dim=64, num_classes=7)
|
|
edge_emb = torch.randn(10, 64)
|
|
out = head(edge_emb)
|
|
assert out.shape == (10, 7)
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
head = JointTypeHead(hidden_dim=64)
|
|
edge_emb = torch.randn(10, 64, requires_grad=True)
|
|
out = head(edge_emb)
|
|
out.sum().backward()
|
|
assert edge_emb.grad is not None
|
|
|
|
def test_zero_edges(self) -> None:
|
|
head = JointTypeHead(hidden_dim=64)
|
|
edge_emb = torch.zeros(0, 64)
|
|
out = head(edge_emb)
|
|
assert out.shape == (0, 11)
|
|
|
|
|
|
class TestDOFRegressionHead:
|
|
"""DOFRegressionHead produces correct shape and non-negative output."""
|
|
|
|
def test_output_shape(self) -> None:
|
|
head = DOFRegressionHead(hidden_dim=128)
|
|
graph_emb = torch.randn(3, 128)
|
|
out = head(graph_emb)
|
|
assert out.shape == (3, 1)
|
|
|
|
def test_output_non_negative(self) -> None:
|
|
"""Softplus ensures non-negative output."""
|
|
head = DOFRegressionHead(hidden_dim=64)
|
|
torch.manual_seed(0)
|
|
graph_emb = torch.randn(50, 64)
|
|
out = head(graph_emb)
|
|
assert (out >= 0).all()
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
head = DOFRegressionHead(hidden_dim=64)
|
|
graph_emb = torch.randn(2, 64, requires_grad=True)
|
|
out = head(graph_emb)
|
|
out.sum().backward()
|
|
assert graph_emb.grad is not None
|
|
|
|
def test_single_graph(self) -> None:
|
|
head = DOFRegressionHead(hidden_dim=32)
|
|
graph_emb = torch.randn(1, 32)
|
|
out = head(graph_emb)
|
|
assert out.shape == (1, 1)
|
|
assert out.item() >= 0
|
|
|
|
|
|
class TestDOFTrackingHead:
|
|
"""DOFTrackingHead produces correct shape and non-negative output."""
|
|
|
|
def test_output_shape(self) -> None:
|
|
head = DOFTrackingHead(hidden_dim=128)
|
|
node_emb = torch.randn(10, 128)
|
|
out = head(node_emb)
|
|
assert out.shape == (10, 2)
|
|
|
|
def test_output_non_negative(self) -> None:
|
|
"""Softplus ensures non-negative output."""
|
|
head = DOFTrackingHead(hidden_dim=64)
|
|
torch.manual_seed(0)
|
|
node_emb = torch.randn(50, 64)
|
|
out = head(node_emb)
|
|
assert (out >= 0).all()
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
head = DOFTrackingHead(hidden_dim=64)
|
|
node_emb = torch.randn(10, 64, requires_grad=True)
|
|
out = head(node_emb)
|
|
out.sum().backward()
|
|
assert node_emb.grad is not None
|
|
|
|
def test_single_node(self) -> None:
|
|
head = DOFTrackingHead(hidden_dim=32)
|
|
node_emb = torch.randn(1, 32)
|
|
out = head(node_emb)
|
|
assert out.shape == (1, 2)
|
|
assert (out >= 0).all()
|
|
|
|
def test_two_columns_independent(self) -> None:
|
|
"""Translational and rotational DOF are independently predicted."""
|
|
head = DOFTrackingHead(hidden_dim=64)
|
|
node_emb = torch.randn(20, 64)
|
|
out = head(node_emb)
|
|
# The two columns should generally differ.
|
|
assert not torch.allclose(out[:, 0], out[:, 1], atol=1e-6)
|