feat(models): multi-task prediction heads for constraint analysis #21

Closed
opened 2026-02-04 20:22:53 +00:00 by forbes · 0 comments
Owner

Summary

Implement solver/models/heads.py with standalone prediction head modules.

Depends on #18, #19, #20.

Heads

1. EdgeClassificationHead

  • Input: source + target node embeddings (concatenated) [E, 2*hidden_dim]
  • Output: binary logit per edge [E, 1]
  • MLP with configurable hidden_dim

2. GraphClassificationHead

  • Input: graph-level embedding [hidden_dim]
  • Output: num_classes logits [4] (rigid / underconstrained / overconstrained / mixed)

3. JointTypeHead

  • Input: edge embeddings [E, hidden_dim] (from edge MLP on source+target)
  • Output: num_classes logits [E, 12]

4. DofRegressionHead

  • Input: graph-level embedding [hidden_dim]
  • Output: scalar DOF prediction [1]

5. DofTrackingHead (GAT only)

  • Input: node embeddings [n, hidden_dim]
  • Output: per-body DOF [n, 2] (translational, rotational)

Each head is enabled/disabled via HeadConfig.enabled. Disabled heads return None.

Files

  • solver/models/heads.py
  • tests/models/test_heads.py

Acceptance criteria

  • Each head produces correct output shape for variable-size graphs
  • Heads work independently (testable in isolation)
  • Disabled heads return None
  • Passes ruff, mypy, pytest
## Summary Implement `solver/models/heads.py` with standalone prediction head modules. Depends on #18, #19, #20. ## Heads ### 1. EdgeClassificationHead - Input: source + target node embeddings (concatenated) `[E, 2*hidden_dim]` - Output: binary logit per edge `[E, 1]` - MLP with configurable `hidden_dim` ### 2. GraphClassificationHead - Input: graph-level embedding `[hidden_dim]` - Output: `num_classes` logits `[4]` (rigid / underconstrained / overconstrained / mixed) ### 3. JointTypeHead - Input: edge embeddings `[E, hidden_dim]` (from edge MLP on source+target) - Output: `num_classes` logits `[E, 12]` ### 4. DofRegressionHead - Input: graph-level embedding `[hidden_dim]` - Output: scalar DOF prediction `[1]` ### 5. DofTrackingHead (GAT only) - Input: node embeddings `[n, hidden_dim]` - Output: per-body DOF `[n, 2]` (translational, rotational) Each head is enabled/disabled via `HeadConfig.enabled`. Disabled heads return `None`. ## Files - `solver/models/heads.py` - `tests/models/test_heads.py` ## Acceptance criteria - [ ] Each head produces correct output shape for variable-size graphs - [ ] Heads work independently (testable in isolation) - [ ] Disabled heads return `None` - [ ] Passes ruff, mypy, pytest
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: kindred/solver#21