"""Tests for solver.models.graph_conv -- assembly to PyG conversion.""" from __future__ import annotations import torch from solver.datagen.generator import SyntheticAssemblyGenerator from solver.datagen.types import JointType from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg def _make_example(n_bodies: int = 4, grounded: bool = True, seed: int = 0) -> dict: """Generate a single training example via the datagen pipeline.""" gen = SyntheticAssemblyGenerator(seed=seed) batch = gen.generate_training_batch(batch_size=1, n_bodies_range=(n_bodies, n_bodies + 1)) return batch[0] class TestAssemblyToPyg: """assembly_to_pyg converts datagen dicts to PyG Data correctly.""" def test_node_feature_shape(self) -> None: ex = _make_example(n_bodies=5) data = assembly_to_pyg(ex) assert data.x.shape == (5, 22) def test_edge_feature_shape(self) -> None: ex = _make_example(n_bodies=4) data = assembly_to_pyg(ex) n_joints = ex["n_joints"] assert data.edge_attr.shape == (n_joints * 2, 22) def test_edge_index_bidirectional(self) -> None: ex = _make_example(n_bodies=4) data = assembly_to_pyg(ex) ei = data.edge_index # Each joint produces 2 directed edges: a->b and b->a. for i in range(0, ei.size(1), 2): assert ei[0, i].item() == ei[1, i + 1].item() assert ei[1, i].item() == ei[0, i + 1].item() def test_edge_index_shape(self) -> None: ex = _make_example(n_bodies=4) data = assembly_to_pyg(ex) n_joints = ex["n_joints"] assert data.edge_index.shape == (2, n_joints * 2) def test_node_features_centered(self) -> None: ex = _make_example(n_bodies=5) data = assembly_to_pyg(ex) # Positions (dims 0-2) should be centered (mean ~0). pos = data.x[:, :3] assert pos.mean(dim=0).abs().max().item() < 1e-5 def test_grounded_flag_set(self) -> None: ex = _make_example(n_bodies=4, grounded=True) ex["grounded"] = True data = assembly_to_pyg(ex) assert data.x[0, 12].item() == 1.0 def test_ungrounded_flag_clear(self) -> None: ex = _make_example(n_bodies=4) ex["grounded"] = False data = assembly_to_pyg(ex) assert (data.x[:, 12] == 0.0).all() def test_edge_type_one_hot_valid(self) -> None: ex = _make_example(n_bodies=5) data = assembly_to_pyg(ex) if data.edge_attr.size(0) > 0: onehot = data.edge_attr[:, :11] # Each row should have exactly one 1.0. assert (onehot.sum(dim=1) == 1.0).all() def test_labels_present_when_requested(self) -> None: ex = _make_example(n_bodies=4) data = assembly_to_pyg(ex, include_labels=True) assert hasattr(data, "y_edge") assert hasattr(data, "y_graph") assert hasattr(data, "y_joint_type") assert hasattr(data, "y_dof") assert hasattr(data, "y_body_dof") def test_labels_absent_when_not_requested(self) -> None: ex = _make_example(n_bodies=4) data = assembly_to_pyg(ex, include_labels=False) assert not hasattr(data, "y_edge") assert not hasattr(data, "y_graph") def test_graph_classification_label_mapping(self) -> None: ex = _make_example(n_bodies=4) data = assembly_to_pyg(ex) cls = ex["assembly_classification"] expected = ASSEMBLY_CLASSES[cls] assert data.y_graph.item() == expected def test_body_dof_shape(self) -> None: ex = _make_example(n_bodies=5) data = assembly_to_pyg(ex) assert data.y_body_dof.shape == (5, 2) def test_edge_labels_binary(self) -> None: ex = _make_example(n_bodies=5) data = assembly_to_pyg(ex) if data.y_edge.numel() > 0: assert ((data.y_edge == 0.0) | (data.y_edge == 1.0)).all() def test_dof_removed_normalized(self) -> None: ex = _make_example(n_bodies=4) data = assembly_to_pyg(ex) if data.edge_attr.size(0) > 0: dof_norm = data.edge_attr[:, 21] assert (dof_norm >= 0.0).all() assert (dof_norm <= 1.0).all() def test_roundtrip_with_generator(self) -> None: """Generate a real example and convert -- no crash.""" gen = SyntheticAssemblyGenerator(seed=42) batch = gen.generate_training_batch(batch_size=5, complexity_tier="simple") for ex in batch: data = assembly_to_pyg(ex) assert data.x.shape[1] == 22 assert data.edge_attr.shape[1] == 22 class TestAssemblyClasses: """ASSEMBLY_CLASSES covers all classifications.""" def test_four_classes(self) -> None: assert len(ASSEMBLY_CLASSES) == 4 def test_values_are_0_to_3(self) -> None: assert set(ASSEMBLY_CLASSES.values()) == {0, 1, 2, 3} class TestJointTypeNames: """JOINT_TYPE_NAMES matches the JointType enum.""" def test_length_matches_enum(self) -> None: assert len(JOINT_TYPE_NAMES) == len(JointType) def test_order_matches_ordinal(self) -> None: for i, name in enumerate(JOINT_TYPE_NAMES): assert JointType[name].value[0] == i