Files
solver/solver/models/heads.py
forbes fe41fa3b00 feat(models): implement GNN model layer for assembly constraint analysis
Add complete model layer with GIN and GAT encoders, 5 task-specific
prediction heads, uncertainty-weighted multi-task loss (Kendall et al.
2018), and config-driven factory functions.

New modules:
- graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features)
- encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head)
- heads.py: edge classification, graph classification, joint type,
  DOF regression, per-body DOF tracking heads
- assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads
- losses.py: MultiTaskLoss with learnable log-variance per task
- factory.py: build_model() and build_loss() from YAML configs

Supporting changes:
- generator.py: serialize anchor_a, anchor_b, pitch in joint dicts
- configs: fix joint_type num_classes 12 -> 11 (matches JointType enum)

92 tests covering shapes, gradients, edge cases, and end-to-end
datagen-to-model pipeline.
2026-02-07 10:14:19 -06:00

120 lines
3.3 KiB
Python

"""Task-specific prediction heads for assembly GNN."""
from __future__ import annotations
import torch
import torch.nn as nn
__all__ = [
"DOFRegressionHead",
"DOFTrackingHead",
"EdgeClassificationHead",
"GraphClassificationHead",
"JointTypeHead",
]
class EdgeClassificationHead(nn.Module):
"""Binary edge classification (independent vs redundant).
Args:
hidden_dim: Input embedding dimension.
inner_dim: Internal MLP hidden dimension.
"""
def __init__(self, hidden_dim: int = 128, inner_dim: int = 64) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, inner_dim),
nn.ReLU(),
nn.Linear(inner_dim, 1),
)
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
"""Return logits [E, 1]."""
return self.mlp(edge_emb)
class GraphClassificationHead(nn.Module):
"""Assembly classification (well/under/over-constrained/mixed).
Args:
hidden_dim: Input embedding dimension.
num_classes: Number of classification categories.
"""
def __init__(self, hidden_dim: int = 128, num_classes: int = 4) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, num_classes),
)
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
"""Return logits [B, num_classes]."""
return self.mlp(graph_emb)
class JointTypeHead(nn.Module):
"""Joint type classification from edge embeddings.
Args:
hidden_dim: Input embedding dimension.
num_classes: Number of joint types.
"""
def __init__(self, hidden_dim: int = 128, num_classes: int = 11) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, num_classes),
)
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
"""Return logits [E, num_classes]."""
return self.mlp(edge_emb)
class DOFRegressionHead(nn.Module):
"""Total DOF regression from graph embedding.
Args:
hidden_dim: Input embedding dimension.
"""
def __init__(self, hidden_dim: int = 128) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 1),
nn.Softplus(),
)
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
"""Return non-negative DOF prediction [B, 1]."""
return self.mlp(graph_emb)
class DOFTrackingHead(nn.Module):
"""Per-body DOF prediction (translational, rotational) from node embeddings.
Args:
hidden_dim: Input embedding dimension.
"""
def __init__(self, hidden_dim: int = 128) -> None:
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim // 2),
nn.ReLU(),
nn.Linear(hidden_dim // 2, 2),
nn.Softplus(),
)
def forward(self, node_emb: torch.Tensor) -> torch.Tensor:
"""Return non-negative per-body DOF [N, 2]."""
return self.mlp(node_emb)