"""Tests for solver.models.encoder -- GIN and GAT graph encoders.""" from __future__ import annotations import pytest import torch from solver.models.encoder import GATEncoder, GINEncoder def _random_graph( n_nodes: int = 8, n_edges: int = 20, node_dim: int = 22, edge_dim: int = 22, batch_size: int = 2, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Create a random graph for testing.""" x = torch.randn(n_nodes, node_dim) edge_index = torch.randint(0, n_nodes, (2, n_edges)) edge_attr = torch.randn(n_edges, edge_dim) # Assign nodes to batches roughly evenly. batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size) # Handle remainder nodes. if len(batch) < n_nodes: batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)]) return x, edge_index, edge_attr, batch class TestGINEncoder: """GINEncoder shape and gradient tests.""" def test_output_shapes(self) -> None: enc = GINEncoder(node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2) x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3) node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) assert node_emb.shape == (10, 64) assert edge_emb.shape == (16, 64) assert graph_emb.shape == (3, 64) def test_hidden_dim_property(self) -> None: enc = GINEncoder(hidden_dim=128) assert enc.hidden_dim == 128 def test_default_dimensions(self) -> None: enc = GINEncoder() x, ei, ea, batch = _random_graph() node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) assert node_emb.shape[1] == 128 assert edge_emb.shape[1] == 128 assert graph_emb.shape[1] == 128 def test_no_batch_defaults_to_single_graph(self) -> None: enc = GINEncoder(hidden_dim=64, num_layers=2) x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10) node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None) assert graph_emb.shape == (1, 64) def test_gradients_flow(self) -> None: enc = GINEncoder(hidden_dim=32, num_layers=2) x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) x.requires_grad_(True) node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) loss = graph_emb.sum() loss.backward() assert x.grad is not None assert x.grad.abs().sum() > 0 def test_zero_edges(self) -> None: enc = GINEncoder(hidden_dim=32, num_layers=2) x = torch.randn(4, 22) ei = torch.zeros(2, 0, dtype=torch.long) ea = torch.zeros(0, 22) batch = torch.tensor([0, 0, 1, 1]) node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) assert node_emb.shape == (4, 32) assert edge_emb.shape == (0, 32) assert graph_emb.shape == (2, 32) def test_single_node(self) -> None: enc = GINEncoder(hidden_dim=32, num_layers=2) # Train with a small batch first to populate BN running stats. x_train = torch.randn(4, 22) ei_train = torch.zeros(2, 0, dtype=torch.long) ea_train = torch.zeros(0, 22) batch_train = torch.tensor([0, 0, 1, 1]) enc.train() enc(x_train, ei_train, ea_train, batch_train) # Now test single node in eval mode (BN uses running stats). enc.eval() x = torch.randn(1, 22) ei = torch.zeros(2, 0, dtype=torch.long) ea = torch.zeros(0, 22) with torch.no_grad(): node_emb, edge_emb, graph_emb = enc(x, ei, ea) assert node_emb.shape == (1, 32) assert graph_emb.shape == (1, 32) def test_eval_mode(self) -> None: """Encoder works in eval mode (BatchNorm uses running stats).""" enc = GINEncoder(hidden_dim=32, num_layers=2) # Forward pass in train mode to populate BN stats. x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) enc.train() enc(x, ei, ea, batch) # Switch to eval. enc.eval() with torch.no_grad(): node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) assert node_emb.shape[1] == 32 class TestGATEncoder: """GATEncoder shape and gradient tests.""" def test_output_shapes(self) -> None: enc = GATEncoder( node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2, num_heads=4, ) x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3) node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) assert node_emb.shape == (10, 64) assert edge_emb.shape == (16, 64) assert graph_emb.shape == (3, 64) def test_hidden_dim_property(self) -> None: enc = GATEncoder(hidden_dim=256, num_heads=8) assert enc.hidden_dim == 256 def test_default_dimensions(self) -> None: enc = GATEncoder() x, ei, ea, batch = _random_graph() node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) assert node_emb.shape[1] == 256 assert edge_emb.shape[1] == 256 assert graph_emb.shape[1] == 256 def test_no_batch_defaults_to_single_graph(self) -> None: enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4) x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10) node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None) assert graph_emb.shape == (1, 64) def test_gradients_flow(self) -> None: enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4) x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) x.requires_grad_(True) node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) loss = graph_emb.sum() loss.backward() assert x.grad is not None assert x.grad.abs().sum() > 0 def test_residual_connection(self) -> None: """With residual=True, output should differ from residual=False.""" x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) torch.manual_seed(0) enc_res = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=True) torch.manual_seed(0) enc_no = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=False) with torch.no_grad(): n1, _, _ = enc_res(x, ei, ea, batch) n2, _, _ = enc_no(x, ei, ea, batch) # Outputs should generally differ (unless by very unlikely coincidence). assert not torch.allclose(n1, n2, atol=1e-4) def test_hidden_dim_must_divide_heads(self) -> None: with pytest.raises(ValueError, match="divisible"): GATEncoder(hidden_dim=100, num_heads=8) def test_eval_mode(self) -> None: enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4) x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) enc.train() enc(x, ei, ea, batch) enc.eval() with torch.no_grad(): node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) assert node_emb.shape[1] == 64