Files
solver/tests/models/test_factory.py
forbes fe41fa3b00 feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific
prediction heads, uncertainty-weighted multi-task loss (Kendall et al.
2018), and config-driven factory functions.

New modules:
- graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features)
- encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head)
- heads.py: edge classification, graph classification, joint type,
  DOF regression, per-body DOF tracking heads
- assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads
- losses.py: MultiTaskLoss with learnable log-variance per task
- factory.py: build_model() and build_loss() from YAML configs

Supporting changes:
- generator.py: serialize anchor_a, anchor_b, pitch in joint dicts
- configs: fix joint_type num_classes 12 -> 11 (matches JointType enum)

92 tests covering shapes, gradients, edge cases, and end-to-end
datagen-to-model pipeline.
2026-02-07 10:14:19 -06:00

111 lines
3.7 KiB
Python

"""Tests for solver.models.factory -- model and loss construction from config."""
from __future__ import annotations
import yaml
from solver.models.assembly_gnn import AssemblyGNN
from solver.models.factory import build_loss, build_model
from solver.models.losses import MultiTaskLoss
def _load_yaml(path: str) -> dict:
with open(path) as f:
return yaml.safe_load(f)
class TestBuildModel:
"""build_model constructs AssemblyGNN from config."""
def test_baseline_config(self) -> None:
config = _load_yaml("configs/model/baseline.yaml")
model = build_model(config)
assert isinstance(model, AssemblyGNN)
assert model.encoder.hidden_dim == 128
def test_gat_config(self) -> None:
config = _load_yaml("configs/model/gat.yaml")
model = build_model(config)
assert isinstance(model, AssemblyGNN)
assert model.encoder.hidden_dim == 256
def test_baseline_heads_present(self) -> None:
config = _load_yaml("configs/model/baseline.yaml")
model = build_model(config)
assert "edge_pred" in model.heads
assert "graph_pred" in model.heads
assert "joint_type_pred" in model.heads
assert "dof_pred" in model.heads
def test_gat_has_dof_tracking(self) -> None:
config = _load_yaml("configs/model/gat.yaml")
model = build_model(config)
assert "body_dof_pred" in model.heads
def test_baseline_no_dof_tracking(self) -> None:
config = _load_yaml("configs/model/baseline.yaml")
model = build_model(config)
assert "body_dof_pred" not in model.heads
def test_minimal_config(self) -> None:
config = {"architecture": "gin"}
model = build_model(config)
assert isinstance(model, AssemblyGNN)
# No heads enabled.
assert len(model.heads) == 0
def test_custom_config(self) -> None:
config = {
"architecture": "gin",
"encoder": {"hidden_dim": 64, "num_layers": 2},
"heads": {
"edge_classification": {"enabled": True},
"graph_classification": {"enabled": True, "num_classes": 4},
},
}
model = build_model(config)
assert model.encoder.hidden_dim == 64
assert "edge_pred" in model.heads
assert "graph_pred" in model.heads
assert "joint_type_pred" not in model.heads
class TestBuildLoss:
"""build_loss constructs MultiTaskLoss from training config."""
def test_pretrain_config(self) -> None:
config = _load_yaml("configs/training/pretrain.yaml")
loss_fn = build_loss(config)
assert isinstance(loss_fn, MultiTaskLoss)
def test_weights_from_config(self) -> None:
config = _load_yaml("configs/training/pretrain.yaml")
loss_fn = build_loss(config)
assert loss_fn.weights["edge"] == 1.0
assert loss_fn.weights["graph"] == 0.5
assert loss_fn.weights["joint_type"] == 0.3
assert loss_fn.weights["dof"] == 0.2
def test_redundant_penalty_from_config(self) -> None:
config = _load_yaml("configs/training/pretrain.yaml")
loss_fn = build_loss(config)
assert loss_fn.redundant_penalty == 2.0
def test_empty_config_uses_defaults(self) -> None:
loss_fn = build_loss({})
assert isinstance(loss_fn, MultiTaskLoss)
assert loss_fn.weights["edge"] == 1.0
def test_custom_weights(self) -> None:
config = {
"loss": {
"edge_weight": 2.0,
"graph_weight": 1.0,
"redundant_penalty": 5.0,
},
}
loss_fn = build_loss(config)
assert loss_fn.weights["edge"] == 2.0
assert loss_fn.weights["graph"] == 1.0
assert loss_fn.redundant_penalty == 5.0