Compare commits
1 Commits
8e521b4519
...
feat/gnn-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe41fa3b00 |
@@ -16,9 +16,9 @@ heads:
|
||||
hidden_dim: 64
|
||||
graph_classification:
|
||||
enabled: true
|
||||
num_classes: 4 # rigid, under, over, mixed
|
||||
num_classes: 4 # rigid, under, over, mixed
|
||||
joint_type:
|
||||
enabled: true
|
||||
num_classes: 12
|
||||
num_classes: 11
|
||||
dof_regression:
|
||||
enabled: true
|
||||
|
||||
@@ -21,7 +21,7 @@ heads:
|
||||
num_classes: 4
|
||||
joint_type:
|
||||
enabled: true
|
||||
num_classes: 12
|
||||
num_classes: 11
|
||||
dof_regression:
|
||||
enabled: true
|
||||
dof_tracking:
|
||||
|
||||
@@ -877,6 +877,9 @@ class SyntheticAssemblyGenerator:
|
||||
"body_b": j.body_b,
|
||||
"type": j.joint_type.name,
|
||||
"axis": j.axis.tolist(),
|
||||
"anchor_a": j.anchor_a.tolist(),
|
||||
"anchor_b": j.anchor_b.tolist(),
|
||||
"pitch": j.pitch,
|
||||
}
|
||||
for j in joints
|
||||
],
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""GNN models for assembly constraint analysis."""
|
||||
|
||||
from solver.models.assembly_gnn import AssemblyGNN
|
||||
from solver.models.encoder import GATEncoder, GINEncoder
|
||||
from solver.models.factory import build_loss, build_model
|
||||
from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg
|
||||
from solver.models.heads import (
|
||||
DOFRegressionHead,
|
||||
DOFTrackingHead,
|
||||
EdgeClassificationHead,
|
||||
GraphClassificationHead,
|
||||
JointTypeHead,
|
||||
)
|
||||
from solver.models.losses import MultiTaskLoss
|
||||
|
||||
__all__ = [
|
||||
"ASSEMBLY_CLASSES",
|
||||
"AssemblyGNN",
|
||||
"DOFRegressionHead",
|
||||
"DOFTrackingHead",
|
||||
"EdgeClassificationHead",
|
||||
"GATEncoder",
|
||||
"GINEncoder",
|
||||
"GraphClassificationHead",
|
||||
"JOINT_TYPE_NAMES",
|
||||
"JointTypeHead",
|
||||
"MultiTaskLoss",
|
||||
"assembly_to_pyg",
|
||||
"build_loss",
|
||||
"build_model",
|
||||
]
|
||||
|
||||
131
solver/models/assembly_gnn.py
Normal file
131
solver/models/assembly_gnn.py
Normal file
@@ -0,0 +1,131 @@
|
||||
"""AssemblyGNN -- main model wiring encoder and task heads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from solver.models.encoder import GATEncoder, GINEncoder
|
||||
from solver.models.heads import (
|
||||
DOFRegressionHead,
|
||||
DOFTrackingHead,
|
||||
EdgeClassificationHead,
|
||||
GraphClassificationHead,
|
||||
JointTypeHead,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["AssemblyGNN"]
|
||||
|
||||
_ENCODERS = {
|
||||
"gin": GINEncoder,
|
||||
"gat": GATEncoder,
|
||||
}
|
||||
|
||||
|
||||
class AssemblyGNN(nn.Module):
|
||||
"""Multi-task GNN for assembly constraint analysis.
|
||||
|
||||
Wires an encoder (GIN or GAT) with optional task-specific prediction
|
||||
heads for edge classification, graph classification, joint type
|
||||
prediction, DOF regression, and per-body DOF tracking.
|
||||
|
||||
Args:
|
||||
encoder_type: ``"gin"`` or ``"gat"``.
|
||||
encoder_config: Kwargs passed to the encoder constructor.
|
||||
heads_config: Dict of head name → config dict. Each entry must have
|
||||
an ``enabled`` bool. Additional keys are passed as kwargs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_type: str = "gin",
|
||||
encoder_config: dict[str, Any] | None = None,
|
||||
heads_config: dict[str, dict[str, Any]] | None = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
encoder_config = encoder_config or {}
|
||||
heads_config = heads_config or {}
|
||||
|
||||
if encoder_type not in _ENCODERS:
|
||||
msg = f"Unknown encoder type: {encoder_type!r}. Choose from {list(_ENCODERS)}"
|
||||
raise ValueError(msg)
|
||||
|
||||
self.encoder = _ENCODERS[encoder_type](**encoder_config)
|
||||
hidden_dim = self.encoder.hidden_dim
|
||||
|
||||
self.heads = nn.ModuleDict()
|
||||
self._build_heads(heads_config, hidden_dim)
|
||||
|
||||
def _build_heads(
|
||||
self,
|
||||
heads_config: dict[str, dict[str, Any]],
|
||||
hidden_dim: int,
|
||||
) -> None:
|
||||
"""Instantiate enabled heads."""
|
||||
cfg = heads_config.get("edge_classification", {})
|
||||
if cfg.get("enabled", False):
|
||||
self.heads["edge_pred"] = EdgeClassificationHead(
|
||||
hidden_dim=hidden_dim,
|
||||
inner_dim=cfg.get("hidden_dim", 64),
|
||||
)
|
||||
|
||||
cfg = heads_config.get("graph_classification", {})
|
||||
if cfg.get("enabled", False):
|
||||
self.heads["graph_pred"] = GraphClassificationHead(
|
||||
hidden_dim=hidden_dim,
|
||||
num_classes=cfg.get("num_classes", 4),
|
||||
)
|
||||
|
||||
cfg = heads_config.get("joint_type", {})
|
||||
if cfg.get("enabled", False):
|
||||
self.heads["joint_type_pred"] = JointTypeHead(
|
||||
hidden_dim=hidden_dim,
|
||||
num_classes=cfg.get("num_classes", 11),
|
||||
)
|
||||
|
||||
cfg = heads_config.get("dof_regression", {})
|
||||
if cfg.get("enabled", False):
|
||||
self.heads["dof_pred"] = DOFRegressionHead(hidden_dim=hidden_dim)
|
||||
|
||||
cfg = heads_config.get("dof_tracking", {})
|
||||
if cfg.get("enabled", False):
|
||||
self.heads["body_dof_pred"] = DOFTrackingHead(hidden_dim=hidden_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
batch: torch.Tensor | None = None,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Run encoder and all enabled heads.
|
||||
|
||||
Returns:
|
||||
Dict with keys matching enabled head names:
|
||||
``edge_pred``, ``graph_pred``, ``joint_type_pred``,
|
||||
``dof_pred``, ``body_dof_pred``.
|
||||
"""
|
||||
node_emb, edge_emb, graph_emb = self.encoder(x, edge_index, edge_attr, batch)
|
||||
|
||||
preds: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Route embeddings to the appropriate heads.
|
||||
_edge_heads = {"edge_pred", "joint_type_pred"}
|
||||
_graph_heads = {"graph_pred", "dof_pred"}
|
||||
_node_heads = {"body_dof_pred"}
|
||||
|
||||
for name, head in self.heads.items():
|
||||
if name in _edge_heads:
|
||||
preds[name] = head(edge_emb)
|
||||
elif name in _graph_heads:
|
||||
preds[name] = head(graph_emb)
|
||||
elif name in _node_heads:
|
||||
preds[name] = head(node_emb)
|
||||
|
||||
return preds
|
||||
194
solver/models/encoder.py
Normal file
194
solver/models/encoder.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""GIN and GAT graph neural network encoders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch_geometric.nn import GATv2Conv, GINEConv, global_mean_pool
|
||||
|
||||
__all__ = ["GATEncoder", "GINEncoder"]
|
||||
|
||||
|
||||
def _make_gin_mlp(in_dim: int, hidden_dim: int) -> nn.Sequential:
|
||||
"""Two-layer MLP used inside GINEConv."""
|
||||
return nn.Sequential(
|
||||
nn.Linear(in_dim, hidden_dim),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
)
|
||||
|
||||
|
||||
class GINEncoder(nn.Module):
|
||||
"""Graph Isomorphism Network encoder with edge features (GINE).
|
||||
|
||||
Args:
|
||||
node_features_dim: Input node feature dimension.
|
||||
edge_features_dim: Input edge feature dimension.
|
||||
hidden_dim: Hidden dimension for all layers.
|
||||
num_layers: Number of GINEConv layers.
|
||||
dropout: Dropout probability.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_features_dim: int = 22,
|
||||
edge_features_dim: int = 22,
|
||||
hidden_dim: int = 128,
|
||||
num_layers: int = 3,
|
||||
dropout: float = 0.1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self._hidden_dim = hidden_dim
|
||||
|
||||
self.node_proj = nn.Sequential(
|
||||
nn.Linear(node_features_dim, hidden_dim),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.edge_proj = nn.Sequential(
|
||||
nn.Linear(edge_features_dim, hidden_dim),
|
||||
nn.BatchNorm1d(hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.norms = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
conv = GINEConv(nn=_make_gin_mlp(hidden_dim, hidden_dim), edge_dim=hidden_dim)
|
||||
self.convs.append(conv)
|
||||
self.norms.append(nn.BatchNorm1d(hidden_dim))
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# Edge embedding from endpoint + edge features.
|
||||
self.edge_mlp = nn.Linear(hidden_dim * 3, hidden_dim)
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._hidden_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
batch: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Encode graph and return node, edge, and graph embeddings.
|
||||
|
||||
Returns:
|
||||
node_emb: [N, hidden_dim]
|
||||
edge_emb: [E, hidden_dim]
|
||||
graph_emb: [B, hidden_dim]
|
||||
"""
|
||||
h = self.node_proj(x)
|
||||
e = self.edge_proj(edge_attr)
|
||||
|
||||
for conv, norm in zip(self.convs, self.norms):
|
||||
h = conv(h, edge_index, e)
|
||||
h = norm(h)
|
||||
h = torch.relu(h)
|
||||
h = self.dropout(h)
|
||||
|
||||
# Edge embeddings from endpoint concatenation.
|
||||
src, dst = edge_index
|
||||
edge_emb = self.edge_mlp(torch.cat([h[src], h[dst], e], dim=1))
|
||||
|
||||
# Graph embedding via mean pooling.
|
||||
graph_emb = global_mean_pool(h, batch)
|
||||
|
||||
return h, edge_emb, graph_emb
|
||||
|
||||
|
||||
class GATEncoder(nn.Module):
|
||||
"""Graph Attention Network v2 encoder with edge features and residuals.
|
||||
|
||||
Args:
|
||||
node_features_dim: Input node feature dimension.
|
||||
edge_features_dim: Input edge feature dimension.
|
||||
hidden_dim: Hidden dimension (must be divisible by num_heads).
|
||||
num_layers: Number of GATv2Conv layers.
|
||||
num_heads: Number of attention heads.
|
||||
dropout: Dropout probability.
|
||||
residual: Use residual connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
node_features_dim: int = 22,
|
||||
edge_features_dim: int = 22,
|
||||
hidden_dim: int = 256,
|
||||
num_layers: int = 4,
|
||||
num_heads: int = 8,
|
||||
dropout: float = 0.1,
|
||||
residual: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if hidden_dim % num_heads != 0:
|
||||
msg = f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})"
|
||||
raise ValueError(msg)
|
||||
|
||||
self._hidden_dim = hidden_dim
|
||||
self.residual = residual
|
||||
head_dim = hidden_dim // num_heads
|
||||
|
||||
self.node_proj = nn.Sequential(
|
||||
nn.Linear(node_features_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.edge_proj = nn.Sequential(
|
||||
nn.Linear(edge_features_dim, hidden_dim),
|
||||
nn.LayerNorm(hidden_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
self.norms = nn.ModuleList()
|
||||
for _ in range(num_layers):
|
||||
conv = GATv2Conv(
|
||||
in_channels=hidden_dim,
|
||||
out_channels=head_dim,
|
||||
heads=num_heads,
|
||||
edge_dim=hidden_dim,
|
||||
concat=True,
|
||||
)
|
||||
self.convs.append(conv)
|
||||
self.norms.append(nn.LayerNorm(hidden_dim))
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# Edge embedding from endpoint + edge features.
|
||||
self.edge_mlp = nn.Linear(hidden_dim * 3, hidden_dim)
|
||||
|
||||
@property
|
||||
def hidden_dim(self) -> int:
|
||||
return self._hidden_dim
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
edge_index: torch.Tensor,
|
||||
edge_attr: torch.Tensor,
|
||||
batch: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Encode graph and return node, edge, and graph embeddings."""
|
||||
h = self.node_proj(x)
|
||||
e = self.edge_proj(edge_attr)
|
||||
|
||||
for conv, norm in zip(self.convs, self.norms):
|
||||
h_new = conv(h, edge_index, e)
|
||||
h_new = norm(h_new)
|
||||
h_new = torch.relu(h_new)
|
||||
h_new = self.dropout(h_new)
|
||||
if self.residual:
|
||||
h = h + h_new
|
||||
else:
|
||||
h = h_new
|
||||
|
||||
src, dst = edge_index
|
||||
edge_emb = self.edge_mlp(torch.cat([h[src], h[dst], e], dim=1))
|
||||
graph_emb = global_mean_pool(h, batch)
|
||||
|
||||
return h, edge_emb, graph_emb
|
||||
82
solver/models/factory.py
Normal file
82
solver/models/factory.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""Factory functions to build model and loss from config dicts."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.models.assembly_gnn import AssemblyGNN
|
||||
from solver.models.losses import MultiTaskLoss
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["build_loss", "build_model"]
|
||||
|
||||
|
||||
def build_model(config: dict[str, Any]) -> AssemblyGNN:
|
||||
"""Construct an AssemblyGNN from a parsed YAML model config.
|
||||
|
||||
Expected config structure (matches ``configs/model/*.yaml``)::
|
||||
|
||||
architecture: gin # or gat
|
||||
encoder:
|
||||
hidden_dim: 128
|
||||
num_layers: 3
|
||||
...
|
||||
node_features_dim: 22
|
||||
edge_features_dim: 22
|
||||
heads:
|
||||
edge_classification:
|
||||
enabled: true
|
||||
...
|
||||
|
||||
Args:
|
||||
config: Parsed YAML model config dict.
|
||||
|
||||
Returns:
|
||||
Configured ``AssemblyGNN`` instance.
|
||||
"""
|
||||
encoder_type = config.get("architecture", "gin")
|
||||
|
||||
encoder_config: dict[str, Any] = dict(config.get("encoder", {}))
|
||||
encoder_config.setdefault("node_features_dim", config.get("node_features_dim", 22))
|
||||
encoder_config.setdefault("edge_features_dim", config.get("edge_features_dim", 22))
|
||||
|
||||
heads_config = config.get("heads", {})
|
||||
|
||||
return AssemblyGNN(
|
||||
encoder_type=encoder_type,
|
||||
encoder_config=encoder_config,
|
||||
heads_config=heads_config,
|
||||
)
|
||||
|
||||
|
||||
def build_loss(config: dict[str, Any]) -> MultiTaskLoss:
|
||||
"""Construct a MultiTaskLoss from a parsed YAML training config.
|
||||
|
||||
Expected config structure (from ``configs/training/*.yaml`` ``loss`` section)::
|
||||
|
||||
loss:
|
||||
edge_weight: 1.0
|
||||
graph_weight: 0.5
|
||||
joint_type_weight: 0.3
|
||||
dof_weight: 0.2
|
||||
body_dof_weight: 0.2
|
||||
redundant_penalty: 2.0
|
||||
|
||||
Args:
|
||||
config: Parsed YAML training config dict (full config, not just loss section).
|
||||
|
||||
Returns:
|
||||
Configured ``MultiTaskLoss`` instance.
|
||||
"""
|
||||
loss_config = config.get("loss", {})
|
||||
|
||||
return MultiTaskLoss(
|
||||
edge_weight=loss_config.get("edge_weight", 1.0),
|
||||
graph_weight=loss_config.get("graph_weight", 0.5),
|
||||
joint_type_weight=loss_config.get("joint_type_weight", 0.3),
|
||||
dof_weight=loss_config.get("dof_weight", 0.2),
|
||||
body_dof_weight=loss_config.get("body_dof_weight", 0.2),
|
||||
redundant_penalty=loss_config.get("redundant_penalty", 2.0),
|
||||
)
|
||||
260
solver/models/graph_conv.py
Normal file
260
solver/models/graph_conv.py
Normal file
@@ -0,0 +1,260 @@
|
||||
"""Convert datagen assembly dicts to PyTorch Geometric Data objects."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch_geometric.data import Data
|
||||
|
||||
from solver.datagen.types import JointType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"ASSEMBLY_CLASSES",
|
||||
"JOINT_TYPE_NAMES",
|
||||
"assembly_to_pyg",
|
||||
]
|
||||
|
||||
# Ordered list matching JointType ordinal values (0-10).
|
||||
JOINT_TYPE_NAMES: list[str] = [jt.name for jt in JointType]
|
||||
|
||||
# Assembly classification label mapping.
|
||||
ASSEMBLY_CLASSES: dict[str, int] = {
|
||||
"well-constrained": 0,
|
||||
"underconstrained": 1,
|
||||
"overconstrained": 2,
|
||||
"mixed": 3,
|
||||
}
|
||||
|
||||
# Joint type name -> ordinal for fast lookup.
|
||||
_JOINT_TYPE_TO_ORD: dict[str, int] = {jt.name: jt.value[0] for jt in JointType}
|
||||
|
||||
# Joint type name -> DOF removed.
|
||||
_JOINT_TYPE_TO_DOF: dict[str, int] = {jt.name: jt.dof for jt in JointType}
|
||||
|
||||
_NUM_JOINT_TYPES = len(JointType)
|
||||
|
||||
|
||||
def _encode_node_features(
|
||||
body_positions: list[list[float]],
|
||||
body_orientations: list[list[list[float]]],
|
||||
joints: list[dict[str, Any]],
|
||||
grounded: bool,
|
||||
) -> torch.Tensor:
|
||||
"""Encode node features as a [N, 22] tensor.
|
||||
|
||||
Dims 0-2: position (centered per graph)
|
||||
Dims 3-11: flattened 3x3 rotation matrix
|
||||
Dim 12: is grounded flag
|
||||
Dim 13: node degree / 10
|
||||
Dims 14-19: degree bucket one-hot (0, 1, 2, 3, 4, 5+)
|
||||
Dim 20: total incident DOF removed / 30
|
||||
Dim 21: fraction of incident joints that are FIXED
|
||||
"""
|
||||
n_bodies = len(body_positions)
|
||||
|
||||
# Positions centered per graph.
|
||||
pos = torch.tensor(body_positions, dtype=torch.float32)
|
||||
centroid = pos.mean(dim=0, keepdim=True)
|
||||
pos = pos - centroid
|
||||
|
||||
# Flattened orientation matrices [N, 9].
|
||||
orient = torch.tensor(body_orientations, dtype=torch.float32).reshape(n_bodies, 9)
|
||||
|
||||
# Compute per-node degree and incident joint stats.
|
||||
degree = torch.zeros(n_bodies, dtype=torch.float32)
|
||||
dof_removed = torch.zeros(n_bodies, dtype=torch.float32)
|
||||
fixed_count = torch.zeros(n_bodies, dtype=torch.float32)
|
||||
|
||||
for j in joints:
|
||||
a, b = j["body_a"], j["body_b"]
|
||||
jtype = j["type"]
|
||||
dof = _JOINT_TYPE_TO_DOF.get(jtype, 0)
|
||||
degree[a] += 1
|
||||
degree[b] += 1
|
||||
dof_removed[a] += dof
|
||||
dof_removed[b] += dof
|
||||
if jtype == "FIXED":
|
||||
fixed_count[a] += 1
|
||||
fixed_count[b] += 1
|
||||
|
||||
# Grounded flag: body 0 if assembly is grounded.
|
||||
grounded_flag = torch.zeros(n_bodies, 1, dtype=torch.float32)
|
||||
if grounded and n_bodies > 0:
|
||||
grounded_flag[0, 0] = 1.0
|
||||
|
||||
# Degree normalized.
|
||||
degree_norm = (degree / 10.0).unsqueeze(1)
|
||||
|
||||
# Degree bucket one-hot [N, 6]: buckets 0, 1, 2, 3, 4, 5+.
|
||||
bucket = degree.clamp(max=5).long()
|
||||
degree_onehot = torch.zeros(n_bodies, 6, dtype=torch.float32)
|
||||
degree_onehot.scatter_(1, bucket.unsqueeze(1), 1.0)
|
||||
|
||||
# Total incident DOF removed, normalized.
|
||||
dof_norm = (dof_removed / 30.0).unsqueeze(1)
|
||||
|
||||
# Fraction of incident joints that are FIXED.
|
||||
safe_degree = degree.clamp(min=1)
|
||||
fixed_frac = (fixed_count / safe_degree).unsqueeze(1)
|
||||
|
||||
# Concatenate: [N, 3+9+1+1+6+1+1] = [N, 22].
|
||||
x = torch.cat(
|
||||
[pos, orient, grounded_flag, degree_norm, degree_onehot, dof_norm, fixed_frac], dim=1
|
||||
)
|
||||
return x
|
||||
|
||||
|
||||
def _encode_edge_features(
|
||||
joints: list[dict[str, Any]],
|
||||
body_positions: list[list[float]],
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Encode edge features and build bidirectional edge_index.
|
||||
|
||||
Returns:
|
||||
edge_index: [2, 2*n_joints] (each joint as two directed edges).
|
||||
edge_attr: [2*n_joints, 22] edge features.
|
||||
"""
|
||||
n_joints = len(joints)
|
||||
if n_joints == 0:
|
||||
return (
|
||||
torch.zeros(2, 0, dtype=torch.long),
|
||||
torch.zeros(0, 22, dtype=torch.float32),
|
||||
)
|
||||
|
||||
pos = body_positions
|
||||
src_list: list[int] = []
|
||||
dst_list: list[int] = []
|
||||
features: list[list[float]] = []
|
||||
|
||||
for j in joints:
|
||||
a, b = j["body_a"], j["body_b"]
|
||||
jtype_name = j["type"]
|
||||
ordinal = _JOINT_TYPE_TO_ORD.get(jtype_name, 0)
|
||||
dof = _JOINT_TYPE_TO_DOF.get(jtype_name, 0)
|
||||
|
||||
# One-hot joint type [11].
|
||||
onehot = [0.0] * _NUM_JOINT_TYPES
|
||||
onehot[ordinal] = 1.0
|
||||
|
||||
# Axis [3].
|
||||
axis = j.get("axis", [0.0, 0.0, 1.0])
|
||||
|
||||
# Anchor offsets relative to body positions (fallback to zeros).
|
||||
anchor_a_raw = j.get("anchor_a")
|
||||
anchor_b_raw = j.get("anchor_b")
|
||||
if anchor_a_raw is not None:
|
||||
anchor_a_off = [anchor_a_raw[k] - pos[a][k] for k in range(3)]
|
||||
else:
|
||||
anchor_a_off = [0.0, 0.0, 0.0]
|
||||
if anchor_b_raw is not None:
|
||||
anchor_b_off = [anchor_b_raw[k] - pos[b][k] for k in range(3)]
|
||||
else:
|
||||
anchor_b_off = [0.0, 0.0, 0.0]
|
||||
|
||||
pitch = j.get("pitch", 0.0)
|
||||
dof_norm = dof / 6.0
|
||||
|
||||
feat = onehot + axis + anchor_a_off + anchor_b_off + [pitch, dof_norm]
|
||||
|
||||
# Bidirectional: a->b and b->a with identical features.
|
||||
src_list.extend([a, b])
|
||||
dst_list.extend([b, a])
|
||||
features.append(feat)
|
||||
features.append(feat)
|
||||
|
||||
edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
|
||||
edge_attr = torch.tensor(features, dtype=torch.float32)
|
||||
return edge_index, edge_attr
|
||||
|
||||
|
||||
def assembly_to_pyg(
|
||||
example: dict[str, Any],
|
||||
*,
|
||||
include_labels: bool = True,
|
||||
) -> Data:
|
||||
"""Convert a datagen training example dict to a PyG Data object.
|
||||
|
||||
Args:
|
||||
example: Dict from ``SyntheticAssemblyGenerator.generate_training_batch()``.
|
||||
include_labels: Attach ground truth labels to the Data object.
|
||||
|
||||
Returns:
|
||||
``torch_geometric.data.Data`` with node features ``x``, ``edge_index``,
|
||||
``edge_attr``, and optionally label tensors ``y_edge``, ``y_graph``,
|
||||
``y_joint_type``, ``y_dof``, ``y_body_dof``.
|
||||
"""
|
||||
body_positions = example["body_positions"]
|
||||
body_orientations = example["body_orientations"]
|
||||
joints = example["joints"]
|
||||
grounded = example.get("grounded", False)
|
||||
|
||||
x = _encode_node_features(body_positions, body_orientations, joints, grounded)
|
||||
edge_index, edge_attr = _encode_edge_features(joints, body_positions)
|
||||
|
||||
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
||||
data.num_nodes = x.size(0)
|
||||
|
||||
if include_labels:
|
||||
_attach_labels(data, example, joints)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _attach_labels(
|
||||
data: Data,
|
||||
example: dict[str, Any],
|
||||
joints: list[dict[str, Any]],
|
||||
) -> None:
|
||||
"""Attach ground truth label tensors to a Data object."""
|
||||
joint_labels = example.get("joint_labels", {})
|
||||
labels = example.get("labels", {})
|
||||
|
||||
# Per-edge: binary independent (1) / redundant (0).
|
||||
# Duplicated for bidirectional edges.
|
||||
n_joints = len(joints)
|
||||
edge_labels: list[float] = []
|
||||
joint_type_labels: list[int] = []
|
||||
for j in joints:
|
||||
jid = j["joint_id"]
|
||||
jl = joint_labels.get(jid) or joint_labels.get(str(jid), {})
|
||||
is_independent = 1.0 if jl.get("redundant_constraints", 0) == 0 else 0.0
|
||||
ordinal = _JOINT_TYPE_TO_ORD.get(j["type"], 0)
|
||||
# Bidirectional: duplicate.
|
||||
edge_labels.extend([is_independent, is_independent])
|
||||
joint_type_labels.extend([ordinal, ordinal])
|
||||
|
||||
if n_joints > 0:
|
||||
data.y_edge = torch.tensor(edge_labels, dtype=torch.float32)
|
||||
data.y_joint_type = torch.tensor(joint_type_labels, dtype=torch.long)
|
||||
else:
|
||||
data.y_edge = torch.zeros(0, dtype=torch.float32)
|
||||
data.y_joint_type = torch.zeros(0, dtype=torch.long)
|
||||
|
||||
# Assembly classification.
|
||||
classification = example.get("assembly_classification", "")
|
||||
data.y_graph = torch.tensor(
|
||||
[ASSEMBLY_CLASSES.get(classification, 0)],
|
||||
dtype=torch.long,
|
||||
)
|
||||
|
||||
# Total DOF.
|
||||
assembly_labels = labels.get("assembly", {})
|
||||
data.y_dof = torch.tensor(
|
||||
[float(assembly_labels.get("total_dof", 0))],
|
||||
dtype=torch.float32,
|
||||
)
|
||||
|
||||
# Per-body DOF: [N, 2] (translational, rotational).
|
||||
per_body = labels.get("per_body", [])
|
||||
n_bodies = data.num_nodes
|
||||
body_dof = torch.zeros(n_bodies, 2, dtype=torch.float32)
|
||||
for entry in per_body:
|
||||
bid = entry["body_id"]
|
||||
if 0 <= bid < n_bodies:
|
||||
body_dof[bid, 0] = float(entry["translational_dof"])
|
||||
body_dof[bid, 1] = float(entry["rotational_dof"])
|
||||
data.y_body_dof = body_dof
|
||||
119
solver/models/heads.py
Normal file
119
solver/models/heads.py
Normal file
@@ -0,0 +1,119 @@
|
||||
"""Task-specific prediction heads for assembly GNN."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
__all__ = [
|
||||
"DOFRegressionHead",
|
||||
"DOFTrackingHead",
|
||||
"EdgeClassificationHead",
|
||||
"GraphClassificationHead",
|
||||
"JointTypeHead",
|
||||
]
|
||||
|
||||
|
||||
class EdgeClassificationHead(nn.Module):
|
||||
"""Binary edge classification (independent vs redundant).
|
||||
|
||||
Args:
|
||||
hidden_dim: Input embedding dimension.
|
||||
inner_dim: Internal MLP hidden dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int = 128, inner_dim: int = 64) -> None:
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_dim, inner_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(inner_dim, 1),
|
||||
)
|
||||
|
||||
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Return logits [E, 1]."""
|
||||
return self.mlp(edge_emb)
|
||||
|
||||
|
||||
class GraphClassificationHead(nn.Module):
|
||||
"""Assembly classification (well/under/over-constrained/mixed).
|
||||
|
||||
Args:
|
||||
hidden_dim: Input embedding dimension.
|
||||
num_classes: Number of classification categories.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int = 128, num_classes: int = 4) -> None:
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Return logits [B, num_classes]."""
|
||||
return self.mlp(graph_emb)
|
||||
|
||||
|
||||
class JointTypeHead(nn.Module):
|
||||
"""Joint type classification from edge embeddings.
|
||||
|
||||
Args:
|
||||
hidden_dim: Input embedding dimension.
|
||||
num_classes: Number of joint types.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int = 128, num_classes: int = 11) -> None:
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, num_classes),
|
||||
)
|
||||
|
||||
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Return logits [E, num_classes]."""
|
||||
return self.mlp(edge_emb)
|
||||
|
||||
|
||||
class DOFRegressionHead(nn.Module):
|
||||
"""Total DOF regression from graph embedding.
|
||||
|
||||
Args:
|
||||
hidden_dim: Input embedding dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int = 128) -> None:
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, 1),
|
||||
nn.Softplus(),
|
||||
)
|
||||
|
||||
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Return non-negative DOF prediction [B, 1]."""
|
||||
return self.mlp(graph_emb)
|
||||
|
||||
|
||||
class DOFTrackingHead(nn.Module):
|
||||
"""Per-body DOF prediction (translational, rotational) from node embeddings.
|
||||
|
||||
Args:
|
||||
hidden_dim: Input embedding dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_dim: int = 128) -> None:
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||
nn.ReLU(),
|
||||
nn.Linear(hidden_dim // 2, 2),
|
||||
nn.Softplus(),
|
||||
)
|
||||
|
||||
def forward(self, node_emb: torch.Tensor) -> torch.Tensor:
|
||||
"""Return non-negative per-body DOF [N, 2]."""
|
||||
return self.mlp(node_emb)
|
||||
161
solver/models/losses.py
Normal file
161
solver/models/losses.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Uncertainty-weighted multi-task loss (Kendall et al., 2018)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
__all__ = ["MultiTaskLoss"]
|
||||
|
||||
|
||||
class MultiTaskLoss(nn.Module):
|
||||
"""Multi-task loss with learnable uncertainty weighting.
|
||||
|
||||
Each task has a learnable ``log_var`` parameter (log variance) that
|
||||
automatically balances task contributions during training. The loss
|
||||
for task *i* is::
|
||||
|
||||
(1 / (2 * sigma_i^2)) * weight_i * L_i + 0.5 * log(sigma_i^2)
|
||||
|
||||
which simplifies to::
|
||||
|
||||
exp(-log_var_i) * weight_i * L_i + 0.5 * log_var_i
|
||||
|
||||
Args:
|
||||
edge_weight: Initial scale for edge classification loss.
|
||||
graph_weight: Initial scale for graph classification loss.
|
||||
joint_type_weight: Initial scale for joint type loss.
|
||||
dof_weight: Initial scale for DOF regression loss.
|
||||
body_dof_weight: Initial scale for per-body DOF loss.
|
||||
redundant_penalty: Extra weight on redundant edges (label=0) in
|
||||
the edge BCE loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
edge_weight: float = 1.0,
|
||||
graph_weight: float = 0.5,
|
||||
joint_type_weight: float = 0.3,
|
||||
dof_weight: float = 0.2,
|
||||
body_dof_weight: float = 0.2,
|
||||
redundant_penalty: float = 2.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.weights = {
|
||||
"edge": edge_weight,
|
||||
"graph": graph_weight,
|
||||
"joint_type": joint_type_weight,
|
||||
"dof": dof_weight,
|
||||
"body_dof": body_dof_weight,
|
||||
}
|
||||
self.redundant_penalty = redundant_penalty
|
||||
|
||||
# Learnable log-variance parameters, one per task.
|
||||
# Initialized to 0 → sigma^2 = 1.
|
||||
self.log_vars = nn.ParameterDict(
|
||||
{name: nn.Parameter(torch.zeros(1)) for name in self.weights}
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
predictions: dict[str, torch.Tensor],
|
||||
targets: dict[str, torch.Tensor],
|
||||
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||
"""Compute total loss and per-task breakdown.
|
||||
|
||||
Args:
|
||||
predictions: Dict with keys from AssemblyGNN output:
|
||||
``edge_pred``, ``graph_pred``, ``joint_type_pred``,
|
||||
``dof_pred``, ``body_dof_pred``.
|
||||
targets: Dict with label tensors:
|
||||
``y_edge``, ``y_graph``, ``y_joint_type``,
|
||||
``y_dof``, ``y_body_dof``.
|
||||
|
||||
Returns:
|
||||
total_loss: Scalar total loss.
|
||||
breakdown: Dict of per-task raw loss values (before weighting).
|
||||
"""
|
||||
total = torch.tensor(0.0, device=self._device(predictions))
|
||||
breakdown: dict[str, torch.Tensor] = {}
|
||||
|
||||
# Edge classification (BCE with asymmetric redundancy penalty).
|
||||
if "edge_pred" in predictions and "y_edge" in targets:
|
||||
loss = self._edge_loss(predictions["edge_pred"], targets["y_edge"])
|
||||
total = total + self._weighted(loss, "edge")
|
||||
breakdown["edge"] = loss.detach()
|
||||
|
||||
# Graph classification.
|
||||
if "graph_pred" in predictions and "y_graph" in targets:
|
||||
loss = nn.functional.cross_entropy(
|
||||
predictions["graph_pred"],
|
||||
targets["y_graph"],
|
||||
)
|
||||
total = total + self._weighted(loss, "graph")
|
||||
breakdown["graph"] = loss.detach()
|
||||
|
||||
# Joint type classification.
|
||||
if "joint_type_pred" in predictions and "y_joint_type" in targets:
|
||||
loss = nn.functional.cross_entropy(
|
||||
predictions["joint_type_pred"],
|
||||
targets["y_joint_type"],
|
||||
)
|
||||
total = total + self._weighted(loss, "joint_type")
|
||||
breakdown["joint_type"] = loss.detach()
|
||||
|
||||
# DOF regression.
|
||||
if "dof_pred" in predictions and "y_dof" in targets:
|
||||
loss = nn.functional.smooth_l1_loss(
|
||||
predictions["dof_pred"],
|
||||
targets["y_dof"],
|
||||
)
|
||||
total = total + self._weighted(loss, "dof")
|
||||
breakdown["dof"] = loss.detach()
|
||||
|
||||
# Per-body DOF tracking.
|
||||
if "body_dof_pred" in predictions and "y_body_dof" in targets:
|
||||
loss = nn.functional.smooth_l1_loss(
|
||||
predictions["body_dof_pred"],
|
||||
targets["y_body_dof"],
|
||||
)
|
||||
total = total + self._weighted(loss, "body_dof")
|
||||
breakdown["body_dof"] = loss.detach()
|
||||
|
||||
return total, breakdown
|
||||
|
||||
def _edge_loss(
|
||||
self,
|
||||
pred: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""BCE loss with asymmetric weighting for redundant edges."""
|
||||
# pred: [E, 1] logits, target: [E] binary.
|
||||
pred_flat = pred.squeeze(-1)
|
||||
# Weight: redundant (0) gets higher penalty, independent (1) gets 1.0.
|
||||
weight = torch.where(
|
||||
target == 0,
|
||||
torch.tensor(self.redundant_penalty, device=pred.device),
|
||||
torch.tensor(1.0, device=pred.device),
|
||||
)
|
||||
return nn.functional.binary_cross_entropy_with_logits(
|
||||
pred_flat,
|
||||
target,
|
||||
weight=weight,
|
||||
)
|
||||
|
||||
def _weighted(self, loss: torch.Tensor, task_name: str) -> torch.Tensor:
|
||||
"""Apply uncertainty weighting: exp(-log_var) * w * L + 0.5 * log_var."""
|
||||
log_var = self.log_vars[task_name].squeeze()
|
||||
weight = self.weights[task_name]
|
||||
return torch.exp(-log_var) * weight * loss + 0.5 * log_var
|
||||
|
||||
@staticmethod
|
||||
def _device(predictions: dict[str, torch.Tensor]) -> torch.device:
|
||||
"""Infer device from prediction tensors."""
|
||||
for v in predictions.values():
|
||||
return v.device
|
||||
return torch.device("cpu")
|
||||
0
tests/models/__init__.py
Normal file
0
tests/models/__init__.py
Normal file
188
tests/models/test_assembly_gnn.py
Normal file
188
tests/models/test_assembly_gnn.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""Tests for solver.models.assembly_gnn -- main model wiring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from solver.models.assembly_gnn import AssemblyGNN
|
||||
|
||||
|
||||
def _default_heads_config(dof_tracking: bool = False) -> dict:
|
||||
return {
|
||||
"edge_classification": {"enabled": True, "hidden_dim": 64},
|
||||
"graph_classification": {"enabled": True, "num_classes": 4},
|
||||
"joint_type": {"enabled": True, "num_classes": 11},
|
||||
"dof_regression": {"enabled": True},
|
||||
"dof_tracking": {"enabled": dof_tracking},
|
||||
}
|
||||
|
||||
|
||||
def _random_graph(
|
||||
n_nodes: int = 8,
|
||||
n_edges: int = 16,
|
||||
batch_size: int = 2,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
x = torch.randn(n_nodes, 22)
|
||||
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
||||
edge_attr = torch.randn(n_edges, 22)
|
||||
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
|
||||
if len(batch) < n_nodes:
|
||||
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
|
||||
return x, edge_index, edge_attr, batch
|
||||
|
||||
|
||||
class TestAssemblyGNNGIN:
|
||||
"""AssemblyGNN with GIN encoder."""
|
||||
|
||||
def test_forward_all_heads(self) -> None:
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
||||
heads_config=_default_heads_config(),
|
||||
)
|
||||
x, ei, ea, batch = _random_graph()
|
||||
preds = model(x, ei, ea, batch)
|
||||
assert "edge_pred" in preds
|
||||
assert "graph_pred" in preds
|
||||
assert "joint_type_pred" in preds
|
||||
assert "dof_pred" in preds
|
||||
|
||||
def test_output_shapes(self) -> None:
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
||||
heads_config=_default_heads_config(),
|
||||
)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20, batch_size=3)
|
||||
preds = model(x, ei, ea, batch)
|
||||
assert preds["edge_pred"].shape == (20, 1)
|
||||
assert preds["graph_pred"].shape == (3, 4)
|
||||
assert preds["joint_type_pred"].shape == (20, 11)
|
||||
assert preds["dof_pred"].shape == (3, 1)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||
heads_config=_default_heads_config(),
|
||||
)
|
||||
x, ei, ea, batch = _random_graph()
|
||||
x.requires_grad_(True)
|
||||
preds = model(x, ei, ea, batch)
|
||||
total = sum(p.sum() for p in preds.values())
|
||||
total.backward()
|
||||
assert x.grad is not None
|
||||
|
||||
def test_no_heads_returns_empty(self) -> None:
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||
heads_config={},
|
||||
)
|
||||
x, ei, ea, batch = _random_graph()
|
||||
preds = model(x, ei, ea, batch)
|
||||
assert len(preds) == 0
|
||||
|
||||
|
||||
class TestAssemblyGNNGAT:
|
||||
"""AssemblyGNN with GAT encoder."""
|
||||
|
||||
def test_forward_all_heads(self) -> None:
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gat",
|
||||
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
|
||||
heads_config=_default_heads_config(dof_tracking=True),
|
||||
)
|
||||
x, ei, ea, batch = _random_graph()
|
||||
preds = model(x, ei, ea, batch)
|
||||
assert "edge_pred" in preds
|
||||
assert "graph_pred" in preds
|
||||
assert "joint_type_pred" in preds
|
||||
assert "dof_pred" in preds
|
||||
assert "body_dof_pred" in preds
|
||||
|
||||
def test_body_dof_shape(self) -> None:
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gat",
|
||||
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
|
||||
heads_config=_default_heads_config(dof_tracking=True),
|
||||
)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20)
|
||||
preds = model(x, ei, ea, batch)
|
||||
assert preds["body_dof_pred"].shape == (10, 2)
|
||||
|
||||
|
||||
class TestAssemblyGNNEdgeCases:
|
||||
"""Edge cases and error handling."""
|
||||
|
||||
def test_unknown_encoder_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="Unknown encoder"):
|
||||
AssemblyGNN(encoder_type="transformer")
|
||||
|
||||
def test_selective_heads(self) -> None:
|
||||
"""Only enabled heads produce output."""
|
||||
config = {
|
||||
"edge_classification": {"enabled": True},
|
||||
"graph_classification": {"enabled": False},
|
||||
"joint_type": {"enabled": True, "num_classes": 11},
|
||||
}
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||
heads_config=config,
|
||||
)
|
||||
x, ei, ea, batch = _random_graph()
|
||||
preds = model(x, ei, ea, batch)
|
||||
assert "edge_pred" in preds
|
||||
assert "joint_type_pred" in preds
|
||||
assert "graph_pred" not in preds
|
||||
assert "dof_pred" not in preds
|
||||
|
||||
def test_no_batch_single_graph(self) -> None:
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||
heads_config=_default_heads_config(),
|
||||
)
|
||||
x = torch.randn(6, 22)
|
||||
ei = torch.randint(0, 6, (2, 10))
|
||||
ea = torch.randn(10, 22)
|
||||
preds = model(x, ei, ea)
|
||||
assert preds["graph_pred"].shape == (1, 4)
|
||||
assert preds["dof_pred"].shape == (1, 1)
|
||||
|
||||
def test_parameter_count_reasonable(self) -> None:
|
||||
"""Sanity check that model has learnable parameters."""
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
||||
heads_config=_default_heads_config(),
|
||||
)
|
||||
n_params = sum(p.numel() for p in model.parameters())
|
||||
assert n_params > 1000 # non-trivial model
|
||||
|
||||
|
||||
class TestAssemblyGNNEndToEnd:
|
||||
"""End-to-end test with datagen pipeline."""
|
||||
|
||||
def test_datagen_to_model(self) -> None:
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.models.graph_conv import assembly_to_pyg
|
||||
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(batch_size=2, complexity_tier="simple")
|
||||
|
||||
model = AssemblyGNN(
|
||||
encoder_type="gin",
|
||||
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||
heads_config=_default_heads_config(),
|
||||
)
|
||||
model.eval()
|
||||
|
||||
for ex in batch:
|
||||
data = assembly_to_pyg(ex)
|
||||
with torch.no_grad():
|
||||
preds = model(data.x, data.edge_index, data.edge_attr)
|
||||
assert "edge_pred" in preds
|
||||
assert preds["edge_pred"].shape[0] == data.edge_index.shape[1]
|
||||
183
tests/models/test_encoder.py
Normal file
183
tests/models/test_encoder.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Tests for solver.models.encoder -- GIN and GAT graph encoders."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from solver.models.encoder import GATEncoder, GINEncoder
|
||||
|
||||
|
||||
def _random_graph(
|
||||
n_nodes: int = 8,
|
||||
n_edges: int = 20,
|
||||
node_dim: int = 22,
|
||||
edge_dim: int = 22,
|
||||
batch_size: int = 2,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Create a random graph for testing."""
|
||||
x = torch.randn(n_nodes, node_dim)
|
||||
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
||||
edge_attr = torch.randn(n_edges, edge_dim)
|
||||
# Assign nodes to batches roughly evenly.
|
||||
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
|
||||
# Handle remainder nodes.
|
||||
if len(batch) < n_nodes:
|
||||
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
|
||||
return x, edge_index, edge_attr, batch
|
||||
|
||||
|
||||
class TestGINEncoder:
|
||||
"""GINEncoder shape and gradient tests."""
|
||||
|
||||
def test_output_shapes(self) -> None:
|
||||
enc = GINEncoder(node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape == (10, 64)
|
||||
assert edge_emb.shape == (16, 64)
|
||||
assert graph_emb.shape == (3, 64)
|
||||
|
||||
def test_hidden_dim_property(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=128)
|
||||
assert enc.hidden_dim == 128
|
||||
|
||||
def test_default_dimensions(self) -> None:
|
||||
enc = GINEncoder()
|
||||
x, ei, ea, batch = _random_graph()
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 128
|
||||
assert edge_emb.shape[1] == 128
|
||||
assert graph_emb.shape[1] == 128
|
||||
|
||||
def test_no_batch_defaults_to_single_graph(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=64, num_layers=2)
|
||||
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
||||
assert graph_emb.shape == (1, 64)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
x.requires_grad_(True)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
loss = graph_emb.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.abs().sum() > 0
|
||||
|
||||
def test_zero_edges(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
x = torch.randn(4, 22)
|
||||
ei = torch.zeros(2, 0, dtype=torch.long)
|
||||
ea = torch.zeros(0, 22)
|
||||
batch = torch.tensor([0, 0, 1, 1])
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape == (4, 32)
|
||||
assert edge_emb.shape == (0, 32)
|
||||
assert graph_emb.shape == (2, 32)
|
||||
|
||||
def test_single_node(self) -> None:
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
# Train with a small batch first to populate BN running stats.
|
||||
x_train = torch.randn(4, 22)
|
||||
ei_train = torch.zeros(2, 0, dtype=torch.long)
|
||||
ea_train = torch.zeros(0, 22)
|
||||
batch_train = torch.tensor([0, 0, 1, 1])
|
||||
enc.train()
|
||||
enc(x_train, ei_train, ea_train, batch_train)
|
||||
# Now test single node in eval mode (BN uses running stats).
|
||||
enc.eval()
|
||||
x = torch.randn(1, 22)
|
||||
ei = torch.zeros(2, 0, dtype=torch.long)
|
||||
ea = torch.zeros(0, 22)
|
||||
with torch.no_grad():
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea)
|
||||
assert node_emb.shape == (1, 32)
|
||||
assert graph_emb.shape == (1, 32)
|
||||
|
||||
def test_eval_mode(self) -> None:
|
||||
"""Encoder works in eval mode (BatchNorm uses running stats)."""
|
||||
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||
# Forward pass in train mode to populate BN stats.
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
enc.train()
|
||||
enc(x, ei, ea, batch)
|
||||
# Switch to eval.
|
||||
enc.eval()
|
||||
with torch.no_grad():
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 32
|
||||
|
||||
|
||||
class TestGATEncoder:
|
||||
"""GATEncoder shape and gradient tests."""
|
||||
|
||||
def test_output_shapes(self) -> None:
|
||||
enc = GATEncoder(
|
||||
node_features_dim=22,
|
||||
edge_features_dim=22,
|
||||
hidden_dim=64,
|
||||
num_layers=2,
|
||||
num_heads=4,
|
||||
)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape == (10, 64)
|
||||
assert edge_emb.shape == (16, 64)
|
||||
assert graph_emb.shape == (3, 64)
|
||||
|
||||
def test_hidden_dim_property(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=256, num_heads=8)
|
||||
assert enc.hidden_dim == 256
|
||||
|
||||
def test_default_dimensions(self) -> None:
|
||||
enc = GATEncoder()
|
||||
x, ei, ea, batch = _random_graph()
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 256
|
||||
assert edge_emb.shape[1] == 256
|
||||
assert graph_emb.shape[1] == 256
|
||||
|
||||
def test_no_batch_defaults_to_single_graph(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
||||
assert graph_emb.shape == (1, 64)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
x.requires_grad_(True)
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
loss = graph_emb.sum()
|
||||
loss.backward()
|
||||
assert x.grad is not None
|
||||
assert x.grad.abs().sum() > 0
|
||||
|
||||
def test_residual_connection(self) -> None:
|
||||
"""With residual=True, output should differ from residual=False."""
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
torch.manual_seed(0)
|
||||
enc_res = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=True)
|
||||
torch.manual_seed(0)
|
||||
enc_no = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=False)
|
||||
with torch.no_grad():
|
||||
n1, _, _ = enc_res(x, ei, ea, batch)
|
||||
n2, _, _ = enc_no(x, ei, ea, batch)
|
||||
# Outputs should generally differ (unless by very unlikely coincidence).
|
||||
assert not torch.allclose(n1, n2, atol=1e-4)
|
||||
|
||||
def test_hidden_dim_must_divide_heads(self) -> None:
|
||||
with pytest.raises(ValueError, match="divisible"):
|
||||
GATEncoder(hidden_dim=100, num_heads=8)
|
||||
|
||||
def test_eval_mode(self) -> None:
|
||||
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||
enc.train()
|
||||
enc(x, ei, ea, batch)
|
||||
enc.eval()
|
||||
with torch.no_grad():
|
||||
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||
assert node_emb.shape[1] == 64
|
||||
110
tests/models/test_factory.py
Normal file
110
tests/models/test_factory.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""Tests for solver.models.factory -- model and loss construction from config."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import yaml
|
||||
|
||||
from solver.models.assembly_gnn import AssemblyGNN
|
||||
from solver.models.factory import build_loss, build_model
|
||||
from solver.models.losses import MultiTaskLoss
|
||||
|
||||
|
||||
def _load_yaml(path: str) -> dict:
|
||||
with open(path) as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
class TestBuildModel:
|
||||
"""build_model constructs AssemblyGNN from config."""
|
||||
|
||||
def test_baseline_config(self) -> None:
|
||||
config = _load_yaml("configs/model/baseline.yaml")
|
||||
model = build_model(config)
|
||||
assert isinstance(model, AssemblyGNN)
|
||||
assert model.encoder.hidden_dim == 128
|
||||
|
||||
def test_gat_config(self) -> None:
|
||||
config = _load_yaml("configs/model/gat.yaml")
|
||||
model = build_model(config)
|
||||
assert isinstance(model, AssemblyGNN)
|
||||
assert model.encoder.hidden_dim == 256
|
||||
|
||||
def test_baseline_heads_present(self) -> None:
|
||||
config = _load_yaml("configs/model/baseline.yaml")
|
||||
model = build_model(config)
|
||||
assert "edge_pred" in model.heads
|
||||
assert "graph_pred" in model.heads
|
||||
assert "joint_type_pred" in model.heads
|
||||
assert "dof_pred" in model.heads
|
||||
|
||||
def test_gat_has_dof_tracking(self) -> None:
|
||||
config = _load_yaml("configs/model/gat.yaml")
|
||||
model = build_model(config)
|
||||
assert "body_dof_pred" in model.heads
|
||||
|
||||
def test_baseline_no_dof_tracking(self) -> None:
|
||||
config = _load_yaml("configs/model/baseline.yaml")
|
||||
model = build_model(config)
|
||||
assert "body_dof_pred" not in model.heads
|
||||
|
||||
def test_minimal_config(self) -> None:
|
||||
config = {"architecture": "gin"}
|
||||
model = build_model(config)
|
||||
assert isinstance(model, AssemblyGNN)
|
||||
# No heads enabled.
|
||||
assert len(model.heads) == 0
|
||||
|
||||
def test_custom_config(self) -> None:
|
||||
config = {
|
||||
"architecture": "gin",
|
||||
"encoder": {"hidden_dim": 64, "num_layers": 2},
|
||||
"heads": {
|
||||
"edge_classification": {"enabled": True},
|
||||
"graph_classification": {"enabled": True, "num_classes": 4},
|
||||
},
|
||||
}
|
||||
model = build_model(config)
|
||||
assert model.encoder.hidden_dim == 64
|
||||
assert "edge_pred" in model.heads
|
||||
assert "graph_pred" in model.heads
|
||||
assert "joint_type_pred" not in model.heads
|
||||
|
||||
|
||||
class TestBuildLoss:
|
||||
"""build_loss constructs MultiTaskLoss from training config."""
|
||||
|
||||
def test_pretrain_config(self) -> None:
|
||||
config = _load_yaml("configs/training/pretrain.yaml")
|
||||
loss_fn = build_loss(config)
|
||||
assert isinstance(loss_fn, MultiTaskLoss)
|
||||
|
||||
def test_weights_from_config(self) -> None:
|
||||
config = _load_yaml("configs/training/pretrain.yaml")
|
||||
loss_fn = build_loss(config)
|
||||
assert loss_fn.weights["edge"] == 1.0
|
||||
assert loss_fn.weights["graph"] == 0.5
|
||||
assert loss_fn.weights["joint_type"] == 0.3
|
||||
assert loss_fn.weights["dof"] == 0.2
|
||||
|
||||
def test_redundant_penalty_from_config(self) -> None:
|
||||
config = _load_yaml("configs/training/pretrain.yaml")
|
||||
loss_fn = build_loss(config)
|
||||
assert loss_fn.redundant_penalty == 2.0
|
||||
|
||||
def test_empty_config_uses_defaults(self) -> None:
|
||||
loss_fn = build_loss({})
|
||||
assert isinstance(loss_fn, MultiTaskLoss)
|
||||
assert loss_fn.weights["edge"] == 1.0
|
||||
|
||||
def test_custom_weights(self) -> None:
|
||||
config = {
|
||||
"loss": {
|
||||
"edge_weight": 2.0,
|
||||
"graph_weight": 1.0,
|
||||
"redundant_penalty": 5.0,
|
||||
},
|
||||
}
|
||||
loss_fn = build_loss(config)
|
||||
assert loss_fn.weights["edge"] == 2.0
|
||||
assert loss_fn.weights["graph"] == 1.0
|
||||
assert loss_fn.redundant_penalty == 5.0
|
||||
144
tests/models/test_graph_conv.py
Normal file
144
tests/models/test_graph_conv.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Tests for solver.models.graph_conv -- assembly to PyG conversion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.datagen.types import JointType
|
||||
from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg
|
||||
|
||||
|
||||
def _make_example(n_bodies: int = 4, grounded: bool = True, seed: int = 0) -> dict:
|
||||
"""Generate a single training example via the datagen pipeline."""
|
||||
gen = SyntheticAssemblyGenerator(seed=seed)
|
||||
batch = gen.generate_training_batch(batch_size=1, n_bodies_range=(n_bodies, n_bodies + 1))
|
||||
return batch[0]
|
||||
|
||||
|
||||
class TestAssemblyToPyg:
|
||||
"""assembly_to_pyg converts datagen dicts to PyG Data correctly."""
|
||||
|
||||
def test_node_feature_shape(self) -> None:
|
||||
ex = _make_example(n_bodies=5)
|
||||
data = assembly_to_pyg(ex)
|
||||
assert data.x.shape == (5, 22)
|
||||
|
||||
def test_edge_feature_shape(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
data = assembly_to_pyg(ex)
|
||||
n_joints = ex["n_joints"]
|
||||
assert data.edge_attr.shape == (n_joints * 2, 22)
|
||||
|
||||
def test_edge_index_bidirectional(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
data = assembly_to_pyg(ex)
|
||||
ei = data.edge_index
|
||||
# Each joint produces 2 directed edges: a->b and b->a.
|
||||
for i in range(0, ei.size(1), 2):
|
||||
assert ei[0, i].item() == ei[1, i + 1].item()
|
||||
assert ei[1, i].item() == ei[0, i + 1].item()
|
||||
|
||||
def test_edge_index_shape(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
data = assembly_to_pyg(ex)
|
||||
n_joints = ex["n_joints"]
|
||||
assert data.edge_index.shape == (2, n_joints * 2)
|
||||
|
||||
def test_node_features_centered(self) -> None:
|
||||
ex = _make_example(n_bodies=5)
|
||||
data = assembly_to_pyg(ex)
|
||||
# Positions (dims 0-2) should be centered (mean ~0).
|
||||
pos = data.x[:, :3]
|
||||
assert pos.mean(dim=0).abs().max().item() < 1e-5
|
||||
|
||||
def test_grounded_flag_set(self) -> None:
|
||||
ex = _make_example(n_bodies=4, grounded=True)
|
||||
ex["grounded"] = True
|
||||
data = assembly_to_pyg(ex)
|
||||
assert data.x[0, 12].item() == 1.0
|
||||
|
||||
def test_ungrounded_flag_clear(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
ex["grounded"] = False
|
||||
data = assembly_to_pyg(ex)
|
||||
assert (data.x[:, 12] == 0.0).all()
|
||||
|
||||
def test_edge_type_one_hot_valid(self) -> None:
|
||||
ex = _make_example(n_bodies=5)
|
||||
data = assembly_to_pyg(ex)
|
||||
if data.edge_attr.size(0) > 0:
|
||||
onehot = data.edge_attr[:, :11]
|
||||
# Each row should have exactly one 1.0.
|
||||
assert (onehot.sum(dim=1) == 1.0).all()
|
||||
|
||||
def test_labels_present_when_requested(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
data = assembly_to_pyg(ex, include_labels=True)
|
||||
assert hasattr(data, "y_edge")
|
||||
assert hasattr(data, "y_graph")
|
||||
assert hasattr(data, "y_joint_type")
|
||||
assert hasattr(data, "y_dof")
|
||||
assert hasattr(data, "y_body_dof")
|
||||
|
||||
def test_labels_absent_when_not_requested(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
data = assembly_to_pyg(ex, include_labels=False)
|
||||
assert not hasattr(data, "y_edge")
|
||||
assert not hasattr(data, "y_graph")
|
||||
|
||||
def test_graph_classification_label_mapping(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
data = assembly_to_pyg(ex)
|
||||
cls = ex["assembly_classification"]
|
||||
expected = ASSEMBLY_CLASSES[cls]
|
||||
assert data.y_graph.item() == expected
|
||||
|
||||
def test_body_dof_shape(self) -> None:
|
||||
ex = _make_example(n_bodies=5)
|
||||
data = assembly_to_pyg(ex)
|
||||
assert data.y_body_dof.shape == (5, 2)
|
||||
|
||||
def test_edge_labels_binary(self) -> None:
|
||||
ex = _make_example(n_bodies=5)
|
||||
data = assembly_to_pyg(ex)
|
||||
if data.y_edge.numel() > 0:
|
||||
assert ((data.y_edge == 0.0) | (data.y_edge == 1.0)).all()
|
||||
|
||||
def test_dof_removed_normalized(self) -> None:
|
||||
ex = _make_example(n_bodies=4)
|
||||
data = assembly_to_pyg(ex)
|
||||
if data.edge_attr.size(0) > 0:
|
||||
dof_norm = data.edge_attr[:, 21]
|
||||
assert (dof_norm >= 0.0).all()
|
||||
assert (dof_norm <= 1.0).all()
|
||||
|
||||
def test_roundtrip_with_generator(self) -> None:
|
||||
"""Generate a real example and convert -- no crash."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(batch_size=5, complexity_tier="simple")
|
||||
for ex in batch:
|
||||
data = assembly_to_pyg(ex)
|
||||
assert data.x.shape[1] == 22
|
||||
assert data.edge_attr.shape[1] == 22
|
||||
|
||||
|
||||
class TestAssemblyClasses:
|
||||
"""ASSEMBLY_CLASSES covers all classifications."""
|
||||
|
||||
def test_four_classes(self) -> None:
|
||||
assert len(ASSEMBLY_CLASSES) == 4
|
||||
|
||||
def test_values_are_0_to_3(self) -> None:
|
||||
assert set(ASSEMBLY_CLASSES.values()) == {0, 1, 2, 3}
|
||||
|
||||
|
||||
class TestJointTypeNames:
|
||||
"""JOINT_TYPE_NAMES matches the JointType enum."""
|
||||
|
||||
def test_length_matches_enum(self) -> None:
|
||||
assert len(JOINT_TYPE_NAMES) == len(JointType)
|
||||
|
||||
def test_order_matches_ordinal(self) -> None:
|
||||
for i, name in enumerate(JOINT_TYPE_NAMES):
|
||||
assert JointType[name].value[0] == i
|
||||
182
tests/models/test_heads.py
Normal file
182
tests/models/test_heads.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Tests for solver.models.heads -- task-specific prediction heads."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from solver.models.heads import (
|
||||
DOFRegressionHead,
|
||||
DOFTrackingHead,
|
||||
EdgeClassificationHead,
|
||||
GraphClassificationHead,
|
||||
JointTypeHead,
|
||||
)
|
||||
|
||||
|
||||
class TestEdgeClassificationHead:
|
||||
"""EdgeClassificationHead produces correct shape and gradients."""
|
||||
|
||||
def test_output_shape(self) -> None:
|
||||
head = EdgeClassificationHead(hidden_dim=128)
|
||||
edge_emb = torch.randn(20, 128)
|
||||
out = head(edge_emb)
|
||||
assert out.shape == (20, 1)
|
||||
|
||||
def test_output_shape_small(self) -> None:
|
||||
head = EdgeClassificationHead(hidden_dim=32)
|
||||
edge_emb = torch.randn(5, 32)
|
||||
out = head(edge_emb)
|
||||
assert out.shape == (5, 1)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
head = EdgeClassificationHead(hidden_dim=64)
|
||||
edge_emb = torch.randn(10, 64, requires_grad=True)
|
||||
out = head(edge_emb)
|
||||
out.sum().backward()
|
||||
assert edge_emb.grad is not None
|
||||
assert edge_emb.grad.abs().sum() > 0
|
||||
|
||||
def test_zero_edges(self) -> None:
|
||||
head = EdgeClassificationHead(hidden_dim=64)
|
||||
edge_emb = torch.zeros(0, 64)
|
||||
out = head(edge_emb)
|
||||
assert out.shape == (0, 1)
|
||||
|
||||
def test_output_is_logits(self) -> None:
|
||||
"""Output should be unbounded logits (not probabilities)."""
|
||||
head = EdgeClassificationHead(hidden_dim=64)
|
||||
torch.manual_seed(42)
|
||||
edge_emb = torch.randn(100, 64)
|
||||
out = head(edge_emb)
|
||||
# Logits can be negative.
|
||||
assert out.min().item() < 0 or out.max().item() > 1
|
||||
|
||||
|
||||
class TestGraphClassificationHead:
|
||||
"""GraphClassificationHead produces correct shape and gradients."""
|
||||
|
||||
def test_output_shape(self) -> None:
|
||||
head = GraphClassificationHead(hidden_dim=128, num_classes=4)
|
||||
graph_emb = torch.randn(3, 128)
|
||||
out = head(graph_emb)
|
||||
assert out.shape == (3, 4)
|
||||
|
||||
def test_custom_num_classes(self) -> None:
|
||||
head = GraphClassificationHead(hidden_dim=64, num_classes=8)
|
||||
graph_emb = torch.randn(2, 64)
|
||||
out = head(graph_emb)
|
||||
assert out.shape == (2, 8)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
head = GraphClassificationHead(hidden_dim=64)
|
||||
graph_emb = torch.randn(2, 64, requires_grad=True)
|
||||
out = head(graph_emb)
|
||||
out.sum().backward()
|
||||
assert graph_emb.grad is not None
|
||||
|
||||
def test_single_graph(self) -> None:
|
||||
head = GraphClassificationHead(hidden_dim=128)
|
||||
graph_emb = torch.randn(1, 128)
|
||||
out = head(graph_emb)
|
||||
assert out.shape == (1, 4)
|
||||
|
||||
|
||||
class TestJointTypeHead:
|
||||
"""JointTypeHead produces correct shape and gradients."""
|
||||
|
||||
def test_output_shape(self) -> None:
|
||||
head = JointTypeHead(hidden_dim=128, num_classes=11)
|
||||
edge_emb = torch.randn(20, 128)
|
||||
out = head(edge_emb)
|
||||
assert out.shape == (20, 11)
|
||||
|
||||
def test_custom_classes(self) -> None:
|
||||
head = JointTypeHead(hidden_dim=64, num_classes=7)
|
||||
edge_emb = torch.randn(10, 64)
|
||||
out = head(edge_emb)
|
||||
assert out.shape == (10, 7)
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
head = JointTypeHead(hidden_dim=64)
|
||||
edge_emb = torch.randn(10, 64, requires_grad=True)
|
||||
out = head(edge_emb)
|
||||
out.sum().backward()
|
||||
assert edge_emb.grad is not None
|
||||
|
||||
def test_zero_edges(self) -> None:
|
||||
head = JointTypeHead(hidden_dim=64)
|
||||
edge_emb = torch.zeros(0, 64)
|
||||
out = head(edge_emb)
|
||||
assert out.shape == (0, 11)
|
||||
|
||||
|
||||
class TestDOFRegressionHead:
|
||||
"""DOFRegressionHead produces correct shape and non-negative output."""
|
||||
|
||||
def test_output_shape(self) -> None:
|
||||
head = DOFRegressionHead(hidden_dim=128)
|
||||
graph_emb = torch.randn(3, 128)
|
||||
out = head(graph_emb)
|
||||
assert out.shape == (3, 1)
|
||||
|
||||
def test_output_non_negative(self) -> None:
|
||||
"""Softplus ensures non-negative output."""
|
||||
head = DOFRegressionHead(hidden_dim=64)
|
||||
torch.manual_seed(0)
|
||||
graph_emb = torch.randn(50, 64)
|
||||
out = head(graph_emb)
|
||||
assert (out >= 0).all()
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
head = DOFRegressionHead(hidden_dim=64)
|
||||
graph_emb = torch.randn(2, 64, requires_grad=True)
|
||||
out = head(graph_emb)
|
||||
out.sum().backward()
|
||||
assert graph_emb.grad is not None
|
||||
|
||||
def test_single_graph(self) -> None:
|
||||
head = DOFRegressionHead(hidden_dim=32)
|
||||
graph_emb = torch.randn(1, 32)
|
||||
out = head(graph_emb)
|
||||
assert out.shape == (1, 1)
|
||||
assert out.item() >= 0
|
||||
|
||||
|
||||
class TestDOFTrackingHead:
|
||||
"""DOFTrackingHead produces correct shape and non-negative output."""
|
||||
|
||||
def test_output_shape(self) -> None:
|
||||
head = DOFTrackingHead(hidden_dim=128)
|
||||
node_emb = torch.randn(10, 128)
|
||||
out = head(node_emb)
|
||||
assert out.shape == (10, 2)
|
||||
|
||||
def test_output_non_negative(self) -> None:
|
||||
"""Softplus ensures non-negative output."""
|
||||
head = DOFTrackingHead(hidden_dim=64)
|
||||
torch.manual_seed(0)
|
||||
node_emb = torch.randn(50, 64)
|
||||
out = head(node_emb)
|
||||
assert (out >= 0).all()
|
||||
|
||||
def test_gradients_flow(self) -> None:
|
||||
head = DOFTrackingHead(hidden_dim=64)
|
||||
node_emb = torch.randn(10, 64, requires_grad=True)
|
||||
out = head(node_emb)
|
||||
out.sum().backward()
|
||||
assert node_emb.grad is not None
|
||||
|
||||
def test_single_node(self) -> None:
|
||||
head = DOFTrackingHead(hidden_dim=32)
|
||||
node_emb = torch.randn(1, 32)
|
||||
out = head(node_emb)
|
||||
assert out.shape == (1, 2)
|
||||
assert (out >= 0).all()
|
||||
|
||||
def test_two_columns_independent(self) -> None:
|
||||
"""Translational and rotational DOF are independently predicted."""
|
||||
head = DOFTrackingHead(hidden_dim=64)
|
||||
node_emb = torch.randn(20, 64)
|
||||
out = head(node_emb)
|
||||
# The two columns should generally differ.
|
||||
assert not torch.allclose(out[:, 0], out[:, 1], atol=1e-6)
|
||||
156
tests/models/test_losses.py
Normal file
156
tests/models/test_losses.py
Normal file
@@ -0,0 +1,156 @@
|
||||
"""Tests for solver.models.losses -- uncertainty-weighted multi-task loss."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from solver.models.losses import MultiTaskLoss
|
||||
|
||||
|
||||
def _make_predictions_and_targets(
|
||||
n_edges: int = 20,
|
||||
batch_size: int = 3,
|
||||
n_nodes: int = 10,
|
||||
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||
preds = {
|
||||
"edge_pred": torch.randn(n_edges, 1),
|
||||
"graph_pred": torch.randn(batch_size, 4),
|
||||
"joint_type_pred": torch.randn(n_edges, 11),
|
||||
"dof_pred": torch.rand(batch_size, 1) * 10,
|
||||
"body_dof_pred": torch.rand(n_nodes, 2) * 6,
|
||||
}
|
||||
targets = {
|
||||
"y_edge": torch.randint(0, 2, (n_edges,)).float(),
|
||||
"y_graph": torch.randint(0, 4, (batch_size,)),
|
||||
"y_joint_type": torch.randint(0, 11, (n_edges,)),
|
||||
"y_dof": torch.rand(batch_size, 1) * 10,
|
||||
"y_body_dof": torch.rand(n_nodes, 2) * 6,
|
||||
}
|
||||
return preds, targets
|
||||
|
||||
|
||||
class TestMultiTaskLoss:
|
||||
"""MultiTaskLoss computation tests."""
|
||||
|
||||
def test_returns_scalar_and_breakdown(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
preds, targets = _make_predictions_and_targets()
|
||||
total, breakdown = loss_fn(preds, targets)
|
||||
assert total.dim() == 0 # scalar
|
||||
assert isinstance(breakdown, dict)
|
||||
|
||||
def test_all_tasks_in_breakdown(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
preds, targets = _make_predictions_and_targets()
|
||||
_, breakdown = loss_fn(preds, targets)
|
||||
assert "edge" in breakdown
|
||||
assert "graph" in breakdown
|
||||
assert "joint_type" in breakdown
|
||||
assert "dof" in breakdown
|
||||
assert "body_dof" in breakdown
|
||||
|
||||
def test_total_is_positive(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
preds, targets = _make_predictions_and_targets()
|
||||
total, _ = loss_fn(preds, targets)
|
||||
# With random predictions, loss should be positive.
|
||||
assert total.item() > 0
|
||||
|
||||
def test_skips_missing_predictions(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
preds = {"edge_pred": torch.randn(10, 1)}
|
||||
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
|
||||
total, breakdown = loss_fn(preds, targets)
|
||||
assert "edge" in breakdown
|
||||
assert "graph" not in breakdown
|
||||
assert "joint_type" not in breakdown
|
||||
|
||||
def test_skips_missing_targets(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
preds = {
|
||||
"edge_pred": torch.randn(10, 1),
|
||||
"graph_pred": torch.randn(2, 4),
|
||||
}
|
||||
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
|
||||
_, breakdown = loss_fn(preds, targets)
|
||||
assert "edge" in breakdown
|
||||
assert "graph" not in breakdown
|
||||
|
||||
def test_gradients_flow_to_log_vars(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
preds, targets = _make_predictions_and_targets()
|
||||
# Make preds require grad.
|
||||
for k in preds:
|
||||
preds[k] = preds[k].requires_grad_(True)
|
||||
total, _ = loss_fn(preds, targets)
|
||||
total.backward()
|
||||
for name, param in loss_fn.log_vars.items():
|
||||
assert param.grad is not None, f"No gradient for log_var[{name}]"
|
||||
|
||||
def test_gradients_flow_to_predictions(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
preds, targets = _make_predictions_and_targets()
|
||||
for k in preds:
|
||||
preds[k] = preds[k].requires_grad_(True)
|
||||
total, _ = loss_fn(preds, targets)
|
||||
total.backward()
|
||||
for k, v in preds.items():
|
||||
assert v.grad is not None, f"No gradient for prediction[{k}]"
|
||||
|
||||
def test_redundant_penalty_applies(self) -> None:
|
||||
"""Redundant edges (label=0) should have higher loss contribution."""
|
||||
loss_fn = MultiTaskLoss(redundant_penalty=5.0)
|
||||
# All-zero predictions, label=0 (redundant).
|
||||
preds_red = {"edge_pred": torch.zeros(10, 1)}
|
||||
targets_red = {"y_edge": torch.zeros(10)}
|
||||
total_red, _ = loss_fn(preds_red, targets_red)
|
||||
|
||||
loss_fn2 = MultiTaskLoss(redundant_penalty=1.0)
|
||||
total_eq, _ = loss_fn2(preds_red, targets_red)
|
||||
|
||||
# Higher penalty should produce higher loss.
|
||||
assert total_red.item() > total_eq.item()
|
||||
|
||||
def test_empty_predictions_returns_zero(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
total, breakdown = loss_fn({}, {})
|
||||
assert total.item() == 0.0
|
||||
assert len(breakdown) == 0
|
||||
|
||||
|
||||
class TestUncertaintyWeighting:
|
||||
"""Test uncertainty weighting mechanism specifically."""
|
||||
|
||||
def test_log_vars_initialized_to_zero(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
for param in loss_fn.log_vars.values():
|
||||
assert param.item() == 0.0
|
||||
|
||||
def test_log_vars_are_learnable(self) -> None:
|
||||
loss_fn = MultiTaskLoss()
|
||||
params = list(loss_fn.parameters())
|
||||
log_var_params = [p for p in params if p.shape == (1,)]
|
||||
assert len(log_var_params) == 5 # one per task
|
||||
|
||||
def test_weighting_reduces_high_loss_influence(self) -> None:
|
||||
"""After a few gradient steps, log_var for a noisy task should increase."""
|
||||
loss_fn = MultiTaskLoss(edge_weight=1.0, graph_weight=1.0)
|
||||
optimizer = torch.optim.SGD(loss_fn.parameters(), lr=0.1)
|
||||
|
||||
# Simulate: edge task has high loss, graph has low.
|
||||
for _ in range(20):
|
||||
preds = {
|
||||
"edge_pred": torch.randn(10, 1) * 10, # high variance -> high loss
|
||||
"graph_pred": torch.zeros(2, 4), # near-zero loss
|
||||
}
|
||||
targets = {
|
||||
"y_edge": torch.randint(0, 2, (10,)).float(),
|
||||
"y_graph": torch.zeros(2, dtype=torch.long),
|
||||
}
|
||||
optimizer.zero_grad()
|
||||
total, _ = loss_fn(preds, targets)
|
||||
total.backward()
|
||||
optimizer.step()
|
||||
|
||||
# The edge task log_var should have increased (higher uncertainty).
|
||||
assert loss_fn.log_vars["edge"].item() > loss_fn.log_vars["graph"].item()
|
||||
Reference in New Issue
Block a user