test: add comprehensive test suite for backend

Add 56 tests covering the core backend packages:

Unit tests (no database required):
- internal/partnum: 7 tests for part number generation logic
  (sequence, format templates, enum validation, constants)
- internal/schema: 8 tests for YAML schema loading, property
  merging, validation, and default application

Integration tests (require TEST_DATABASE_URL):
- internal/db/items: 10 tests for item CRUD, archive/unarchive,
  revisions, and thumbnail operations
- internal/db/relationships: 10 tests for BOM CRUD, cycle detection,
  self-reference blocking, where-used, expanded/flat BOM
- internal/db/projects: 5 tests for project CRUD and item association
- internal/api/bom_handlers: 6 HTTP handler tests for BOM endpoints
  including flat BOM, cost calculation, add/delete entries
- internal/api/items: 5 HTTP handler tests for item CRUD endpoints

Infrastructure:
- internal/testutil: shared helpers for test DB pool setup,
  migration runner, and table truncation
- internal/db/helpers_test.go: DB wrapper for integration tests
- internal/db/db.go: add NewFromPool constructor
- Makefile: add test-integration target with default DSN

Integration tests skip gracefully when TEST_DATABASE_URL is unset.
Dev-mode auth (nil authConfig) used for API handler tests.

Fixes: fmt.Errorf Go vet warning in partnum/generator.go

Closes #2
This commit is contained in:
Forbes
2026-02-07 01:57:10 -06:00
parent 3704adb584
commit d08b178466
12 changed files with 1616 additions and 6 deletions

View File

@@ -1,4 +1,4 @@
.PHONY: build run test clean migrate fmt lint \ .PHONY: build run test test-integration clean migrate fmt lint \
docker-build docker-up docker-down docker-logs docker-ps \ docker-build docker-up docker-down docker-logs docker-ps \
docker-clean docker-rebuild \ docker-clean docker-rebuild \
web-install web-dev web-build web-install web-dev web-build
@@ -7,8 +7,8 @@
# Local Development # Local Development
# ============================================================================= # =============================================================================
# Build all binaries # Build all binaries (frontend + backend)
build: build: web-build
go build -o silo ./cmd/silo go build -o silo ./cmd/silo
go build -o silod ./cmd/silod go build -o silod ./cmd/silod
@@ -20,14 +20,19 @@ run:
cli: cli:
go run ./cmd/silo $(ARGS) go run ./cmd/silo $(ARGS)
# Run tests # Run unit tests (integration tests skipped without TEST_DATABASE_URL)
test: test:
go test -v ./... go test -v ./...
# Run all tests including integration tests (requires PostgreSQL)
test-integration:
TEST_DATABASE_URL="postgres://silo:silodev@localhost:5432/silo_test?sslmode=disable" go test -v -count=1 ./...
# Clean build artifacts # Clean build artifacts
clean: clean:
rm -f silo silod rm -f silo silod
rm -f *.out rm -f *.out
rm -rf web/dist
# Format code # Format code
fmt: fmt:
@@ -153,7 +158,8 @@ help:
@echo " build - Build CLI and server binaries" @echo " build - Build CLI and server binaries"
@echo " run - Run API server locally" @echo " run - Run API server locally"
@echo " cli ARGS=... - Run CLI with arguments" @echo " cli ARGS=... - Run CLI with arguments"
@echo " test - Run tests" @echo " test - Run unit tests (integration tests skip without DB)"
@echo " test-integration - Run all tests including integration (needs PostgreSQL)"
@echo " fmt - Format code" @echo " fmt - Format code"
@echo " lint - Run linter" @echo " lint - Run linter"
@echo " tidy - Tidy go.mod" @echo " tidy - Tidy go.mod"

View File

@@ -0,0 +1,238 @@
package api
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
"github.com/kindredsystems/silo/internal/auth"
"github.com/kindredsystems/silo/internal/db"
"github.com/kindredsystems/silo/internal/schema"
"github.com/kindredsystems/silo/internal/testutil"
"github.com/rs/zerolog"
)
// newTestServer creates a Server backed by a real test DB with no auth.
func newTestServer(t *testing.T) *Server {
t.Helper()
pool := testutil.MustConnectTestPool(t)
database := db.NewFromPool(pool)
return NewServer(
zerolog.Nop(),
database,
map[string]*schema.Schema{},
"", // schemasDir
nil, // storage
nil, // authService
nil, // sessionManager
nil, // oidcBackend
nil, // authConfig (nil = dev mode)
)
}
// newTestRouter creates a chi router with BOM routes for testing.
func newTestRouter(s *Server) http.Handler {
r := chi.NewRouter()
r.Route("/api/items/{partNumber}", func(r chi.Router) {
r.Get("/bom", s.HandleGetBOM)
r.Get("/bom/flat", s.HandleGetFlatBOM)
r.Get("/bom/cost", s.HandleGetBOMCost)
r.Post("/bom", s.HandleAddBOMEntry)
r.Delete("/bom/{childPartNumber}", s.HandleDeleteBOMEntry)
})
return r
}
// createItemDirect creates an item directly via the DB for test setup.
func createItemDirect(t *testing.T, s *Server, pn, desc string, cost *float64) {
t.Helper()
item := &db.Item{
PartNumber: pn,
ItemType: "part",
Description: desc,
StandardCost: cost,
}
if err := s.items.Create(context.Background(), item, nil); err != nil {
t.Fatalf("creating item %s: %v", pn, err)
}
}
// authRequest returns a copy of the request with an admin user in context.
func authRequest(r *http.Request) *http.Request {
u := &auth.User{
ID: "test-admin-id",
Username: "testadmin",
DisplayName: "Test Admin",
Role: auth.RoleAdmin,
AuthSource: "local",
}
return r.WithContext(auth.ContextWithUser(r.Context(), u))
}
// addBOMDirect adds a BOM relationship directly via the DB.
func addBOMDirect(t *testing.T, s *Server, parentPN, childPN string, qty float64) {
t.Helper()
ctx := context.Background()
parent, _ := s.items.GetByPartNumber(ctx, parentPN)
child, _ := s.items.GetByPartNumber(ctx, childPN)
if parent == nil || child == nil {
t.Fatalf("parent or child not found: %s, %s", parentPN, childPN)
}
rel := &db.Relationship{
ParentItemID: parent.ID,
ChildItemID: child.ID,
RelType: "component",
Quantity: &qty,
}
if err := s.relationships.Create(ctx, rel); err != nil {
t.Fatalf("adding BOM %s→%s: %v", parentPN, childPN, err)
}
}
func TestHandleGetBOM(t *testing.T) {
s := newTestServer(t)
router := newTestRouter(s)
createItemDirect(t, s, "API-P1", "parent", nil)
createItemDirect(t, s, "API-C1", "child1", nil)
createItemDirect(t, s, "API-C2", "child2", nil)
addBOMDirect(t, s, "API-P1", "API-C1", 2)
addBOMDirect(t, s, "API-P1", "API-C2", 5)
req := httptest.NewRequest("GET", "/api/items/API-P1/bom", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
}
var entries []BOMEntryResponse
if err := json.Unmarshal(w.Body.Bytes(), &entries); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(entries) != 2 {
t.Errorf("expected 2 BOM entries, got %d", len(entries))
}
}
func TestHandleGetFlatBOM(t *testing.T) {
s := newTestServer(t)
router := newTestRouter(s)
// A(qty 1) → B(qty 2) → X(qty 3) = X total 6
createItemDirect(t, s, "FA", "assembly A", nil)
createItemDirect(t, s, "FB", "sub B", nil)
createItemDirect(t, s, "FX", "leaf X", nil)
addBOMDirect(t, s, "FA", "FB", 2)
addBOMDirect(t, s, "FB", "FX", 3)
req := httptest.NewRequest("GET", "/api/items/FA/bom/flat", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
}
var resp FlatBOMResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if len(resp.FlatBOM) != 1 {
t.Fatalf("expected 1 leaf part, got %d", len(resp.FlatBOM))
}
if resp.FlatBOM[0].TotalQuantity != 6 {
t.Errorf("total quantity: got %.1f, want 6.0", resp.FlatBOM[0].TotalQuantity)
}
}
func TestHandleGetBOMCost(t *testing.T) {
s := newTestServer(t)
router := newTestRouter(s)
cost10 := 10.0
cost5 := 5.0
createItemDirect(t, s, "CA", "assembly", nil)
createItemDirect(t, s, "CX", "part X", &cost10)
createItemDirect(t, s, "CY", "part Y", &cost5)
addBOMDirect(t, s, "CA", "CX", 3)
addBOMDirect(t, s, "CA", "CY", 2)
req := httptest.NewRequest("GET", "/api/items/CA/bom/cost", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
}
var resp CostResponse
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
// 3*10 + 2*5 = 40
if resp.TotalCost != 40 {
t.Errorf("total cost: got %.2f, want 40.00", resp.TotalCost)
}
}
func TestHandleGetFlatBOMNotFound(t *testing.T) {
s := newTestServer(t)
router := newTestRouter(s)
req := httptest.NewRequest("GET", "/api/items/NONEXISTENT/bom/flat", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("status: got %d, want %d", w.Code, http.StatusNotFound)
}
}
func TestHandleAddBOMEntry(t *testing.T) {
s := newTestServer(t)
router := newTestRouter(s)
createItemDirect(t, s, "ADD-P", "parent", nil)
createItemDirect(t, s, "ADD-C", "child", nil)
body := `{"child_part_number":"ADD-C","rel_type":"component","quantity":7}`
req := authRequest(httptest.NewRequest("POST", "/api/items/ADD-P/bom", strings.NewReader(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusCreated, w.Body.String())
}
var entry BOMEntryResponse
if err := json.Unmarshal(w.Body.Bytes(), &entry); err != nil {
t.Fatalf("decoding response: %v", err)
}
if entry.ChildPartNumber != "ADD-C" {
t.Errorf("child_part_number: got %q, want %q", entry.ChildPartNumber, "ADD-C")
}
}
func TestHandleDeleteBOMEntry(t *testing.T) {
s := newTestServer(t)
router := newTestRouter(s)
createItemDirect(t, s, "DEL-P", "parent", nil)
createItemDirect(t, s, "DEL-C", "child", nil)
addBOMDirect(t, s, "DEL-P", "DEL-C", 1)
req := authRequest(httptest.NewRequest("DELETE", "/api/items/DEL-P/bom/DEL-C", nil))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNoContent {
t.Errorf("status: got %d, want %d; body: %s", w.Code, http.StatusNoContent, w.Body.String())
}
}

133
internal/api/items_test.go Normal file
View File

@@ -0,0 +1,133 @@
package api
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/go-chi/chi/v5"
)
// newItemRouter creates a chi router with item routes for testing.
func newItemRouter(s *Server) http.Handler {
r := chi.NewRouter()
r.Route("/api/items", func(r chi.Router) {
r.Get("/", s.HandleListItems)
r.Post("/", s.HandleCreateItem)
r.Route("/{partNumber}", func(r chi.Router) {
r.Get("/", s.HandleGetItem)
r.Put("/", s.HandleUpdateItem)
r.Delete("/", s.HandleDeleteItem)
})
})
return r
}
func TestHandleCreateItem(t *testing.T) {
s := newTestServer(t)
router := newItemRouter(s)
body := `{"part_number":"NEW-001","item_type":"part","description":"new item"}`
req := authRequest(httptest.NewRequest("POST", "/api/items", strings.NewReader(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusCreated {
t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusCreated, w.Body.String())
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("decoding response: %v", err)
}
if resp["part_number"] != "NEW-001" {
t.Errorf("part_number: got %v, want %q", resp["part_number"], "NEW-001")
}
}
func TestHandleGetItem(t *testing.T) {
s := newTestServer(t)
router := newItemRouter(s)
createItemDirect(t, s, "GET-001", "get test", nil)
req := httptest.NewRequest("GET", "/api/items/GET-001", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
}
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
if resp["part_number"] != "GET-001" {
t.Errorf("part_number: got %v, want %q", resp["part_number"], "GET-001")
}
}
func TestHandleGetItemNotFound(t *testing.T) {
s := newTestServer(t)
router := newItemRouter(s)
req := httptest.NewRequest("GET", "/api/items/NOPE-999", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNotFound {
t.Errorf("status: got %d, want %d", w.Code, http.StatusNotFound)
}
}
func TestHandleListItems(t *testing.T) {
s := newTestServer(t)
router := newItemRouter(s)
createItemDirect(t, s, "LST-001", "list item 1", nil)
createItemDirect(t, s, "LST-002", "list item 2", nil)
req := httptest.NewRequest("GET", "/api/items", nil)
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String())
}
var resp map[string]any
json.Unmarshal(w.Body.Bytes(), &resp)
items, ok := resp["items"].([]any)
if !ok {
t.Fatalf("expected items array in response, got: %s", w.Body.String())
}
if len(items) < 2 {
t.Errorf("expected at least 2 items, got %d", len(items))
}
}
func TestHandleDeleteItem(t *testing.T) {
s := newTestServer(t)
router := newItemRouter(s)
createItemDirect(t, s, "DEL-ITEM-001", "deletable", nil)
req := authRequest(httptest.NewRequest("DELETE", "/api/items/DEL-ITEM-001", nil))
w := httptest.NewRecorder()
router.ServeHTTP(w, req)
if w.Code != http.StatusNoContent && w.Code != http.StatusOK {
t.Fatalf("status: got %d, want 200 or 204; body: %s", w.Code, w.Body.String())
}
// Should be gone (archived)
req2 := httptest.NewRequest("GET", "/api/items/DEL-ITEM-001", nil)
w2 := httptest.NewRecorder()
router.ServeHTTP(w2, req2)
if w2.Code != http.StatusNotFound {
t.Errorf("after delete, expected 404, got %d", w2.Code)
}
}

View File

@@ -44,6 +44,23 @@ func Connect(ctx context.Context, cfg Config) (*DB, error) {
return &DB{pool: pool}, nil return &DB{pool: pool}, nil
} }
// ConnectDSN establishes a database connection pool from a DSN string.
func ConnectDSN(ctx context.Context, dsn string) (*DB, error) {
pool, err := pgxpool.New(ctx, dsn)
if err != nil {
return nil, fmt.Errorf("creating connection pool: %w", err)
}
if err := pool.Ping(ctx); err != nil {
return nil, fmt.Errorf("pinging database: %w", err)
}
return &DB{pool: pool}, nil
}
// NewFromPool wraps an existing connection pool in a DB handle.
func NewFromPool(pool *pgxpool.Pool) *DB {
return &DB{pool: pool}
}
// Close closes the connection pool. // Close closes the connection pool.
func (db *DB) Close() { func (db *DB) Close() {
db.pool.Close() db.pool.Close()

View File

@@ -0,0 +1,15 @@
package db
import (
"testing"
"github.com/kindredsystems/silo/internal/testutil"
)
// mustConnectTestDB returns a *DB backed by a real test Postgres instance.
// It skips the test if TEST_DATABASE_URL is not set.
func mustConnectTestDB(t *testing.T) *DB {
t.Helper()
pool := testutil.MustConnectTestPool(t)
return NewFromPool(pool)
}

272
internal/db/items_test.go Normal file
View File

@@ -0,0 +1,272 @@
package db
import (
"context"
"fmt"
"testing"
)
func TestItemCreate(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{
PartNumber: "TEST-0001",
ItemType: "part",
Description: "Test item",
}
err := repo.Create(ctx, item, map[string]any{"color": "red"})
if err != nil {
t.Fatalf("Create: %v", err)
}
if item.ID == "" {
t.Error("expected item ID to be set")
}
if item.CurrentRevision != 1 {
t.Errorf("current revision: got %d, want 1", item.CurrentRevision)
}
}
func TestItemGetByPartNumber(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{PartNumber: "GET-PN-001", ItemType: "part", Description: "get by pn test"}
if err := repo.Create(ctx, item, nil); err != nil {
t.Fatalf("Create: %v", err)
}
got, err := repo.GetByPartNumber(ctx, "GET-PN-001")
if err != nil {
t.Fatalf("GetByPartNumber: %v", err)
}
if got == nil {
t.Fatal("expected item, got nil")
}
if got.Description != "get by pn test" {
t.Errorf("description: got %q, want %q", got.Description, "get by pn test")
}
// Non-existent should return nil, not error
missing, err := repo.GetByPartNumber(ctx, "DOES-NOT-EXIST")
if err != nil {
t.Fatalf("GetByPartNumber (missing): %v", err)
}
if missing != nil {
t.Error("expected nil for missing item")
}
}
func TestItemGetByID(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{PartNumber: "GET-ID-001", ItemType: "assembly", Description: "get by id"}
if err := repo.Create(ctx, item, nil); err != nil {
t.Fatalf("Create: %v", err)
}
got, err := repo.GetByID(ctx, item.ID)
if err != nil {
t.Fatalf("GetByID: %v", err)
}
if got == nil {
t.Fatal("expected item, got nil")
}
if got.PartNumber != "GET-ID-001" {
t.Errorf("part_number: got %q, want %q", got.PartNumber, "GET-ID-001")
}
}
func TestItemList(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
for i := 0; i < 3; i++ {
item := &Item{
PartNumber: fmt.Sprintf("LIST-%04d", i),
ItemType: "part",
Description: fmt.Sprintf("list item %d", i),
}
if err := repo.Create(ctx, item, nil); err != nil {
t.Fatalf("Create #%d: %v", i, err)
}
}
items, err := repo.List(ctx, ListOptions{})
if err != nil {
t.Fatalf("List: %v", err)
}
if len(items) != 3 {
t.Errorf("expected 3 items, got %d", len(items))
}
}
func TestItemListByType(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
repo.Create(ctx, &Item{PartNumber: "TYPE-P-001", ItemType: "part", Description: "a part"}, nil)
repo.Create(ctx, &Item{PartNumber: "TYPE-A-001", ItemType: "assembly", Description: "an assembly"}, nil)
repo.Create(ctx, &Item{PartNumber: "TYPE-P-002", ItemType: "part", Description: "another part"}, nil)
items, err := repo.List(ctx, ListOptions{ItemType: "part"})
if err != nil {
t.Fatalf("List: %v", err)
}
if len(items) != 2 {
t.Errorf("expected 2 parts, got %d", len(items))
}
}
func TestItemUpdate(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{PartNumber: "UPD-001", ItemType: "part", Description: "original"}
if err := repo.Create(ctx, item, nil); err != nil {
t.Fatalf("Create: %v", err)
}
cost := 42.50
err := repo.Update(ctx, item.ID, UpdateItemFields{
PartNumber: "UPD-001",
ItemType: "part",
Description: "updated",
StandardCost: &cost,
})
if err != nil {
t.Fatalf("Update: %v", err)
}
got, _ := repo.GetByID(ctx, item.ID)
if got.Description != "updated" {
t.Errorf("description: got %q, want %q", got.Description, "updated")
}
if got.StandardCost == nil || *got.StandardCost != 42.50 {
t.Errorf("standard_cost: got %v, want 42.50", got.StandardCost)
}
}
func TestItemArchiveUnarchive(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{PartNumber: "ARC-001", ItemType: "part", Description: "archivable"}
if err := repo.Create(ctx, item, nil); err != nil {
t.Fatalf("Create: %v", err)
}
// Archive
if err := repo.Archive(ctx, item.ID); err != nil {
t.Fatalf("Archive: %v", err)
}
// Should not appear in GetByPartNumber (excludes archived)
got, _ := repo.GetByPartNumber(ctx, "ARC-001")
if got != nil {
t.Error("archived item should not be returned by GetByPartNumber")
}
// But should still be accessible by ID
gotByID, _ := repo.GetByID(ctx, item.ID)
if gotByID == nil {
t.Fatal("archived item should still be accessible by GetByID")
}
if gotByID.ArchivedAt == nil {
t.Error("archived_at should be set")
}
// Unarchive
if err := repo.Unarchive(ctx, item.ID); err != nil {
t.Fatalf("Unarchive: %v", err)
}
got, _ = repo.GetByPartNumber(ctx, "ARC-001")
if got == nil {
t.Error("unarchived item should be returned by GetByPartNumber")
}
}
func TestItemCreateRevision(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{PartNumber: "REV-001", ItemType: "part", Description: "revisable"}
if err := repo.Create(ctx, item, map[string]any{"v": 1}); err != nil {
t.Fatalf("Create: %v", err)
}
// Create second revision
rev := &Revision{
ItemID: item.ID,
Properties: map[string]any{"v": 2},
}
if err := repo.CreateRevision(ctx, rev); err != nil {
t.Fatalf("CreateRevision: %v", err)
}
if rev.RevisionNumber != 2 {
t.Errorf("revision number: got %d, want 2", rev.RevisionNumber)
}
// Item's current_revision should be updated by trigger
got, _ := repo.GetByPartNumber(ctx, "REV-001")
if got.CurrentRevision != 2 {
t.Errorf("current_revision: got %d, want 2", got.CurrentRevision)
}
}
func TestItemGetRevisions(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{PartNumber: "REVS-001", ItemType: "part", Description: "multi rev"}
if err := repo.Create(ctx, item, map[string]any{"step": "initial"}); err != nil {
t.Fatalf("Create: %v", err)
}
comment := "second revision"
repo.CreateRevision(ctx, &Revision{
ItemID: item.ID, Properties: map[string]any{"step": "updated"}, Comment: &comment,
})
revisions, err := repo.GetRevisions(ctx, item.ID)
if err != nil {
t.Fatalf("GetRevisions: %v", err)
}
if len(revisions) != 2 {
t.Errorf("expected 2 revisions, got %d", len(revisions))
}
// Revisions are returned newest first
if revisions[0].RevisionNumber != 2 {
t.Errorf("first revision should be #2 (newest), got #%d", revisions[0].RevisionNumber)
}
}
func TestItemSetThumbnailKey(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewItemRepository(database)
ctx := context.Background()
item := &Item{PartNumber: "THUMB-001", ItemType: "part", Description: "thumbnail test"}
if err := repo.Create(ctx, item, nil); err != nil {
t.Fatalf("Create: %v", err)
}
if err := repo.SetThumbnailKey(ctx, item.ID, "items/thumb.png"); err != nil {
t.Fatalf("SetThumbnailKey: %v", err)
}
got, _ := repo.GetByID(ctx, item.ID)
if got.ThumbnailKey == nil || *got.ThumbnailKey != "items/thumb.png" {
t.Errorf("thumbnail_key: got %v, want %q", got.ThumbnailKey, "items/thumb.png")
}
}

View File

@@ -0,0 +1,119 @@
package db
import (
"context"
"testing"
)
func TestProjectCreate(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewProjectRepository(database)
ctx := context.Background()
p := &Project{Code: "TPRJ", Name: "Test Project"}
if err := repo.Create(ctx, p); err != nil {
t.Fatalf("Create: %v", err)
}
if p.ID == "" {
t.Error("expected project ID to be set")
}
}
func TestProjectGet(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewProjectRepository(database)
ctx := context.Background()
repo.Create(ctx, &Project{Code: "GPRJ", Name: "Get Project"})
got, err := repo.GetByCode(ctx, "GPRJ")
if err != nil {
t.Fatalf("GetByCode: %v", err)
}
if got == nil {
t.Fatal("expected project, got nil")
}
if got.Name != "Get Project" {
t.Errorf("name: got %q, want %q", got.Name, "Get Project")
}
// Missing should return nil
missing, err := repo.GetByCode(ctx, "NOPE")
if err != nil {
t.Fatalf("GetByCode (missing): %v", err)
}
if missing != nil {
t.Error("expected nil for missing project")
}
}
func TestProjectList(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewProjectRepository(database)
ctx := context.Background()
repo.Create(ctx, &Project{Code: "AA", Name: "Alpha"})
repo.Create(ctx, &Project{Code: "BB", Name: "Beta"})
projects, err := repo.List(ctx)
if err != nil {
t.Fatalf("List: %v", err)
}
if len(projects) != 2 {
t.Errorf("expected 2 projects, got %d", len(projects))
}
}
func TestProjectDelete(t *testing.T) {
database := mustConnectTestDB(t)
repo := NewProjectRepository(database)
ctx := context.Background()
repo.Create(ctx, &Project{Code: "DEL", Name: "Deletable"})
if err := repo.Delete(ctx, "DEL"); err != nil {
t.Fatalf("Delete: %v", err)
}
got, _ := repo.GetByCode(ctx, "DEL")
if got != nil {
t.Error("deleted project should not be found")
}
}
func TestProjectItemAssociation(t *testing.T) {
database := mustConnectTestDB(t)
projRepo := NewProjectRepository(database)
itemRepo := NewItemRepository(database)
ctx := context.Background()
proj := &Project{Code: "ASSC", Name: "Assoc Project"}
projRepo.Create(ctx, proj)
item := &Item{PartNumber: "ASSC-001", ItemType: "part", Description: "associated item"}
itemRepo.Create(ctx, item, nil)
// Add item to project
if err := projRepo.AddItemToProject(ctx, item.ID, proj.ID); err != nil {
t.Fatalf("AddItemToProject: %v", err)
}
// Get items for project (takes project UUID)
items, err := projRepo.GetItemsForProject(ctx, proj.ID)
if err != nil {
t.Fatalf("GetItemsForProject: %v", err)
}
if len(items) != 1 {
t.Errorf("expected 1 item in project, got %d", len(items))
}
// Remove item from project
if err := projRepo.RemoveItemFromProject(ctx, item.ID, proj.ID); err != nil {
t.Fatalf("RemoveItemFromProject: %v", err)
}
items, _ = projRepo.GetItemsForProject(ctx, proj.ID)
if len(items) != 0 {
t.Errorf("expected 0 items after removal, got %d", len(items))
}
}

View File

@@ -0,0 +1,278 @@
package db
import (
"context"
"strings"
"testing"
)
// createTestItem is a helper that creates a minimal item for BOM tests.
func createTestItem(t *testing.T, repo *ItemRepository, pn, desc string) *Item {
t.Helper()
item := &Item{PartNumber: pn, ItemType: "part", Description: desc}
if err := repo.Create(context.Background(), item, nil); err != nil {
t.Fatalf("creating test item %s: %v", pn, err)
}
return item
}
func TestBOMCreate(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
parent := createTestItem(t, items, "BOM-P-001", "parent")
child := createTestItem(t, items, "BOM-C-001", "child")
qty := 3.0
rel := &Relationship{
ParentItemID: parent.ID,
ChildItemID: child.ID,
RelType: "component",
Quantity: &qty,
}
if err := rels.Create(ctx, rel); err != nil {
t.Fatalf("Create: %v", err)
}
if rel.ID == "" {
t.Error("expected relationship ID to be set")
}
}
func TestBOMGetBOM(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
parent := createTestItem(t, items, "BOM-G-001", "parent")
child1 := createTestItem(t, items, "BOM-G-002", "child1")
child2 := createTestItem(t, items, "BOM-G-003", "child2")
qty1, qty2 := 2.0, 5.0
rels.Create(ctx, &Relationship{ParentItemID: parent.ID, ChildItemID: child1.ID, RelType: "component", Quantity: &qty1})
rels.Create(ctx, &Relationship{ParentItemID: parent.ID, ChildItemID: child2.ID, RelType: "component", Quantity: &qty2})
bom, err := rels.GetBOM(ctx, parent.ID)
if err != nil {
t.Fatalf("GetBOM: %v", err)
}
if len(bom) != 2 {
t.Errorf("expected 2 BOM entries, got %d", len(bom))
}
}
func TestBOMSelfReference(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
item := createTestItem(t, items, "BOM-SELF-001", "self-referencing")
err := rels.Create(ctx, &Relationship{
ParentItemID: item.ID,
ChildItemID: item.ID,
RelType: "component",
})
if err == nil {
t.Fatal("expected error for self-reference, got nil")
}
}
func TestBOMCycleDetection(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
a := createTestItem(t, items, "CYC-A", "A")
b := createTestItem(t, items, "CYC-B", "B")
c := createTestItem(t, items, "CYC-C", "C")
// A → B → C
rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: b.ID, RelType: "component"})
rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: c.ID, RelType: "component"})
// C → A should be detected as a cycle
err := rels.Create(ctx, &Relationship{
ParentItemID: c.ID,
ChildItemID: a.ID,
RelType: "component",
})
if err == nil {
t.Fatal("expected cycle error, got nil")
}
if !strings.Contains(err.Error(), "cycle") {
t.Errorf("expected cycle error, got: %v", err)
}
}
func TestBOMDelete(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
parent := createTestItem(t, items, "BOM-D-001", "parent")
child := createTestItem(t, items, "BOM-D-002", "child")
rel := &Relationship{ParentItemID: parent.ID, ChildItemID: child.ID, RelType: "component"}
rels.Create(ctx, rel)
if err := rels.Delete(ctx, rel.ID); err != nil {
t.Fatalf("Delete: %v", err)
}
bom, _ := rels.GetBOM(ctx, parent.ID)
if len(bom) != 0 {
t.Errorf("expected 0 BOM entries after delete, got %d", len(bom))
}
}
func TestBOMUpdate(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
parent := createTestItem(t, items, "BOM-U-001", "parent")
child := createTestItem(t, items, "BOM-U-002", "child")
qty := 1.0
rel := &Relationship{ParentItemID: parent.ID, ChildItemID: child.ID, RelType: "component", Quantity: &qty}
rels.Create(ctx, rel)
newQty := 10.0
if err := rels.Update(ctx, rel.ID, nil, &newQty, nil, nil, nil, nil, nil); err != nil {
t.Fatalf("Update: %v", err)
}
bom, _ := rels.GetBOM(ctx, parent.ID)
if len(bom) != 1 {
t.Fatalf("expected 1 BOM entry, got %d", len(bom))
}
if bom[0].Quantity == nil || *bom[0].Quantity != 10.0 {
t.Errorf("quantity: got %v, want 10.0", bom[0].Quantity)
}
}
func TestBOMWhereUsed(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
parent1 := createTestItem(t, items, "WU-P1", "parent1")
parent2 := createTestItem(t, items, "WU-P2", "parent2")
child := createTestItem(t, items, "WU-C1", "shared child")
rels.Create(ctx, &Relationship{ParentItemID: parent1.ID, ChildItemID: child.ID, RelType: "component"})
rels.Create(ctx, &Relationship{ParentItemID: parent2.ID, ChildItemID: child.ID, RelType: "component"})
wu, err := rels.GetWhereUsed(ctx, child.ID)
if err != nil {
t.Fatalf("GetWhereUsed: %v", err)
}
if len(wu) != 2 {
t.Errorf("expected 2 where-used entries, got %d", len(wu))
}
}
func TestBOMExpandedBOM(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
// A → B → C (3 levels)
a := createTestItem(t, items, "EXP-A", "top assembly")
b := createTestItem(t, items, "EXP-B", "sub assembly")
c := createTestItem(t, items, "EXP-C", "leaf part")
qty2, qty3 := 2.0, 3.0
rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: b.ID, RelType: "component", Quantity: &qty2})
rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: c.ID, RelType: "component", Quantity: &qty3})
expanded, err := rels.GetExpandedBOM(ctx, a.ID, 10)
if err != nil {
t.Fatalf("GetExpandedBOM: %v", err)
}
if len(expanded) != 2 {
t.Errorf("expected 2 expanded entries (B and C), got %d", len(expanded))
}
// Verify depths
for _, e := range expanded {
if e.ChildPartNumber == "EXP-B" && e.Depth != 1 {
t.Errorf("EXP-B depth: got %d, want 1", e.Depth)
}
if e.ChildPartNumber == "EXP-C" && e.Depth != 2 {
t.Errorf("EXP-C depth: got %d, want 2", e.Depth)
}
}
}
func TestBOMFlatBOM(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
// Assembly A (qty 1)
// ├── Sub-assembly B (qty 2)
// │ ├── Part X (qty 3) → total 6
// │ └── Part Y (qty 1) → total 2
// └── Part X (qty 4) → total 4 (+ 6 = 10 total for X)
a := createTestItem(t, items, "FLAT-A", "top")
b := createTestItem(t, items, "FLAT-B", "sub")
x := createTestItem(t, items, "FLAT-X", "leaf X")
y := createTestItem(t, items, "FLAT-Y", "leaf Y")
q2, q3, q1, q4 := 2.0, 3.0, 1.0, 4.0
rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: b.ID, RelType: "component", Quantity: &q2})
rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: x.ID, RelType: "component", Quantity: &q4})
rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: x.ID, RelType: "component", Quantity: &q3})
rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: y.ID, RelType: "component", Quantity: &q1})
flat, err := rels.GetFlatBOM(ctx, a.ID)
if err != nil {
t.Fatalf("GetFlatBOM: %v", err)
}
if len(flat) != 2 {
t.Errorf("expected 2 leaf parts, got %d", len(flat))
}
for _, e := range flat {
switch e.PartNumber {
case "FLAT-X":
if e.TotalQuantity != 10.0 {
t.Errorf("FLAT-X total qty: got %.1f, want 10.0", e.TotalQuantity)
}
case "FLAT-Y":
if e.TotalQuantity != 2.0 {
t.Errorf("FLAT-Y total qty: got %.1f, want 2.0", e.TotalQuantity)
}
default:
t.Errorf("unexpected part in flat BOM: %s", e.PartNumber)
}
}
}
func TestBOMFlatBOMEmpty(t *testing.T) {
database := mustConnectTestDB(t)
items := NewItemRepository(database)
rels := NewRelationshipRepository(database)
ctx := context.Background()
item := createTestItem(t, items, "FLAT-EMPTY", "no children")
flat, err := rels.GetFlatBOM(ctx, item.ID)
if err != nil {
t.Fatalf("GetFlatBOM: %v", err)
}
if len(flat) != 0 {
t.Errorf("expected 0 leaf parts for item with no BOM, got %d", len(flat))
}
}

View File

@@ -135,7 +135,7 @@ func (g *Generator) formatString(seg *schema.Segment, val string) (string, error
if msg == "" { if msg == "" {
msg = fmt.Sprintf("value does not match pattern %s", seg.Validation.Pattern) msg = fmt.Sprintf("value does not match pattern %s", seg.Validation.Pattern)
} }
return "", fmt.Errorf(msg) return "", fmt.Errorf("%s", msg)
} }
} }

View File

@@ -0,0 +1,167 @@
package partnum
import (
"context"
"fmt"
"testing"
"github.com/kindredsystems/silo/internal/schema"
)
// mockSeqStore implements SequenceStore for testing.
type mockSeqStore struct {
counter int
}
func (m *mockSeqStore) NextValue(_ context.Context, _ string, _ string) (int, error) {
m.counter++
return m.counter, nil
}
func testSchema() *schema.Schema {
return &schema.Schema{
Name: "test",
Version: 1,
Separator: "-",
Segments: []schema.Segment{
{
Name: "category",
Type: "enum",
Required: true,
Values: map[string]string{
"F01": "Fasteners",
"R01": "Resistors",
},
},
{
Name: "serial",
Type: "serial",
Length: 4,
Scope: "{category}",
},
},
}
}
func TestGenerateBasic(t *testing.T) {
s := testSchema()
gen := NewGenerator(map[string]*schema.Schema{"test": s}, &mockSeqStore{})
pn, err := gen.Generate(context.Background(), Input{
SchemaName: "test",
Values: map[string]string{"category": "F01"},
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if pn != "F01-0001" {
t.Errorf("got %q, want %q", pn, "F01-0001")
}
}
func TestGenerateSequentialNumbers(t *testing.T) {
s := testSchema()
seq := &mockSeqStore{}
gen := NewGenerator(map[string]*schema.Schema{"test": s}, seq)
for i := 1; i <= 3; i++ {
pn, err := gen.Generate(context.Background(), Input{
SchemaName: "test",
Values: map[string]string{"category": "F01"},
})
if err != nil {
t.Fatalf("Generate #%d returned error: %v", i, err)
}
want := fmt.Sprintf("F01-%04d", i)
if pn != want {
t.Errorf("Generate #%d: got %q, want %q", i, pn, want)
}
}
}
func TestGenerateWithFormat(t *testing.T) {
s := &schema.Schema{
Name: "formatted",
Version: 1,
Format: "{prefix}/{category}-{serial}",
Segments: []schema.Segment{
{Name: "prefix", Type: "constant", Value: "KS"},
{Name: "category", Type: "enum", Required: true, Values: map[string]string{"A": "Alpha"}},
{Name: "serial", Type: "serial", Length: 3},
},
}
gen := NewGenerator(map[string]*schema.Schema{"formatted": s}, &mockSeqStore{})
pn, err := gen.Generate(context.Background(), Input{
SchemaName: "formatted",
Values: map[string]string{"category": "A"},
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if pn != "KS/A-001" {
t.Errorf("got %q, want %q", pn, "KS/A-001")
}
}
func TestGenerateUnknownSchema(t *testing.T) {
gen := NewGenerator(map[string]*schema.Schema{}, &mockSeqStore{})
_, err := gen.Generate(context.Background(), Input{
SchemaName: "nonexistent",
Values: map[string]string{},
})
if err == nil {
t.Fatal("expected error for unknown schema, got nil")
}
}
func TestGenerateMissingRequiredEnum(t *testing.T) {
s := testSchema()
gen := NewGenerator(map[string]*schema.Schema{"test": s}, &mockSeqStore{})
_, err := gen.Generate(context.Background(), Input{
SchemaName: "test",
Values: map[string]string{}, // missing required "category"
})
if err == nil {
t.Fatal("expected error for missing required enum, got nil")
}
}
func TestGenerateInvalidEnumValue(t *testing.T) {
s := testSchema()
gen := NewGenerator(map[string]*schema.Schema{"test": s}, &mockSeqStore{})
_, err := gen.Generate(context.Background(), Input{
SchemaName: "test",
Values: map[string]string{"category": "INVALID"},
})
if err == nil {
t.Fatal("expected error for invalid enum value, got nil")
}
}
func TestGenerateConstantSegment(t *testing.T) {
s := &schema.Schema{
Name: "const-test",
Version: 1,
Separator: "-",
Segments: []schema.Segment{
{Name: "prefix", Type: "constant", Value: "KS"},
{Name: "serial", Type: "serial", Length: 4},
},
}
gen := NewGenerator(map[string]*schema.Schema{"const-test": s}, &mockSeqStore{})
pn, err := gen.Generate(context.Background(), Input{
SchemaName: "const-test",
Values: map[string]string{},
})
if err != nil {
t.Fatalf("Generate returned error: %v", err)
}
if pn != "KS-0001" {
t.Errorf("got %q, want %q", pn, "KS-0001")
}
}

View File

@@ -0,0 +1,251 @@
package schema
import (
"os"
"path/filepath"
"testing"
)
// findSchemasDir walks upward to find the project root and returns
// the path to the schemas/ directory.
func findSchemasDir(t *testing.T) string {
t.Helper()
dir, err := os.Getwd()
if err != nil {
t.Fatalf("getting working directory: %v", err)
}
for {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return filepath.Join(dir, "schemas")
}
parent := filepath.Dir(dir)
if parent == dir {
t.Fatal("could not find project root")
}
dir = parent
}
}
func TestLoadSchema(t *testing.T) {
schemasDir := findSchemasDir(t)
path := filepath.Join(schemasDir, "kindred-rd.yaml")
s, err := Load(path)
if err != nil {
t.Fatalf("Load returned error: %v", err)
}
if s.Name != "kindred-rd" {
t.Errorf("schema name: got %q, want %q", s.Name, "kindred-rd")
}
if s.Separator != "-" {
t.Errorf("separator: got %q, want %q", s.Separator, "-")
}
if len(s.Segments) == 0 {
t.Fatal("expected at least one segment")
}
// First segment should be the category enum
cat := s.Segments[0]
if cat.Name != "category" {
t.Errorf("first segment name: got %q, want %q", cat.Name, "category")
}
if cat.Type != "enum" {
t.Errorf("first segment type: got %q, want %q", cat.Type, "enum")
}
if len(cat.Values) == 0 {
t.Error("category segment has no values")
}
}
func TestLoadSchemaDir(t *testing.T) {
schemasDir := findSchemasDir(t)
schemas, err := LoadAll(schemasDir)
if err != nil {
t.Fatalf("LoadAll returned error: %v", err)
}
if len(schemas) == 0 {
t.Fatal("expected at least one schema")
}
if _, ok := schemas["kindred-rd"]; !ok {
t.Error("kindred-rd schema not found in loaded schemas")
}
}
func TestLoadSchemaValidation(t *testing.T) {
schemasDir := findSchemasDir(t)
schemas, err := LoadAll(schemasDir)
if err != nil {
t.Fatalf("LoadAll returned error: %v", err)
}
for name, s := range schemas {
if s.Name == "" {
continue // non-part-numbering schemas (e.g., location_schema)
}
if err := s.Validate(); err != nil {
t.Errorf("schema %q failed validation: %v", name, err)
}
}
}
func TestGetPropertiesForCategory(t *testing.T) {
ps := &PropertySchemas{
Version: 1,
Defaults: map[string]PropertyDefinition{
"weight": {Type: "number", Unit: "kg"},
"color": {Type: "string"},
},
Categories: map[string]map[string]PropertyDefinition{
"F": {
"thread_size": {Type: "string", Required: true},
"weight": {Type: "number", Unit: "g"}, // override default
},
},
}
props := ps.GetPropertiesForCategory("F01")
// Should have all three: weight (overridden), color (default), thread_size (category)
if len(props) != 3 {
t.Errorf("expected 3 properties, got %d", len(props))
}
if props["weight"].Unit != "g" {
t.Errorf("weight unit: got %q, want %q (should be overridden by category)", props["weight"].Unit, "g")
}
if props["color"].Type != "string" {
t.Errorf("color type: got %q, want %q", props["color"].Type, "string")
}
if !props["thread_size"].Required {
t.Error("thread_size should be required")
}
}
func TestGetPropertiesForUnknownCategory(t *testing.T) {
ps := &PropertySchemas{
Version: 1,
Defaults: map[string]PropertyDefinition{
"weight": {Type: "number"},
},
Categories: map[string]map[string]PropertyDefinition{
"F": {"thread_size": {Type: "string"}},
},
}
props := ps.GetPropertiesForCategory("Z99")
// Only defaults, no category-specific properties
if len(props) != 1 {
t.Errorf("expected 1 property (defaults only), got %d", len(props))
}
if _, ok := props["weight"]; !ok {
t.Error("default property 'weight' should be present")
}
}
func TestApplyDefaults(t *testing.T) {
ps := &PropertySchemas{
Version: 1,
Defaults: map[string]PropertyDefinition{
"status": {Type: "string", Default: "draft"},
"weight": {Type: "number"},
},
}
props := map[string]any{"custom": "value"}
result := ps.ApplyDefaults(props, "X")
if result["status"] != "draft" {
t.Errorf("status: got %v, want %q", result["status"], "draft")
}
if result["custom"] != "value" {
t.Errorf("custom: got %v, want %q", result["custom"], "value")
}
// weight has no default, should not be added
if _, ok := result["weight"]; ok {
t.Error("weight should not be added (no default value)")
}
}
func TestSchemaValidate(t *testing.T) {
tests := []struct {
name string
schema Schema
wantErr bool
}{
{
name: "empty name",
schema: Schema{Name: "", Segments: []Segment{{Name: "s", Type: "constant", Value: "X"}}},
wantErr: true,
},
{
name: "no segments",
schema: Schema{Name: "test", Segments: nil},
wantErr: true,
},
{
name: "valid minimal",
schema: Schema{
Name: "test",
Segments: []Segment{
{Name: "prefix", Type: "constant", Value: "X"},
},
},
wantErr: false,
},
{
name: "enum without values",
schema: Schema{
Name: "test",
Segments: []Segment{
{Name: "cat", Type: "enum"},
},
},
wantErr: true,
},
{
name: "serial without length",
schema: Schema{
Name: "test",
Segments: []Segment{
{Name: "seq", Type: "serial", Length: 0},
},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.schema.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestGetSegment(t *testing.T) {
s := &Schema{
Segments: []Segment{
{Name: "prefix", Type: "constant", Value: "KS"},
{Name: "serial", Type: "serial", Length: 4},
},
}
seg := s.GetSegment("serial")
if seg == nil {
t.Fatal("GetSegment returned nil for existing segment")
}
if seg.Type != "serial" {
t.Errorf("segment type: got %q, want %q", seg.Type, "serial")
}
if s.GetSegment("nonexistent") != nil {
t.Error("GetSegment should return nil for nonexistent segment")
}
}

View File

@@ -0,0 +1,114 @@
// Package testutil provides shared helpers for Silo tests.
package testutil
import (
"context"
"fmt"
"os"
"path/filepath"
"sort"
"testing"
"github.com/jackc/pgx/v5/pgxpool"
)
// MustConnectTestPool connects to a test database using TEST_DATABASE_URL.
// If the env var is unset the test is skipped. Migrations are applied and
// all data tables are truncated before returning. The pool is closed
// automatically when the test finishes.
func MustConnectTestPool(t *testing.T) *pgxpool.Pool {
t.Helper()
dsn := os.Getenv("TEST_DATABASE_URL")
if dsn == "" {
t.Skip("TEST_DATABASE_URL not set, skipping integration test")
}
pool, err := pgxpool.New(context.Background(), dsn)
if err != nil {
t.Fatalf("connecting to test database: %v", err)
}
if err := pool.Ping(context.Background()); err != nil {
pool.Close()
t.Fatalf("pinging test database: %v", err)
}
t.Cleanup(func() { pool.Close() })
RunMigrations(t, pool)
TruncateAll(t, pool)
return pool
}
// RunMigrations applies all SQL migration files from the migrations/
// directory. It walks upward from the current working directory to find
// the project root (containing go.mod).
func RunMigrations(t *testing.T, pool *pgxpool.Pool) {
t.Helper()
root := findProjectRoot(t)
migDir := filepath.Join(root, "migrations")
entries, err := os.ReadDir(migDir)
if err != nil {
t.Fatalf("reading migrations directory: %v", err)
}
// Sort by filename to apply in order.
sort.Slice(entries, func(i, j int) bool {
return entries[i].Name() < entries[j].Name()
})
for _, e := range entries {
if e.IsDir() || filepath.Ext(e.Name()) != ".sql" {
continue
}
sql, err := os.ReadFile(filepath.Join(migDir, e.Name()))
if err != nil {
t.Fatalf("reading migration %s: %v", e.Name(), err)
}
if _, err := pool.Exec(context.Background(), string(sql)); err != nil {
// Migrations may contain IF NOT EXISTS / CREATE OR REPLACE,
// so most "already exists" errors are fine. Log and continue.
t.Logf("migration %s: %v (may be OK if already applied)", e.Name(), err)
}
}
}
// TruncateAll removes all data from tables, leaving schema intact.
func TruncateAll(t *testing.T, pool *pgxpool.Pool) {
t.Helper()
_, err := pool.Exec(context.Background(), `
TRUNCATE
audit_log, sync_log, api_tokens, sessions, item_files,
item_projects, relationships, revisions, inventory, items,
projects, sequences_by_name, users, property_migrations
CASCADE
`)
if err != nil {
t.Fatalf("truncating tables: %v", err)
}
}
// findProjectRoot walks upward from cwd to find the directory containing go.mod.
func findProjectRoot(t *testing.T) string {
t.Helper()
dir, err := os.Getwd()
if err != nil {
t.Fatalf("getting working directory: %v", err)
}
for {
if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil {
return dir
}
parent := filepath.Dir(dir)
if parent == dir {
t.Fatalf("could not find project root (go.mod)")
}
dir = parent
}
panic(fmt.Sprintf("unreachable"))
}