Files
solver/tests/models/test_losses.py
forbes fe41fa3b00 feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific
prediction heads, uncertainty-weighted multi-task loss (Kendall et al.
2018), and config-driven factory functions.

New modules:
- graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features)
- encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head)
- heads.py: edge classification, graph classification, joint type,
  DOF regression, per-body DOF tracking heads
- assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads
- losses.py: MultiTaskLoss with learnable log-variance per task
- factory.py: build_model() and build_loss() from YAML configs

Supporting changes:
- generator.py: serialize anchor_a, anchor_b, pitch in joint dicts
- configs: fix joint_type num_classes 12 -> 11 (matches JointType enum)

92 tests covering shapes, gradients, edge cases, and end-to-end
datagen-to-model pipeline.
2026-02-07 10:14:19 -06:00

157 lines
5.8 KiB
Python

"""Tests for solver.models.losses -- uncertainty-weighted multi-task loss."""
from __future__ import annotations
import torch
from solver.models.losses import MultiTaskLoss
def _make_predictions_and_targets(
n_edges: int = 20,
batch_size: int = 3,
n_nodes: int = 10,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
preds = {
"edge_pred": torch.randn(n_edges, 1),
"graph_pred": torch.randn(batch_size, 4),
"joint_type_pred": torch.randn(n_edges, 11),
"dof_pred": torch.rand(batch_size, 1) * 10,
"body_dof_pred": torch.rand(n_nodes, 2) * 6,
}
targets = {
"y_edge": torch.randint(0, 2, (n_edges,)).float(),
"y_graph": torch.randint(0, 4, (batch_size,)),
"y_joint_type": torch.randint(0, 11, (n_edges,)),
"y_dof": torch.rand(batch_size, 1) * 10,
"y_body_dof": torch.rand(n_nodes, 2) * 6,
}
return preds, targets
class TestMultiTaskLoss:
"""MultiTaskLoss computation tests."""
def test_returns_scalar_and_breakdown(self) -> None:
loss_fn = MultiTaskLoss()
preds, targets = _make_predictions_and_targets()
total, breakdown = loss_fn(preds, targets)
assert total.dim() == 0 # scalar
assert isinstance(breakdown, dict)
def test_all_tasks_in_breakdown(self) -> None:
loss_fn = MultiTaskLoss()
preds, targets = _make_predictions_and_targets()
_, breakdown = loss_fn(preds, targets)
assert "edge" in breakdown
assert "graph" in breakdown
assert "joint_type" in breakdown
assert "dof" in breakdown
assert "body_dof" in breakdown
def test_total_is_positive(self) -> None:
loss_fn = MultiTaskLoss()
preds, targets = _make_predictions_and_targets()
total, _ = loss_fn(preds, targets)
# With random predictions, loss should be positive.
assert total.item() > 0
def test_skips_missing_predictions(self) -> None:
loss_fn = MultiTaskLoss()
preds = {"edge_pred": torch.randn(10, 1)}
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
total, breakdown = loss_fn(preds, targets)
assert "edge" in breakdown
assert "graph" not in breakdown
assert "joint_type" not in breakdown
def test_skips_missing_targets(self) -> None:
loss_fn = MultiTaskLoss()
preds = {
"edge_pred": torch.randn(10, 1),
"graph_pred": torch.randn(2, 4),
}
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
_, breakdown = loss_fn(preds, targets)
assert "edge" in breakdown
assert "graph" not in breakdown
def test_gradients_flow_to_log_vars(self) -> None:
loss_fn = MultiTaskLoss()
preds, targets = _make_predictions_and_targets()
# Make preds require grad.
for k in preds:
preds[k] = preds[k].requires_grad_(True)
total, _ = loss_fn(preds, targets)
total.backward()
for name, param in loss_fn.log_vars.items():
assert param.grad is not None, f"No gradient for log_var[{name}]"
def test_gradients_flow_to_predictions(self) -> None:
loss_fn = MultiTaskLoss()
preds, targets = _make_predictions_and_targets()
for k in preds:
preds[k] = preds[k].requires_grad_(True)
total, _ = loss_fn(preds, targets)
total.backward()
for k, v in preds.items():
assert v.grad is not None, f"No gradient for prediction[{k}]"
def test_redundant_penalty_applies(self) -> None:
"""Redundant edges (label=0) should have higher loss contribution."""
loss_fn = MultiTaskLoss(redundant_penalty=5.0)
# All-zero predictions, label=0 (redundant).
preds_red = {"edge_pred": torch.zeros(10, 1)}
targets_red = {"y_edge": torch.zeros(10)}
total_red, _ = loss_fn(preds_red, targets_red)
loss_fn2 = MultiTaskLoss(redundant_penalty=1.0)
total_eq, _ = loss_fn2(preds_red, targets_red)
# Higher penalty should produce higher loss.
assert total_red.item() > total_eq.item()
def test_empty_predictions_returns_zero(self) -> None:
loss_fn = MultiTaskLoss()
total, breakdown = loss_fn({}, {})
assert total.item() == 0.0
assert len(breakdown) == 0
class TestUncertaintyWeighting:
"""Test uncertainty weighting mechanism specifically."""
def test_log_vars_initialized_to_zero(self) -> None:
loss_fn = MultiTaskLoss()
for param in loss_fn.log_vars.values():
assert param.item() == 0.0
def test_log_vars_are_learnable(self) -> None:
loss_fn = MultiTaskLoss()
params = list(loss_fn.parameters())
log_var_params = [p for p in params if p.shape == (1,)]
assert len(log_var_params) == 5 # one per task
def test_weighting_reduces_high_loss_influence(self) -> None:
"""After a few gradient steps, log_var for a noisy task should increase."""
loss_fn = MultiTaskLoss(edge_weight=1.0, graph_weight=1.0)
optimizer = torch.optim.SGD(loss_fn.parameters(), lr=0.1)
# Simulate: edge task has high loss, graph has low.
for _ in range(20):
preds = {
"edge_pred": torch.randn(10, 1) * 10, # high variance -> high loss
"graph_pred": torch.zeros(2, 4), # near-zero loss
}
targets = {
"y_edge": torch.randint(0, 2, (10,)).float(),
"y_graph": torch.zeros(2, dtype=torch.long),
}
optimizer.zero_grad()
total, _ = loss_fn(preds, targets)
total.backward()
optimizer.step()
# The edge task log_var should have increased (higher uncertainty).
assert loss_fn.log_vars["edge"].item() > loss_fn.log_vars["graph"].item()