feat(models): assembled multi-task model and loss function #22

Closed
opened 2026-02-04 20:23:08 +00:00 by forbes · 0 comments
Owner

Summary

Implement the top-level model that composes encoder + heads, and the multi-task loss function.

Depends on #16, #17, #18, #19, #20, #21.

ConstraintSolverModel (solver/models/solver_model.py)

  • Composes an encoder (GIN or GAT) with enabled prediction heads
  • Forward pass returns dict[str, Tensor] keyed by head name:
    • edge_pred: [E, 1]
    • graph_pred: [batch, 4]
    • joint_type_pred: [E, 12]
    • dof_pred: [batch, 1]
    • body_dof_pred: [n, 2] (if enabled)
  • Wire into build_model() factory from feat(models): model config dataclass and architecture registry (#18)

MultiTaskLoss (solver/models/loss.py)

Weighted sum of per-head losses with configurable weights:

Head Loss Weight (default)
Edge classification BCE with logits, redundant_penalty=2.0 for class imbalance 1.0
Graph classification Cross-entropy 0.5
Joint type Cross-entropy 0.3
DOF regression Huber loss 0.2
DOF tracking MSE per body 0.2

Disabled heads excluded from loss computation.

Public API (solver/models/__init__.py)

Export: build_model, ModelConfig, ConstraintSolverModel, MultiTaskLoss.

Files

  • solver/models/solver_model.py
  • solver/models/loss.py
  • solver/models/__init__.py
  • tests/models/test_solver_model.py
  • tests/models/test_loss.py

Acceptance criteria

  • End-to-end forward pass: raw datagen example -> features -> Data -> model -> predictions
  • Loss computes and backpropagates through all enabled heads
  • build_model produces working models for both baseline and gat_solver configs
  • Disabled heads excluded from loss
  • Passes ruff, mypy, pytest
## Summary Implement the top-level model that composes encoder + heads, and the multi-task loss function. Depends on #16, #17, #18, #19, #20, #21. ## ConstraintSolverModel (`solver/models/solver_model.py`) - Composes an encoder (GIN or GAT) with enabled prediction heads - Forward pass returns `dict[str, Tensor]` keyed by head name: - `edge_pred`: `[E, 1]` - `graph_pred`: `[batch, 4]` - `joint_type_pred`: `[E, 12]` - `dof_pred`: `[batch, 1]` - `body_dof_pred`: `[n, 2]` (if enabled) - Wire into `build_model()` factory from #18 ## MultiTaskLoss (`solver/models/loss.py`) Weighted sum of per-head losses with configurable weights: | Head | Loss | Weight (default) | |------|------|-------------------| | Edge classification | BCE with logits, `redundant_penalty=2.0` for class imbalance | 1.0 | | Graph classification | Cross-entropy | 0.5 | | Joint type | Cross-entropy | 0.3 | | DOF regression | Huber loss | 0.2 | | DOF tracking | MSE per body | 0.2 | Disabled heads excluded from loss computation. ## Public API (`solver/models/__init__.py`) Export: `build_model`, `ModelConfig`, `ConstraintSolverModel`, `MultiTaskLoss`. ## Files - `solver/models/solver_model.py` - `solver/models/loss.py` - `solver/models/__init__.py` - `tests/models/test_solver_model.py` - `tests/models/test_loss.py` ## Acceptance criteria - [ ] End-to-end forward pass: raw datagen example -> features -> Data -> model -> predictions - [ ] Loss computes and backpropagates through all enabled heads - [ ] `build_model` produces working models for both `baseline` and `gat_solver` configs - [ ] Disabled heads excluded from loss - [ ] Passes ruff, mypy, pytest
Sign in to join this conversation.
1 Participants
Notifications
Due Date
No due date set.
Dependencies

No dependencies set.

Reference: kindred/solver#22