feat(models): GAT encoder with multi-head attention and residual connections #20

Closed
opened 2026-02-04 20:22:41 +00:00 by forbes · 0 comments
Owner

Summary

Implement solver/models/gat.py with a GATEncoder(nn.Module) using GATv2Conv.

Depends on #18.

Architecture

  • Input: node features [n, 22], edge_index [2, E], edge_attr [E, 22]
  • num_layers GATv2Conv layers with num_heads attention heads
  • Residual connections between layers (togglable via residual config flag)
  • Edge features via GATv2Conv's edge_dim parameter
  • Output: node embeddings [n, hidden_dim] and graph-level embedding [hidden_dim]

Config (from configs/model/gat.yaml)

architecture: gat
encoder:
  num_layers: 4
  hidden_dim: 256
  num_heads: 8
  dropout: 0.1
  residual: true

Registration

Register with @register_encoder("gat").

Files

  • solver/models/gat.py
  • tests/models/test_gat.py

Acceptance criteria

  • Forward pass produces correct output shapes
  • Attention weights accessible for interpretability
  • Residual connections togglable via config
  • Works with variable-size batched graphs
  • Registered and resolved by build_model
  • Passes ruff, mypy, pytest
## Summary Implement `solver/models/gat.py` with a `GATEncoder(nn.Module)` using GATv2Conv. Depends on #18. ## Architecture - Input: node features `[n, 22]`, edge_index `[2, E]`, edge_attr `[E, 22]` - `num_layers` GATv2Conv layers with `num_heads` attention heads - Residual connections between layers (togglable via `residual` config flag) - Edge features via GATv2Conv's `edge_dim` parameter - Output: node embeddings `[n, hidden_dim]` and graph-level embedding `[hidden_dim]` ## Config (from `configs/model/gat.yaml`) ```yaml architecture: gat encoder: num_layers: 4 hidden_dim: 256 num_heads: 8 dropout: 0.1 residual: true ``` ## Registration Register with `@register_encoder("gat")`. ## Files - `solver/models/gat.py` - `tests/models/test_gat.py` ## Acceptance criteria - [ ] Forward pass produces correct output shapes - [ ] Attention weights accessible for interpretability - [ ] Residual connections togglable via config - [ ] Works with variable-size batched graphs - [ ] Registered and resolved by `build_model` - [ ] Passes ruff, mypy, pytest
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: kindred/solver#20