Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
157 lines
5.8 KiB
Python
157 lines
5.8 KiB
Python
"""Tests for solver.models.losses -- uncertainty-weighted multi-task loss."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
from solver.models.losses import MultiTaskLoss
|
|
|
|
|
|
def _make_predictions_and_targets(
|
|
n_edges: int = 20,
|
|
batch_size: int = 3,
|
|
n_nodes: int = 10,
|
|
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
|
preds = {
|
|
"edge_pred": torch.randn(n_edges, 1),
|
|
"graph_pred": torch.randn(batch_size, 4),
|
|
"joint_type_pred": torch.randn(n_edges, 11),
|
|
"dof_pred": torch.rand(batch_size, 1) * 10,
|
|
"body_dof_pred": torch.rand(n_nodes, 2) * 6,
|
|
}
|
|
targets = {
|
|
"y_edge": torch.randint(0, 2, (n_edges,)).float(),
|
|
"y_graph": torch.randint(0, 4, (batch_size,)),
|
|
"y_joint_type": torch.randint(0, 11, (n_edges,)),
|
|
"y_dof": torch.rand(batch_size, 1) * 10,
|
|
"y_body_dof": torch.rand(n_nodes, 2) * 6,
|
|
}
|
|
return preds, targets
|
|
|
|
|
|
class TestMultiTaskLoss:
|
|
"""MultiTaskLoss computation tests."""
|
|
|
|
def test_returns_scalar_and_breakdown(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
preds, targets = _make_predictions_and_targets()
|
|
total, breakdown = loss_fn(preds, targets)
|
|
assert total.dim() == 0 # scalar
|
|
assert isinstance(breakdown, dict)
|
|
|
|
def test_all_tasks_in_breakdown(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
preds, targets = _make_predictions_and_targets()
|
|
_, breakdown = loss_fn(preds, targets)
|
|
assert "edge" in breakdown
|
|
assert "graph" in breakdown
|
|
assert "joint_type" in breakdown
|
|
assert "dof" in breakdown
|
|
assert "body_dof" in breakdown
|
|
|
|
def test_total_is_positive(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
preds, targets = _make_predictions_and_targets()
|
|
total, _ = loss_fn(preds, targets)
|
|
# With random predictions, loss should be positive.
|
|
assert total.item() > 0
|
|
|
|
def test_skips_missing_predictions(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
preds = {"edge_pred": torch.randn(10, 1)}
|
|
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
|
|
total, breakdown = loss_fn(preds, targets)
|
|
assert "edge" in breakdown
|
|
assert "graph" not in breakdown
|
|
assert "joint_type" not in breakdown
|
|
|
|
def test_skips_missing_targets(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
preds = {
|
|
"edge_pred": torch.randn(10, 1),
|
|
"graph_pred": torch.randn(2, 4),
|
|
}
|
|
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
|
|
_, breakdown = loss_fn(preds, targets)
|
|
assert "edge" in breakdown
|
|
assert "graph" not in breakdown
|
|
|
|
def test_gradients_flow_to_log_vars(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
preds, targets = _make_predictions_and_targets()
|
|
# Make preds require grad.
|
|
for k in preds:
|
|
preds[k] = preds[k].requires_grad_(True)
|
|
total, _ = loss_fn(preds, targets)
|
|
total.backward()
|
|
for name, param in loss_fn.log_vars.items():
|
|
assert param.grad is not None, f"No gradient for log_var[{name}]"
|
|
|
|
def test_gradients_flow_to_predictions(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
preds, targets = _make_predictions_and_targets()
|
|
for k in preds:
|
|
preds[k] = preds[k].requires_grad_(True)
|
|
total, _ = loss_fn(preds, targets)
|
|
total.backward()
|
|
for k, v in preds.items():
|
|
assert v.grad is not None, f"No gradient for prediction[{k}]"
|
|
|
|
def test_redundant_penalty_applies(self) -> None:
|
|
"""Redundant edges (label=0) should have higher loss contribution."""
|
|
loss_fn = MultiTaskLoss(redundant_penalty=5.0)
|
|
# All-zero predictions, label=0 (redundant).
|
|
preds_red = {"edge_pred": torch.zeros(10, 1)}
|
|
targets_red = {"y_edge": torch.zeros(10)}
|
|
total_red, _ = loss_fn(preds_red, targets_red)
|
|
|
|
loss_fn2 = MultiTaskLoss(redundant_penalty=1.0)
|
|
total_eq, _ = loss_fn2(preds_red, targets_red)
|
|
|
|
# Higher penalty should produce higher loss.
|
|
assert total_red.item() > total_eq.item()
|
|
|
|
def test_empty_predictions_returns_zero(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
total, breakdown = loss_fn({}, {})
|
|
assert total.item() == 0.0
|
|
assert len(breakdown) == 0
|
|
|
|
|
|
class TestUncertaintyWeighting:
|
|
"""Test uncertainty weighting mechanism specifically."""
|
|
|
|
def test_log_vars_initialized_to_zero(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
for param in loss_fn.log_vars.values():
|
|
assert param.item() == 0.0
|
|
|
|
def test_log_vars_are_learnable(self) -> None:
|
|
loss_fn = MultiTaskLoss()
|
|
params = list(loss_fn.parameters())
|
|
log_var_params = [p for p in params if p.shape == (1,)]
|
|
assert len(log_var_params) == 5 # one per task
|
|
|
|
def test_weighting_reduces_high_loss_influence(self) -> None:
|
|
"""After a few gradient steps, log_var for a noisy task should increase."""
|
|
loss_fn = MultiTaskLoss(edge_weight=1.0, graph_weight=1.0)
|
|
optimizer = torch.optim.SGD(loss_fn.parameters(), lr=0.1)
|
|
|
|
# Simulate: edge task has high loss, graph has low.
|
|
for _ in range(20):
|
|
preds = {
|
|
"edge_pred": torch.randn(10, 1) * 10, # high variance -> high loss
|
|
"graph_pred": torch.zeros(2, 4), # near-zero loss
|
|
}
|
|
targets = {
|
|
"y_edge": torch.randint(0, 2, (10,)).float(),
|
|
"y_graph": torch.zeros(2, dtype=torch.long),
|
|
}
|
|
optimizer.zero_grad()
|
|
total, _ = loss_fn(preds, targets)
|
|
total.backward()
|
|
optimizer.step()
|
|
|
|
# The edge task log_var should have increased (higher uncertainty).
|
|
assert loss_fn.log_vars["edge"].item() > loss_fn.log_vars["graph"].item()
|