Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
189 lines
6.6 KiB
Python
189 lines
6.6 KiB
Python
"""Tests for solver.models.assembly_gnn -- main model wiring."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from solver.models.assembly_gnn import AssemblyGNN
|
|
|
|
|
|
def _default_heads_config(dof_tracking: bool = False) -> dict:
|
|
return {
|
|
"edge_classification": {"enabled": True, "hidden_dim": 64},
|
|
"graph_classification": {"enabled": True, "num_classes": 4},
|
|
"joint_type": {"enabled": True, "num_classes": 11},
|
|
"dof_regression": {"enabled": True},
|
|
"dof_tracking": {"enabled": dof_tracking},
|
|
}
|
|
|
|
|
|
def _random_graph(
|
|
n_nodes: int = 8,
|
|
n_edges: int = 16,
|
|
batch_size: int = 2,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
x = torch.randn(n_nodes, 22)
|
|
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
|
edge_attr = torch.randn(n_edges, 22)
|
|
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
|
|
if len(batch) < n_nodes:
|
|
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
|
|
return x, edge_index, edge_attr, batch
|
|
|
|
|
|
class TestAssemblyGNNGIN:
|
|
"""AssemblyGNN with GIN encoder."""
|
|
|
|
def test_forward_all_heads(self) -> None:
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
|
heads_config=_default_heads_config(),
|
|
)
|
|
x, ei, ea, batch = _random_graph()
|
|
preds = model(x, ei, ea, batch)
|
|
assert "edge_pred" in preds
|
|
assert "graph_pred" in preds
|
|
assert "joint_type_pred" in preds
|
|
assert "dof_pred" in preds
|
|
|
|
def test_output_shapes(self) -> None:
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
|
heads_config=_default_heads_config(),
|
|
)
|
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20, batch_size=3)
|
|
preds = model(x, ei, ea, batch)
|
|
assert preds["edge_pred"].shape == (20, 1)
|
|
assert preds["graph_pred"].shape == (3, 4)
|
|
assert preds["joint_type_pred"].shape == (20, 11)
|
|
assert preds["dof_pred"].shape == (3, 1)
|
|
|
|
def test_gradients_flow(self) -> None:
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
|
heads_config=_default_heads_config(),
|
|
)
|
|
x, ei, ea, batch = _random_graph()
|
|
x.requires_grad_(True)
|
|
preds = model(x, ei, ea, batch)
|
|
total = sum(p.sum() for p in preds.values())
|
|
total.backward()
|
|
assert x.grad is not None
|
|
|
|
def test_no_heads_returns_empty(self) -> None:
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
|
heads_config={},
|
|
)
|
|
x, ei, ea, batch = _random_graph()
|
|
preds = model(x, ei, ea, batch)
|
|
assert len(preds) == 0
|
|
|
|
|
|
class TestAssemblyGNNGAT:
|
|
"""AssemblyGNN with GAT encoder."""
|
|
|
|
def test_forward_all_heads(self) -> None:
|
|
model = AssemblyGNN(
|
|
encoder_type="gat",
|
|
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
|
|
heads_config=_default_heads_config(dof_tracking=True),
|
|
)
|
|
x, ei, ea, batch = _random_graph()
|
|
preds = model(x, ei, ea, batch)
|
|
assert "edge_pred" in preds
|
|
assert "graph_pred" in preds
|
|
assert "joint_type_pred" in preds
|
|
assert "dof_pred" in preds
|
|
assert "body_dof_pred" in preds
|
|
|
|
def test_body_dof_shape(self) -> None:
|
|
model = AssemblyGNN(
|
|
encoder_type="gat",
|
|
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
|
|
heads_config=_default_heads_config(dof_tracking=True),
|
|
)
|
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20)
|
|
preds = model(x, ei, ea, batch)
|
|
assert preds["body_dof_pred"].shape == (10, 2)
|
|
|
|
|
|
class TestAssemblyGNNEdgeCases:
|
|
"""Edge cases and error handling."""
|
|
|
|
def test_unknown_encoder_raises(self) -> None:
|
|
with pytest.raises(ValueError, match="Unknown encoder"):
|
|
AssemblyGNN(encoder_type="transformer")
|
|
|
|
def test_selective_heads(self) -> None:
|
|
"""Only enabled heads produce output."""
|
|
config = {
|
|
"edge_classification": {"enabled": True},
|
|
"graph_classification": {"enabled": False},
|
|
"joint_type": {"enabled": True, "num_classes": 11},
|
|
}
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
|
heads_config=config,
|
|
)
|
|
x, ei, ea, batch = _random_graph()
|
|
preds = model(x, ei, ea, batch)
|
|
assert "edge_pred" in preds
|
|
assert "joint_type_pred" in preds
|
|
assert "graph_pred" not in preds
|
|
assert "dof_pred" not in preds
|
|
|
|
def test_no_batch_single_graph(self) -> None:
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
|
heads_config=_default_heads_config(),
|
|
)
|
|
x = torch.randn(6, 22)
|
|
ei = torch.randint(0, 6, (2, 10))
|
|
ea = torch.randn(10, 22)
|
|
preds = model(x, ei, ea)
|
|
assert preds["graph_pred"].shape == (1, 4)
|
|
assert preds["dof_pred"].shape == (1, 1)
|
|
|
|
def test_parameter_count_reasonable(self) -> None:
|
|
"""Sanity check that model has learnable parameters."""
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
|
heads_config=_default_heads_config(),
|
|
)
|
|
n_params = sum(p.numel() for p in model.parameters())
|
|
assert n_params > 1000 # non-trivial model
|
|
|
|
|
|
class TestAssemblyGNNEndToEnd:
|
|
"""End-to-end test with datagen pipeline."""
|
|
|
|
def test_datagen_to_model(self) -> None:
|
|
from solver.datagen.generator import SyntheticAssemblyGenerator
|
|
from solver.models.graph_conv import assembly_to_pyg
|
|
|
|
gen = SyntheticAssemblyGenerator(seed=42)
|
|
batch = gen.generate_training_batch(batch_size=2, complexity_tier="simple")
|
|
|
|
model = AssemblyGNN(
|
|
encoder_type="gin",
|
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
|
heads_config=_default_heads_config(),
|
|
)
|
|
model.eval()
|
|
|
|
for ex in batch:
|
|
data = assembly_to_pyg(ex)
|
|
with torch.no_grad():
|
|
preds = model(data.x, data.edge_index, data.edge_attr)
|
|
assert "edge_pred" in preds
|
|
assert preds["edge_pred"].shape[0] == data.edge_index.shape[1]
|