521 lines
16 KiB
Go
521 lines
16 KiB
Go
package db
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
)
|
|
|
|
// DAGNode represents a feature-level node in the dependency graph.
|
|
type DAGNode struct {
|
|
ID string
|
|
ItemID string
|
|
RevisionNumber int
|
|
NodeKey string
|
|
NodeType string
|
|
PropertiesHash *string
|
|
ValidationState string
|
|
ValidationMsg *string
|
|
Metadata map[string]any
|
|
CreatedAt time.Time
|
|
UpdatedAt time.Time
|
|
}
|
|
|
|
// DAGEdge represents a dependency between two nodes.
|
|
type DAGEdge struct {
|
|
ID string
|
|
SourceNodeID string
|
|
TargetNodeID string
|
|
EdgeType string
|
|
Metadata map[string]any
|
|
}
|
|
|
|
// DAGCrossEdge represents a dependency between nodes in different items.
|
|
type DAGCrossEdge struct {
|
|
ID string
|
|
SourceNodeID string
|
|
TargetNodeID string
|
|
RelationshipID *string
|
|
EdgeType string
|
|
Metadata map[string]any
|
|
}
|
|
|
|
// DAGRepository provides dependency graph database operations.
|
|
type DAGRepository struct {
|
|
db *DB
|
|
}
|
|
|
|
// NewDAGRepository creates a new DAG repository.
|
|
func NewDAGRepository(db *DB) *DAGRepository {
|
|
return &DAGRepository{db: db}
|
|
}
|
|
|
|
// GetNodes returns all DAG nodes for an item at a specific revision.
|
|
func (r *DAGRepository) GetNodes(ctx context.Context, itemID string, revisionNumber int) ([]*DAGNode, error) {
|
|
rows, err := r.db.pool.Query(ctx, `
|
|
SELECT id, item_id, revision_number, node_key, node_type,
|
|
properties_hash, validation_state, validation_msg,
|
|
metadata, created_at, updated_at
|
|
FROM dag_nodes
|
|
WHERE item_id = $1 AND revision_number = $2
|
|
ORDER BY node_key
|
|
`, itemID, revisionNumber)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying DAG nodes: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanDAGNodes(rows)
|
|
}
|
|
|
|
// GetNodeByKey returns a single DAG node by item, revision, and key.
|
|
func (r *DAGRepository) GetNodeByKey(ctx context.Context, itemID string, revisionNumber int, nodeKey string) (*DAGNode, error) {
|
|
n := &DAGNode{}
|
|
var metadataJSON []byte
|
|
err := r.db.pool.QueryRow(ctx, `
|
|
SELECT id, item_id, revision_number, node_key, node_type,
|
|
properties_hash, validation_state, validation_msg,
|
|
metadata, created_at, updated_at
|
|
FROM dag_nodes
|
|
WHERE item_id = $1 AND revision_number = $2 AND node_key = $3
|
|
`, itemID, revisionNumber, nodeKey).Scan(
|
|
&n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType,
|
|
&n.PropertiesHash, &n.ValidationState, &n.ValidationMsg,
|
|
&metadataJSON, &n.CreatedAt, &n.UpdatedAt,
|
|
)
|
|
if err == pgx.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying DAG node: %w", err)
|
|
}
|
|
if metadataJSON != nil {
|
|
if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling node metadata: %w", err)
|
|
}
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// GetNodeByID returns a single DAG node by its ID.
|
|
func (r *DAGRepository) GetNodeByID(ctx context.Context, nodeID string) (*DAGNode, error) {
|
|
n := &DAGNode{}
|
|
var metadataJSON []byte
|
|
err := r.db.pool.QueryRow(ctx, `
|
|
SELECT id, item_id, revision_number, node_key, node_type,
|
|
properties_hash, validation_state, validation_msg,
|
|
metadata, created_at, updated_at
|
|
FROM dag_nodes
|
|
WHERE id = $1
|
|
`, nodeID).Scan(
|
|
&n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType,
|
|
&n.PropertiesHash, &n.ValidationState, &n.ValidationMsg,
|
|
&metadataJSON, &n.CreatedAt, &n.UpdatedAt,
|
|
)
|
|
if err == pgx.ErrNoRows {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying DAG node by ID: %w", err)
|
|
}
|
|
if metadataJSON != nil {
|
|
if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling node metadata: %w", err)
|
|
}
|
|
}
|
|
return n, nil
|
|
}
|
|
|
|
// UpsertNode inserts or updates a single DAG node.
|
|
func (r *DAGRepository) UpsertNode(ctx context.Context, n *DAGNode) error {
|
|
metadataJSON, err := json.Marshal(n.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling metadata: %w", err)
|
|
}
|
|
|
|
err = r.db.pool.QueryRow(ctx, `
|
|
INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type,
|
|
properties_hash, validation_state, validation_msg, metadata)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
ON CONFLICT (item_id, revision_number, node_key)
|
|
DO UPDATE SET
|
|
node_type = EXCLUDED.node_type,
|
|
properties_hash = EXCLUDED.properties_hash,
|
|
validation_state = EXCLUDED.validation_state,
|
|
validation_msg = EXCLUDED.validation_msg,
|
|
metadata = EXCLUDED.metadata,
|
|
updated_at = now()
|
|
RETURNING id, created_at, updated_at
|
|
`, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType,
|
|
n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON,
|
|
).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt)
|
|
if err != nil {
|
|
return fmt.Errorf("upserting DAG node: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetEdges returns all edges for nodes belonging to an item at a specific revision.
|
|
func (r *DAGRepository) GetEdges(ctx context.Context, itemID string, revisionNumber int) ([]*DAGEdge, error) {
|
|
rows, err := r.db.pool.Query(ctx, `
|
|
SELECT e.id, e.source_node_id, e.target_node_id, e.edge_type, e.metadata
|
|
FROM dag_edges e
|
|
JOIN dag_nodes src ON src.id = e.source_node_id
|
|
WHERE src.item_id = $1 AND src.revision_number = $2
|
|
ORDER BY e.source_node_id, e.target_node_id
|
|
`, itemID, revisionNumber)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying DAG edges: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
|
|
var edges []*DAGEdge
|
|
for rows.Next() {
|
|
e := &DAGEdge{}
|
|
var metadataJSON []byte
|
|
if err := rows.Scan(&e.ID, &e.SourceNodeID, &e.TargetNodeID, &e.EdgeType, &metadataJSON); err != nil {
|
|
return nil, fmt.Errorf("scanning DAG edge: %w", err)
|
|
}
|
|
if metadataJSON != nil {
|
|
if err := json.Unmarshal(metadataJSON, &e.Metadata); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling edge metadata: %w", err)
|
|
}
|
|
}
|
|
edges = append(edges, e)
|
|
}
|
|
return edges, rows.Err()
|
|
}
|
|
|
|
// CreateEdge inserts a new edge between two nodes.
|
|
func (r *DAGRepository) CreateEdge(ctx context.Context, e *DAGEdge) error {
|
|
if e.EdgeType == "" {
|
|
e.EdgeType = "depends_on"
|
|
}
|
|
metadataJSON, err := json.Marshal(e.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling edge metadata: %w", err)
|
|
}
|
|
|
|
err = r.db.pool.QueryRow(ctx, `
|
|
INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata)
|
|
VALUES ($1, $2, $3, $4)
|
|
ON CONFLICT (source_node_id, target_node_id, edge_type) DO NOTHING
|
|
RETURNING id
|
|
`, e.SourceNodeID, e.TargetNodeID, e.EdgeType, metadataJSON).Scan(&e.ID)
|
|
if err == pgx.ErrNoRows {
|
|
// Edge already exists, not an error
|
|
return nil
|
|
}
|
|
if err != nil {
|
|
return fmt.Errorf("creating DAG edge: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// DeleteEdgesForItem removes all edges for nodes belonging to an item/revision.
|
|
func (r *DAGRepository) DeleteEdgesForItem(ctx context.Context, itemID string, revisionNumber int) error {
|
|
_, err := r.db.pool.Exec(ctx, `
|
|
DELETE FROM dag_edges
|
|
WHERE source_node_id IN (
|
|
SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2
|
|
)
|
|
`, itemID, revisionNumber)
|
|
if err != nil {
|
|
return fmt.Errorf("deleting edges for item: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// GetForwardCone returns all downstream dependent nodes reachable from the
|
|
// given node via edges. This is the key query for interference detection.
|
|
func (r *DAGRepository) GetForwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) {
|
|
rows, err := r.db.pool.Query(ctx, `
|
|
WITH RECURSIVE forward_cone AS (
|
|
SELECT target_node_id AS node_id
|
|
FROM dag_edges
|
|
WHERE source_node_id = $1
|
|
UNION
|
|
SELECT e.target_node_id
|
|
FROM dag_edges e
|
|
JOIN forward_cone fc ON fc.node_id = e.source_node_id
|
|
)
|
|
SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type,
|
|
n.properties_hash, n.validation_state, n.validation_msg,
|
|
n.metadata, n.created_at, n.updated_at
|
|
FROM dag_nodes n
|
|
JOIN forward_cone fc ON n.id = fc.node_id
|
|
ORDER BY n.node_key
|
|
`, nodeID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying forward cone: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanDAGNodes(rows)
|
|
}
|
|
|
|
// GetBackwardCone returns all upstream dependency nodes that the given
|
|
// node depends on.
|
|
func (r *DAGRepository) GetBackwardCone(ctx context.Context, nodeID string) ([]*DAGNode, error) {
|
|
rows, err := r.db.pool.Query(ctx, `
|
|
WITH RECURSIVE backward_cone AS (
|
|
SELECT source_node_id AS node_id
|
|
FROM dag_edges
|
|
WHERE target_node_id = $1
|
|
UNION
|
|
SELECT e.source_node_id
|
|
FROM dag_edges e
|
|
JOIN backward_cone bc ON bc.node_id = e.target_node_id
|
|
)
|
|
SELECT n.id, n.item_id, n.revision_number, n.node_key, n.node_type,
|
|
n.properties_hash, n.validation_state, n.validation_msg,
|
|
n.metadata, n.created_at, n.updated_at
|
|
FROM dag_nodes n
|
|
JOIN backward_cone bc ON n.id = bc.node_id
|
|
ORDER BY n.node_key
|
|
`, nodeID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying backward cone: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanDAGNodes(rows)
|
|
}
|
|
|
|
// GetDirtySubgraph returns all non-clean nodes for an item.
|
|
func (r *DAGRepository) GetDirtySubgraph(ctx context.Context, itemID string) ([]*DAGNode, error) {
|
|
rows, err := r.db.pool.Query(ctx, `
|
|
SELECT id, item_id, revision_number, node_key, node_type,
|
|
properties_hash, validation_state, validation_msg,
|
|
metadata, created_at, updated_at
|
|
FROM dag_nodes
|
|
WHERE item_id = $1 AND validation_state != 'clean'
|
|
ORDER BY node_key
|
|
`, itemID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("querying dirty subgraph: %w", err)
|
|
}
|
|
defer rows.Close()
|
|
return scanDAGNodes(rows)
|
|
}
|
|
|
|
// MarkDirty marks a node and all its downstream dependents as dirty.
|
|
func (r *DAGRepository) MarkDirty(ctx context.Context, nodeID string) (int64, error) {
|
|
result, err := r.db.pool.Exec(ctx, `
|
|
WITH RECURSIVE forward_cone AS (
|
|
SELECT $1::uuid AS node_id
|
|
UNION
|
|
SELECT e.target_node_id
|
|
FROM dag_edges e
|
|
JOIN forward_cone fc ON fc.node_id = e.source_node_id
|
|
)
|
|
UPDATE dag_nodes SET validation_state = 'dirty', updated_at = now()
|
|
WHERE id IN (SELECT node_id FROM forward_cone)
|
|
AND validation_state = 'clean'
|
|
`, nodeID)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("marking dirty: %w", err)
|
|
}
|
|
return result.RowsAffected(), nil
|
|
}
|
|
|
|
// MarkValidating sets a node's state to 'validating'.
|
|
func (r *DAGRepository) MarkValidating(ctx context.Context, nodeID string) error {
|
|
_, err := r.db.pool.Exec(ctx, `
|
|
UPDATE dag_nodes SET validation_state = 'validating', updated_at = now()
|
|
WHERE id = $1
|
|
`, nodeID)
|
|
if err != nil {
|
|
return fmt.Errorf("marking validating: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MarkClean sets a node's state to 'clean' and updates its properties hash.
|
|
func (r *DAGRepository) MarkClean(ctx context.Context, nodeID string, propertiesHash string) error {
|
|
_, err := r.db.pool.Exec(ctx, `
|
|
UPDATE dag_nodes
|
|
SET validation_state = 'clean',
|
|
properties_hash = $2,
|
|
validation_msg = NULL,
|
|
updated_at = now()
|
|
WHERE id = $1
|
|
`, nodeID, propertiesHash)
|
|
if err != nil {
|
|
return fmt.Errorf("marking clean: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// MarkFailed sets a node's state to 'failed' with an error message.
|
|
func (r *DAGRepository) MarkFailed(ctx context.Context, nodeID string, message string) error {
|
|
_, err := r.db.pool.Exec(ctx, `
|
|
UPDATE dag_nodes
|
|
SET validation_state = 'failed',
|
|
validation_msg = $2,
|
|
updated_at = now()
|
|
WHERE id = $1
|
|
`, nodeID, message)
|
|
if err != nil {
|
|
return fmt.Errorf("marking failed: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// HasCycle checks whether adding an edge from sourceID to targetID would
|
|
// create a cycle. It walks upward from sourceID to see if targetID is
|
|
// already an ancestor.
|
|
func (r *DAGRepository) HasCycle(ctx context.Context, sourceID, targetID string) (bool, error) {
|
|
if sourceID == targetID {
|
|
return true, nil
|
|
}
|
|
var hasCycle bool
|
|
err := r.db.pool.QueryRow(ctx, `
|
|
WITH RECURSIVE ancestors AS (
|
|
SELECT source_node_id AS node_id
|
|
FROM dag_edges
|
|
WHERE target_node_id = $1
|
|
UNION
|
|
SELECT e.source_node_id
|
|
FROM dag_edges e
|
|
JOIN ancestors a ON a.node_id = e.target_node_id
|
|
)
|
|
SELECT EXISTS (
|
|
SELECT 1 FROM ancestors WHERE node_id = $2
|
|
)
|
|
`, sourceID, targetID).Scan(&hasCycle)
|
|
if err != nil {
|
|
return false, fmt.Errorf("checking for cycle: %w", err)
|
|
}
|
|
return hasCycle, nil
|
|
}
|
|
|
|
// SyncFeatureTree replaces the entire feature DAG for an item/revision
|
|
// within a single transaction. It upserts nodes, replaces edges, and
|
|
// marks changed nodes dirty.
|
|
func (r *DAGRepository) SyncFeatureTree(ctx context.Context, itemID string, revisionNumber int, nodes []DAGNode, edges []DAGEdge) error {
|
|
tx, err := r.db.pool.Begin(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("beginning transaction: %w", err)
|
|
}
|
|
defer tx.Rollback(ctx)
|
|
|
|
// Upsert all nodes
|
|
for i := range nodes {
|
|
n := &nodes[i]
|
|
n.ItemID = itemID
|
|
n.RevisionNumber = revisionNumber
|
|
if n.ValidationState == "" {
|
|
n.ValidationState = "clean"
|
|
}
|
|
|
|
metadataJSON, err := json.Marshal(n.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling node metadata: %w", err)
|
|
}
|
|
|
|
err = tx.QueryRow(ctx, `
|
|
INSERT INTO dag_nodes (item_id, revision_number, node_key, node_type,
|
|
properties_hash, validation_state, validation_msg, metadata)
|
|
VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
|
|
ON CONFLICT (item_id, revision_number, node_key)
|
|
DO UPDATE SET
|
|
node_type = EXCLUDED.node_type,
|
|
properties_hash = EXCLUDED.properties_hash,
|
|
metadata = EXCLUDED.metadata,
|
|
updated_at = now()
|
|
RETURNING id, created_at, updated_at
|
|
`, n.ItemID, n.RevisionNumber, n.NodeKey, n.NodeType,
|
|
n.PropertiesHash, n.ValidationState, n.ValidationMsg, metadataJSON,
|
|
).Scan(&n.ID, &n.CreatedAt, &n.UpdatedAt)
|
|
if err != nil {
|
|
return fmt.Errorf("upserting node %s: %w", n.NodeKey, err)
|
|
}
|
|
}
|
|
|
|
// Build key→ID map for edge resolution
|
|
keyToID := make(map[string]string, len(nodes))
|
|
for _, n := range nodes {
|
|
keyToID[n.NodeKey] = n.ID
|
|
}
|
|
|
|
// Delete existing edges for this item/revision
|
|
_, err = tx.Exec(ctx, `
|
|
DELETE FROM dag_edges
|
|
WHERE source_node_id IN (
|
|
SELECT id FROM dag_nodes WHERE item_id = $1 AND revision_number = $2
|
|
)
|
|
`, itemID, revisionNumber)
|
|
if err != nil {
|
|
return fmt.Errorf("deleting old edges: %w", err)
|
|
}
|
|
|
|
// Insert new edges
|
|
for i := range edges {
|
|
e := &edges[i]
|
|
if e.EdgeType == "" {
|
|
e.EdgeType = "depends_on"
|
|
}
|
|
|
|
// Resolve source/target from node keys if IDs are not set
|
|
sourceID := e.SourceNodeID
|
|
targetID := e.TargetNodeID
|
|
if sourceID == "" {
|
|
return fmt.Errorf("edge %d: source_node_id is required", i)
|
|
}
|
|
if targetID == "" {
|
|
return fmt.Errorf("edge %d: target_node_id is required", i)
|
|
}
|
|
|
|
metadataJSON, err := json.Marshal(e.Metadata)
|
|
if err != nil {
|
|
return fmt.Errorf("marshaling edge metadata: %w", err)
|
|
}
|
|
|
|
err = tx.QueryRow(ctx, `
|
|
INSERT INTO dag_edges (source_node_id, target_node_id, edge_type, metadata)
|
|
VALUES ($1, $2, $3, $4)
|
|
RETURNING id
|
|
`, sourceID, targetID, e.EdgeType, metadataJSON).Scan(&e.ID)
|
|
if err != nil {
|
|
return fmt.Errorf("creating edge: %w", err)
|
|
}
|
|
}
|
|
|
|
return tx.Commit(ctx)
|
|
}
|
|
|
|
// DeleteNodesForItem removes all DAG nodes (and cascades to edges) for an item/revision.
|
|
func (r *DAGRepository) DeleteNodesForItem(ctx context.Context, itemID string, revisionNumber int) error {
|
|
_, err := r.db.pool.Exec(ctx, `
|
|
DELETE FROM dag_nodes WHERE item_id = $1 AND revision_number = $2
|
|
`, itemID, revisionNumber)
|
|
if err != nil {
|
|
return fmt.Errorf("deleting nodes for item: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func scanDAGNodes(rows pgx.Rows) ([]*DAGNode, error) {
|
|
var nodes []*DAGNode
|
|
for rows.Next() {
|
|
n := &DAGNode{}
|
|
var metadataJSON []byte
|
|
err := rows.Scan(
|
|
&n.ID, &n.ItemID, &n.RevisionNumber, &n.NodeKey, &n.NodeType,
|
|
&n.PropertiesHash, &n.ValidationState, &n.ValidationMsg,
|
|
&metadataJSON, &n.CreatedAt, &n.UpdatedAt,
|
|
)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("scanning DAG node: %w", err)
|
|
}
|
|
if metadataJSON != nil {
|
|
if err := json.Unmarshal(metadataJSON, &n.Metadata); err != nil {
|
|
return nil, fmt.Errorf("unmarshaling node metadata: %w", err)
|
|
}
|
|
}
|
|
nodes = append(nodes, n)
|
|
}
|
|
return nodes, rows.Err()
|
|
}
|