feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
This commit is contained in:
183
tests/models/test_encoder.py
Normal file
183
tests/models/test_encoder.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Tests for solver.models.encoder -- GIN and GAT graph encoders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from solver.models.encoder import GATEncoder, GINEncoder
|
||||
|
||||
|
||||
def _random_graph(
|
||||
n_nodes: int = 8,
|
||||
n_edges: int = 20,
|
||||
node_dim: int = 22,
|
||||
edge_dim: int = 22,
|
||||
batch_size: int = 2,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Create a random graph for testing."""
|
||||
x = torch.randn(n_nodes, node_dim)
|
||||
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
||||
edge_attr = torch.randn(n_edges, edge_dim)
|
||||
# Assign nodes to batches roughly evenly.
|
||||
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
|
||||
# Handle remainder nodes.
|
||||
if len(batch) < n_nodes:
|
||||
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
|
||||
return x, edge_index, edge_attr, batch
|
||||
|
||||
|
||||
class TestGINEncoder:
|
||||
"""GINEncoder shape and gradient tests."""
|
||||
|
||||
def test_output_shapes(self) -> None:
|
||||
enc = GINEncoder(node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape == (10, 64)
|
||||
assert edge_emb.shape == (16, 64)
|
||||
assert graph_emb.shape == (3, 64)
|
||||
|
||||
def test_hidden_dim_property(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=128)
|
||||
assert enc.hidden_dim == 128
|
||||
|
||||
def test_default_dimensions(self) -> None:
|
||||
enc = GINEncoder()
|
||||
x, ei, ea, batch = _random_graph()
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 128
|
||||
assert edge_emb.shape[1] == 128
|
||||
assert graph_emb.shape[1] == 128
|
||||
|
||||
def test_no_batch_defaults_to_single_graph(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=64, num_layers=2)
|
||||
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
||||
assert graph_emb.shape == (1, 64)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
x.requires_grad_(True)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
loss = graph_emb.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.abs().sum() > 0
|
||||
|
||||
def test_zero_edges(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
x = torch.randn(4, 22)
|
||||
ei = torch.zeros(2, 0, dtype=torch.long)
|
||||
ea = torch.zeros(0, 22)
|
||||
batch = torch.tensor([0, 0, 1, 1])
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape == (4, 32)
|
||||
assert edge_emb.shape == (0, 32)
|
||||
assert graph_emb.shape == (2, 32)
|
||||
|
||||
def test_single_node(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
# Train with a small batch first to populate BN running stats.
|
||||
x_train = torch.randn(4, 22)
|
||||
ei_train = torch.zeros(2, 0, dtype=torch.long)
|
||||
ea_train = torch.zeros(0, 22)
|
||||
batch_train = torch.tensor([0, 0, 1, 1])
|
||||
enc.train()
|
||||
enc(x_train, ei_train, ea_train, batch_train)
|
||||
# Now test single node in eval mode (BN uses running stats).
|
||||
enc.eval()
|
||||
x = torch.randn(1, 22)
|
||||
ei = torch.zeros(2, 0, dtype=torch.long)
|
||||
ea = torch.zeros(0, 22)
|
||||
with torch.no_grad():
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea)
|
||||
assert node_emb.shape == (1, 32)
|
||||
assert graph_emb.shape == (1, 32)
|
||||
|
||||
def test_eval_mode(self) -> None:
|
||||
"""Encoder works in eval mode (BatchNorm uses running stats)."""
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
# Forward pass in train mode to populate BN stats.
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
enc.train()
|
||||
enc(x, ei, ea, batch)
|
||||
# Switch to eval.
|
||||
enc.eval()
|
||||
with torch.no_grad():
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 32
|
||||
|
||||
|
||||
class TestGATEncoder:
|
||||
"""GATEncoder shape and gradient tests."""
|
||||
|
||||
def test_output_shapes(self) -> None:
|
||||
enc = GATEncoder(
|
||||
node_features_dim=22,
|
||||
edge_features_dim=22,
|
||||
hidden_dim=64,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape == (10, 64)
|
||||
assert edge_emb.shape == (16, 64)
|
||||
assert graph_emb.shape == (3, 64)
|
||||
|
||||
def test_hidden_dim_property(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=256, num_heads=8)
|
||||
assert enc.hidden_dim == 256
|
||||
|
||||
def test_default_dimensions(self) -> None:
|
||||
enc = GATEncoder()
|
||||
x, ei, ea, batch = _random_graph()
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 256
|
||||
assert edge_emb.shape[1] == 256
|
||||
assert graph_emb.shape[1] == 256
|
||||
|
||||
def test_no_batch_defaults_to_single_graph(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
||||
assert graph_emb.shape == (1, 64)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
x.requires_grad_(True)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
loss = graph_emb.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.abs().sum() > 0
|
||||
|
||||
def test_residual_connection(self) -> None:
|
||||
"""With residual=True, output should differ from residual=False."""
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
torch.manual_seed(0)
|
||||
enc_res = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=True)
|
||||
torch.manual_seed(0)
|
||||
enc_no = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=False)
|
||||
with torch.no_grad():
|
||||
n1, _, _ = enc_res(x, ei, ea, batch)
|
||||
n2, _, _ = enc_no(x, ei, ea, batch)
|
||||
# Outputs should generally differ (unless by very unlikely coincidence).
|
||||
assert not torch.allclose(n1, n2, atol=1e-4)
|
||||
|
||||
def test_hidden_dim_must_divide_heads(self) -> None:
|
||||
with pytest.raises(ValueError, match="divisible"):
|
||||
GATEncoder(hidden_dim=100, num_heads=8)
|
||||
|
||||
def test_eval_mode(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
enc.train()
|
||||
enc(x, ei, ea, batch)
|
||||
enc.eval()
|
||||
with torch.no_grad():
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 64
|
||||
Reference in New Issue
Block a user