feat(models): assembled multi-task model and loss function #22
Reference in New Issue
Block a user
Delete Branch "%!s()"
Deleting a branch is permanent. Although the deleted branch may continue to exist for a short time before it actually gets removed, it CANNOT be undone in most cases. Continue?
Summary
Implement the top-level model that composes encoder + heads, and the multi-task loss function.
Depends on #16, #17, #18, #19, #20, #21.
ConstraintSolverModel (
solver/models/solver_model.py)dict[str, Tensor]keyed by head name:edge_pred:[E, 1]graph_pred:[batch, 4]joint_type_pred:[E, 12]dof_pred:[batch, 1]body_dof_pred:[n, 2](if enabled)build_model()factory from feat(models): model config dataclass and architecture registry (#18)MultiTaskLoss (
solver/models/loss.py)Weighted sum of per-head losses with configurable weights:
redundant_penalty=2.0for class imbalanceDisabled heads excluded from loss computation.
Public API (
solver/models/__init__.py)Export:
build_model,ModelConfig,ConstraintSolverModel,MultiTaskLoss.Files
solver/models/solver_model.pysolver/models/loss.pysolver/models/__init__.pytests/models/test_solver_model.pytests/models/test_loss.pyAcceptance criteria
build_modelproduces working models for bothbaselineandgat_solverconfigs