diff --git a/internal/db/dag.go b/internal/db/dag.go new file mode 100644 index 0000000..8a42cbf --- /dev/null +++ b/internal/db/dag.go @@ -0,0 +1,520 @@ +package db + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/jackc/pgx/v5" +) + +// DAGNode represents a feature-level node in the dependency graph. +type DAGNode struct { + ID string + ItemID string + RevisionNumber int + NodeKey string + NodeType string + PropertiesHash *string + ValidationState string + ValidationMsg *string + Metadata map[string]any + CreatedAt time.Time + UpdatedAt time.Time +} + +// DAGEdge represents a dependency between two nodes. +type DAGEdge struct { + ID string + SourceNodeID string + TargetNodeID string + EdgeType string + Metadata map[string]any +} + +// DAGCrossEdge represents a dependency between nodes in different items. +type DAGCrossEdge struct { + ID string + SourceNodeID string + TargetNodeID string + RelationshipID *string + EdgeType string + Metadata map[string]any +} + +// DAGRepository provides dependency graph database operations. +type DAGRepository struct { + db *DB +} + +// NewDAGRepository creates a new DAG repository. +func NewDAGRepository(db *DB) *DAGRepository { + return &DAGRepository{db: db} +} + +// GetNodes returns all DAG nodes for an item at a specific revision. +func (r *DAGRepository) GetNodes(ctx context.Context, itemID string, revisionNumber int) ([]*DAGNode, error) { + rows, err := r.db.pool.Query(ctx, ` + SELECT id, item_id, revision_number, node_key, node_type, + properties_hash, validation_state, validation_msg, + metadata, created_at, updated_at + FROM dag_nodes + WHERE item_id = $1 AND revision_number = $2 + ORDER BY node_key + `, itemID, revisionNumber) + if err != nil { + return nil, fmt.Errorf("querying DAG nodes: %w", err) + } + defer rows.Close() + return scanDAGNodes(rows) +} + +// GetNodeByKey returns a single DAG node by item, revision, and key. +func (r *DAGRepository) GetNodeByKey(ctx context.Context, itemID string, revisionNumber int, nodeKey string) (*DAGNode, error) { + n := &DAGNode{} + var metadataJSON []byte + err := r.db.pool.QueryRow(ctx, ` + SELECT id, item_id, revision_number, node_key, node_type, + properties_hash, validation_state, validation_msg, + metadata, created_at, updated_at + FROM dag_nodes + WHERE item_id = $1 AND revision_number = $2 AND node_key = $3 + `, itemID, revisionNumber, nodeKey).Scan( + &n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType, + &n.PropertiesHash, &n.ValidationState, &n.ValidationMsg, + &metadataJSON, &n.CreatedAt, &n.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("querying DAG node: %w", err) + } + if metadataJSON != nil { + if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil { + return nil, fmt.Errorf("unmarshaling node metadata: %w", err) + } + } + return n, nil +} + +// GetNodeByID returns a single DAG node by its ID. +func (r *DAGRepository) GetNodeByID(ctx context.Context, nodeID string) (*DAGNode, error) { + n := &DAGNode{} + var metadataJSON []byte + err := r.db.pool.QueryRow(ctx, ` + SELECT id, item_id, revision_number, node_key, node_type, + properties_hash, validation_state, validation_msg, + metadata, created_at, updated_at + FROM dag_nodes + WHERE id = $1 + `, nodeID).Scan( + &n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType, + &n.PropertiesHash, &n.ValidationState, &n.ValidationMsg, + &metadataJSON, &n.CreatedAt, &n.UpdatedAt, + ) + if err == pgx.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, fmt.Errorf("querying DAG node by ID: %w", err) + } + if metadataJSON != nil { + if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil { + return nil, fmt.Errorf("unmarshaling node metadata: %w", err) + } + } + return n, nil +} + +// UpsertNode inserts or updates a single DAG node. +func (r *DAGRepository) UpsertNode(ctx context.Context, n *DAGNode) error { + metadataJSON, err := json.Marshal(n.Metadata) + if err != nil { + return fmt.Errorf("marshaling metadata: %w", err) + } + + err = r.db.pool.QueryRow(ctx, ` + INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type, + properties_hash, validation_state, validation_msg, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (item_id, revision_number, node_key) + DO UPDATE SET + node_type = EXCLUDED.node_type, + properties_hash = EXCLUDED.properties_hash, + validation_state = EXCLUDED.validation_state, + validation_msg = EXCLUDED.validation_msg, + metadata = EXCLUDED.metadata, + updated_at = now() + RETURNING id, created_at, updated_at + `, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType, + n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON, + ).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt) + if err != nil { + return fmt.Errorf("upserting DAG node: %w", err) + } + return nil +} + +// GetEdges returns all edges for nodes belonging to an item at a specific revision. +func (r *DAGRepository) GetEdges(ctx context.Context, itemID string, revisionNumber int) ([]*DAGEdge, error) { + rows, err := r.db.pool.Query(ctx, ` + SELECT e.id, e.source_node_id, e.target_node_id, e.edge_type, e.metadata + FROM dag_edges e + JOIN dag_nodes src ON src.id = e.source_node_id + WHERE src.item_id = $1 AND src.revision_number = $2 + ORDER BY e.source_node_id, e.target_node_id + `, itemID, revisionNumber) + if err != nil { + return nil, fmt.Errorf("querying DAG edges: %w", err) + } + defer rows.Close() + + var edges []*DAGEdge + for rows.Next() { + e := &DAGEdge{} + var metadataJSON []byte + if err := rows.Scan(&e.ID, &e.SourceNodeID, &e.TargetNodeID, &e.EdgeType, &metadataJSON); err != nil { + return nil, fmt.Errorf("scanning DAG edge: %w", err) + } + if metadataJSON != nil { + if err := json.Unmarshal(metadataJSON, &e.Metadata); err != nil { + return nil, fmt.Errorf("unmarshaling edge metadata: %w", err) + } + } + edges = append(edges, e) + } + return edges, rows.Err() +} + +// CreateEdge inserts a new edge between two nodes. +func (r *DAGRepository) CreateEdge(ctx context.Context, e *DAGEdge) error { + if e.EdgeType == "" { + e.EdgeType = "depends_on" + } + metadataJSON, err := json.Marshal(e.Metadata) + if err != nil { + return fmt.Errorf("marshaling edge metadata: %w", err) + } + + err = r.db.pool.QueryRow(ctx, ` + INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata) + VALUES ($1, $2, $3, $4) + ON CONFLICT (source_node_id, target_node_id, edge_type) DO NOTHING + RETURNING id + `, e.SourceNodeID, e.TargetNodeID, e.EdgeType, metadataJSON).Scan(&e.ID) + if err == pgx.ErrNoRows { + // Edge already exists, not an error + return nil + } + if err != nil { + return fmt.Errorf("creating DAG edge: %w", err) + } + return nil +} + +// DeleteEdgesForItem removes all edges for nodes belonging to an item/revision. +func (r *DAGRepository) DeleteEdgesForItem(ctx context.Context, itemID string, revisionNumber int) error { + _, err := r.db.pool.Exec(ctx, ` + DELETE FROM dag_edges + WHERE source_node_id IN ( + SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 + ) + `, itemID, revisionNumber) + if err != nil { + return fmt.Errorf("deleting edges for item: %w", err) + } + return nil +} + +// GetForwardCone returns all downstream dependent nodes reachable from the +// given node via edges. This is the key query for interference detection. +func (r *DAGRepository) GetForwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) { + rows, err := r.db.pool.Query(ctx, ` + WITH RECURSIVE forward_cone AS ( + SELECT target_node_id AS node_id + FROM dag_edges + WHERE source_node_id = $1 + UNION + SELECT e.target_node_id + FROM dag_edges e + JOIN forward_cone fc ON fc.node_id = e.source_node_id + ) + SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type, + n.properties_hash, n.validation_state, n.validation_msg, + n.metadata, n.created_at, n.updated_at + FROM dag_nodes n + JOIN forward_cone fc ON n.id = fc.node_id + ORDER BY n.node_key + `, nodeID) + if err != nil { + return nil, fmt.Errorf("querying forward cone: %w", err) + } + defer rows.Close() + return scanDAGNodes(rows) +} + +// GetBackwardCone returns all upstream dependency nodes that the given +// node depends on. +func (r *DAGRepository) GetBackwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) { + rows, err := r.db.pool.Query(ctx, ` + WITH RECURSIVE backward_cone AS ( + SELECT source_node_id AS node_id + FROM dag_edges + WHERE target_node_id = $1 + UNION + SELECT e.source_node_id + FROM dag_edges e + JOIN backward_cone bc ON bc.node_id = e.target_node_id + ) + SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type, + n.properties_hash, n.validation_state, n.validation_msg, + n.metadata, n.created_at, n.updated_at + FROM dag_nodes n + JOIN backward_cone bc ON n.id = bc.node_id + ORDER BY n.node_key + `, nodeID) + if err != nil { + return nil, fmt.Errorf("querying backward cone: %w", err) + } + defer rows.Close() + return scanDAGNodes(rows) +} + +// GetDirtySubgraph returns all non-clean nodes for an item. +func (r *DAGRepository) GetDirtySubgraph(ctx context.Context, itemID string) ([]*DAGNode, error) { + rows, err := r.db.pool.Query(ctx, ` + SELECT id, item_id, revision_number, node_key, node_type, + properties_hash, validation_state, validation_msg, + metadata, created_at, updated_at + FROM dag_nodes + WHERE item_id = $1 AND validation_state != 'clean' + ORDER BY node_key + `, itemID) + if err != nil { + return nil, fmt.Errorf("querying dirty subgraph: %w", err) + } + defer rows.Close() + return scanDAGNodes(rows) +} + +// MarkDirty marks a node and all its downstream dependents as dirty. +func (r *DAGRepository) MarkDirty(ctx context.Context, nodeID string) (int64, error) { + result, err := r.db.pool.Exec(ctx, ` + WITH RECURSIVE forward_cone AS ( + SELECT $1::uuid AS node_id + UNION + SELECT e.target_node_id + FROM dag_edges e + JOIN forward_cone fc ON fc.node_id = e.source_node_id + ) + UPDATE dag_nodes SET validation_state = 'dirty', updated_at = now() + WHERE id IN (SELECT node_id FROM forward_cone) + AND validation_state = 'clean' + `, nodeID) + if err != nil { + return 0, fmt.Errorf("marking dirty: %w", err) + } + return result.RowsAffected(), nil +} + +// MarkValidating sets a node's state to 'validating'. +func (r *DAGRepository) MarkValidating(ctx context.Context, nodeID string) error { + _, err := r.db.pool.Exec(ctx, ` + UPDATE dag_nodes SET validation_state = 'validating', updated_at = now() + WHERE id = $1 + `, nodeID) + if err != nil { + return fmt.Errorf("marking validating: %w", err) + } + return nil +} + +// MarkClean sets a node's state to 'clean' and updates its properties hash. +func (r *DAGRepository) MarkClean(ctx context.Context, nodeID string, propertiesHash string) error { + _, err := r.db.pool.Exec(ctx, ` + UPDATE dag_nodes + SET validation_state = 'clean', + properties_hash = $2, + validation_msg = NULL, + updated_at = now() + WHERE id = $1 + `, nodeID, propertiesHash) + if err != nil { + return fmt.Errorf("marking clean: %w", err) + } + return nil +} + +// MarkFailed sets a node's state to 'failed' with an error message. +func (r *DAGRepository) MarkFailed(ctx context.Context, nodeID string, message string) error { + _, err := r.db.pool.Exec(ctx, ` + UPDATE dag_nodes + SET validation_state = 'failed', + validation_msg = $2, + updated_at = now() + WHERE id = $1 + `, nodeID, message) + if err != nil { + return fmt.Errorf("marking failed: %w", err) + } + return nil +} + +// HasCycle checks whether adding an edge from sourceID to targetID would +// create a cycle. It walks upward from sourceID to see if targetID is +// already an ancestor. +func (r *DAGRepository) HasCycle(ctx context.Context, sourceID, targetID string) (bool, error) { + if sourceID == targetID { + return true, nil + } + var hasCycle bool + err := r.db.pool.QueryRow(ctx, ` + WITH RECURSIVE ancestors AS ( + SELECT source_node_id AS node_id + FROM dag_edges + WHERE target_node_id = $1 + UNION + SELECT e.source_node_id + FROM dag_edges e + JOIN ancestors a ON a.node_id = e.target_node_id + ) + SELECT EXISTS ( + SELECT 1 FROM ancestors WHERE node_id = $2 + ) + `, sourceID, targetID).Scan(&hasCycle) + if err != nil { + return false, fmt.Errorf("checking for cycle: %w", err) + } + return hasCycle, nil +} + +// SyncFeatureTree replaces the entire feature DAG for an item/revision +// within a single transaction. It upserts nodes, replaces edges, and +// marks changed nodes dirty. +func (r *DAGRepository) SyncFeatureTree(ctx context.Context, itemID string, revisionNumber int, nodes []DAGNode, edges []DAGEdge) error { + tx, err := r.db.pool.Begin(ctx) + if err != nil { + return fmt.Errorf("beginning transaction: %w", err) + } + defer tx.Rollback(ctx) + + // Upsert all nodes + for i := range nodes { + n := &nodes[i] + n.ItemID = itemID + n.RevisionNumber = revisionNumber + if n.ValidationState == "" { + n.ValidationState = "clean" + } + + metadataJSON, err := json.Marshal(n.Metadata) + if err != nil { + return fmt.Errorf("marshaling node metadata: %w", err) + } + + err = tx.QueryRow(ctx, ` + INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type, + properties_hash, validation_state, validation_msg, metadata) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8) + ON CONFLICT (item_id, revision_number, node_key) + DO UPDATE SET + node_type = EXCLUDED.node_type, + properties_hash = EXCLUDED.properties_hash, + metadata = EXCLUDED.metadata, + updated_at = now() + RETURNING id, created_at, updated_at + `, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType, + n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON, + ).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt) + if err != nil { + return fmt.Errorf("upserting node %s: %w", n.NodeKey, err) + } + } + + // Build key→ID map for edge resolution + keyToID := make(map[string]string, len(nodes)) + for _, n := range nodes { + keyToID[n.NodeKey] = n.ID + } + + // Delete existing edges for this item/revision + _, err = tx.Exec(ctx, ` + DELETE FROM dag_edges + WHERE source_node_id IN ( + SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 + ) + `, itemID, revisionNumber) + if err != nil { + return fmt.Errorf("deleting old edges: %w", err) + } + + // Insert new edges + for i := range edges { + e := &edges[i] + if e.EdgeType == "" { + e.EdgeType = "depends_on" + } + + // Resolve source/target from node keys if IDs are not set + sourceID := e.SourceNodeID + targetID := e.TargetNodeID + if sourceID == "" { + return fmt.Errorf("edge %d: source_node_id is required", i) + } + if targetID == "" { + return fmt.Errorf("edge %d: target_node_id is required", i) + } + + metadataJSON, err := json.Marshal(e.Metadata) + if err != nil { + return fmt.Errorf("marshaling edge metadata: %w", err) + } + + err = tx.QueryRow(ctx, ` + INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata) + VALUES ($1, $2, $3, $4) + RETURNING id + `, sourceID, targetID, e.EdgeType, metadataJSON).Scan(&e.ID) + if err != nil { + return fmt.Errorf("creating edge: %w", err) + } + } + + return tx.Commit(ctx) +} + +// DeleteNodesForItem removes all DAG nodes (and cascades to edges) for an item/revision. +func (r *DAGRepository) DeleteNodesForItem(ctx context.Context, itemID string, revisionNumber int) error { + _, err := r.db.pool.Exec(ctx, ` + DELETE FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 + `, itemID, revisionNumber) + if err != nil { + return fmt.Errorf("deleting nodes for item: %w", err) + } + return nil +} + +func scanDAGNodes(rows pgx.Rows) ([]*DAGNode, error) { + var nodes []*DAGNode + for rows.Next() { + n := &DAGNode{} + var metadataJSON []byte + err := rows.Scan( + &n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType, + &n.PropertiesHash, &n.ValidationState, &n.ValidationMsg, + &metadataJSON, &n.CreatedAt, &n.UpdatedAt, + ) + if err != nil { + return nil, fmt.Errorf("scanning DAG node: %w", err) + } + if metadataJSON != nil { + if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil { + return nil, fmt.Errorf("unmarshaling node metadata: %w", err) + } + } + nodes = append(nodes, n) + } + return nodes, rows.Err() +}