Files
silo/internal/db/db.go
2026-01-24 15:03:17 -06:00

91 lines
2.1 KiB
Go

// Package db provides PostgreSQL database access.
package db
import (
"context"
"fmt"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// Config holds database connection settings.
type Config struct {
Host string
Port int
Name string
User string
Password string
SSLMode string
MaxConnections int
}
// DB wraps the connection pool.
type DB struct {
pool *pgxpool.Pool
}
// Connect establishes a database connection pool.
func Connect(ctx context.Context, cfg Config) (*DB, error) {
dsn := fmt.Sprintf(
"host=%s port=%d dbname=%s user=%s password=%s sslmode=%s pool_max_conns=%d",
cfg.Host, cfg.Port, cfg.Name, cfg.User, cfg.Password, cfg.SSLMode, cfg.MaxConnections,
)
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("creating connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
return nil, fmt.Errorf("pinging database: %w", err)
}
return &DB{pool: pool}, nil
}
// Close closes the connection pool.
func (db *DB) Close() {
db.pool.Close()
}
// Pool returns the underlying connection pool for direct access.
func (db *DB) Pool() *pgxpool.Pool {
return db.pool
}
// Tx executes a function within a transaction.
func (db *DB) Tx(ctx context.Context, fn func(tx pgx.Tx) error) error {
tx, err := db.pool.Begin(ctx)
if err != nil {
return fmt.Errorf("beginning transaction: %w", err)
}
if err := fn(tx); err != nil {
if rbErr := tx.Rollback(ctx); rbErr != nil {
return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
}
return err
}
if err := tx.Commit(ctx); err != nil {
return fmt.Errorf("committing transaction: %w", err)
}
return nil
}
// NextSequenceValue atomically increments and returns the next sequence value.
// Uses schema name (not UUID) for simpler operation.
func (db *DB) NextSequenceValue(ctx context.Context, schemaName string, scope string) (int, error) {
var val int
err := db.pool.QueryRow(ctx,
"SELECT next_sequence_by_name($1, $2)",
schemaName, scope,
).Scan(&val)
if err != nil {
return 0, fmt.Errorf("getting next sequence: %w", err)
}
return val, nil
}