91 lines
2.1 KiB
Go
91 lines
2.1 KiB
Go
// Package db provides PostgreSQL database access.
|
|
package db
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
// Config holds database connection settings.
|
|
type Config struct {
|
|
Host string
|
|
Port int
|
|
Name string
|
|
User string
|
|
Password string
|
|
SSLMode string
|
|
MaxConnections int
|
|
}
|
|
|
|
// DB wraps the connection pool.
|
|
type DB struct {
|
|
pool *pgxpool.Pool
|
|
}
|
|
|
|
// Connect establishes a database connection pool.
|
|
func Connect(ctx context.Context, cfg Config) (*DB, error) {
|
|
dsn := fmt.Sprintf(
|
|
"host=%s port=%d dbname=%s user=%s password=%s sslmode=%s pool_max_conns=%d",
|
|
cfg.Host, cfg.Port, cfg.Name, cfg.User, cfg.Password, cfg.SSLMode, cfg.MaxConnections,
|
|
)
|
|
|
|
pool, err := pgxpool.New(ctx, dsn)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("creating connection pool: %w", err)
|
|
}
|
|
|
|
if err := pool.Ping(ctx); err != nil {
|
|
return nil, fmt.Errorf("pinging database: %w", err)
|
|
}
|
|
|
|
return &DB{pool: pool}, nil
|
|
}
|
|
|
|
// Close closes the connection pool.
|
|
func (db *DB) Close() {
|
|
db.pool.Close()
|
|
}
|
|
|
|
// Pool returns the underlying connection pool for direct access.
|
|
func (db *DB) Pool() *pgxpool.Pool {
|
|
return db.pool
|
|
}
|
|
|
|
// Tx executes a function within a transaction.
|
|
func (db *DB) Tx(ctx context.Context, fn func(tx pgx.Tx) error) error {
|
|
tx, err := db.pool.Begin(ctx)
|
|
if err != nil {
|
|
return fmt.Errorf("beginning transaction: %w", err)
|
|
}
|
|
|
|
if err := fn(tx); err != nil {
|
|
if rbErr := tx.Rollback(ctx); rbErr != nil {
|
|
return fmt.Errorf("rollback failed: %v (original error: %w)", rbErr, err)
|
|
}
|
|
return err
|
|
}
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
return fmt.Errorf("committing transaction: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// NextSequenceValue atomically increments and returns the next sequence value.
|
|
// Uses schema name (not UUID) for simpler operation.
|
|
func (db *DB) NextSequenceValue(ctx context.Context, schemaName string, scope string) (int, error) {
|
|
var val int
|
|
err := db.pool.QueryRow(ctx,
|
|
"SELECT next_sequence_by_name($1, $2)",
|
|
schemaName, scope,
|
|
).Scan(&val)
|
|
if err != nil {
|
|
return 0, fmt.Errorf("getting next sequence: %w", err)
|
|
}
|
|
return val, nil
|
|
}
|