feat(models): GIN encoder for assembly graphs #19

Closed
opened 2026-02-04 20:22:33 +00:00 by forbes · 0 comments
Owner

Summary

Implement solver/models/gin.py with a GINEncoder(nn.Module) for assembly constraint graphs.

Depends on #18.

Architecture

  • Input: node features [n, 22], edge_index [2, E], edge_attr [E, 22]
  • num_layers GINConv layers with batch normalization and dropout
  • Edge features incorporated via an edge MLP that transforms edge_attr and adds to neighbor messages
  • Output: node embeddings [n, hidden_dim] and graph-level embedding [hidden_dim] via global mean pooling

Config (from configs/model/baseline.yaml)

architecture: gin
encoder:
  num_layers: 3
  hidden_dim: 128
  dropout: 0.1

Registration

Register with @register_encoder("gin") so build_model resolves it automatically.

Files

  • solver/models/gin.py
  • tests/models/test_gin.py

Acceptance criteria

  • Forward pass produces correct output shapes (n, hidden_dim) and (hidden_dim,)
  • Works with variable-size graphs via PyG batching
  • Gradient flows through all parameters (no detached ops)
  • Registered and resolved by build_model
  • Passes ruff, mypy, pytest
## Summary Implement `solver/models/gin.py` with a `GINEncoder(nn.Module)` for assembly constraint graphs. Depends on #18. ## Architecture - Input: node features `[n, 22]`, edge_index `[2, E]`, edge_attr `[E, 22]` - `num_layers` GINConv layers with batch normalization and dropout - Edge features incorporated via an edge MLP that transforms `edge_attr` and adds to neighbor messages - Output: node embeddings `[n, hidden_dim]` and graph-level embedding `[hidden_dim]` via global mean pooling ## Config (from `configs/model/baseline.yaml`) ```yaml architecture: gin encoder: num_layers: 3 hidden_dim: 128 dropout: 0.1 ``` ## Registration Register with `@register_encoder("gin")` so `build_model` resolves it automatically. ## Files - `solver/models/gin.py` - `tests/models/test_gin.py` ## Acceptance criteria - [ ] Forward pass produces correct output shapes `(n, hidden_dim)` and `(hidden_dim,)` - [ ] Works with variable-size graphs via PyG batching - [ ] Gradient flows through all parameters (no detached ops) - [ ] Registered and resolved by `build_model` - [ ] Passes ruff, mypy, pytest
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: kindred/solver#19