feat: dependency DAG and YAML-defined compute jobs #92

Merged
forbes merged 13 commits from feat-dag-workers into main 2026-02-14 19:27:19 +00:00
Showing only changes of commit 671a0aeefe - Show all commits

520
internal/db/dag.go Normal file
View File

@@ -0,0 +1,520 @@
package db
import (
"context"
"encoding/json"
"fmt"
"time"
"github.com/jackc/pgx/v5"
)
// DAGNode represents a feature-level node in the dependency graph.
type DAGNode struct {
ID string
ItemID string
RevisionNumber int
NodeKey string
NodeType string
PropertiesHash *string
ValidationState string
ValidationMsg *string
Metadata map[string]any
CreatedAt time.Time
UpdatedAt time.Time
}
// DAGEdge represents a dependency between two nodes.
type DAGEdge struct {
ID string
SourceNodeID string
TargetNodeID string
EdgeType string
Metadata map[string]any
}
// DAGCrossEdge represents a dependency between nodes in different items.
type DAGCrossEdge struct {
ID string
SourceNodeID string
TargetNodeID string
RelationshipID *string
EdgeType string
Metadata map[string]any
}
// DAGRepository provides dependency graph database operations.
type DAGRepository struct {
db *DB
}
// NewDAGRepository creates a new DAG repository.
func NewDAGRepository(db *DB) *DAGRepository {
return &DAGRepository{db: db}
}
// GetNodes returns all DAG nodes for an item at a specific revision.
func (r *DAGRepository) GetNodes(ctx context.Context, itemID string, revisionNumber int) ([]*DAGNode, error) {
rows, err := r.db.pool.Query(ctx, `
SELECT id, item_id, revision_number, node_key, node_type,
properties_hash, validation_state, validation_msg,
metadata, created_at, updated_at
FROM dag_nodes
WHERE item_id = $1 AND revision_number = $2
ORDER BY node_key
`, itemID, revisionNumber)
if err != nil {
return nil, fmt.Errorf("querying DAG nodes: %w", err)
}
defer rows.Close()
return scanDAGNodes(rows)
}
// GetNodeByKey returns a single DAG node by item, revision, and key.
func (r *DAGRepository) GetNodeByKey(ctx context.Context, itemID string, revisionNumber int, nodeKey string) (*DAGNode, error) {
n := &DAGNode{}
var metadataJSON []byte
err := r.db.pool.QueryRow(ctx, `
SELECT id, item_id, revision_number, node_key, node_type,
properties_hash, validation_state, validation_msg,
metadata, created_at, updated_at
FROM dag_nodes
WHERE item_id = $1 AND revision_number = $2 AND node_key = $3
`, itemID, revisionNumber, nodeKey).Scan(
&n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType,
&n.PropertiesHash, &n.ValidationState, &n.ValidationMsg,
&metadataJSON, &n.CreatedAt, &n.UpdatedAt,
)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("querying DAG node: %w", err)
}
if metadataJSON != nil {
if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil {
return nil, fmt.Errorf("unmarshaling node metadata: %w", err)
}
}
return n, nil
}
// GetNodeByID returns a single DAG node by its ID.
func (r *DAGRepository) GetNodeByID(ctx context.Context, nodeID string) (*DAGNode, error) {
n := &DAGNode{}
var metadataJSON []byte
err := r.db.pool.QueryRow(ctx, `
SELECT id, item_id, revision_number, node_key, node_type,
properties_hash, validation_state, validation_msg,
metadata, created_at, updated_at
FROM dag_nodes
WHERE id = $1
`, nodeID).Scan(
&n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType,
&n.PropertiesHash, &n.ValidationState, &n.ValidationMsg,
&metadataJSON, &n.CreatedAt, &n.UpdatedAt,
)
if err == pgx.ErrNoRows {
return nil, nil
}
if err != nil {
return nil, fmt.Errorf("querying DAG node by ID: %w", err)
}
if metadataJSON != nil {
if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil {
return nil, fmt.Errorf("unmarshaling node metadata: %w", err)
}
}
return n, nil
}
// UpsertNode inserts or updates a single DAG node.
func (r *DAGRepository) UpsertNode(ctx context.Context, n *DAGNode) error {
metadataJSON, err := json.Marshal(n.Metadata)
if err != nil {
return fmt.Errorf("marshaling metadata: %w", err)
}
err = r.db.pool.QueryRow(ctx, `
INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type,
properties_hash, validation_state, validation_msg, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (item_id, revision_number, node_key)
DO UPDATE SET
node_type = EXCLUDED.node_type,
properties_hash = EXCLUDED.properties_hash,
validation_state = EXCLUDED.validation_state,
validation_msg = EXCLUDED.validation_msg,
metadata = EXCLUDED.metadata,
updated_at = now()
RETURNING id, created_at, updated_at
`, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType,
n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON,
).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt)
if err != nil {
return fmt.Errorf("upserting DAG node: %w", err)
}
return nil
}
// GetEdges returns all edges for nodes belonging to an item at a specific revision.
func (r *DAGRepository) GetEdges(ctx context.Context, itemID string, revisionNumber int) ([]*DAGEdge, error) {
rows, err := r.db.pool.Query(ctx, `
SELECT e.id, e.source_node_id, e.target_node_id, e.edge_type, e.metadata
FROM dag_edges e
JOIN dag_nodes src ON src.id = e.source_node_id
WHERE src.item_id = $1 AND src.revision_number = $2
ORDER BY e.source_node_id, e.target_node_id
`, itemID, revisionNumber)
if err != nil {
return nil, fmt.Errorf("querying DAG edges: %w", err)
}
defer rows.Close()
var edges []*DAGEdge
for rows.Next() {
e := &DAGEdge{}
var metadataJSON []byte
if err := rows.Scan(&e.ID, &e.SourceNodeID, &e.TargetNodeID, &e.EdgeType, &metadataJSON); err != nil {
return nil, fmt.Errorf("scanning DAG edge: %w", err)
}
if metadataJSON != nil {
if err := json.Unmarshal(metadataJSON, &e.Metadata); err != nil {
return nil, fmt.Errorf("unmarshaling edge metadata: %w", err)
}
}
edges = append(edges, e)
}
return edges, rows.Err()
}
// CreateEdge inserts a new edge between two nodes.
func (r *DAGRepository) CreateEdge(ctx context.Context, e *DAGEdge) error {
if e.EdgeType == "" {
e.EdgeType = "depends_on"
}
metadataJSON, err := json.Marshal(e.Metadata)
if err != nil {
return fmt.Errorf("marshaling edge metadata: %w", err)
}
err = r.db.pool.QueryRow(ctx, `
INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata)
VALUES ($1, $2, $3, $4)
ON CONFLICT (source_node_id, target_node_id, edge_type) DO NOTHING
RETURNING id
`, e.SourceNodeID, e.TargetNodeID, e.EdgeType, metadataJSON).Scan(&e.ID)
if err == pgx.ErrNoRows {
// Edge already exists, not an error
return nil
}
if err != nil {
return fmt.Errorf("creating DAG edge: %w", err)
}
return nil
}
// DeleteEdgesForItem removes all edges for nodes belonging to an item/revision.
func (r *DAGRepository) DeleteEdgesForItem(ctx context.Context, itemID string, revisionNumber int) error {
_, err := r.db.pool.Exec(ctx, `
DELETE FROM dag_edges
WHERE source_node_id IN (
SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2
)
`, itemID, revisionNumber)
if err != nil {
return fmt.Errorf("deleting edges for item: %w", err)
}
return nil
}
// GetForwardCone returns all downstream dependent nodes reachable from the
// given node via edges. This is the key query for interference detection.
func (r *DAGRepository) GetForwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) {
rows, err := r.db.pool.Query(ctx, `
WITH RECURSIVE forward_cone AS (
SELECT target_node_id AS node_id
FROM dag_edges
WHERE source_node_id = $1
UNION
SELECT e.target_node_id
FROM dag_edges e
JOIN forward_cone fc ON fc.node_id = e.source_node_id
)
SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type,
n.properties_hash, n.validation_state, n.validation_msg,
n.metadata, n.created_at, n.updated_at
FROM dag_nodes n
JOIN forward_cone fc ON n.id = fc.node_id
ORDER BY n.node_key
`, nodeID)
if err != nil {
return nil, fmt.Errorf("querying forward cone: %w", err)
}
defer rows.Close()
return scanDAGNodes(rows)
}
// GetBackwardCone returns all upstream dependency nodes that the given
// node depends on.
func (r *DAGRepository) GetBackwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) {
rows, err := r.db.pool.Query(ctx, `
WITH RECURSIVE backward_cone AS (
SELECT source_node_id AS node_id
FROM dag_edges
WHERE target_node_id = $1
UNION
SELECT e.source_node_id
FROM dag_edges e
JOIN backward_cone bc ON bc.node_id = e.target_node_id
)
SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type,
n.properties_hash, n.validation_state, n.validation_msg,
n.metadata, n.created_at, n.updated_at
FROM dag_nodes n
JOIN backward_cone bc ON n.id = bc.node_id
ORDER BY n.node_key
`, nodeID)
if err != nil {
return nil, fmt.Errorf("querying backward cone: %w", err)
}
defer rows.Close()
return scanDAGNodes(rows)
}
// GetDirtySubgraph returns all non-clean nodes for an item.
func (r *DAGRepository) GetDirtySubgraph(ctx context.Context, itemID string) ([]*DAGNode, error) {
rows, err := r.db.pool.Query(ctx, `
SELECT id, item_id, revision_number, node_key, node_type,
properties_hash, validation_state, validation_msg,
metadata, created_at, updated_at
FROM dag_nodes
WHERE item_id = $1 AND validation_state != 'clean'
ORDER BY node_key
`, itemID)
if err != nil {
return nil, fmt.Errorf("querying dirty subgraph: %w", err)
}
defer rows.Close()
return scanDAGNodes(rows)
}
// MarkDirty marks a node and all its downstream dependents as dirty.
func (r *DAGRepository) MarkDirty(ctx context.Context, nodeID string) (int64, error) {
result, err := r.db.pool.Exec(ctx, `
WITH RECURSIVE forward_cone AS (
SELECT $1::uuid AS node_id
UNION
SELECT e.target_node_id
FROM dag_edges e
JOIN forward_cone fc ON fc.node_id = e.source_node_id
)
UPDATE dag_nodes SET validation_state = 'dirty', updated_at = now()
WHERE id IN (SELECT node_id FROM forward_cone)
AND validation_state = 'clean'
`, nodeID)
if err != nil {
return 0, fmt.Errorf("marking dirty: %w", err)
}
return result.RowsAffected(), nil
}
// MarkValidating sets a node's state to 'validating'.
func (r *DAGRepository) MarkValidating(ctx context.Context, nodeID string) error {
_, err := r.db.pool.Exec(ctx, `
UPDATE dag_nodes SET validation_state = 'validating', updated_at = now()
WHERE id = $1
`, nodeID)
if err != nil {
return fmt.Errorf("marking validating: %w", err)
}
return nil
}
// MarkClean sets a node's state to 'clean' and updates its properties hash.
func (r *DAGRepository) MarkClean(ctx context.Context, nodeID string, propertiesHash string) error {
_, err := r.db.pool.Exec(ctx, `
UPDATE dag_nodes
SET validation_state = 'clean',
properties_hash = $2,
validation_msg = NULL,
updated_at = now()
WHERE id = $1
`, nodeID, propertiesHash)
if err != nil {
return fmt.Errorf("marking clean: %w", err)
}
return nil
}
// MarkFailed sets a node's state to 'failed' with an error message.
func (r *DAGRepository) MarkFailed(ctx context.Context, nodeID string, message string) error {
_, err := r.db.pool.Exec(ctx, `
UPDATE dag_nodes
SET validation_state = 'failed',
validation_msg = $2,
updated_at = now()
WHERE id = $1
`, nodeID, message)
if err != nil {
return fmt.Errorf("marking failed: %w", err)
}
return nil
}
// HasCycle checks whether adding an edge from sourceID to targetID would
// create a cycle. It walks upward from sourceID to see if targetID is
// already an ancestor.
func (r *DAGRepository) HasCycle(ctx context.Context, sourceID, targetID string) (bool, error) {
if sourceID == targetID {
return true, nil
}
var hasCycle bool
err := r.db.pool.QueryRow(ctx, `
WITH RECURSIVE ancestors AS (
SELECT source_node_id AS node_id
FROM dag_edges
WHERE target_node_id = $1
UNION
SELECT e.source_node_id
FROM dag_edges e
JOIN ancestors a ON a.node_id = e.target_node_id
)
SELECT EXISTS (
SELECT 1 FROM ancestors WHERE node_id = $2
)
`, sourceID, targetID).Scan(&hasCycle)
if err != nil {
return false, fmt.Errorf("checking for cycle: %w", err)
}
return hasCycle, nil
}
// SyncFeatureTree replaces the entire feature DAG for an item/revision
// within a single transaction. It upserts nodes, replaces edges, and
// marks changed nodes dirty.
func (r *DAGRepository) SyncFeatureTree(ctx context.Context, itemID string, revisionNumber int, nodes []DAGNode, edges []DAGEdge) error {
tx, err := r.db.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
defer tx.Rollback(ctx)
// Upsert all nodes
for i := range nodes {
n := &nodes[i]
n.ItemID = itemID
n.RevisionNumber = revisionNumber
if n.ValidationState == "" {
n.ValidationState = "clean"
}
metadataJSON, err := json.Marshal(n.Metadata)
if err != nil {
return fmt.Errorf("marshaling node metadata: %w", err)
}
err = tx.QueryRow(ctx, `
INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type,
properties_hash, validation_state, validation_msg, metadata)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
ON CONFLICT (item_id, revision_number, node_key)
DO UPDATE SET
node_type = EXCLUDED.node_type,
properties_hash = EXCLUDED.properties_hash,
metadata = EXCLUDED.metadata,
updated_at = now()
RETURNING id, created_at, updated_at
`, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType,
n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON,
).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt)
if err != nil {
return fmt.Errorf("upserting node %s: %w", n.NodeKey, err)
}
}
// Build key→ID map for edge resolution
keyToID := make(map[string]string, len(nodes))
for _, n := range nodes {
keyToID[n.NodeKey] = n.ID
}
// Delete existing edges for this item/revision
_, err = tx.Exec(ctx, `
DELETE FROM dag_edges
WHERE source_node_id IN (
SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2
)
`, itemID, revisionNumber)
if err != nil {
return fmt.Errorf("deleting old edges: %w", err)
}
// Insert new edges
for i := range edges {
e := &edges[i]
if e.EdgeType == "" {
e.EdgeType = "depends_on"
}
// Resolve source/target from node keys if IDs are not set
sourceID := e.SourceNodeID
targetID := e.TargetNodeID
if sourceID == "" {
return fmt.Errorf("edge %d: source_node_id is required", i)
}
if targetID == "" {
return fmt.Errorf("edge %d: target_node_id is required", i)
}
metadataJSON, err := json.Marshal(e.Metadata)
if err != nil {
return fmt.Errorf("marshaling edge metadata: %w", err)
}
err = tx.QueryRow(ctx, `
INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata)
VALUES ($1, $2, $3, $4)
RETURNING id
`, sourceID, targetID, e.EdgeType, metadataJSON).Scan(&e.ID)
if err != nil {
return fmt.Errorf("creating edge: %w", err)
}
}
return tx.Commit(ctx)
}
// DeleteNodesForItem removes all DAG nodes (and cascades to edges) for an item/revision.
func (r *DAGRepository) DeleteNodesForItem(ctx context.Context, itemID string, revisionNumber int) error {
_, err := r.db.pool.Exec(ctx, `
DELETE FROM dag_nodes WHERE item_id = $1 AND revision_number = $2
`, itemID, revisionNumber)
if err != nil {
return fmt.Errorf("deleting nodes for item: %w", err)
}
return nil
}
func scanDAGNodes(rows pgx.Rows) ([]*DAGNode, error) {
var nodes []*DAGNode
for rows.Next() {
n := &DAGNode{}
var metadataJSON []byte
err := rows.Scan(
&n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType,
&n.PropertiesHash, &n.ValidationState, &n.ValidationMsg,
&metadataJSON, &n.CreatedAt, &n.UpdatedAt,
)
if err != nil {
return nil, fmt.Errorf("scanning DAG node: %w", err)
}
if metadataJSON != nil {
if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil {
return nil, fmt.Errorf("unmarshaling node metadata: %w", err)
}
}
nodes = append(nodes, n)
}
return nodes, rows.Err()
}