"""Tests for solver.models.heads -- task-specific prediction heads.""" from __future__ import annotations import torch from solver.models.heads import ( DOFRegressionHead, DOFTrackingHead, EdgeClassificationHead, GraphClassificationHead, JointTypeHead, ) class TestEdgeClassificationHead: """EdgeClassificationHead produces correct shape and gradients.""" def test_output_shape(self) -> None: head = EdgeClassificationHead(hidden_dim=128) edge_emb = torch.randn(20, 128) out = head(edge_emb) assert out.shape == (20, 1) def test_output_shape_small(self) -> None: head = EdgeClassificationHead(hidden_dim=32) edge_emb = torch.randn(5, 32) out = head(edge_emb) assert out.shape == (5, 1) def test_gradients_flow(self) -> None: head = EdgeClassificationHead(hidden_dim=64) edge_emb = torch.randn(10, 64, requires_grad=True) out = head(edge_emb) out.sum().backward() assert edge_emb.grad is not None assert edge_emb.grad.abs().sum() > 0 def test_zero_edges(self) -> None: head = EdgeClassificationHead(hidden_dim=64) edge_emb = torch.zeros(0, 64) out = head(edge_emb) assert out.shape == (0, 1) def test_output_is_logits(self) -> None: """Output should be unbounded logits (not probabilities).""" head = EdgeClassificationHead(hidden_dim=64) torch.manual_seed(42) edge_emb = torch.randn(100, 64) out = head(edge_emb) # Logits can be negative. assert out.min().item() < 0 or out.max().item() > 1 class TestGraphClassificationHead: """GraphClassificationHead produces correct shape and gradients.""" def test_output_shape(self) -> None: head = GraphClassificationHead(hidden_dim=128, num_classes=4) graph_emb = torch.randn(3, 128) out = head(graph_emb) assert out.shape == (3, 4) def test_custom_num_classes(self) -> None: head = GraphClassificationHead(hidden_dim=64, num_classes=8) graph_emb = torch.randn(2, 64) out = head(graph_emb) assert out.shape == (2, 8) def test_gradients_flow(self) -> None: head = GraphClassificationHead(hidden_dim=64) graph_emb = torch.randn(2, 64, requires_grad=True) out = head(graph_emb) out.sum().backward() assert graph_emb.grad is not None def test_single_graph(self) -> None: head = GraphClassificationHead(hidden_dim=128) graph_emb = torch.randn(1, 128) out = head(graph_emb) assert out.shape == (1, 4) class TestJointTypeHead: """JointTypeHead produces correct shape and gradients.""" def test_output_shape(self) -> None: head = JointTypeHead(hidden_dim=128, num_classes=11) edge_emb = torch.randn(20, 128) out = head(edge_emb) assert out.shape == (20, 11) def test_custom_classes(self) -> None: head = JointTypeHead(hidden_dim=64, num_classes=7) edge_emb = torch.randn(10, 64) out = head(edge_emb) assert out.shape == (10, 7) def test_gradients_flow(self) -> None: head = JointTypeHead(hidden_dim=64) edge_emb = torch.randn(10, 64, requires_grad=True) out = head(edge_emb) out.sum().backward() assert edge_emb.grad is not None def test_zero_edges(self) -> None: head = JointTypeHead(hidden_dim=64) edge_emb = torch.zeros(0, 64) out = head(edge_emb) assert out.shape == (0, 11) class TestDOFRegressionHead: """DOFRegressionHead produces correct shape and non-negative output.""" def test_output_shape(self) -> None: head = DOFRegressionHead(hidden_dim=128) graph_emb = torch.randn(3, 128) out = head(graph_emb) assert out.shape == (3, 1) def test_output_non_negative(self) -> None: """Softplus ensures non-negative output.""" head = DOFRegressionHead(hidden_dim=64) torch.manual_seed(0) graph_emb = torch.randn(50, 64) out = head(graph_emb) assert (out >= 0).all() def test_gradients_flow(self) -> None: head = DOFRegressionHead(hidden_dim=64) graph_emb = torch.randn(2, 64, requires_grad=True) out = head(graph_emb) out.sum().backward() assert graph_emb.grad is not None def test_single_graph(self) -> None: head = DOFRegressionHead(hidden_dim=32) graph_emb = torch.randn(1, 32) out = head(graph_emb) assert out.shape == (1, 1) assert out.item() >= 0 class TestDOFTrackingHead: """DOFTrackingHead produces correct shape and non-negative output.""" def test_output_shape(self) -> None: head = DOFTrackingHead(hidden_dim=128) node_emb = torch.randn(10, 128) out = head(node_emb) assert out.shape == (10, 2) def test_output_non_negative(self) -> None: """Softplus ensures non-negative output.""" head = DOFTrackingHead(hidden_dim=64) torch.manual_seed(0) node_emb = torch.randn(50, 64) out = head(node_emb) assert (out >= 0).all() def test_gradients_flow(self) -> None: head = DOFTrackingHead(hidden_dim=64) node_emb = torch.randn(10, 64, requires_grad=True) out = head(node_emb) out.sum().backward() assert node_emb.grad is not None def test_single_node(self) -> None: head = DOFTrackingHead(hidden_dim=32) node_emb = torch.randn(1, 32) out = head(node_emb) assert out.shape == (1, 2) assert (out >= 0).all() def test_two_columns_independent(self) -> None: """Translational and rotational DOF are independently predicted.""" head = DOFTrackingHead(hidden_dim=64) node_emb = torch.randn(20, 64) out = head(node_emb) # The two columns should generally differ. assert not torch.allclose(out[:, 0], out[:, 1], atol=1e-6)