Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
184 lines
7.0 KiB
Python
184 lines
7.0 KiB
Python
"""Tests for solver.models.encoder -- GIN and GAT graph encoders."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from solver.models.encoder import GATEncoder, GINEncoder
|
|
|
|
|
|
def _random_graph(
|
|
n_nodes: int = 8,
|
|
n_edges: int = 20,
|
|
node_dim: int = 22,
|
|
edge_dim: int = 22,
|
|
batch_size: int = 2,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""Create a random graph for testing."""
|
|
x = torch.randn(n_nodes, node_dim)
|
|
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
|
edge_attr = torch.randn(n_edges, edge_dim)
|
|
# Assign nodes to batches roughly evenly.
|
|
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
|
|
# Handle remainder nodes.
|
|
if len(batch) < n_nodes:
|
|
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
|
|
return x, edge_index, edge_attr, batch
|
|
|
|
|
|
class TestGINEncoder:
|
|
"""GINEncoder shape and gradient tests."""
|
|
|
|
def test_output_shapes(self) -> None:
|
|
enc = GINEncoder(node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2)
|
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
assert node_emb.shape == (10, 64)
|
|
assert edge_emb.shape == (16, 64)
|
|
assert graph_emb.shape == (3, 64)
|
|
|
|
def test_hidden_dim_property(self) -> None:
|
|
enc = GINEncoder(hidden_dim=128)
|
|
assert enc.hidden_dim == 128
|
|
|
|
def test_default_dimensions(self) -> None:
|
|
enc = GINEncoder()
|
|
x, ei, ea, batch = _random_graph()
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
assert node_emb.shape[1] == 128
|
|
assert edge_emb.shape[1] == 128
|
|
assert graph_emb.shape[1] == 128
|
|
|
|
def test_no_batch_defaults_to_single_graph(self) -> None:
|
|
enc = GINEncoder(hidden_dim=64, num_layers=2)
|
|
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
|
assert graph_emb.shape == (1, 64)
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
|
x.requires_grad_(True)
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
loss = graph_emb.sum()
|
|
loss.backward()
|
|
assert x.grad is not None
|
|
assert x.grad.abs().sum() > 0
|
|
|
|
def test_zero_edges(self) -> None:
|
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
|
x = torch.randn(4, 22)
|
|
ei = torch.zeros(2, 0, dtype=torch.long)
|
|
ea = torch.zeros(0, 22)
|
|
batch = torch.tensor([0, 0, 1, 1])
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
assert node_emb.shape == (4, 32)
|
|
assert edge_emb.shape == (0, 32)
|
|
assert graph_emb.shape == (2, 32)
|
|
|
|
def test_single_node(self) -> None:
|
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
|
# Train with a small batch first to populate BN running stats.
|
|
x_train = torch.randn(4, 22)
|
|
ei_train = torch.zeros(2, 0, dtype=torch.long)
|
|
ea_train = torch.zeros(0, 22)
|
|
batch_train = torch.tensor([0, 0, 1, 1])
|
|
enc.train()
|
|
enc(x_train, ei_train, ea_train, batch_train)
|
|
# Now test single node in eval mode (BN uses running stats).
|
|
enc.eval()
|
|
x = torch.randn(1, 22)
|
|
ei = torch.zeros(2, 0, dtype=torch.long)
|
|
ea = torch.zeros(0, 22)
|
|
with torch.no_grad():
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea)
|
|
assert node_emb.shape == (1, 32)
|
|
assert graph_emb.shape == (1, 32)
|
|
|
|
def test_eval_mode(self) -> None:
|
|
"""Encoder works in eval mode (BatchNorm uses running stats)."""
|
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
|
# Forward pass in train mode to populate BN stats.
|
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
|
enc.train()
|
|
enc(x, ei, ea, batch)
|
|
# Switch to eval.
|
|
enc.eval()
|
|
with torch.no_grad():
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
assert node_emb.shape[1] == 32
|
|
|
|
|
|
class TestGATEncoder:
|
|
"""GATEncoder shape and gradient tests."""
|
|
|
|
def test_output_shapes(self) -> None:
|
|
enc = GATEncoder(
|
|
node_features_dim=22,
|
|
edge_features_dim=22,
|
|
hidden_dim=64,
|
|
num_layers=2,
|
|
num_heads=4,
|
|
)
|
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
assert node_emb.shape == (10, 64)
|
|
assert edge_emb.shape == (16, 64)
|
|
assert graph_emb.shape == (3, 64)
|
|
|
|
def test_hidden_dim_property(self) -> None:
|
|
enc = GATEncoder(hidden_dim=256, num_heads=8)
|
|
assert enc.hidden_dim == 256
|
|
|
|
def test_default_dimensions(self) -> None:
|
|
enc = GATEncoder()
|
|
x, ei, ea, batch = _random_graph()
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
assert node_emb.shape[1] == 256
|
|
assert edge_emb.shape[1] == 256
|
|
assert graph_emb.shape[1] == 256
|
|
|
|
def test_no_batch_defaults_to_single_graph(self) -> None:
|
|
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
|
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
|
assert graph_emb.shape == (1, 64)
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
|
x.requires_grad_(True)
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
loss = graph_emb.sum()
|
|
loss.backward()
|
|
assert x.grad is not None
|
|
assert x.grad.abs().sum() > 0
|
|
|
|
def test_residual_connection(self) -> None:
|
|
"""With residual=True, output should differ from residual=False."""
|
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
|
torch.manual_seed(0)
|
|
enc_res = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=True)
|
|
torch.manual_seed(0)
|
|
enc_no = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=False)
|
|
with torch.no_grad():
|
|
n1, _, _ = enc_res(x, ei, ea, batch)
|
|
n2, _, _ = enc_no(x, ei, ea, batch)
|
|
# Outputs should generally differ (unless by very unlikely coincidence).
|
|
assert not torch.allclose(n1, n2, atol=1e-4)
|
|
|
|
def test_hidden_dim_must_divide_heads(self) -> None:
|
|
with pytest.raises(ValueError, match="divisible"):
|
|
GATEncoder(hidden_dim=100, num_heads=8)
|
|
|
|
def test_eval_mode(self) -> None:
|
|
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
|
enc.train()
|
|
enc(x, ei, ea, batch)
|
|
enc.eval()
|
|
with torch.no_grad():
|
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
|
assert node_emb.shape[1] == 64
|