"""Tests for solver.models.factory -- model and loss construction from config.""" from __future__ import annotations import yaml from solver.models.assembly_gnn import AssemblyGNN from solver.models.factory import build_loss, build_model from solver.models.losses import MultiTaskLoss def _load_yaml(path: str) -> dict: with open(path) as f: return yaml.safe_load(f) class TestBuildModel: """build_model constructs AssemblyGNN from config.""" def test_baseline_config(self) -> None: config = _load_yaml("configs/model/baseline.yaml") model = build_model(config) assert isinstance(model, AssemblyGNN) assert model.encoder.hidden_dim == 128 def test_gat_config(self) -> None: config = _load_yaml("configs/model/gat.yaml") model = build_model(config) assert isinstance(model, AssemblyGNN) assert model.encoder.hidden_dim == 256 def test_baseline_heads_present(self) -> None: config = _load_yaml("configs/model/baseline.yaml") model = build_model(config) assert "edge_pred" in model.heads assert "graph_pred" in model.heads assert "joint_type_pred" in model.heads assert "dof_pred" in model.heads def test_gat_has_dof_tracking(self) -> None: config = _load_yaml("configs/model/gat.yaml") model = build_model(config) assert "body_dof_pred" in model.heads def test_baseline_no_dof_tracking(self) -> None: config = _load_yaml("configs/model/baseline.yaml") model = build_model(config) assert "body_dof_pred" not in model.heads def test_minimal_config(self) -> None: config = {"architecture": "gin"} model = build_model(config) assert isinstance(model, AssemblyGNN) # No heads enabled. assert len(model.heads) == 0 def test_custom_config(self) -> None: config = { "architecture": "gin", "encoder": {"hidden_dim": 64, "num_layers": 2}, "heads": { "edge_classification": {"enabled": True}, "graph_classification": {"enabled": True, "num_classes": 4}, }, } model = build_model(config) assert model.encoder.hidden_dim == 64 assert "edge_pred" in model.heads assert "graph_pred" in model.heads assert "joint_type_pred" not in model.heads class TestBuildLoss: """build_loss constructs MultiTaskLoss from training config.""" def test_pretrain_config(self) -> None: config = _load_yaml("configs/training/pretrain.yaml") loss_fn = build_loss(config) assert isinstance(loss_fn, MultiTaskLoss) def test_weights_from_config(self) -> None: config = _load_yaml("configs/training/pretrain.yaml") loss_fn = build_loss(config) assert loss_fn.weights["edge"] == 1.0 assert loss_fn.weights["graph"] == 0.5 assert loss_fn.weights["joint_type"] == 0.3 assert loss_fn.weights["dof"] == 0.2 def test_redundant_penalty_from_config(self) -> None: config = _load_yaml("configs/training/pretrain.yaml") loss_fn = build_loss(config) assert loss_fn.redundant_penalty == 2.0 def test_empty_config_uses_defaults(self) -> None: loss_fn = build_loss({}) assert isinstance(loss_fn, MultiTaskLoss) assert loss_fn.weights["edge"] == 1.0 def test_custom_weights(self) -> None: config = { "loss": { "edge_weight": 2.0, "graph_weight": 1.0, "redundant_penalty": 5.0, }, } loss_fn = build_loss(config) assert loss_fn.weights["edge"] == 2.0 assert loss_fn.weights["graph"] == 1.0 assert loss_fn.redundant_penalty == 5.0