Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
111 lines
3.7 KiB
Python
111 lines
3.7 KiB
Python
"""Tests for solver.models.factory -- model and loss construction from config."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import yaml
|
|
|
|
from solver.models.assembly_gnn import AssemblyGNN
|
|
from solver.models.factory import build_loss, build_model
|
|
from solver.models.losses import MultiTaskLoss
|
|
|
|
|
|
def _load_yaml(path: str) -> dict:
|
|
with open(path) as f:
|
|
return yaml.safe_load(f)
|
|
|
|
|
|
class TestBuildModel:
|
|
"""build_model constructs AssemblyGNN from config."""
|
|
|
|
def test_baseline_config(self) -> None:
|
|
config = _load_yaml("configs/model/baseline.yaml")
|
|
model = build_model(config)
|
|
assert isinstance(model, AssemblyGNN)
|
|
assert model.encoder.hidden_dim == 128
|
|
|
|
def test_gat_config(self) -> None:
|
|
config = _load_yaml("configs/model/gat.yaml")
|
|
model = build_model(config)
|
|
assert isinstance(model, AssemblyGNN)
|
|
assert model.encoder.hidden_dim == 256
|
|
|
|
def test_baseline_heads_present(self) -> None:
|
|
config = _load_yaml("configs/model/baseline.yaml")
|
|
model = build_model(config)
|
|
assert "edge_pred" in model.heads
|
|
assert "graph_pred" in model.heads
|
|
assert "joint_type_pred" in model.heads
|
|
assert "dof_pred" in model.heads
|
|
|
|
def test_gat_has_dof_tracking(self) -> None:
|
|
config = _load_yaml("configs/model/gat.yaml")
|
|
model = build_model(config)
|
|
assert "body_dof_pred" in model.heads
|
|
|
|
def test_baseline_no_dof_tracking(self) -> None:
|
|
config = _load_yaml("configs/model/baseline.yaml")
|
|
model = build_model(config)
|
|
assert "body_dof_pred" not in model.heads
|
|
|
|
def test_minimal_config(self) -> None:
|
|
config = {"architecture": "gin"}
|
|
model = build_model(config)
|
|
assert isinstance(model, AssemblyGNN)
|
|
# No heads enabled.
|
|
assert len(model.heads) == 0
|
|
|
|
def test_custom_config(self) -> None:
|
|
config = {
|
|
"architecture": "gin",
|
|
"encoder": {"hidden_dim": 64, "num_layers": 2},
|
|
"heads": {
|
|
"edge_classification": {"enabled": True},
|
|
"graph_classification": {"enabled": True, "num_classes": 4},
|
|
},
|
|
}
|
|
model = build_model(config)
|
|
assert model.encoder.hidden_dim == 64
|
|
assert "edge_pred" in model.heads
|
|
assert "graph_pred" in model.heads
|
|
assert "joint_type_pred" not in model.heads
|
|
|
|
|
|
class TestBuildLoss:
|
|
"""build_loss constructs MultiTaskLoss from training config."""
|
|
|
|
def test_pretrain_config(self) -> None:
|
|
config = _load_yaml("configs/training/pretrain.yaml")
|
|
loss_fn = build_loss(config)
|
|
assert isinstance(loss_fn, MultiTaskLoss)
|
|
|
|
def test_weights_from_config(self) -> None:
|
|
config = _load_yaml("configs/training/pretrain.yaml")
|
|
loss_fn = build_loss(config)
|
|
assert loss_fn.weights["edge"] == 1.0
|
|
assert loss_fn.weights["graph"] == 0.5
|
|
assert loss_fn.weights["joint_type"] == 0.3
|
|
assert loss_fn.weights["dof"] == 0.2
|
|
|
|
def test_redundant_penalty_from_config(self) -> None:
|
|
config = _load_yaml("configs/training/pretrain.yaml")
|
|
loss_fn = build_loss(config)
|
|
assert loss_fn.redundant_penalty == 2.0
|
|
|
|
def test_empty_config_uses_defaults(self) -> None:
|
|
loss_fn = build_loss({})
|
|
assert isinstance(loss_fn, MultiTaskLoss)
|
|
assert loss_fn.weights["edge"] == 1.0
|
|
|
|
def test_custom_weights(self) -> None:
|
|
config = {
|
|
"loss": {
|
|
"edge_weight": 2.0,
|
|
"graph_weight": 1.0,
|
|
"redundant_penalty": 5.0,
|
|
},
|
|
}
|
|
loss_fn = build_loss(config)
|
|
assert loss_fn.weights["edge"] == 2.0
|
|
assert loss_fn.weights["graph"] == 1.0
|
|
assert loss_fn.redundant_penalty == 5.0
|