From fe41fa3b00fa4e49f5719356d20e2d35591937b6 Mon Sep 17 00:00:00 2001 From: forbes Date: Sat, 7 Feb 2026 10:14:19 -0600 Subject: [PATCH] feat(models): implement GNN model layer for assembly constraint analysis Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline. --- configs/model/baseline.yaml | 4 +- configs/model/gat.yaml | 2 +- solver/datagen/generator.py | 3 + solver/models/__init__.py | 31 ++++ solver/models/assembly_gnn.py | 131 +++++++++++++++ solver/models/encoder.py | 194 ++++++++++++++++++++++ solver/models/factory.py | 82 ++++++++++ solver/models/graph_conv.py | 260 ++++++++++++++++++++++++++++++ solver/models/heads.py | 119 ++++++++++++++ solver/models/losses.py | 161 ++++++++++++++++++ tests/models/__init__.py | 0 tests/models/test_assembly_gnn.py | 188 +++++++++++++++++++++ tests/models/test_encoder.py | 183 +++++++++++++++++++++ tests/models/test_factory.py | 110 +++++++++++++ tests/models/test_graph_conv.py | 144 +++++++++++++++++ tests/models/test_heads.py | 182 +++++++++++++++++++++ tests/models/test_losses.py | 156 ++++++++++++++++++ 17 files changed, 1947 insertions(+), 3 deletions(-) create mode 100644 solver/models/assembly_gnn.py create mode 100644 solver/models/encoder.py create mode 100644 solver/models/factory.py create mode 100644 solver/models/graph_conv.py create mode 100644 solver/models/heads.py create mode 100644 solver/models/losses.py create mode 100644 tests/models/__init__.py create mode 100644 tests/models/test_assembly_gnn.py create mode 100644 tests/models/test_encoder.py create mode 100644 tests/models/test_factory.py create mode 100644 tests/models/test_graph_conv.py create mode 100644 tests/models/test_heads.py create mode 100644 tests/models/test_losses.py diff --git a/configs/model/baseline.yaml b/configs/model/baseline.yaml index ebefe9b..34fa989 100644 --- a/configs/model/baseline.yaml +++ b/configs/model/baseline.yaml @@ -16,9 +16,9 @@ heads: hidden_dim: 64 graph_classification: enabled: true - num_classes: 4 # rigid, under, over, mixed + num_classes: 4 # rigid, under, over, mixed joint_type: enabled: true - num_classes: 12 + num_classes: 11 dof_regression: enabled: true diff --git a/configs/model/gat.yaml b/configs/model/gat.yaml index 6592bc1..344dfd1 100644 --- a/configs/model/gat.yaml +++ b/configs/model/gat.yaml @@ -21,7 +21,7 @@ heads: num_classes: 4 joint_type: enabled: true - num_classes: 12 + num_classes: 11 dof_regression: enabled: true dof_tracking: diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py index ad753f9..929f259 100644 --- a/solver/datagen/generator.py +++ b/solver/datagen/generator.py @@ -877,6 +877,9 @@ class SyntheticAssemblyGenerator: "body_b": j.body_b, "type": j.joint_type.name, "axis": j.axis.tolist(), + "anchor_a": j.anchor_a.tolist(), + "anchor_b": j.anchor_b.tolist(), + "pitch": j.pitch, } for j in joints ], diff --git a/solver/models/__init__.py b/solver/models/__init__.py index e69de29..bc66a3e 100644 --- a/solver/models/__init__.py +++ b/solver/models/__init__.py @@ -0,0 +1,31 @@ +"""GNN models for assembly constraint analysis.""" + +from solver.models.assembly_gnn import AssemblyGNN +from solver.models.encoder import GATEncoder, GINEncoder +from solver.models.factory import build_loss, build_model +from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg +from solver.models.heads import ( + DOFRegressionHead, + DOFTrackingHead, + EdgeClassificationHead, + GraphClassificationHead, + JointTypeHead, +) +from solver.models.losses import MultiTaskLoss + +__all__ = [ + "ASSEMBLY_CLASSES", + "AssemblyGNN", + "DOFRegressionHead", + "DOFTrackingHead", + "EdgeClassificationHead", + "GATEncoder", + "GINEncoder", + "GraphClassificationHead", + "JOINT_TYPE_NAMES", + "JointTypeHead", + "MultiTaskLoss", + "assembly_to_pyg", + "build_loss", + "build_model", +] diff --git a/solver/models/assembly_gnn.py b/solver/models/assembly_gnn.py new file mode 100644 index 0000000..5ddd680 --- /dev/null +++ b/solver/models/assembly_gnn.py @@ -0,0 +1,131 @@ +"""AssemblyGNN -- main model wiring encoder and task heads.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch.nn as nn + +from solver.models.encoder import GATEncoder, GINEncoder +from solver.models.heads import ( + DOFRegressionHead, + DOFTrackingHead, + EdgeClassificationHead, + GraphClassificationHead, + JointTypeHead, +) + +if TYPE_CHECKING: + from typing import Any + + import torch + +__all__ = ["AssemblyGNN"] + +_ENCODERS = { + "gin": GINEncoder, + "gat": GATEncoder, +} + + +class AssemblyGNN(nn.Module): + """Multi-task GNN for assembly constraint analysis. + + Wires an encoder (GIN or GAT) with optional task-specific prediction + heads for edge classification, graph classification, joint type + prediction, DOF regression, and per-body DOF tracking. + + Args: + encoder_type: ``"gin"`` or ``"gat"``. + encoder_config: Kwargs passed to the encoder constructor. + heads_config: Dict of head name → config dict. Each entry must have + an ``enabled`` bool. Additional keys are passed as kwargs. + """ + + def __init__( + self, + encoder_type: str = "gin", + encoder_config: dict[str, Any] | None = None, + heads_config: dict[str, dict[str, Any]] | None = None, + ) -> None: + super().__init__() + encoder_config = encoder_config or {} + heads_config = heads_config or {} + + if encoder_type not in _ENCODERS: + msg = f"Unknown encoder type: {encoder_type!r}. Choose from {list(_ENCODERS)}" + raise ValueError(msg) + + self.encoder = _ENCODERS[encoder_type](**encoder_config) + hidden_dim = self.encoder.hidden_dim + + self.heads = nn.ModuleDict() + self._build_heads(heads_config, hidden_dim) + + def _build_heads( + self, + heads_config: dict[str, dict[str, Any]], + hidden_dim: int, + ) -> None: + """Instantiate enabled heads.""" + cfg = heads_config.get("edge_classification", {}) + if cfg.get("enabled", False): + self.heads["edge_pred"] = EdgeClassificationHead( + hidden_dim=hidden_dim, + inner_dim=cfg.get("hidden_dim", 64), + ) + + cfg = heads_config.get("graph_classification", {}) + if cfg.get("enabled", False): + self.heads["graph_pred"] = GraphClassificationHead( + hidden_dim=hidden_dim, + num_classes=cfg.get("num_classes", 4), + ) + + cfg = heads_config.get("joint_type", {}) + if cfg.get("enabled", False): + self.heads["joint_type_pred"] = JointTypeHead( + hidden_dim=hidden_dim, + num_classes=cfg.get("num_classes", 11), + ) + + cfg = heads_config.get("dof_regression", {}) + if cfg.get("enabled", False): + self.heads["dof_pred"] = DOFRegressionHead(hidden_dim=hidden_dim) + + cfg = heads_config.get("dof_tracking", {}) + if cfg.get("enabled", False): + self.heads["body_dof_pred"] = DOFTrackingHead(hidden_dim=hidden_dim) + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + batch: torch.Tensor | None = None, + ) -> dict[str, torch.Tensor]: + """Run encoder and all enabled heads. + + Returns: + Dict with keys matching enabled head names: + ``edge_pred``, ``graph_pred``, ``joint_type_pred``, + ``dof_pred``, ``body_dof_pred``. + """ + node_emb, edge_emb, graph_emb = self.encoder(x, edge_index, edge_attr, batch) + + preds: dict[str, torch.Tensor] = {} + + # Route embeddings to the appropriate heads. + _edge_heads = {"edge_pred", "joint_type_pred"} + _graph_heads = {"graph_pred", "dof_pred"} + _node_heads = {"body_dof_pred"} + + for name, head in self.heads.items(): + if name in _edge_heads: + preds[name] = head(edge_emb) + elif name in _graph_heads: + preds[name] = head(graph_emb) + elif name in _node_heads: + preds[name] = head(node_emb) + + return preds diff --git a/solver/models/encoder.py b/solver/models/encoder.py new file mode 100644 index 0000000..87ca168 --- /dev/null +++ b/solver/models/encoder.py @@ -0,0 +1,194 @@ +"""GIN and GAT graph neural network encoders.""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from torch_geometric.nn import GATv2Conv, GINEConv, global_mean_pool + +__all__ = ["GATEncoder", "GINEncoder"] + + +def _make_gin_mlp(in_dim: int, hidden_dim: int) -> nn.Sequential: + """Two-layer MLP used inside GINEConv.""" + return nn.Sequential( + nn.Linear(in_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, hidden_dim), + ) + + +class GINEncoder(nn.Module): + """Graph Isomorphism Network encoder with edge features (GINE). + + Args: + node_features_dim: Input node feature dimension. + edge_features_dim: Input edge feature dimension. + hidden_dim: Hidden dimension for all layers. + num_layers: Number of GINEConv layers. + dropout: Dropout probability. + """ + + def __init__( + self, + node_features_dim: int = 22, + edge_features_dim: int = 22, + hidden_dim: int = 128, + num_layers: int = 3, + dropout: float = 0.1, + ) -> None: + super().__init__() + self._hidden_dim = hidden_dim + + self.node_proj = nn.Sequential( + nn.Linear(node_features_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + ) + self.edge_proj = nn.Sequential( + nn.Linear(edge_features_dim, hidden_dim), + nn.BatchNorm1d(hidden_dim), + nn.ReLU(), + ) + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + for _ in range(num_layers): + conv = GINEConv(nn=_make_gin_mlp(hidden_dim, hidden_dim), edge_dim=hidden_dim) + self.convs.append(conv) + self.norms.append(nn.BatchNorm1d(hidden_dim)) + + self.dropout = nn.Dropout(dropout) + + # Edge embedding from endpoint + edge features. + self.edge_mlp = nn.Linear(hidden_dim * 3, hidden_dim) + + @property + def hidden_dim(self) -> int: + return self._hidden_dim + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + batch: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode graph and return node, edge, and graph embeddings. + + Returns: + node_emb: [N, hidden_dim] + edge_emb: [E, hidden_dim] + graph_emb: [B, hidden_dim] + """ + h = self.node_proj(x) + e = self.edge_proj(edge_attr) + + for conv, norm in zip(self.convs, self.norms): + h = conv(h, edge_index, e) + h = norm(h) + h = torch.relu(h) + h = self.dropout(h) + + # Edge embeddings from endpoint concatenation. + src, dst = edge_index + edge_emb = self.edge_mlp(torch.cat([h[src], h[dst], e], dim=1)) + + # Graph embedding via mean pooling. + graph_emb = global_mean_pool(h, batch) + + return h, edge_emb, graph_emb + + +class GATEncoder(nn.Module): + """Graph Attention Network v2 encoder with edge features and residuals. + + Args: + node_features_dim: Input node feature dimension. + edge_features_dim: Input edge feature dimension. + hidden_dim: Hidden dimension (must be divisible by num_heads). + num_layers: Number of GATv2Conv layers. + num_heads: Number of attention heads. + dropout: Dropout probability. + residual: Use residual connections. + """ + + def __init__( + self, + node_features_dim: int = 22, + edge_features_dim: int = 22, + hidden_dim: int = 256, + num_layers: int = 4, + num_heads: int = 8, + dropout: float = 0.1, + residual: bool = True, + ) -> None: + super().__init__() + if hidden_dim % num_heads != 0: + msg = f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})" + raise ValueError(msg) + + self._hidden_dim = hidden_dim + self.residual = residual + head_dim = hidden_dim // num_heads + + self.node_proj = nn.Sequential( + nn.Linear(node_features_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + ) + self.edge_proj = nn.Sequential( + nn.Linear(edge_features_dim, hidden_dim), + nn.LayerNorm(hidden_dim), + nn.ReLU(), + ) + + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + for _ in range(num_layers): + conv = GATv2Conv( + in_channels=hidden_dim, + out_channels=head_dim, + heads=num_heads, + edge_dim=hidden_dim, + concat=True, + ) + self.convs.append(conv) + self.norms.append(nn.LayerNorm(hidden_dim)) + + self.dropout = nn.Dropout(dropout) + + # Edge embedding from endpoint + edge features. + self.edge_mlp = nn.Linear(hidden_dim * 3, hidden_dim) + + @property + def hidden_dim(self) -> int: + return self._hidden_dim + + def forward( + self, + x: torch.Tensor, + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + batch: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Encode graph and return node, edge, and graph embeddings.""" + h = self.node_proj(x) + e = self.edge_proj(edge_attr) + + for conv, norm in zip(self.convs, self.norms): + h_new = conv(h, edge_index, e) + h_new = norm(h_new) + h_new = torch.relu(h_new) + h_new = self.dropout(h_new) + if self.residual: + h = h + h_new + else: + h = h_new + + src, dst = edge_index + edge_emb = self.edge_mlp(torch.cat([h[src], h[dst], e], dim=1)) + graph_emb = global_mean_pool(h, batch) + + return h, edge_emb, graph_emb diff --git a/solver/models/factory.py b/solver/models/factory.py new file mode 100644 index 0000000..a350c80 --- /dev/null +++ b/solver/models/factory.py @@ -0,0 +1,82 @@ +"""Factory functions to build model and loss from config dicts.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from solver.models.assembly_gnn import AssemblyGNN +from solver.models.losses import MultiTaskLoss + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["build_loss", "build_model"] + + +def build_model(config: dict[str, Any]) -> AssemblyGNN: + """Construct an AssemblyGNN from a parsed YAML model config. + + Expected config structure (matches ``configs/model/*.yaml``):: + + architecture: gin # or gat + encoder: + hidden_dim: 128 + num_layers: 3 + ... + node_features_dim: 22 + edge_features_dim: 22 + heads: + edge_classification: + enabled: true + ... + + Args: + config: Parsed YAML model config dict. + + Returns: + Configured ``AssemblyGNN`` instance. + """ + encoder_type = config.get("architecture", "gin") + + encoder_config: dict[str, Any] = dict(config.get("encoder", {})) + encoder_config.setdefault("node_features_dim", config.get("node_features_dim", 22)) + encoder_config.setdefault("edge_features_dim", config.get("edge_features_dim", 22)) + + heads_config = config.get("heads", {}) + + return AssemblyGNN( + encoder_type=encoder_type, + encoder_config=encoder_config, + heads_config=heads_config, + ) + + +def build_loss(config: dict[str, Any]) -> MultiTaskLoss: + """Construct a MultiTaskLoss from a parsed YAML training config. + + Expected config structure (from ``configs/training/*.yaml`` ``loss`` section):: + + loss: + edge_weight: 1.0 + graph_weight: 0.5 + joint_type_weight: 0.3 + dof_weight: 0.2 + body_dof_weight: 0.2 + redundant_penalty: 2.0 + + Args: + config: Parsed YAML training config dict (full config, not just loss section). + + Returns: + Configured ``MultiTaskLoss`` instance. + """ + loss_config = config.get("loss", {}) + + return MultiTaskLoss( + edge_weight=loss_config.get("edge_weight", 1.0), + graph_weight=loss_config.get("graph_weight", 0.5), + joint_type_weight=loss_config.get("joint_type_weight", 0.3), + dof_weight=loss_config.get("dof_weight", 0.2), + body_dof_weight=loss_config.get("body_dof_weight", 0.2), + redundant_penalty=loss_config.get("redundant_penalty", 2.0), + ) diff --git a/solver/models/graph_conv.py b/solver/models/graph_conv.py new file mode 100644 index 0000000..af921f9 --- /dev/null +++ b/solver/models/graph_conv.py @@ -0,0 +1,260 @@ +"""Convert datagen assembly dicts to PyTorch Geometric Data objects.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch_geometric.data import Data + +from solver.datagen.types import JointType + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "ASSEMBLY_CLASSES", + "JOINT_TYPE_NAMES", + "assembly_to_pyg", +] + +# Ordered list matching JointType ordinal values (0-10). +JOINT_TYPE_NAMES: list[str] = [jt.name for jt in JointType] + +# Assembly classification label mapping. +ASSEMBLY_CLASSES: dict[str, int] = { + "well-constrained": 0, + "underconstrained": 1, + "overconstrained": 2, + "mixed": 3, +} + +# Joint type name -> ordinal for fast lookup. +_JOINT_TYPE_TO_ORD: dict[str, int] = {jt.name: jt.value[0] for jt in JointType} + +# Joint type name -> DOF removed. +_JOINT_TYPE_TO_DOF: dict[str, int] = {jt.name: jt.dof for jt in JointType} + +_NUM_JOINT_TYPES = len(JointType) + + +def _encode_node_features( + body_positions: list[list[float]], + body_orientations: list[list[list[float]]], + joints: list[dict[str, Any]], + grounded: bool, +) -> torch.Tensor: + """Encode node features as a [N, 22] tensor. + + Dims 0-2: position (centered per graph) + Dims 3-11: flattened 3x3 rotation matrix + Dim 12: is grounded flag + Dim 13: node degree / 10 + Dims 14-19: degree bucket one-hot (0, 1, 2, 3, 4, 5+) + Dim 20: total incident DOF removed / 30 + Dim 21: fraction of incident joints that are FIXED + """ + n_bodies = len(body_positions) + + # Positions centered per graph. + pos = torch.tensor(body_positions, dtype=torch.float32) + centroid = pos.mean(dim=0, keepdim=True) + pos = pos - centroid + + # Flattened orientation matrices [N, 9]. + orient = torch.tensor(body_orientations, dtype=torch.float32).reshape(n_bodies, 9) + + # Compute per-node degree and incident joint stats. + degree = torch.zeros(n_bodies, dtype=torch.float32) + dof_removed = torch.zeros(n_bodies, dtype=torch.float32) + fixed_count = torch.zeros(n_bodies, dtype=torch.float32) + + for j in joints: + a, b = j["body_a"], j["body_b"] + jtype = j["type"] + dof = _JOINT_TYPE_TO_DOF.get(jtype, 0) + degree[a] += 1 + degree[b] += 1 + dof_removed[a] += dof + dof_removed[b] += dof + if jtype == "FIXED": + fixed_count[a] += 1 + fixed_count[b] += 1 + + # Grounded flag: body 0 if assembly is grounded. + grounded_flag = torch.zeros(n_bodies, 1, dtype=torch.float32) + if grounded and n_bodies > 0: + grounded_flag[0, 0] = 1.0 + + # Degree normalized. + degree_norm = (degree / 10.0).unsqueeze(1) + + # Degree bucket one-hot [N, 6]: buckets 0, 1, 2, 3, 4, 5+. + bucket = degree.clamp(max=5).long() + degree_onehot = torch.zeros(n_bodies, 6, dtype=torch.float32) + degree_onehot.scatter_(1, bucket.unsqueeze(1), 1.0) + + # Total incident DOF removed, normalized. + dof_norm = (dof_removed / 30.0).unsqueeze(1) + + # Fraction of incident joints that are FIXED. + safe_degree = degree.clamp(min=1) + fixed_frac = (fixed_count / safe_degree).unsqueeze(1) + + # Concatenate: [N, 3+9+1+1+6+1+1] = [N, 22]. + x = torch.cat( + [pos, orient, grounded_flag, degree_norm, degree_onehot, dof_norm, fixed_frac], dim=1 + ) + return x + + +def _encode_edge_features( + joints: list[dict[str, Any]], + body_positions: list[list[float]], +) -> tuple[torch.Tensor, torch.Tensor]: + """Encode edge features and build bidirectional edge_index. + + Returns: + edge_index: [2, 2*n_joints] (each joint as two directed edges). + edge_attr: [2*n_joints, 22] edge features. + """ + n_joints = len(joints) + if n_joints == 0: + return ( + torch.zeros(2, 0, dtype=torch.long), + torch.zeros(0, 22, dtype=torch.float32), + ) + + pos = body_positions + src_list: list[int] = [] + dst_list: list[int] = [] + features: list[list[float]] = [] + + for j in joints: + a, b = j["body_a"], j["body_b"] + jtype_name = j["type"] + ordinal = _JOINT_TYPE_TO_ORD.get(jtype_name, 0) + dof = _JOINT_TYPE_TO_DOF.get(jtype_name, 0) + + # One-hot joint type [11]. + onehot = [0.0] * _NUM_JOINT_TYPES + onehot[ordinal] = 1.0 + + # Axis [3]. + axis = j.get("axis", [0.0, 0.0, 1.0]) + + # Anchor offsets relative to body positions (fallback to zeros). + anchor_a_raw = j.get("anchor_a") + anchor_b_raw = j.get("anchor_b") + if anchor_a_raw is not None: + anchor_a_off = [anchor_a_raw[k] - pos[a][k] for k in range(3)] + else: + anchor_a_off = [0.0, 0.0, 0.0] + if anchor_b_raw is not None: + anchor_b_off = [anchor_b_raw[k] - pos[b][k] for k in range(3)] + else: + anchor_b_off = [0.0, 0.0, 0.0] + + pitch = j.get("pitch", 0.0) + dof_norm = dof / 6.0 + + feat = onehot + axis + anchor_a_off + anchor_b_off + [pitch, dof_norm] + + # Bidirectional: a->b and b->a with identical features. + src_list.extend([a, b]) + dst_list.extend([b, a]) + features.append(feat) + features.append(feat) + + edge_index = torch.tensor([src_list, dst_list], dtype=torch.long) + edge_attr = torch.tensor(features, dtype=torch.float32) + return edge_index, edge_attr + + +def assembly_to_pyg( + example: dict[str, Any], + *, + include_labels: bool = True, +) -> Data: + """Convert a datagen training example dict to a PyG Data object. + + Args: + example: Dict from ``SyntheticAssemblyGenerator.generate_training_batch()``. + include_labels: Attach ground truth labels to the Data object. + + Returns: + ``torch_geometric.data.Data`` with node features ``x``, ``edge_index``, + ``edge_attr``, and optionally label tensors ``y_edge``, ``y_graph``, + ``y_joint_type``, ``y_dof``, ``y_body_dof``. + """ + body_positions = example["body_positions"] + body_orientations = example["body_orientations"] + joints = example["joints"] + grounded = example.get("grounded", False) + + x = _encode_node_features(body_positions, body_orientations, joints, grounded) + edge_index, edge_attr = _encode_edge_features(joints, body_positions) + + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + data.num_nodes = x.size(0) + + if include_labels: + _attach_labels(data, example, joints) + + return data + + +def _attach_labels( + data: Data, + example: dict[str, Any], + joints: list[dict[str, Any]], +) -> None: + """Attach ground truth label tensors to a Data object.""" + joint_labels = example.get("joint_labels", {}) + labels = example.get("labels", {}) + + # Per-edge: binary independent (1) / redundant (0). + # Duplicated for bidirectional edges. + n_joints = len(joints) + edge_labels: list[float] = [] + joint_type_labels: list[int] = [] + for j in joints: + jid = j["joint_id"] + jl = joint_labels.get(jid) or joint_labels.get(str(jid), {}) + is_independent = 1.0 if jl.get("redundant_constraints", 0) == 0 else 0.0 + ordinal = _JOINT_TYPE_TO_ORD.get(j["type"], 0) + # Bidirectional: duplicate. + edge_labels.extend([is_independent, is_independent]) + joint_type_labels.extend([ordinal, ordinal]) + + if n_joints > 0: + data.y_edge = torch.tensor(edge_labels, dtype=torch.float32) + data.y_joint_type = torch.tensor(joint_type_labels, dtype=torch.long) + else: + data.y_edge = torch.zeros(0, dtype=torch.float32) + data.y_joint_type = torch.zeros(0, dtype=torch.long) + + # Assembly classification. + classification = example.get("assembly_classification", "") + data.y_graph = torch.tensor( + [ASSEMBLY_CLASSES.get(classification, 0)], + dtype=torch.long, + ) + + # Total DOF. + assembly_labels = labels.get("assembly", {}) + data.y_dof = torch.tensor( + [float(assembly_labels.get("total_dof", 0))], + dtype=torch.float32, + ) + + # Per-body DOF: [N, 2] (translational, rotational). + per_body = labels.get("per_body", []) + n_bodies = data.num_nodes + body_dof = torch.zeros(n_bodies, 2, dtype=torch.float32) + for entry in per_body: + bid = entry["body_id"] + if 0 <= bid < n_bodies: + body_dof[bid, 0] = float(entry["translational_dof"]) + body_dof[bid, 1] = float(entry["rotational_dof"]) + data.y_body_dof = body_dof diff --git a/solver/models/heads.py b/solver/models/heads.py new file mode 100644 index 0000000..840de10 --- /dev/null +++ b/solver/models/heads.py @@ -0,0 +1,119 @@ +"""Task-specific prediction heads for assembly GNN.""" + +from __future__ import annotations + +import torch +import torch.nn as nn + +__all__ = [ + "DOFRegressionHead", + "DOFTrackingHead", + "EdgeClassificationHead", + "GraphClassificationHead", + "JointTypeHead", +] + + +class EdgeClassificationHead(nn.Module): + """Binary edge classification (independent vs redundant). + + Args: + hidden_dim: Input embedding dimension. + inner_dim: Internal MLP hidden dimension. + """ + + def __init__(self, hidden_dim: int = 128, inner_dim: int = 64) -> None: + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, inner_dim), + nn.ReLU(), + nn.Linear(inner_dim, 1), + ) + + def forward(self, edge_emb: torch.Tensor) -> torch.Tensor: + """Return logits [E, 1].""" + return self.mlp(edge_emb) + + +class GraphClassificationHead(nn.Module): + """Assembly classification (well/under/over-constrained/mixed). + + Args: + hidden_dim: Input embedding dimension. + num_classes: Number of classification categories. + """ + + def __init__(self, hidden_dim: int = 128, num_classes: int = 4) -> None: + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim), + nn.ReLU(), + nn.Linear(hidden_dim, num_classes), + ) + + def forward(self, graph_emb: torch.Tensor) -> torch.Tensor: + """Return logits [B, num_classes].""" + return self.mlp(graph_emb) + + +class JointTypeHead(nn.Module): + """Joint type classification from edge embeddings. + + Args: + hidden_dim: Input embedding dimension. + num_classes: Number of joint types. + """ + + def __init__(self, hidden_dim: int = 128, num_classes: int = 11) -> None: + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, num_classes), + ) + + def forward(self, edge_emb: torch.Tensor) -> torch.Tensor: + """Return logits [E, num_classes].""" + return self.mlp(edge_emb) + + +class DOFRegressionHead(nn.Module): + """Total DOF regression from graph embedding. + + Args: + hidden_dim: Input embedding dimension. + """ + + def __init__(self, hidden_dim: int = 128) -> None: + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 1), + nn.Softplus(), + ) + + def forward(self, graph_emb: torch.Tensor) -> torch.Tensor: + """Return non-negative DOF prediction [B, 1].""" + return self.mlp(graph_emb) + + +class DOFTrackingHead(nn.Module): + """Per-body DOF prediction (translational, rotational) from node embeddings. + + Args: + hidden_dim: Input embedding dimension. + """ + + def __init__(self, hidden_dim: int = 128) -> None: + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(hidden_dim, hidden_dim // 2), + nn.ReLU(), + nn.Linear(hidden_dim // 2, 2), + nn.Softplus(), + ) + + def forward(self, node_emb: torch.Tensor) -> torch.Tensor: + """Return non-negative per-body DOF [N, 2].""" + return self.mlp(node_emb) diff --git a/solver/models/losses.py b/solver/models/losses.py new file mode 100644 index 0000000..1b1d48e --- /dev/null +++ b/solver/models/losses.py @@ -0,0 +1,161 @@ +"""Uncertainty-weighted multi-task loss (Kendall et al., 2018).""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn + +if TYPE_CHECKING: + pass + +__all__ = ["MultiTaskLoss"] + + +class MultiTaskLoss(nn.Module): + """Multi-task loss with learnable uncertainty weighting. + + Each task has a learnable ``log_var`` parameter (log variance) that + automatically balances task contributions during training. The loss + for task *i* is:: + + (1 / (2 * sigma_i^2)) * weight_i * L_i + 0.5 * log(sigma_i^2) + + which simplifies to:: + + exp(-log_var_i) * weight_i * L_i + 0.5 * log_var_i + + Args: + edge_weight: Initial scale for edge classification loss. + graph_weight: Initial scale for graph classification loss. + joint_type_weight: Initial scale for joint type loss. + dof_weight: Initial scale for DOF regression loss. + body_dof_weight: Initial scale for per-body DOF loss. + redundant_penalty: Extra weight on redundant edges (label=0) in + the edge BCE loss. + """ + + def __init__( + self, + edge_weight: float = 1.0, + graph_weight: float = 0.5, + joint_type_weight: float = 0.3, + dof_weight: float = 0.2, + body_dof_weight: float = 0.2, + redundant_penalty: float = 2.0, + ) -> None: + super().__init__() + self.weights = { + "edge": edge_weight, + "graph": graph_weight, + "joint_type": joint_type_weight, + "dof": dof_weight, + "body_dof": body_dof_weight, + } + self.redundant_penalty = redundant_penalty + + # Learnable log-variance parameters, one per task. + # Initialized to 0 → sigma^2 = 1. + self.log_vars = nn.ParameterDict( + {name: nn.Parameter(torch.zeros(1)) for name in self.weights} + ) + + def forward( + self, + predictions: dict[str, torch.Tensor], + targets: dict[str, torch.Tensor], + ) -> tuple[torch.Tensor, dict[str, torch.Tensor]]: + """Compute total loss and per-task breakdown. + + Args: + predictions: Dict with keys from AssemblyGNN output: + ``edge_pred``, ``graph_pred``, ``joint_type_pred``, + ``dof_pred``, ``body_dof_pred``. + targets: Dict with label tensors: + ``y_edge``, ``y_graph``, ``y_joint_type``, + ``y_dof``, ``y_body_dof``. + + Returns: + total_loss: Scalar total loss. + breakdown: Dict of per-task raw loss values (before weighting). + """ + total = torch.tensor(0.0, device=self._device(predictions)) + breakdown: dict[str, torch.Tensor] = {} + + # Edge classification (BCE with asymmetric redundancy penalty). + if "edge_pred" in predictions and "y_edge" in targets: + loss = self._edge_loss(predictions["edge_pred"], targets["y_edge"]) + total = total + self._weighted(loss, "edge") + breakdown["edge"] = loss.detach() + + # Graph classification. + if "graph_pred" in predictions and "y_graph" in targets: + loss = nn.functional.cross_entropy( + predictions["graph_pred"], + targets["y_graph"], + ) + total = total + self._weighted(loss, "graph") + breakdown["graph"] = loss.detach() + + # Joint type classification. + if "joint_type_pred" in predictions and "y_joint_type" in targets: + loss = nn.functional.cross_entropy( + predictions["joint_type_pred"], + targets["y_joint_type"], + ) + total = total + self._weighted(loss, "joint_type") + breakdown["joint_type"] = loss.detach() + + # DOF regression. + if "dof_pred" in predictions and "y_dof" in targets: + loss = nn.functional.smooth_l1_loss( + predictions["dof_pred"], + targets["y_dof"], + ) + total = total + self._weighted(loss, "dof") + breakdown["dof"] = loss.detach() + + # Per-body DOF tracking. + if "body_dof_pred" in predictions and "y_body_dof" in targets: + loss = nn.functional.smooth_l1_loss( + predictions["body_dof_pred"], + targets["y_body_dof"], + ) + total = total + self._weighted(loss, "body_dof") + breakdown["body_dof"] = loss.detach() + + return total, breakdown + + def _edge_loss( + self, + pred: torch.Tensor, + target: torch.Tensor, + ) -> torch.Tensor: + """BCE loss with asymmetric weighting for redundant edges.""" + # pred: [E, 1] logits, target: [E] binary. + pred_flat = pred.squeeze(-1) + # Weight: redundant (0) gets higher penalty, independent (1) gets 1.0. + weight = torch.where( + target == 0, + torch.tensor(self.redundant_penalty, device=pred.device), + torch.tensor(1.0, device=pred.device), + ) + return nn.functional.binary_cross_entropy_with_logits( + pred_flat, + target, + weight=weight, + ) + + def _weighted(self, loss: torch.Tensor, task_name: str) -> torch.Tensor: + """Apply uncertainty weighting: exp(-log_var) * w * L + 0.5 * log_var.""" + log_var = self.log_vars[task_name].squeeze() + weight = self.weights[task_name] + return torch.exp(-log_var) * weight * loss + 0.5 * log_var + + @staticmethod + def _device(predictions: dict[str, torch.Tensor]) -> torch.device: + """Infer device from prediction tensors.""" + for v in predictions.values(): + return v.device + return torch.device("cpu") diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/test_assembly_gnn.py b/tests/models/test_assembly_gnn.py new file mode 100644 index 0000000..1b9bb2d --- /dev/null +++ b/tests/models/test_assembly_gnn.py @@ -0,0 +1,188 @@ +"""Tests for solver.models.assembly_gnn -- main model wiring.""" + +from __future__ import annotations + +import pytest +import torch + +from solver.models.assembly_gnn import AssemblyGNN + + +def _default_heads_config(dof_tracking: bool = False) -> dict: + return { + "edge_classification": {"enabled": True, "hidden_dim": 64}, + "graph_classification": {"enabled": True, "num_classes": 4}, + "joint_type": {"enabled": True, "num_classes": 11}, + "dof_regression": {"enabled": True}, + "dof_tracking": {"enabled": dof_tracking}, + } + + +def _random_graph( + n_nodes: int = 8, + n_edges: int = 16, + batch_size: int = 2, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + x = torch.randn(n_nodes, 22) + edge_index = torch.randint(0, n_nodes, (2, n_edges)) + edge_attr = torch.randn(n_edges, 22) + batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size) + if len(batch) < n_nodes: + batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)]) + return x, edge_index, edge_attr, batch + + +class TestAssemblyGNNGIN: + """AssemblyGNN with GIN encoder.""" + + def test_forward_all_heads(self) -> None: + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 64, "num_layers": 2}, + heads_config=_default_heads_config(), + ) + x, ei, ea, batch = _random_graph() + preds = model(x, ei, ea, batch) + assert "edge_pred" in preds + assert "graph_pred" in preds + assert "joint_type_pred" in preds + assert "dof_pred" in preds + + def test_output_shapes(self) -> None: + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 64, "num_layers": 2}, + heads_config=_default_heads_config(), + ) + x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20, batch_size=3) + preds = model(x, ei, ea, batch) + assert preds["edge_pred"].shape == (20, 1) + assert preds["graph_pred"].shape == (3, 4) + assert preds["joint_type_pred"].shape == (20, 11) + assert preds["dof_pred"].shape == (3, 1) + + def test_gradients_flow(self) -> None: + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 32, "num_layers": 2}, + heads_config=_default_heads_config(), + ) + x, ei, ea, batch = _random_graph() + x.requires_grad_(True) + preds = model(x, ei, ea, batch) + total = sum(p.sum() for p in preds.values()) + total.backward() + assert x.grad is not None + + def test_no_heads_returns_empty(self) -> None: + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 32, "num_layers": 2}, + heads_config={}, + ) + x, ei, ea, batch = _random_graph() + preds = model(x, ei, ea, batch) + assert len(preds) == 0 + + +class TestAssemblyGNNGAT: + """AssemblyGNN with GAT encoder.""" + + def test_forward_all_heads(self) -> None: + model = AssemblyGNN( + encoder_type="gat", + encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4}, + heads_config=_default_heads_config(dof_tracking=True), + ) + x, ei, ea, batch = _random_graph() + preds = model(x, ei, ea, batch) + assert "edge_pred" in preds + assert "graph_pred" in preds + assert "joint_type_pred" in preds + assert "dof_pred" in preds + assert "body_dof_pred" in preds + + def test_body_dof_shape(self) -> None: + model = AssemblyGNN( + encoder_type="gat", + encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4}, + heads_config=_default_heads_config(dof_tracking=True), + ) + x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20) + preds = model(x, ei, ea, batch) + assert preds["body_dof_pred"].shape == (10, 2) + + +class TestAssemblyGNNEdgeCases: + """Edge cases and error handling.""" + + def test_unknown_encoder_raises(self) -> None: + with pytest.raises(ValueError, match="Unknown encoder"): + AssemblyGNN(encoder_type="transformer") + + def test_selective_heads(self) -> None: + """Only enabled heads produce output.""" + config = { + "edge_classification": {"enabled": True}, + "graph_classification": {"enabled": False}, + "joint_type": {"enabled": True, "num_classes": 11}, + } + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 32, "num_layers": 2}, + heads_config=config, + ) + x, ei, ea, batch = _random_graph() + preds = model(x, ei, ea, batch) + assert "edge_pred" in preds + assert "joint_type_pred" in preds + assert "graph_pred" not in preds + assert "dof_pred" not in preds + + def test_no_batch_single_graph(self) -> None: + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 32, "num_layers": 2}, + heads_config=_default_heads_config(), + ) + x = torch.randn(6, 22) + ei = torch.randint(0, 6, (2, 10)) + ea = torch.randn(10, 22) + preds = model(x, ei, ea) + assert preds["graph_pred"].shape == (1, 4) + assert preds["dof_pred"].shape == (1, 1) + + def test_parameter_count_reasonable(self) -> None: + """Sanity check that model has learnable parameters.""" + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 64, "num_layers": 2}, + heads_config=_default_heads_config(), + ) + n_params = sum(p.numel() for p in model.parameters()) + assert n_params > 1000 # non-trivial model + + +class TestAssemblyGNNEndToEnd: + """End-to-end test with datagen pipeline.""" + + def test_datagen_to_model(self) -> None: + from solver.datagen.generator import SyntheticAssemblyGenerator + from solver.models.graph_conv import assembly_to_pyg + + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(batch_size=2, complexity_tier="simple") + + model = AssemblyGNN( + encoder_type="gin", + encoder_config={"hidden_dim": 32, "num_layers": 2}, + heads_config=_default_heads_config(), + ) + model.eval() + + for ex in batch: + data = assembly_to_pyg(ex) + with torch.no_grad(): + preds = model(data.x, data.edge_index, data.edge_attr) + assert "edge_pred" in preds + assert preds["edge_pred"].shape[0] == data.edge_index.shape[1] diff --git a/tests/models/test_encoder.py b/tests/models/test_encoder.py new file mode 100644 index 0000000..d2948a4 --- /dev/null +++ b/tests/models/test_encoder.py @@ -0,0 +1,183 @@ +"""Tests for solver.models.encoder -- GIN and GAT graph encoders.""" + +from __future__ import annotations + +import pytest +import torch + +from solver.models.encoder import GATEncoder, GINEncoder + + +def _random_graph( + n_nodes: int = 8, + n_edges: int = 20, + node_dim: int = 22, + edge_dim: int = 22, + batch_size: int = 2, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Create a random graph for testing.""" + x = torch.randn(n_nodes, node_dim) + edge_index = torch.randint(0, n_nodes, (2, n_edges)) + edge_attr = torch.randn(n_edges, edge_dim) + # Assign nodes to batches roughly evenly. + batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size) + # Handle remainder nodes. + if len(batch) < n_nodes: + batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)]) + return x, edge_index, edge_attr, batch + + +class TestGINEncoder: + """GINEncoder shape and gradient tests.""" + + def test_output_shapes(self) -> None: + enc = GINEncoder(node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2) + x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3) + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + assert node_emb.shape == (10, 64) + assert edge_emb.shape == (16, 64) + assert graph_emb.shape == (3, 64) + + def test_hidden_dim_property(self) -> None: + enc = GINEncoder(hidden_dim=128) + assert enc.hidden_dim == 128 + + def test_default_dimensions(self) -> None: + enc = GINEncoder() + x, ei, ea, batch = _random_graph() + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + assert node_emb.shape[1] == 128 + assert edge_emb.shape[1] == 128 + assert graph_emb.shape[1] == 128 + + def test_no_batch_defaults_to_single_graph(self) -> None: + enc = GINEncoder(hidden_dim=64, num_layers=2) + x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10) + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None) + assert graph_emb.shape == (1, 64) + + def test_gradients_flow(self) -> None: + enc = GINEncoder(hidden_dim=32, num_layers=2) + x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) + x.requires_grad_(True) + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + loss = graph_emb.sum() + loss.backward() + assert x.grad is not None + assert x.grad.abs().sum() > 0 + + def test_zero_edges(self) -> None: + enc = GINEncoder(hidden_dim=32, num_layers=2) + x = torch.randn(4, 22) + ei = torch.zeros(2, 0, dtype=torch.long) + ea = torch.zeros(0, 22) + batch = torch.tensor([0, 0, 1, 1]) + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + assert node_emb.shape == (4, 32) + assert edge_emb.shape == (0, 32) + assert graph_emb.shape == (2, 32) + + def test_single_node(self) -> None: + enc = GINEncoder(hidden_dim=32, num_layers=2) + # Train with a small batch first to populate BN running stats. + x_train = torch.randn(4, 22) + ei_train = torch.zeros(2, 0, dtype=torch.long) + ea_train = torch.zeros(0, 22) + batch_train = torch.tensor([0, 0, 1, 1]) + enc.train() + enc(x_train, ei_train, ea_train, batch_train) + # Now test single node in eval mode (BN uses running stats). + enc.eval() + x = torch.randn(1, 22) + ei = torch.zeros(2, 0, dtype=torch.long) + ea = torch.zeros(0, 22) + with torch.no_grad(): + node_emb, edge_emb, graph_emb = enc(x, ei, ea) + assert node_emb.shape == (1, 32) + assert graph_emb.shape == (1, 32) + + def test_eval_mode(self) -> None: + """Encoder works in eval mode (BatchNorm uses running stats).""" + enc = GINEncoder(hidden_dim=32, num_layers=2) + # Forward pass in train mode to populate BN stats. + x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) + enc.train() + enc(x, ei, ea, batch) + # Switch to eval. + enc.eval() + with torch.no_grad(): + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + assert node_emb.shape[1] == 32 + + +class TestGATEncoder: + """GATEncoder shape and gradient tests.""" + + def test_output_shapes(self) -> None: + enc = GATEncoder( + node_features_dim=22, + edge_features_dim=22, + hidden_dim=64, + num_layers=2, + num_heads=4, + ) + x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3) + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + assert node_emb.shape == (10, 64) + assert edge_emb.shape == (16, 64) + assert graph_emb.shape == (3, 64) + + def test_hidden_dim_property(self) -> None: + enc = GATEncoder(hidden_dim=256, num_heads=8) + assert enc.hidden_dim == 256 + + def test_default_dimensions(self) -> None: + enc = GATEncoder() + x, ei, ea, batch = _random_graph() + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + assert node_emb.shape[1] == 256 + assert edge_emb.shape[1] == 256 + assert graph_emb.shape[1] == 256 + + def test_no_batch_defaults_to_single_graph(self) -> None: + enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4) + x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10) + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None) + assert graph_emb.shape == (1, 64) + + def test_gradients_flow(self) -> None: + enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4) + x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) + x.requires_grad_(True) + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + loss = graph_emb.sum() + loss.backward() + assert x.grad is not None + assert x.grad.abs().sum() > 0 + + def test_residual_connection(self) -> None: + """With residual=True, output should differ from residual=False.""" + x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) + torch.manual_seed(0) + enc_res = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=True) + torch.manual_seed(0) + enc_no = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=False) + with torch.no_grad(): + n1, _, _ = enc_res(x, ei, ea, batch) + n2, _, _ = enc_no(x, ei, ea, batch) + # Outputs should generally differ (unless by very unlikely coincidence). + assert not torch.allclose(n1, n2, atol=1e-4) + + def test_hidden_dim_must_divide_heads(self) -> None: + with pytest.raises(ValueError, match="divisible"): + GATEncoder(hidden_dim=100, num_heads=8) + + def test_eval_mode(self) -> None: + enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4) + x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12) + enc.train() + enc(x, ei, ea, batch) + enc.eval() + with torch.no_grad(): + node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch) + assert node_emb.shape[1] == 64 diff --git a/tests/models/test_factory.py b/tests/models/test_factory.py new file mode 100644 index 0000000..636cf2b --- /dev/null +++ b/tests/models/test_factory.py @@ -0,0 +1,110 @@ +"""Tests for solver.models.factory -- model and loss construction from config.""" + +from __future__ import annotations + +import yaml + +from solver.models.assembly_gnn import AssemblyGNN +from solver.models.factory import build_loss, build_model +from solver.models.losses import MultiTaskLoss + + +def _load_yaml(path: str) -> dict: + with open(path) as f: + return yaml.safe_load(f) + + +class TestBuildModel: + """build_model constructs AssemblyGNN from config.""" + + def test_baseline_config(self) -> None: + config = _load_yaml("configs/model/baseline.yaml") + model = build_model(config) + assert isinstance(model, AssemblyGNN) + assert model.encoder.hidden_dim == 128 + + def test_gat_config(self) -> None: + config = _load_yaml("configs/model/gat.yaml") + model = build_model(config) + assert isinstance(model, AssemblyGNN) + assert model.encoder.hidden_dim == 256 + + def test_baseline_heads_present(self) -> None: + config = _load_yaml("configs/model/baseline.yaml") + model = build_model(config) + assert "edge_pred" in model.heads + assert "graph_pred" in model.heads + assert "joint_type_pred" in model.heads + assert "dof_pred" in model.heads + + def test_gat_has_dof_tracking(self) -> None: + config = _load_yaml("configs/model/gat.yaml") + model = build_model(config) + assert "body_dof_pred" in model.heads + + def test_baseline_no_dof_tracking(self) -> None: + config = _load_yaml("configs/model/baseline.yaml") + model = build_model(config) + assert "body_dof_pred" not in model.heads + + def test_minimal_config(self) -> None: + config = {"architecture": "gin"} + model = build_model(config) + assert isinstance(model, AssemblyGNN) + # No heads enabled. + assert len(model.heads) == 0 + + def test_custom_config(self) -> None: + config = { + "architecture": "gin", + "encoder": {"hidden_dim": 64, "num_layers": 2}, + "heads": { + "edge_classification": {"enabled": True}, + "graph_classification": {"enabled": True, "num_classes": 4}, + }, + } + model = build_model(config) + assert model.encoder.hidden_dim == 64 + assert "edge_pred" in model.heads + assert "graph_pred" in model.heads + assert "joint_type_pred" not in model.heads + + +class TestBuildLoss: + """build_loss constructs MultiTaskLoss from training config.""" + + def test_pretrain_config(self) -> None: + config = _load_yaml("configs/training/pretrain.yaml") + loss_fn = build_loss(config) + assert isinstance(loss_fn, MultiTaskLoss) + + def test_weights_from_config(self) -> None: + config = _load_yaml("configs/training/pretrain.yaml") + loss_fn = build_loss(config) + assert loss_fn.weights["edge"] == 1.0 + assert loss_fn.weights["graph"] == 0.5 + assert loss_fn.weights["joint_type"] == 0.3 + assert loss_fn.weights["dof"] == 0.2 + + def test_redundant_penalty_from_config(self) -> None: + config = _load_yaml("configs/training/pretrain.yaml") + loss_fn = build_loss(config) + assert loss_fn.redundant_penalty == 2.0 + + def test_empty_config_uses_defaults(self) -> None: + loss_fn = build_loss({}) + assert isinstance(loss_fn, MultiTaskLoss) + assert loss_fn.weights["edge"] == 1.0 + + def test_custom_weights(self) -> None: + config = { + "loss": { + "edge_weight": 2.0, + "graph_weight": 1.0, + "redundant_penalty": 5.0, + }, + } + loss_fn = build_loss(config) + assert loss_fn.weights["edge"] == 2.0 + assert loss_fn.weights["graph"] == 1.0 + assert loss_fn.redundant_penalty == 5.0 diff --git a/tests/models/test_graph_conv.py b/tests/models/test_graph_conv.py new file mode 100644 index 0000000..e3562fa --- /dev/null +++ b/tests/models/test_graph_conv.py @@ -0,0 +1,144 @@ +"""Tests for solver.models.graph_conv -- assembly to PyG conversion.""" + +from __future__ import annotations + +import torch + +from solver.datagen.generator import SyntheticAssemblyGenerator +from solver.datagen.types import JointType +from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg + + +def _make_example(n_bodies: int = 4, grounded: bool = True, seed: int = 0) -> dict: + """Generate a single training example via the datagen pipeline.""" + gen = SyntheticAssemblyGenerator(seed=seed) + batch = gen.generate_training_batch(batch_size=1, n_bodies_range=(n_bodies, n_bodies + 1)) + return batch[0] + + +class TestAssemblyToPyg: + """assembly_to_pyg converts datagen dicts to PyG Data correctly.""" + + def test_node_feature_shape(self) -> None: + ex = _make_example(n_bodies=5) + data = assembly_to_pyg(ex) + assert data.x.shape == (5, 22) + + def test_edge_feature_shape(self) -> None: + ex = _make_example(n_bodies=4) + data = assembly_to_pyg(ex) + n_joints = ex["n_joints"] + assert data.edge_attr.shape == (n_joints * 2, 22) + + def test_edge_index_bidirectional(self) -> None: + ex = _make_example(n_bodies=4) + data = assembly_to_pyg(ex) + ei = data.edge_index + # Each joint produces 2 directed edges: a->b and b->a. + for i in range(0, ei.size(1), 2): + assert ei[0, i].item() == ei[1, i + 1].item() + assert ei[1, i].item() == ei[0, i + 1].item() + + def test_edge_index_shape(self) -> None: + ex = _make_example(n_bodies=4) + data = assembly_to_pyg(ex) + n_joints = ex["n_joints"] + assert data.edge_index.shape == (2, n_joints * 2) + + def test_node_features_centered(self) -> None: + ex = _make_example(n_bodies=5) + data = assembly_to_pyg(ex) + # Positions (dims 0-2) should be centered (mean ~0). + pos = data.x[:, :3] + assert pos.mean(dim=0).abs().max().item() < 1e-5 + + def test_grounded_flag_set(self) -> None: + ex = _make_example(n_bodies=4, grounded=True) + ex["grounded"] = True + data = assembly_to_pyg(ex) + assert data.x[0, 12].item() == 1.0 + + def test_ungrounded_flag_clear(self) -> None: + ex = _make_example(n_bodies=4) + ex["grounded"] = False + data = assembly_to_pyg(ex) + assert (data.x[:, 12] == 0.0).all() + + def test_edge_type_one_hot_valid(self) -> None: + ex = _make_example(n_bodies=5) + data = assembly_to_pyg(ex) + if data.edge_attr.size(0) > 0: + onehot = data.edge_attr[:, :11] + # Each row should have exactly one 1.0. + assert (onehot.sum(dim=1) == 1.0).all() + + def test_labels_present_when_requested(self) -> None: + ex = _make_example(n_bodies=4) + data = assembly_to_pyg(ex, include_labels=True) + assert hasattr(data, "y_edge") + assert hasattr(data, "y_graph") + assert hasattr(data, "y_joint_type") + assert hasattr(data, "y_dof") + assert hasattr(data, "y_body_dof") + + def test_labels_absent_when_not_requested(self) -> None: + ex = _make_example(n_bodies=4) + data = assembly_to_pyg(ex, include_labels=False) + assert not hasattr(data, "y_edge") + assert not hasattr(data, "y_graph") + + def test_graph_classification_label_mapping(self) -> None: + ex = _make_example(n_bodies=4) + data = assembly_to_pyg(ex) + cls = ex["assembly_classification"] + expected = ASSEMBLY_CLASSES[cls] + assert data.y_graph.item() == expected + + def test_body_dof_shape(self) -> None: + ex = _make_example(n_bodies=5) + data = assembly_to_pyg(ex) + assert data.y_body_dof.shape == (5, 2) + + def test_edge_labels_binary(self) -> None: + ex = _make_example(n_bodies=5) + data = assembly_to_pyg(ex) + if data.y_edge.numel() > 0: + assert ((data.y_edge == 0.0) | (data.y_edge == 1.0)).all() + + def test_dof_removed_normalized(self) -> None: + ex = _make_example(n_bodies=4) + data = assembly_to_pyg(ex) + if data.edge_attr.size(0) > 0: + dof_norm = data.edge_attr[:, 21] + assert (dof_norm >= 0.0).all() + assert (dof_norm <= 1.0).all() + + def test_roundtrip_with_generator(self) -> None: + """Generate a real example and convert -- no crash.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(batch_size=5, complexity_tier="simple") + for ex in batch: + data = assembly_to_pyg(ex) + assert data.x.shape[1] == 22 + assert data.edge_attr.shape[1] == 22 + + +class TestAssemblyClasses: + """ASSEMBLY_CLASSES covers all classifications.""" + + def test_four_classes(self) -> None: + assert len(ASSEMBLY_CLASSES) == 4 + + def test_values_are_0_to_3(self) -> None: + assert set(ASSEMBLY_CLASSES.values()) == {0, 1, 2, 3} + + +class TestJointTypeNames: + """JOINT_TYPE_NAMES matches the JointType enum.""" + + def test_length_matches_enum(self) -> None: + assert len(JOINT_TYPE_NAMES) == len(JointType) + + def test_order_matches_ordinal(self) -> None: + for i, name in enumerate(JOINT_TYPE_NAMES): + assert JointType[name].value[0] == i diff --git a/tests/models/test_heads.py b/tests/models/test_heads.py new file mode 100644 index 0000000..2b63e97 --- /dev/null +++ b/tests/models/test_heads.py @@ -0,0 +1,182 @@ +"""Tests for solver.models.heads -- task-specific prediction heads.""" + +from __future__ import annotations + +import torch + +from solver.models.heads import ( + DOFRegressionHead, + DOFTrackingHead, + EdgeClassificationHead, + GraphClassificationHead, + JointTypeHead, +) + + +class TestEdgeClassificationHead: + """EdgeClassificationHead produces correct shape and gradients.""" + + def test_output_shape(self) -> None: + head = EdgeClassificationHead(hidden_dim=128) + edge_emb = torch.randn(20, 128) + out = head(edge_emb) + assert out.shape == (20, 1) + + def test_output_shape_small(self) -> None: + head = EdgeClassificationHead(hidden_dim=32) + edge_emb = torch.randn(5, 32) + out = head(edge_emb) + assert out.shape == (5, 1) + + def test_gradients_flow(self) -> None: + head = EdgeClassificationHead(hidden_dim=64) + edge_emb = torch.randn(10, 64, requires_grad=True) + out = head(edge_emb) + out.sum().backward() + assert edge_emb.grad is not None + assert edge_emb.grad.abs().sum() > 0 + + def test_zero_edges(self) -> None: + head = EdgeClassificationHead(hidden_dim=64) + edge_emb = torch.zeros(0, 64) + out = head(edge_emb) + assert out.shape == (0, 1) + + def test_output_is_logits(self) -> None: + """Output should be unbounded logits (not probabilities).""" + head = EdgeClassificationHead(hidden_dim=64) + torch.manual_seed(42) + edge_emb = torch.randn(100, 64) + out = head(edge_emb) + # Logits can be negative. + assert out.min().item() < 0 or out.max().item() > 1 + + +class TestGraphClassificationHead: + """GraphClassificationHead produces correct shape and gradients.""" + + def test_output_shape(self) -> None: + head = GraphClassificationHead(hidden_dim=128, num_classes=4) + graph_emb = torch.randn(3, 128) + out = head(graph_emb) + assert out.shape == (3, 4) + + def test_custom_num_classes(self) -> None: + head = GraphClassificationHead(hidden_dim=64, num_classes=8) + graph_emb = torch.randn(2, 64) + out = head(graph_emb) + assert out.shape == (2, 8) + + def test_gradients_flow(self) -> None: + head = GraphClassificationHead(hidden_dim=64) + graph_emb = torch.randn(2, 64, requires_grad=True) + out = head(graph_emb) + out.sum().backward() + assert graph_emb.grad is not None + + def test_single_graph(self) -> None: + head = GraphClassificationHead(hidden_dim=128) + graph_emb = torch.randn(1, 128) + out = head(graph_emb) + assert out.shape == (1, 4) + + +class TestJointTypeHead: + """JointTypeHead produces correct shape and gradients.""" + + def test_output_shape(self) -> None: + head = JointTypeHead(hidden_dim=128, num_classes=11) + edge_emb = torch.randn(20, 128) + out = head(edge_emb) + assert out.shape == (20, 11) + + def test_custom_classes(self) -> None: + head = JointTypeHead(hidden_dim=64, num_classes=7) + edge_emb = torch.randn(10, 64) + out = head(edge_emb) + assert out.shape == (10, 7) + + def test_gradients_flow(self) -> None: + head = JointTypeHead(hidden_dim=64) + edge_emb = torch.randn(10, 64, requires_grad=True) + out = head(edge_emb) + out.sum().backward() + assert edge_emb.grad is not None + + def test_zero_edges(self) -> None: + head = JointTypeHead(hidden_dim=64) + edge_emb = torch.zeros(0, 64) + out = head(edge_emb) + assert out.shape == (0, 11) + + +class TestDOFRegressionHead: + """DOFRegressionHead produces correct shape and non-negative output.""" + + def test_output_shape(self) -> None: + head = DOFRegressionHead(hidden_dim=128) + graph_emb = torch.randn(3, 128) + out = head(graph_emb) + assert out.shape == (3, 1) + + def test_output_non_negative(self) -> None: + """Softplus ensures non-negative output.""" + head = DOFRegressionHead(hidden_dim=64) + torch.manual_seed(0) + graph_emb = torch.randn(50, 64) + out = head(graph_emb) + assert (out >= 0).all() + + def test_gradients_flow(self) -> None: + head = DOFRegressionHead(hidden_dim=64) + graph_emb = torch.randn(2, 64, requires_grad=True) + out = head(graph_emb) + out.sum().backward() + assert graph_emb.grad is not None + + def test_single_graph(self) -> None: + head = DOFRegressionHead(hidden_dim=32) + graph_emb = torch.randn(1, 32) + out = head(graph_emb) + assert out.shape == (1, 1) + assert out.item() >= 0 + + +class TestDOFTrackingHead: + """DOFTrackingHead produces correct shape and non-negative output.""" + + def test_output_shape(self) -> None: + head = DOFTrackingHead(hidden_dim=128) + node_emb = torch.randn(10, 128) + out = head(node_emb) + assert out.shape == (10, 2) + + def test_output_non_negative(self) -> None: + """Softplus ensures non-negative output.""" + head = DOFTrackingHead(hidden_dim=64) + torch.manual_seed(0) + node_emb = torch.randn(50, 64) + out = head(node_emb) + assert (out >= 0).all() + + def test_gradients_flow(self) -> None: + head = DOFTrackingHead(hidden_dim=64) + node_emb = torch.randn(10, 64, requires_grad=True) + out = head(node_emb) + out.sum().backward() + assert node_emb.grad is not None + + def test_single_node(self) -> None: + head = DOFTrackingHead(hidden_dim=32) + node_emb = torch.randn(1, 32) + out = head(node_emb) + assert out.shape == (1, 2) + assert (out >= 0).all() + + def test_two_columns_independent(self) -> None: + """Translational and rotational DOF are independently predicted.""" + head = DOFTrackingHead(hidden_dim=64) + node_emb = torch.randn(20, 64) + out = head(node_emb) + # The two columns should generally differ. + assert not torch.allclose(out[:, 0], out[:, 1], atol=1e-6) diff --git a/tests/models/test_losses.py b/tests/models/test_losses.py new file mode 100644 index 0000000..4d7b941 --- /dev/null +++ b/tests/models/test_losses.py @@ -0,0 +1,156 @@ +"""Tests for solver.models.losses -- uncertainty-weighted multi-task loss.""" + +from __future__ import annotations + +import torch + +from solver.models.losses import MultiTaskLoss + + +def _make_predictions_and_targets( + n_edges: int = 20, + batch_size: int = 3, + n_nodes: int = 10, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + preds = { + "edge_pred": torch.randn(n_edges, 1), + "graph_pred": torch.randn(batch_size, 4), + "joint_type_pred": torch.randn(n_edges, 11), + "dof_pred": torch.rand(batch_size, 1) * 10, + "body_dof_pred": torch.rand(n_nodes, 2) * 6, + } + targets = { + "y_edge": torch.randint(0, 2, (n_edges,)).float(), + "y_graph": torch.randint(0, 4, (batch_size,)), + "y_joint_type": torch.randint(0, 11, (n_edges,)), + "y_dof": torch.rand(batch_size, 1) * 10, + "y_body_dof": torch.rand(n_nodes, 2) * 6, + } + return preds, targets + + +class TestMultiTaskLoss: + """MultiTaskLoss computation tests.""" + + def test_returns_scalar_and_breakdown(self) -> None: + loss_fn = MultiTaskLoss() + preds, targets = _make_predictions_and_targets() + total, breakdown = loss_fn(preds, targets) + assert total.dim() == 0 # scalar + assert isinstance(breakdown, dict) + + def test_all_tasks_in_breakdown(self) -> None: + loss_fn = MultiTaskLoss() + preds, targets = _make_predictions_and_targets() + _, breakdown = loss_fn(preds, targets) + assert "edge" in breakdown + assert "graph" in breakdown + assert "joint_type" in breakdown + assert "dof" in breakdown + assert "body_dof" in breakdown + + def test_total_is_positive(self) -> None: + loss_fn = MultiTaskLoss() + preds, targets = _make_predictions_and_targets() + total, _ = loss_fn(preds, targets) + # With random predictions, loss should be positive. + assert total.item() > 0 + + def test_skips_missing_predictions(self) -> None: + loss_fn = MultiTaskLoss() + preds = {"edge_pred": torch.randn(10, 1)} + targets = {"y_edge": torch.randint(0, 2, (10,)).float()} + total, breakdown = loss_fn(preds, targets) + assert "edge" in breakdown + assert "graph" not in breakdown + assert "joint_type" not in breakdown + + def test_skips_missing_targets(self) -> None: + loss_fn = MultiTaskLoss() + preds = { + "edge_pred": torch.randn(10, 1), + "graph_pred": torch.randn(2, 4), + } + targets = {"y_edge": torch.randint(0, 2, (10,)).float()} + _, breakdown = loss_fn(preds, targets) + assert "edge" in breakdown + assert "graph" not in breakdown + + def test_gradients_flow_to_log_vars(self) -> None: + loss_fn = MultiTaskLoss() + preds, targets = _make_predictions_and_targets() + # Make preds require grad. + for k in preds: + preds[k] = preds[k].requires_grad_(True) + total, _ = loss_fn(preds, targets) + total.backward() + for name, param in loss_fn.log_vars.items(): + assert param.grad is not None, f"No gradient for log_var[{name}]" + + def test_gradients_flow_to_predictions(self) -> None: + loss_fn = MultiTaskLoss() + preds, targets = _make_predictions_and_targets() + for k in preds: + preds[k] = preds[k].requires_grad_(True) + total, _ = loss_fn(preds, targets) + total.backward() + for k, v in preds.items(): + assert v.grad is not None, f"No gradient for prediction[{k}]" + + def test_redundant_penalty_applies(self) -> None: + """Redundant edges (label=0) should have higher loss contribution.""" + loss_fn = MultiTaskLoss(redundant_penalty=5.0) + # All-zero predictions, label=0 (redundant). + preds_red = {"edge_pred": torch.zeros(10, 1)} + targets_red = {"y_edge": torch.zeros(10)} + total_red, _ = loss_fn(preds_red, targets_red) + + loss_fn2 = MultiTaskLoss(redundant_penalty=1.0) + total_eq, _ = loss_fn2(preds_red, targets_red) + + # Higher penalty should produce higher loss. + assert total_red.item() > total_eq.item() + + def test_empty_predictions_returns_zero(self) -> None: + loss_fn = MultiTaskLoss() + total, breakdown = loss_fn({}, {}) + assert total.item() == 0.0 + assert len(breakdown) == 0 + + +class TestUncertaintyWeighting: + """Test uncertainty weighting mechanism specifically.""" + + def test_log_vars_initialized_to_zero(self) -> None: + loss_fn = MultiTaskLoss() + for param in loss_fn.log_vars.values(): + assert param.item() == 0.0 + + def test_log_vars_are_learnable(self) -> None: + loss_fn = MultiTaskLoss() + params = list(loss_fn.parameters()) + log_var_params = [p for p in params if p.shape == (1,)] + assert len(log_var_params) == 5 # one per task + + def test_weighting_reduces_high_loss_influence(self) -> None: + """After a few gradient steps, log_var for a noisy task should increase.""" + loss_fn = MultiTaskLoss(edge_weight=1.0, graph_weight=1.0) + optimizer = torch.optim.SGD(loss_fn.parameters(), lr=0.1) + + # Simulate: edge task has high loss, graph has low. + for _ in range(20): + preds = { + "edge_pred": torch.randn(10, 1) * 10, # high variance -> high loss + "graph_pred": torch.zeros(2, 4), # near-zero loss + } + targets = { + "y_edge": torch.randint(0, 2, (10,)).float(), + "y_graph": torch.zeros(2, dtype=torch.long), + } + optimizer.zero_grad() + total, _ = loss_fn(preds, targets) + total.backward() + optimizer.step() + + # The edge task log_var should have increased (higher uncertainty). + assert loss_fn.log_vars["edge"].item() > loss_fn.log_vars["graph"].item()