package db import ( "context" "encoding/json" "fmt" "time" "github.com/jackc/pgx/v5" ) // DAGNode represents a feature-level node in the dependency graph. type DAGNode struct { ID string ItemID string RevisionNumber int NodeKey string NodeType string PropertiesHash *string ValidationState string ValidationMsg *string Metadata map[string]any CreatedAt time.Time UpdatedAt time.Time } // DAGEdge represents a dependency between two nodes. type DAGEdge struct { ID string SourceNodeID string TargetNodeID string EdgeType string Metadata map[string]any } // DAGCrossEdge represents a dependency between nodes in different items. type DAGCrossEdge struct { ID string SourceNodeID string TargetNodeID string RelationshipID *string EdgeType string Metadata map[string]any } // DAGRepository provides dependency graph database operations. type DAGRepository struct { db *DB } // NewDAGRepository creates a new DAG repository. func NewDAGRepository(db *DB) *DAGRepository { return &DAGRepository{db: db} } // GetNodes returns all DAG nodes for an item at a specific revision. func (r *DAGRepository) GetNodes(ctx context.Context, itemID string, revisionNumber int) ([]*DAGNode, error) { rows, err := r.db.pool.Query(ctx, ` SELECT id, item_id, revision_number, node_key, node_type, properties_hash, validation_state, validation_msg, metadata, created_at, updated_at FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 ORDER BY node_key `, itemID, revisionNumber) if err != nil { return nil, fmt.Errorf("querying DAG nodes: %w", err) } defer rows.Close() return scanDAGNodes(rows) } // GetNodeByKey returns a single DAG node by item, revision, and key. func (r *DAGRepository) GetNodeByKey(ctx context.Context, itemID string, revisionNumber int, nodeKey string) (*DAGNode, error) { n := &DAGNode{} var metadataJSON []byte err := r.db.pool.QueryRow(ctx, ` SELECT id, item_id, revision_number, node_key, node_type, properties_hash, validation_state, validation_msg, metadata, created_at, updated_at FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 AND node_key = $3 `, itemID, revisionNumber, nodeKey).Scan( &n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType, &n.PropertiesHash, &n.ValidationState, &n.ValidationMsg, &metadataJSON, &n.CreatedAt, &n.UpdatedAt, ) if err == pgx.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("querying DAG node: %w", err) } if metadataJSON != nil { if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil { return nil, fmt.Errorf("unmarshaling node metadata: %w", err) } } return n, nil } // GetNodeByID returns a single DAG node by its ID. func (r *DAGRepository) GetNodeByID(ctx context.Context, nodeID string) (*DAGNode, error) { n := &DAGNode{} var metadataJSON []byte err := r.db.pool.QueryRow(ctx, ` SELECT id, item_id, revision_number, node_key, node_type, properties_hash, validation_state, validation_msg, metadata, created_at, updated_at FROM dag_nodes WHERE id = $1 `, nodeID).Scan( &n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType, &n.PropertiesHash, &n.ValidationState, &n.ValidationMsg, &metadataJSON, &n.CreatedAt, &n.UpdatedAt, ) if err == pgx.ErrNoRows { return nil, nil } if err != nil { return nil, fmt.Errorf("querying DAG node by ID: %w", err) } if metadataJSON != nil { if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil { return nil, fmt.Errorf("unmarshaling node metadata: %w", err) } } return n, nil } // UpsertNode inserts or updates a single DAG node. func (r *DAGRepository) UpsertNode(ctx context.Context, n *DAGNode) error { metadataJSON, err := json.Marshal(n.Metadata) if err != nil { return fmt.Errorf("marshaling metadata: %w", err) } err = r.db.pool.QueryRow(ctx, ` INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type, properties_hash, validation_state, validation_msg, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (item_id, revision_number, node_key) DO UPDATE SET node_type = EXCLUDED.node_type, properties_hash = EXCLUDED.properties_hash, validation_state = EXCLUDED.validation_state, validation_msg = EXCLUDED.validation_msg, metadata = EXCLUDED.metadata, updated_at = now() RETURNING id, created_at, updated_at `, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType, n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON, ).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt) if err != nil { return fmt.Errorf("upserting DAG node: %w", err) } return nil } // GetEdges returns all edges for nodes belonging to an item at a specific revision. func (r *DAGRepository) GetEdges(ctx context.Context, itemID string, revisionNumber int) ([]*DAGEdge, error) { rows, err := r.db.pool.Query(ctx, ` SELECT e.id, e.source_node_id, e.target_node_id, e.edge_type, e.metadata FROM dag_edges e JOIN dag_nodes src ON src.id = e.source_node_id WHERE src.item_id = $1 AND src.revision_number = $2 ORDER BY e.source_node_id, e.target_node_id `, itemID, revisionNumber) if err != nil { return nil, fmt.Errorf("querying DAG edges: %w", err) } defer rows.Close() var edges []*DAGEdge for rows.Next() { e := &DAGEdge{} var metadataJSON []byte if err := rows.Scan(&e.ID, &e.SourceNodeID, &e.TargetNodeID, &e.EdgeType, &metadataJSON); err != nil { return nil, fmt.Errorf("scanning DAG edge: %w", err) } if metadataJSON != nil { if err := json.Unmarshal(metadataJSON, &e.Metadata); err != nil { return nil, fmt.Errorf("unmarshaling edge metadata: %w", err) } } edges = append(edges, e) } return edges, rows.Err() } // CreateEdge inserts a new edge between two nodes. func (r *DAGRepository) CreateEdge(ctx context.Context, e *DAGEdge) error { if e.EdgeType == "" { e.EdgeType = "depends_on" } metadataJSON, err := json.Marshal(e.Metadata) if err != nil { return fmt.Errorf("marshaling edge metadata: %w", err) } err = r.db.pool.QueryRow(ctx, ` INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata) VALUES ($1, $2, $3, $4) ON CONFLICT (source_node_id, target_node_id, edge_type) DO NOTHING RETURNING id `, e.SourceNodeID, e.TargetNodeID, e.EdgeType, metadataJSON).Scan(&e.ID) if err == pgx.ErrNoRows { // Edge already exists, not an error return nil } if err != nil { return fmt.Errorf("creating DAG edge: %w", err) } return nil } // DeleteEdgesForItem removes all edges for nodes belonging to an item/revision. func (r *DAGRepository) DeleteEdgesForItem(ctx context.Context, itemID string, revisionNumber int) error { _, err := r.db.pool.Exec(ctx, ` DELETE FROM dag_edges WHERE source_node_id IN ( SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 ) `, itemID, revisionNumber) if err != nil { return fmt.Errorf("deleting edges for item: %w", err) } return nil } // GetForwardCone returns all downstream dependent nodes reachable from the // given node via edges. This is the key query for interference detection. func (r *DAGRepository) GetForwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) { rows, err := r.db.pool.Query(ctx, ` WITH RECURSIVE forward_cone AS ( SELECT target_node_id AS node_id FROM dag_edges WHERE source_node_id = $1 UNION SELECT e.target_node_id FROM dag_edges e JOIN forward_cone fc ON fc.node_id = e.source_node_id ) SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type, n.properties_hash, n.validation_state, n.validation_msg, n.metadata, n.created_at, n.updated_at FROM dag_nodes n JOIN forward_cone fc ON n.id = fc.node_id ORDER BY n.node_key `, nodeID) if err != nil { return nil, fmt.Errorf("querying forward cone: %w", err) } defer rows.Close() return scanDAGNodes(rows) } // GetBackwardCone returns all upstream dependency nodes that the given // node depends on. func (r *DAGRepository) GetBackwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) { rows, err := r.db.pool.Query(ctx, ` WITH RECURSIVE backward_cone AS ( SELECT source_node_id AS node_id FROM dag_edges WHERE target_node_id = $1 UNION SELECT e.source_node_id FROM dag_edges e JOIN backward_cone bc ON bc.node_id = e.target_node_id ) SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type, n.properties_hash, n.validation_state, n.validation_msg, n.metadata, n.created_at, n.updated_at FROM dag_nodes n JOIN backward_cone bc ON n.id = bc.node_id ORDER BY n.node_key `, nodeID) if err != nil { return nil, fmt.Errorf("querying backward cone: %w", err) } defer rows.Close() return scanDAGNodes(rows) } // GetDirtySubgraph returns all non-clean nodes for an item. func (r *DAGRepository) GetDirtySubgraph(ctx context.Context, itemID string) ([]*DAGNode, error) { rows, err := r.db.pool.Query(ctx, ` SELECT id, item_id, revision_number, node_key, node_type, properties_hash, validation_state, validation_msg, metadata, created_at, updated_at FROM dag_nodes WHERE item_id = $1 AND validation_state != 'clean' ORDER BY node_key `, itemID) if err != nil { return nil, fmt.Errorf("querying dirty subgraph: %w", err) } defer rows.Close() return scanDAGNodes(rows) } // MarkDirty marks a node and all its downstream dependents as dirty. func (r *DAGRepository) MarkDirty(ctx context.Context, nodeID string) (int64, error) { result, err := r.db.pool.Exec(ctx, ` WITH RECURSIVE forward_cone AS ( SELECT $1::uuid AS node_id UNION SELECT e.target_node_id FROM dag_edges e JOIN forward_cone fc ON fc.node_id = e.source_node_id ) UPDATE dag_nodes SET validation_state = 'dirty', updated_at = now() WHERE id IN (SELECT node_id FROM forward_cone) AND validation_state = 'clean' `, nodeID) if err != nil { return 0, fmt.Errorf("marking dirty: %w", err) } return result.RowsAffected(), nil } // MarkValidating sets a node's state to 'validating'. func (r *DAGRepository) MarkValidating(ctx context.Context, nodeID string) error { _, err := r.db.pool.Exec(ctx, ` UPDATE dag_nodes SET validation_state = 'validating', updated_at = now() WHERE id = $1 `, nodeID) if err != nil { return fmt.Errorf("marking validating: %w", err) } return nil } // MarkClean sets a node's state to 'clean' and updates its properties hash. func (r *DAGRepository) MarkClean(ctx context.Context, nodeID string, propertiesHash string) error { _, err := r.db.pool.Exec(ctx, ` UPDATE dag_nodes SET validation_state = 'clean', properties_hash = $2, validation_msg = NULL, updated_at = now() WHERE id = $1 `, nodeID, propertiesHash) if err != nil { return fmt.Errorf("marking clean: %w", err) } return nil } // MarkFailed sets a node's state to 'failed' with an error message. func (r *DAGRepository) MarkFailed(ctx context.Context, nodeID string, message string) error { _, err := r.db.pool.Exec(ctx, ` UPDATE dag_nodes SET validation_state = 'failed', validation_msg = $2, updated_at = now() WHERE id = $1 `, nodeID, message) if err != nil { return fmt.Errorf("marking failed: %w", err) } return nil } // HasCycle checks whether adding an edge from sourceID to targetID would // create a cycle. It walks upward from sourceID to see if targetID is // already an ancestor. func (r *DAGRepository) HasCycle(ctx context.Context, sourceID, targetID string) (bool, error) { if sourceID == targetID { return true, nil } var hasCycle bool err := r.db.pool.QueryRow(ctx, ` WITH RECURSIVE ancestors AS ( SELECT source_node_id AS node_id FROM dag_edges WHERE target_node_id = $1 UNION SELECT e.source_node_id FROM dag_edges e JOIN ancestors a ON a.node_id = e.target_node_id ) SELECT EXISTS ( SELECT 1 FROM ancestors WHERE node_id = $2 ) `, sourceID, targetID).Scan(&hasCycle) if err != nil { return false, fmt.Errorf("checking for cycle: %w", err) } return hasCycle, nil } // SyncFeatureTree replaces the entire feature DAG for an item/revision // within a single transaction. It upserts nodes, replaces edges, and // marks changed nodes dirty. func (r *DAGRepository) SyncFeatureTree(ctx context.Context, itemID string, revisionNumber int, nodes []DAGNode, edges []DAGEdge) error { tx, err := r.db.pool.Begin(ctx) if err != nil { return fmt.Errorf("beginning transaction: %w", err) } defer tx.Rollback(ctx) // Upsert all nodes for i := range nodes { n := &nodes[i] n.ItemID = itemID n.RevisionNumber = revisionNumber if n.ValidationState == "" { n.ValidationState = "clean" } metadataJSON, err := json.Marshal(n.Metadata) if err != nil { return fmt.Errorf("marshaling node metadata: %w", err) } err = tx.QueryRow(ctx, ` INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type, properties_hash, validation_state, validation_msg, metadata) VALUES ($1, $2, $3, $4, $5, $6, $7, $8) ON CONFLICT (item_id, revision_number, node_key) DO UPDATE SET node_type = EXCLUDED.node_type, properties_hash = EXCLUDED.properties_hash, metadata = EXCLUDED.metadata, updated_at = now() RETURNING id, created_at, updated_at `, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType, n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON, ).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt) if err != nil { return fmt.Errorf("upserting node %s: %w", n.NodeKey, err) } } // Build key→ID map for edge resolution keyToID := make(map[string]string, len(nodes)) for _, n := range nodes { keyToID[n.NodeKey] = n.ID } // Delete existing edges for this item/revision _, err = tx.Exec(ctx, ` DELETE FROM dag_edges WHERE source_node_id IN ( SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 ) `, itemID, revisionNumber) if err != nil { return fmt.Errorf("deleting old edges: %w", err) } // Insert new edges for i := range edges { e := &edges[i] if e.EdgeType == "" { e.EdgeType = "depends_on" } // Resolve source/target from node keys if IDs are not set sourceID := e.SourceNodeID targetID := e.TargetNodeID if sourceID == "" { return fmt.Errorf("edge %d: source_node_id is required", i) } if targetID == "" { return fmt.Errorf("edge %d: target_node_id is required", i) } metadataJSON, err := json.Marshal(e.Metadata) if err != nil { return fmt.Errorf("marshaling edge metadata: %w", err) } err = tx.QueryRow(ctx, ` INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata) VALUES ($1, $2, $3, $4) RETURNING id `, sourceID, targetID, e.EdgeType, metadataJSON).Scan(&e.ID) if err != nil { return fmt.Errorf("creating edge: %w", err) } } return tx.Commit(ctx) } // DeleteNodesForItem removes all DAG nodes (and cascades to edges) for an item/revision. func (r *DAGRepository) DeleteNodesForItem(ctx context.Context, itemID string, revisionNumber int) error { _, err := r.db.pool.Exec(ctx, ` DELETE FROM dag_nodes WHERE item_id = $1 AND revision_number = $2 `, itemID, revisionNumber) if err != nil { return fmt.Errorf("deleting nodes for item: %w", err) } return nil } func scanDAGNodes(rows pgx.Rows) ([]*DAGNode, error) { var nodes []*DAGNode for rows.Next() { n := &DAGNode{} var metadataJSON []byte err := rows.Scan( &n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType, &n.PropertiesHash, &n.ValidationState, &n.ValidationMsg, &metadataJSON, &n.CreatedAt, &n.UpdatedAt, ) if err != nil { return nil, fmt.Errorf("scanning DAG node: %w", err) } if metadataJSON != nil { if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil { return nil, fmt.Errorf("unmarshaling node metadata: %w", err) } } nodes = append(nodes, n) } return nodes, rows.Err() }