Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
145 lines
5.2 KiB
Python
145 lines
5.2 KiB
Python
"""Tests for solver.models.graph_conv -- assembly to PyG conversion."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
|
|
from solver.datagen.generator import SyntheticAssemblyGenerator
|
|
from solver.datagen.types import JointType
|
|
from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg
|
|
|
|
|
|
def _make_example(n_bodies: int = 4, grounded: bool = True, seed: int = 0) -> dict:
|
|
"""Generate a single training example via the datagen pipeline."""
|
|
gen = SyntheticAssemblyGenerator(seed=seed)
|
|
batch = gen.generate_training_batch(batch_size=1, n_bodies_range=(n_bodies, n_bodies + 1))
|
|
return batch[0]
|
|
|
|
|
|
class TestAssemblyToPyg:
|
|
"""assembly_to_pyg converts datagen dicts to PyG Data correctly."""
|
|
|
|
def test_node_feature_shape(self) -> None:
|
|
ex = _make_example(n_bodies=5)
|
|
data = assembly_to_pyg(ex)
|
|
assert data.x.shape == (5, 22)
|
|
|
|
def test_edge_feature_shape(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
data = assembly_to_pyg(ex)
|
|
n_joints = ex["n_joints"]
|
|
assert data.edge_attr.shape == (n_joints * 2, 22)
|
|
|
|
def test_edge_index_bidirectional(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
data = assembly_to_pyg(ex)
|
|
ei = data.edge_index
|
|
# Each joint produces 2 directed edges: a->b and b->a.
|
|
for i in range(0, ei.size(1), 2):
|
|
assert ei[0, i].item() == ei[1, i + 1].item()
|
|
assert ei[1, i].item() == ei[0, i + 1].item()
|
|
|
|
def test_edge_index_shape(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
data = assembly_to_pyg(ex)
|
|
n_joints = ex["n_joints"]
|
|
assert data.edge_index.shape == (2, n_joints * 2)
|
|
|
|
def test_node_features_centered(self) -> None:
|
|
ex = _make_example(n_bodies=5)
|
|
data = assembly_to_pyg(ex)
|
|
# Positions (dims 0-2) should be centered (mean ~0).
|
|
pos = data.x[:, :3]
|
|
assert pos.mean(dim=0).abs().max().item() < 1e-5
|
|
|
|
def test_grounded_flag_set(self) -> None:
|
|
ex = _make_example(n_bodies=4, grounded=True)
|
|
ex["grounded"] = True
|
|
data = assembly_to_pyg(ex)
|
|
assert data.x[0, 12].item() == 1.0
|
|
|
|
def test_ungrounded_flag_clear(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
ex["grounded"] = False
|
|
data = assembly_to_pyg(ex)
|
|
assert (data.x[:, 12] == 0.0).all()
|
|
|
|
def test_edge_type_one_hot_valid(self) -> None:
|
|
ex = _make_example(n_bodies=5)
|
|
data = assembly_to_pyg(ex)
|
|
if data.edge_attr.size(0) > 0:
|
|
onehot = data.edge_attr[:, :11]
|
|
# Each row should have exactly one 1.0.
|
|
assert (onehot.sum(dim=1) == 1.0).all()
|
|
|
|
def test_labels_present_when_requested(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
data = assembly_to_pyg(ex, include_labels=True)
|
|
assert hasattr(data, "y_edge")
|
|
assert hasattr(data, "y_graph")
|
|
assert hasattr(data, "y_joint_type")
|
|
assert hasattr(data, "y_dof")
|
|
assert hasattr(data, "y_body_dof")
|
|
|
|
def test_labels_absent_when_not_requested(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
data = assembly_to_pyg(ex, include_labels=False)
|
|
assert not hasattr(data, "y_edge")
|
|
assert not hasattr(data, "y_graph")
|
|
|
|
def test_graph_classification_label_mapping(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
data = assembly_to_pyg(ex)
|
|
cls = ex["assembly_classification"]
|
|
expected = ASSEMBLY_CLASSES[cls]
|
|
assert data.y_graph.item() == expected
|
|
|
|
def test_body_dof_shape(self) -> None:
|
|
ex = _make_example(n_bodies=5)
|
|
data = assembly_to_pyg(ex)
|
|
assert data.y_body_dof.shape == (5, 2)
|
|
|
|
def test_edge_labels_binary(self) -> None:
|
|
ex = _make_example(n_bodies=5)
|
|
data = assembly_to_pyg(ex)
|
|
if data.y_edge.numel() > 0:
|
|
assert ((data.y_edge == 0.0) | (data.y_edge == 1.0)).all()
|
|
|
|
def test_dof_removed_normalized(self) -> None:
|
|
ex = _make_example(n_bodies=4)
|
|
data = assembly_to_pyg(ex)
|
|
if data.edge_attr.size(0) > 0:
|
|
dof_norm = data.edge_attr[:, 21]
|
|
assert (dof_norm >= 0.0).all()
|
|
assert (dof_norm <= 1.0).all()
|
|
|
|
def test_roundtrip_with_generator(self) -> None:
|
|
"""Generate a real example and convert -- no crash."""
|
|
gen = SyntheticAssemblyGenerator(seed=42)
|
|
batch = gen.generate_training_batch(batch_size=5, complexity_tier="simple")
|
|
for ex in batch:
|
|
data = assembly_to_pyg(ex)
|
|
assert data.x.shape[1] == 22
|
|
assert data.edge_attr.shape[1] == 22
|
|
|
|
|
|
class TestAssemblyClasses:
|
|
"""ASSEMBLY_CLASSES covers all classifications."""
|
|
|
|
def test_four_classes(self) -> None:
|
|
assert len(ASSEMBLY_CLASSES) == 4
|
|
|
|
def test_values_are_0_to_3(self) -> None:
|
|
assert set(ASSEMBLY_CLASSES.values()) == {0, 1, 2, 3}
|
|
|
|
|
|
class TestJointTypeNames:
|
|
"""JOINT_TYPE_NAMES matches the JointType enum."""
|
|
|
|
def test_length_matches_enum(self) -> None:
|
|
assert len(JOINT_TYPE_NAMES) == len(JointType)
|
|
|
|
def test_order_matches_ordinal(self) -> None:
|
|
for i, name in enumerate(JOINT_TYPE_NAMES):
|
|
assert JointType[name].value[0] == i
|