Add complete model layer with GIN and GAT encoders, 5 task-specific prediction heads, uncertainty-weighted multi-task loss (Kendall et al. 2018), and config-driven factory functions. New modules: - graph_conv.py: datagen dict -> PyG Data conversion (22D node/edge features) - encoder.py: GINEncoder (3-layer, 128D) and GATEncoder (4-layer, 256D, 8-head) - heads.py: edge classification, graph classification, joint type, DOF regression, per-body DOF tracking heads - assembly_gnn.py: AssemblyGNN wiring encoder + configurable heads - losses.py: MultiTaskLoss with learnable log-variance per task - factory.py: build_model() and build_loss() from YAML configs Supporting changes: - generator.py: serialize anchor_a, anchor_b, pitch in joint dicts - configs: fix joint_type num_classes 12 -> 11 (matches JointType enum) 92 tests covering shapes, gradients, edge cases, and end-to-end datagen-to-model pipeline.
120 lines
3.3 KiB
Python
120 lines
3.3 KiB
Python
"""Task-specific prediction heads for assembly GNN."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
__all__ = [
|
|
"DOFRegressionHead",
|
|
"DOFTrackingHead",
|
|
"EdgeClassificationHead",
|
|
"GraphClassificationHead",
|
|
"JointTypeHead",
|
|
]
|
|
|
|
|
|
class EdgeClassificationHead(nn.Module):
|
|
"""Binary edge classification (independent vs redundant).
|
|
|
|
Args:
|
|
hidden_dim: Input embedding dimension.
|
|
inner_dim: Internal MLP hidden dimension.
|
|
"""
|
|
|
|
def __init__(self, hidden_dim: int = 128, inner_dim: int = 64) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(hidden_dim, inner_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(inner_dim, 1),
|
|
)
|
|
|
|
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
|
|
"""Return logits [E, 1]."""
|
|
return self.mlp(edge_emb)
|
|
|
|
|
|
class GraphClassificationHead(nn.Module):
|
|
"""Assembly classification (well/under/over-constrained/mixed).
|
|
|
|
Args:
|
|
hidden_dim: Input embedding dimension.
|
|
num_classes: Number of classification categories.
|
|
"""
|
|
|
|
def __init__(self, hidden_dim: int = 128, num_classes: int = 4) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim, num_classes),
|
|
)
|
|
|
|
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
|
|
"""Return logits [B, num_classes]."""
|
|
return self.mlp(graph_emb)
|
|
|
|
|
|
class JointTypeHead(nn.Module):
|
|
"""Joint type classification from edge embeddings.
|
|
|
|
Args:
|
|
hidden_dim: Input embedding dimension.
|
|
num_classes: Number of joint types.
|
|
"""
|
|
|
|
def __init__(self, hidden_dim: int = 128, num_classes: int = 11) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, num_classes),
|
|
)
|
|
|
|
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
|
|
"""Return logits [E, num_classes]."""
|
|
return self.mlp(edge_emb)
|
|
|
|
|
|
class DOFRegressionHead(nn.Module):
|
|
"""Total DOF regression from graph embedding.
|
|
|
|
Args:
|
|
hidden_dim: Input embedding dimension.
|
|
"""
|
|
|
|
def __init__(self, hidden_dim: int = 128) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, 1),
|
|
nn.Softplus(),
|
|
)
|
|
|
|
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
|
|
"""Return non-negative DOF prediction [B, 1]."""
|
|
return self.mlp(graph_emb)
|
|
|
|
|
|
class DOFTrackingHead(nn.Module):
|
|
"""Per-body DOF prediction (translational, rotational) from node embeddings.
|
|
|
|
Args:
|
|
hidden_dim: Input embedding dimension.
|
|
"""
|
|
|
|
def __init__(self, hidden_dim: int = 128) -> None:
|
|
super().__init__()
|
|
self.mlp = nn.Sequential(
|
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
|
nn.ReLU(),
|
|
nn.Linear(hidden_dim // 2, 2),
|
|
nn.Softplus(),
|
|
)
|
|
|
|
def forward(self, node_emb: torch.Tensor) -> torch.Tensor:
|
|
"""Return non-negative per-body DOF [N, 2]."""
|
|
return self.mlp(node_emb)
|