diff --git a/Makefile b/Makefile index 293d1f3..c7bb968 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: build run test clean migrate fmt lint \ +.PHONY: build run test test-integration clean migrate fmt lint \ docker-build docker-up docker-down docker-logs docker-ps \ docker-clean docker-rebuild \ web-install web-dev web-build @@ -7,8 +7,8 @@ # Local Development # ============================================================================= -# Build all binaries -build: +# Build all binaries (frontend + backend) +build: web-build go build -o silo ./cmd/silo go build -o silod ./cmd/silod @@ -20,14 +20,19 @@ run: cli: go run ./cmd/silo $(ARGS) -# Run tests +# Run unit tests (integration tests skipped without TEST_DATABASE_URL) test: go test -v ./... +# Run all tests including integration tests (requires PostgreSQL) +test-integration: + TEST_DATABASE_URL="postgres://silo:silodev@localhost:5432/silo_test?sslmode=disable" go test -v -count=1 ./... + # Clean build artifacts clean: rm -f silo silod rm -f *.out + rm -rf web/dist # Format code fmt: @@ -153,7 +158,8 @@ help: @echo " build - Build CLI and server binaries" @echo " run - Run API server locally" @echo " cli ARGS=... - Run CLI with arguments" - @echo " test - Run tests" + @echo " test - Run unit tests (integration tests skip without DB)" + @echo " test-integration - Run all tests including integration (needs PostgreSQL)" @echo " fmt - Format code" @echo " lint - Run linter" @echo " tidy - Tidy go.mod" diff --git a/internal/api/bom_handlers_test.go b/internal/api/bom_handlers_test.go new file mode 100644 index 0000000..c31ee10 --- /dev/null +++ b/internal/api/bom_handlers_test.go @@ -0,0 +1,238 @@ +package api + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" + "github.com/kindredsystems/silo/internal/auth" + "github.com/kindredsystems/silo/internal/db" + "github.com/kindredsystems/silo/internal/schema" + "github.com/kindredsystems/silo/internal/testutil" + "github.com/rs/zerolog" +) + +// newTestServer creates a Server backed by a real test DB with no auth. +func newTestServer(t *testing.T) *Server { + t.Helper() + pool := testutil.MustConnectTestPool(t) + database := db.NewFromPool(pool) + return NewServer( + zerolog.Nop(), + database, + map[string]*schema.Schema{}, + "", // schemasDir + nil, // storage + nil, // authService + nil, // sessionManager + nil, // oidcBackend + nil, // authConfig (nil = dev mode) + ) +} + +// newTestRouter creates a chi router with BOM routes for testing. +func newTestRouter(s *Server) http.Handler { + r := chi.NewRouter() + r.Route("/api/items/{partNumber}", func(r chi.Router) { + r.Get("/bom", s.HandleGetBOM) + r.Get("/bom/flat", s.HandleGetFlatBOM) + r.Get("/bom/cost", s.HandleGetBOMCost) + r.Post("/bom", s.HandleAddBOMEntry) + r.Delete("/bom/{childPartNumber}", s.HandleDeleteBOMEntry) + }) + return r +} + +// createItemDirect creates an item directly via the DB for test setup. +func createItemDirect(t *testing.T, s *Server, pn, desc string, cost *float64) { + t.Helper() + item := &db.Item{ + PartNumber: pn, + ItemType: "part", + Description: desc, + StandardCost: cost, + } + if err := s.items.Create(context.Background(), item, nil); err != nil { + t.Fatalf("creating item %s: %v", pn, err) + } +} + +// authRequest returns a copy of the request with an admin user in context. +func authRequest(r *http.Request) *http.Request { + u := &auth.User{ + ID: "test-admin-id", + Username: "testadmin", + DisplayName: "Test Admin", + Role: auth.RoleAdmin, + AuthSource: "local", + } + return r.WithContext(auth.ContextWithUser(r.Context(), u)) +} + +// addBOMDirect adds a BOM relationship directly via the DB. +func addBOMDirect(t *testing.T, s *Server, parentPN, childPN string, qty float64) { + t.Helper() + ctx := context.Background() + parent, _ := s.items.GetByPartNumber(ctx, parentPN) + child, _ := s.items.GetByPartNumber(ctx, childPN) + if parent == nil || child == nil { + t.Fatalf("parent or child not found: %s, %s", parentPN, childPN) + } + rel := &db.Relationship{ + ParentItemID: parent.ID, + ChildItemID: child.ID, + RelType: "component", + Quantity: &qty, + } + if err := s.relationships.Create(ctx, rel); err != nil { + t.Fatalf("adding BOM %s→%s: %v", parentPN, childPN, err) + } +} + +func TestHandleGetBOM(t *testing.T) { + s := newTestServer(t) + router := newTestRouter(s) + + createItemDirect(t, s, "API-P1", "parent", nil) + createItemDirect(t, s, "API-C1", "child1", nil) + createItemDirect(t, s, "API-C2", "child2", nil) + addBOMDirect(t, s, "API-P1", "API-C1", 2) + addBOMDirect(t, s, "API-P1", "API-C2", 5) + + req := httptest.NewRequest("GET", "/api/items/API-P1/bom", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var entries []BOMEntryResponse + if err := json.Unmarshal(w.Body.Bytes(), &entries); err != nil { + t.Fatalf("decoding response: %v", err) + } + if len(entries) != 2 { + t.Errorf("expected 2 BOM entries, got %d", len(entries)) + } +} + +func TestHandleGetFlatBOM(t *testing.T) { + s := newTestServer(t) + router := newTestRouter(s) + + // A(qty 1) → B(qty 2) → X(qty 3) = X total 6 + createItemDirect(t, s, "FA", "assembly A", nil) + createItemDirect(t, s, "FB", "sub B", nil) + createItemDirect(t, s, "FX", "leaf X", nil) + addBOMDirect(t, s, "FA", "FB", 2) + addBOMDirect(t, s, "FB", "FX", 3) + + req := httptest.NewRequest("GET", "/api/items/FA/bom/flat", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp FlatBOMResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("decoding response: %v", err) + } + if len(resp.FlatBOM) != 1 { + t.Fatalf("expected 1 leaf part, got %d", len(resp.FlatBOM)) + } + if resp.FlatBOM[0].TotalQuantity != 6 { + t.Errorf("total quantity: got %.1f, want 6.0", resp.FlatBOM[0].TotalQuantity) + } +} + +func TestHandleGetBOMCost(t *testing.T) { + s := newTestServer(t) + router := newTestRouter(s) + + cost10 := 10.0 + cost5 := 5.0 + createItemDirect(t, s, "CA", "assembly", nil) + createItemDirect(t, s, "CX", "part X", &cost10) + createItemDirect(t, s, "CY", "part Y", &cost5) + addBOMDirect(t, s, "CA", "CX", 3) + addBOMDirect(t, s, "CA", "CY", 2) + + req := httptest.NewRequest("GET", "/api/items/CA/bom/cost", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp CostResponse + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("decoding response: %v", err) + } + // 3*10 + 2*5 = 40 + if resp.TotalCost != 40 { + t.Errorf("total cost: got %.2f, want 40.00", resp.TotalCost) + } +} + +func TestHandleGetFlatBOMNotFound(t *testing.T) { + s := newTestServer(t) + router := newTestRouter(s) + + req := httptest.NewRequest("GET", "/api/items/NONEXISTENT/bom/flat", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status: got %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestHandleAddBOMEntry(t *testing.T) { + s := newTestServer(t) + router := newTestRouter(s) + + createItemDirect(t, s, "ADD-P", "parent", nil) + createItemDirect(t, s, "ADD-C", "child", nil) + + body := `{"child_part_number":"ADD-C","rel_type":"component","quantity":7}` + req := authRequest(httptest.NewRequest("POST", "/api/items/ADD-P/bom", strings.NewReader(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusCreated, w.Body.String()) + } + + var entry BOMEntryResponse + if err := json.Unmarshal(w.Body.Bytes(), &entry); err != nil { + t.Fatalf("decoding response: %v", err) + } + if entry.ChildPartNumber != "ADD-C" { + t.Errorf("child_part_number: got %q, want %q", entry.ChildPartNumber, "ADD-C") + } +} + +func TestHandleDeleteBOMEntry(t *testing.T) { + s := newTestServer(t) + router := newTestRouter(s) + + createItemDirect(t, s, "DEL-P", "parent", nil) + createItemDirect(t, s, "DEL-C", "child", nil) + addBOMDirect(t, s, "DEL-P", "DEL-C", 1) + + req := authRequest(httptest.NewRequest("DELETE", "/api/items/DEL-P/bom/DEL-C", nil)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent { + t.Errorf("status: got %d, want %d; body: %s", w.Code, http.StatusNoContent, w.Body.String()) + } +} diff --git a/internal/api/items_test.go b/internal/api/items_test.go new file mode 100644 index 0000000..77cd6a9 --- /dev/null +++ b/internal/api/items_test.go @@ -0,0 +1,133 @@ +package api + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/go-chi/chi/v5" +) + +// newItemRouter creates a chi router with item routes for testing. +func newItemRouter(s *Server) http.Handler { + r := chi.NewRouter() + r.Route("/api/items", func(r chi.Router) { + r.Get("/", s.HandleListItems) + r.Post("/", s.HandleCreateItem) + r.Route("/{partNumber}", func(r chi.Router) { + r.Get("/", s.HandleGetItem) + r.Put("/", s.HandleUpdateItem) + r.Delete("/", s.HandleDeleteItem) + }) + }) + return r +} + +func TestHandleCreateItem(t *testing.T) { + s := newTestServer(t) + router := newItemRouter(s) + + body := `{"part_number":"NEW-001","item_type":"part","description":"new item"}` + req := authRequest(httptest.NewRequest("POST", "/api/items", strings.NewReader(body))) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusCreated { + t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusCreated, w.Body.String()) + } + + var resp map[string]any + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("decoding response: %v", err) + } + if resp["part_number"] != "NEW-001" { + t.Errorf("part_number: got %v, want %q", resp["part_number"], "NEW-001") + } +} + +func TestHandleGetItem(t *testing.T) { + s := newTestServer(t) + router := newItemRouter(s) + + createItemDirect(t, s, "GET-001", "get test", nil) + + req := httptest.NewRequest("GET", "/api/items/GET-001", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + if resp["part_number"] != "GET-001" { + t.Errorf("part_number: got %v, want %q", resp["part_number"], "GET-001") + } +} + +func TestHandleGetItemNotFound(t *testing.T) { + s := newTestServer(t) + router := newItemRouter(s) + + req := httptest.NewRequest("GET", "/api/items/NOPE-999", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNotFound { + t.Errorf("status: got %d, want %d", w.Code, http.StatusNotFound) + } +} + +func TestHandleListItems(t *testing.T) { + s := newTestServer(t) + router := newItemRouter(s) + + createItemDirect(t, s, "LST-001", "list item 1", nil) + createItemDirect(t, s, "LST-002", "list item 2", nil) + + req := httptest.NewRequest("GET", "/api/items", nil) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("status: got %d, want %d; body: %s", w.Code, http.StatusOK, w.Body.String()) + } + + var resp map[string]any + json.Unmarshal(w.Body.Bytes(), &resp) + items, ok := resp["items"].([]any) + if !ok { + t.Fatalf("expected items array in response, got: %s", w.Body.String()) + } + if len(items) < 2 { + t.Errorf("expected at least 2 items, got %d", len(items)) + } +} + +func TestHandleDeleteItem(t *testing.T) { + s := newTestServer(t) + router := newItemRouter(s) + + createItemDirect(t, s, "DEL-ITEM-001", "deletable", nil) + + req := authRequest(httptest.NewRequest("DELETE", "/api/items/DEL-ITEM-001", nil)) + w := httptest.NewRecorder() + router.ServeHTTP(w, req) + + if w.Code != http.StatusNoContent && w.Code != http.StatusOK { + t.Fatalf("status: got %d, want 200 or 204; body: %s", w.Code, w.Body.String()) + } + + // Should be gone (archived) + req2 := httptest.NewRequest("GET", "/api/items/DEL-ITEM-001", nil) + w2 := httptest.NewRecorder() + router.ServeHTTP(w2, req2) + + if w2.Code != http.StatusNotFound { + t.Errorf("after delete, expected 404, got %d", w2.Code) + } +} diff --git a/internal/db/db.go b/internal/db/db.go index 42bd05f..0bc0745 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -44,6 +44,23 @@ func Connect(ctx context.Context, cfg Config) (*DB, error) { return &DB{pool: pool}, nil } +// ConnectDSN establishes a database connection pool from a DSN string. +func ConnectDSN(ctx context.Context, dsn string) (*DB, error) { + pool, err := pgxpool.New(ctx, dsn) + if err != nil { + return nil, fmt.Errorf("creating connection pool: %w", err) + } + if err := pool.Ping(ctx); err != nil { + return nil, fmt.Errorf("pinging database: %w", err) + } + return &DB{pool: pool}, nil +} + +// NewFromPool wraps an existing connection pool in a DB handle. +func NewFromPool(pool *pgxpool.Pool) *DB { + return &DB{pool: pool} +} + // Close closes the connection pool. func (db *DB) Close() { db.pool.Close() diff --git a/internal/db/helpers_test.go b/internal/db/helpers_test.go new file mode 100644 index 0000000..d521e06 --- /dev/null +++ b/internal/db/helpers_test.go @@ -0,0 +1,15 @@ +package db + +import ( + "testing" + + "github.com/kindredsystems/silo/internal/testutil" +) + +// mustConnectTestDB returns a *DB backed by a real test Postgres instance. +// It skips the test if TEST_DATABASE_URL is not set. +func mustConnectTestDB(t *testing.T) *DB { + t.Helper() + pool := testutil.MustConnectTestPool(t) + return NewFromPool(pool) +} diff --git a/internal/db/items_test.go b/internal/db/items_test.go new file mode 100644 index 0000000..9c7efc6 --- /dev/null +++ b/internal/db/items_test.go @@ -0,0 +1,272 @@ +package db + +import ( + "context" + "fmt" + "testing" +) + +func TestItemCreate(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{ + PartNumber: "TEST-0001", + ItemType: "part", + Description: "Test item", + } + err := repo.Create(ctx, item, map[string]any{"color": "red"}) + if err != nil { + t.Fatalf("Create: %v", err) + } + if item.ID == "" { + t.Error("expected item ID to be set") + } + if item.CurrentRevision != 1 { + t.Errorf("current revision: got %d, want 1", item.CurrentRevision) + } +} + +func TestItemGetByPartNumber(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{PartNumber: "GET-PN-001", ItemType: "part", Description: "get by pn test"} + if err := repo.Create(ctx, item, nil); err != nil { + t.Fatalf("Create: %v", err) + } + + got, err := repo.GetByPartNumber(ctx, "GET-PN-001") + if err != nil { + t.Fatalf("GetByPartNumber: %v", err) + } + if got == nil { + t.Fatal("expected item, got nil") + } + if got.Description != "get by pn test" { + t.Errorf("description: got %q, want %q", got.Description, "get by pn test") + } + + // Non-existent should return nil, not error + missing, err := repo.GetByPartNumber(ctx, "DOES-NOT-EXIST") + if err != nil { + t.Fatalf("GetByPartNumber (missing): %v", err) + } + if missing != nil { + t.Error("expected nil for missing item") + } +} + +func TestItemGetByID(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{PartNumber: "GET-ID-001", ItemType: "assembly", Description: "get by id"} + if err := repo.Create(ctx, item, nil); err != nil { + t.Fatalf("Create: %v", err) + } + + got, err := repo.GetByID(ctx, item.ID) + if err != nil { + t.Fatalf("GetByID: %v", err) + } + if got == nil { + t.Fatal("expected item, got nil") + } + if got.PartNumber != "GET-ID-001" { + t.Errorf("part_number: got %q, want %q", got.PartNumber, "GET-ID-001") + } +} + +func TestItemList(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + for i := 0; i < 3; i++ { + item := &Item{ + PartNumber: fmt.Sprintf("LIST-%04d", i), + ItemType: "part", + Description: fmt.Sprintf("list item %d", i), + } + if err := repo.Create(ctx, item, nil); err != nil { + t.Fatalf("Create #%d: %v", i, err) + } + } + + items, err := repo.List(ctx, ListOptions{}) + if err != nil { + t.Fatalf("List: %v", err) + } + if len(items) != 3 { + t.Errorf("expected 3 items, got %d", len(items)) + } +} + +func TestItemListByType(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + repo.Create(ctx, &Item{PartNumber: "TYPE-P-001", ItemType: "part", Description: "a part"}, nil) + repo.Create(ctx, &Item{PartNumber: "TYPE-A-001", ItemType: "assembly", Description: "an assembly"}, nil) + repo.Create(ctx, &Item{PartNumber: "TYPE-P-002", ItemType: "part", Description: "another part"}, nil) + + items, err := repo.List(ctx, ListOptions{ItemType: "part"}) + if err != nil { + t.Fatalf("List: %v", err) + } + if len(items) != 2 { + t.Errorf("expected 2 parts, got %d", len(items)) + } +} + +func TestItemUpdate(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{PartNumber: "UPD-001", ItemType: "part", Description: "original"} + if err := repo.Create(ctx, item, nil); err != nil { + t.Fatalf("Create: %v", err) + } + + cost := 42.50 + err := repo.Update(ctx, item.ID, UpdateItemFields{ + PartNumber: "UPD-001", + ItemType: "part", + Description: "updated", + StandardCost: &cost, + }) + if err != nil { + t.Fatalf("Update: %v", err) + } + + got, _ := repo.GetByID(ctx, item.ID) + if got.Description != "updated" { + t.Errorf("description: got %q, want %q", got.Description, "updated") + } + if got.StandardCost == nil || *got.StandardCost != 42.50 { + t.Errorf("standard_cost: got %v, want 42.50", got.StandardCost) + } +} + +func TestItemArchiveUnarchive(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{PartNumber: "ARC-001", ItemType: "part", Description: "archivable"} + if err := repo.Create(ctx, item, nil); err != nil { + t.Fatalf("Create: %v", err) + } + + // Archive + if err := repo.Archive(ctx, item.ID); err != nil { + t.Fatalf("Archive: %v", err) + } + + // Should not appear in GetByPartNumber (excludes archived) + got, _ := repo.GetByPartNumber(ctx, "ARC-001") + if got != nil { + t.Error("archived item should not be returned by GetByPartNumber") + } + + // But should still be accessible by ID + gotByID, _ := repo.GetByID(ctx, item.ID) + if gotByID == nil { + t.Fatal("archived item should still be accessible by GetByID") + } + if gotByID.ArchivedAt == nil { + t.Error("archived_at should be set") + } + + // Unarchive + if err := repo.Unarchive(ctx, item.ID); err != nil { + t.Fatalf("Unarchive: %v", err) + } + got, _ = repo.GetByPartNumber(ctx, "ARC-001") + if got == nil { + t.Error("unarchived item should be returned by GetByPartNumber") + } +} + +func TestItemCreateRevision(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{PartNumber: "REV-001", ItemType: "part", Description: "revisable"} + if err := repo.Create(ctx, item, map[string]any{"v": 1}); err != nil { + t.Fatalf("Create: %v", err) + } + + // Create second revision + rev := &Revision{ + ItemID: item.ID, + Properties: map[string]any{"v": 2}, + } + if err := repo.CreateRevision(ctx, rev); err != nil { + t.Fatalf("CreateRevision: %v", err) + } + if rev.RevisionNumber != 2 { + t.Errorf("revision number: got %d, want 2", rev.RevisionNumber) + } + + // Item's current_revision should be updated by trigger + got, _ := repo.GetByPartNumber(ctx, "REV-001") + if got.CurrentRevision != 2 { + t.Errorf("current_revision: got %d, want 2", got.CurrentRevision) + } +} + +func TestItemGetRevisions(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{PartNumber: "REVS-001", ItemType: "part", Description: "multi rev"} + if err := repo.Create(ctx, item, map[string]any{"step": "initial"}); err != nil { + t.Fatalf("Create: %v", err) + } + + comment := "second revision" + repo.CreateRevision(ctx, &Revision{ + ItemID: item.ID, Properties: map[string]any{"step": "updated"}, Comment: &comment, + }) + + revisions, err := repo.GetRevisions(ctx, item.ID) + if err != nil { + t.Fatalf("GetRevisions: %v", err) + } + if len(revisions) != 2 { + t.Errorf("expected 2 revisions, got %d", len(revisions)) + } + // Revisions are returned newest first + if revisions[0].RevisionNumber != 2 { + t.Errorf("first revision should be #2 (newest), got #%d", revisions[0].RevisionNumber) + } +} + +func TestItemSetThumbnailKey(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewItemRepository(database) + ctx := context.Background() + + item := &Item{PartNumber: "THUMB-001", ItemType: "part", Description: "thumbnail test"} + if err := repo.Create(ctx, item, nil); err != nil { + t.Fatalf("Create: %v", err) + } + + if err := repo.SetThumbnailKey(ctx, item.ID, "items/thumb.png"); err != nil { + t.Fatalf("SetThumbnailKey: %v", err) + } + + got, _ := repo.GetByID(ctx, item.ID) + if got.ThumbnailKey == nil || *got.ThumbnailKey != "items/thumb.png" { + t.Errorf("thumbnail_key: got %v, want %q", got.ThumbnailKey, "items/thumb.png") + } +} diff --git a/internal/db/projects_test.go b/internal/db/projects_test.go new file mode 100644 index 0000000..0d24e62 --- /dev/null +++ b/internal/db/projects_test.go @@ -0,0 +1,119 @@ +package db + +import ( + "context" + "testing" +) + +func TestProjectCreate(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewProjectRepository(database) + ctx := context.Background() + + p := &Project{Code: "TPRJ", Name: "Test Project"} + if err := repo.Create(ctx, p); err != nil { + t.Fatalf("Create: %v", err) + } + if p.ID == "" { + t.Error("expected project ID to be set") + } +} + +func TestProjectGet(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewProjectRepository(database) + ctx := context.Background() + + repo.Create(ctx, &Project{Code: "GPRJ", Name: "Get Project"}) + + got, err := repo.GetByCode(ctx, "GPRJ") + if err != nil { + t.Fatalf("GetByCode: %v", err) + } + if got == nil { + t.Fatal("expected project, got nil") + } + if got.Name != "Get Project" { + t.Errorf("name: got %q, want %q", got.Name, "Get Project") + } + + // Missing should return nil + missing, err := repo.GetByCode(ctx, "NOPE") + if err != nil { + t.Fatalf("GetByCode (missing): %v", err) + } + if missing != nil { + t.Error("expected nil for missing project") + } +} + +func TestProjectList(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewProjectRepository(database) + ctx := context.Background() + + repo.Create(ctx, &Project{Code: "AA", Name: "Alpha"}) + repo.Create(ctx, &Project{Code: "BB", Name: "Beta"}) + + projects, err := repo.List(ctx) + if err != nil { + t.Fatalf("List: %v", err) + } + if len(projects) != 2 { + t.Errorf("expected 2 projects, got %d", len(projects)) + } +} + +func TestProjectDelete(t *testing.T) { + database := mustConnectTestDB(t) + repo := NewProjectRepository(database) + ctx := context.Background() + + repo.Create(ctx, &Project{Code: "DEL", Name: "Deletable"}) + + if err := repo.Delete(ctx, "DEL"); err != nil { + t.Fatalf("Delete: %v", err) + } + + got, _ := repo.GetByCode(ctx, "DEL") + if got != nil { + t.Error("deleted project should not be found") + } +} + +func TestProjectItemAssociation(t *testing.T) { + database := mustConnectTestDB(t) + projRepo := NewProjectRepository(database) + itemRepo := NewItemRepository(database) + ctx := context.Background() + + proj := &Project{Code: "ASSC", Name: "Assoc Project"} + projRepo.Create(ctx, proj) + + item := &Item{PartNumber: "ASSC-001", ItemType: "part", Description: "associated item"} + itemRepo.Create(ctx, item, nil) + + // Add item to project + if err := projRepo.AddItemToProject(ctx, item.ID, proj.ID); err != nil { + t.Fatalf("AddItemToProject: %v", err) + } + + // Get items for project (takes project UUID) + items, err := projRepo.GetItemsForProject(ctx, proj.ID) + if err != nil { + t.Fatalf("GetItemsForProject: %v", err) + } + if len(items) != 1 { + t.Errorf("expected 1 item in project, got %d", len(items)) + } + + // Remove item from project + if err := projRepo.RemoveItemFromProject(ctx, item.ID, proj.ID); err != nil { + t.Fatalf("RemoveItemFromProject: %v", err) + } + + items, _ = projRepo.GetItemsForProject(ctx, proj.ID) + if len(items) != 0 { + t.Errorf("expected 0 items after removal, got %d", len(items)) + } +} diff --git a/internal/db/relationships_test.go b/internal/db/relationships_test.go new file mode 100644 index 0000000..693e3a7 --- /dev/null +++ b/internal/db/relationships_test.go @@ -0,0 +1,278 @@ +package db + +import ( + "context" + "strings" + "testing" +) + +// createTestItem is a helper that creates a minimal item for BOM tests. +func createTestItem(t *testing.T, repo *ItemRepository, pn, desc string) *Item { + t.Helper() + item := &Item{PartNumber: pn, ItemType: "part", Description: desc} + if err := repo.Create(context.Background(), item, nil); err != nil { + t.Fatalf("creating test item %s: %v", pn, err) + } + return item +} + +func TestBOMCreate(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + parent := createTestItem(t, items, "BOM-P-001", "parent") + child := createTestItem(t, items, "BOM-C-001", "child") + + qty := 3.0 + rel := &Relationship{ + ParentItemID: parent.ID, + ChildItemID: child.ID, + RelType: "component", + Quantity: &qty, + } + if err := rels.Create(ctx, rel); err != nil { + t.Fatalf("Create: %v", err) + } + if rel.ID == "" { + t.Error("expected relationship ID to be set") + } +} + +func TestBOMGetBOM(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + parent := createTestItem(t, items, "BOM-G-001", "parent") + child1 := createTestItem(t, items, "BOM-G-002", "child1") + child2 := createTestItem(t, items, "BOM-G-003", "child2") + + qty1, qty2 := 2.0, 5.0 + rels.Create(ctx, &Relationship{ParentItemID: parent.ID, ChildItemID: child1.ID, RelType: "component", Quantity: &qty1}) + rels.Create(ctx, &Relationship{ParentItemID: parent.ID, ChildItemID: child2.ID, RelType: "component", Quantity: &qty2}) + + bom, err := rels.GetBOM(ctx, parent.ID) + if err != nil { + t.Fatalf("GetBOM: %v", err) + } + if len(bom) != 2 { + t.Errorf("expected 2 BOM entries, got %d", len(bom)) + } +} + +func TestBOMSelfReference(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + item := createTestItem(t, items, "BOM-SELF-001", "self-referencing") + + err := rels.Create(ctx, &Relationship{ + ParentItemID: item.ID, + ChildItemID: item.ID, + RelType: "component", + }) + if err == nil { + t.Fatal("expected error for self-reference, got nil") + } +} + +func TestBOMCycleDetection(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + a := createTestItem(t, items, "CYC-A", "A") + b := createTestItem(t, items, "CYC-B", "B") + c := createTestItem(t, items, "CYC-C", "C") + + // A → B → C + rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: b.ID, RelType: "component"}) + rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: c.ID, RelType: "component"}) + + // C → A should be detected as a cycle + err := rels.Create(ctx, &Relationship{ + ParentItemID: c.ID, + ChildItemID: a.ID, + RelType: "component", + }) + if err == nil { + t.Fatal("expected cycle error, got nil") + } + if !strings.Contains(err.Error(), "cycle") { + t.Errorf("expected cycle error, got: %v", err) + } +} + +func TestBOMDelete(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + parent := createTestItem(t, items, "BOM-D-001", "parent") + child := createTestItem(t, items, "BOM-D-002", "child") + + rel := &Relationship{ParentItemID: parent.ID, ChildItemID: child.ID, RelType: "component"} + rels.Create(ctx, rel) + + if err := rels.Delete(ctx, rel.ID); err != nil { + t.Fatalf("Delete: %v", err) + } + + bom, _ := rels.GetBOM(ctx, parent.ID) + if len(bom) != 0 { + t.Errorf("expected 0 BOM entries after delete, got %d", len(bom)) + } +} + +func TestBOMUpdate(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + parent := createTestItem(t, items, "BOM-U-001", "parent") + child := createTestItem(t, items, "BOM-U-002", "child") + + qty := 1.0 + rel := &Relationship{ParentItemID: parent.ID, ChildItemID: child.ID, RelType: "component", Quantity: &qty} + rels.Create(ctx, rel) + + newQty := 10.0 + if err := rels.Update(ctx, rel.ID, nil, &newQty, nil, nil, nil, nil, nil); err != nil { + t.Fatalf("Update: %v", err) + } + + bom, _ := rels.GetBOM(ctx, parent.ID) + if len(bom) != 1 { + t.Fatalf("expected 1 BOM entry, got %d", len(bom)) + } + if bom[0].Quantity == nil || *bom[0].Quantity != 10.0 { + t.Errorf("quantity: got %v, want 10.0", bom[0].Quantity) + } +} + +func TestBOMWhereUsed(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + parent1 := createTestItem(t, items, "WU-P1", "parent1") + parent2 := createTestItem(t, items, "WU-P2", "parent2") + child := createTestItem(t, items, "WU-C1", "shared child") + + rels.Create(ctx, &Relationship{ParentItemID: parent1.ID, ChildItemID: child.ID, RelType: "component"}) + rels.Create(ctx, &Relationship{ParentItemID: parent2.ID, ChildItemID: child.ID, RelType: "component"}) + + wu, err := rels.GetWhereUsed(ctx, child.ID) + if err != nil { + t.Fatalf("GetWhereUsed: %v", err) + } + if len(wu) != 2 { + t.Errorf("expected 2 where-used entries, got %d", len(wu)) + } +} + +func TestBOMExpandedBOM(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + // A → B → C (3 levels) + a := createTestItem(t, items, "EXP-A", "top assembly") + b := createTestItem(t, items, "EXP-B", "sub assembly") + c := createTestItem(t, items, "EXP-C", "leaf part") + + qty2, qty3 := 2.0, 3.0 + rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: b.ID, RelType: "component", Quantity: &qty2}) + rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: c.ID, RelType: "component", Quantity: &qty3}) + + expanded, err := rels.GetExpandedBOM(ctx, a.ID, 10) + if err != nil { + t.Fatalf("GetExpandedBOM: %v", err) + } + if len(expanded) != 2 { + t.Errorf("expected 2 expanded entries (B and C), got %d", len(expanded)) + } + + // Verify depths + for _, e := range expanded { + if e.ChildPartNumber == "EXP-B" && e.Depth != 1 { + t.Errorf("EXP-B depth: got %d, want 1", e.Depth) + } + if e.ChildPartNumber == "EXP-C" && e.Depth != 2 { + t.Errorf("EXP-C depth: got %d, want 2", e.Depth) + } + } +} + +func TestBOMFlatBOM(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + // Assembly A (qty 1) + // ├── Sub-assembly B (qty 2) + // │ ├── Part X (qty 3) → total 6 + // │ └── Part Y (qty 1) → total 2 + // └── Part X (qty 4) → total 4 (+ 6 = 10 total for X) + a := createTestItem(t, items, "FLAT-A", "top") + b := createTestItem(t, items, "FLAT-B", "sub") + x := createTestItem(t, items, "FLAT-X", "leaf X") + y := createTestItem(t, items, "FLAT-Y", "leaf Y") + + q2, q3, q1, q4 := 2.0, 3.0, 1.0, 4.0 + rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: b.ID, RelType: "component", Quantity: &q2}) + rels.Create(ctx, &Relationship{ParentItemID: a.ID, ChildItemID: x.ID, RelType: "component", Quantity: &q4}) + rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: x.ID, RelType: "component", Quantity: &q3}) + rels.Create(ctx, &Relationship{ParentItemID: b.ID, ChildItemID: y.ID, RelType: "component", Quantity: &q1}) + + flat, err := rels.GetFlatBOM(ctx, a.ID) + if err != nil { + t.Fatalf("GetFlatBOM: %v", err) + } + if len(flat) != 2 { + t.Errorf("expected 2 leaf parts, got %d", len(flat)) + } + + for _, e := range flat { + switch e.PartNumber { + case "FLAT-X": + if e.TotalQuantity != 10.0 { + t.Errorf("FLAT-X total qty: got %.1f, want 10.0", e.TotalQuantity) + } + case "FLAT-Y": + if e.TotalQuantity != 2.0 { + t.Errorf("FLAT-Y total qty: got %.1f, want 2.0", e.TotalQuantity) + } + default: + t.Errorf("unexpected part in flat BOM: %s", e.PartNumber) + } + } +} + +func TestBOMFlatBOMEmpty(t *testing.T) { + database := mustConnectTestDB(t) + items := NewItemRepository(database) + rels := NewRelationshipRepository(database) + ctx := context.Background() + + item := createTestItem(t, items, "FLAT-EMPTY", "no children") + + flat, err := rels.GetFlatBOM(ctx, item.ID) + if err != nil { + t.Fatalf("GetFlatBOM: %v", err) + } + if len(flat) != 0 { + t.Errorf("expected 0 leaf parts for item with no BOM, got %d", len(flat)) + } +} diff --git a/internal/partnum/generator.go b/internal/partnum/generator.go index ebcc843..9cc7d21 100644 --- a/internal/partnum/generator.go +++ b/internal/partnum/generator.go @@ -135,7 +135,7 @@ func (g *Generator) formatString(seg *schema.Segment, val string) (string, error if msg == "" { msg = fmt.Sprintf("value does not match pattern %s", seg.Validation.Pattern) } - return "", fmt.Errorf(msg) + return "", fmt.Errorf("%s", msg) } } diff --git a/internal/partnum/generator_test.go b/internal/partnum/generator_test.go new file mode 100644 index 0000000..00dcf2b --- /dev/null +++ b/internal/partnum/generator_test.go @@ -0,0 +1,167 @@ +package partnum + +import ( + "context" + "fmt" + "testing" + + "github.com/kindredsystems/silo/internal/schema" +) + +// mockSeqStore implements SequenceStore for testing. +type mockSeqStore struct { + counter int +} + +func (m *mockSeqStore) NextValue(_ context.Context, _ string, _ string) (int, error) { + m.counter++ + return m.counter, nil +} + +func testSchema() *schema.Schema { + return &schema.Schema{ + Name: "test", + Version: 1, + Separator: "-", + Segments: []schema.Segment{ + { + Name: "category", + Type: "enum", + Required: true, + Values: map[string]string{ + "F01": "Fasteners", + "R01": "Resistors", + }, + }, + { + Name: "serial", + Type: "serial", + Length: 4, + Scope: "{category}", + }, + }, + } +} + +func TestGenerateBasic(t *testing.T) { + s := testSchema() + gen := NewGenerator(map[string]*schema.Schema{"test": s}, &mockSeqStore{}) + + pn, err := gen.Generate(context.Background(), Input{ + SchemaName: "test", + Values: map[string]string{"category": "F01"}, + }) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } + if pn != "F01-0001" { + t.Errorf("got %q, want %q", pn, "F01-0001") + } +} + +func TestGenerateSequentialNumbers(t *testing.T) { + s := testSchema() + seq := &mockSeqStore{} + gen := NewGenerator(map[string]*schema.Schema{"test": s}, seq) + + for i := 1; i <= 3; i++ { + pn, err := gen.Generate(context.Background(), Input{ + SchemaName: "test", + Values: map[string]string{"category": "F01"}, + }) + if err != nil { + t.Fatalf("Generate #%d returned error: %v", i, err) + } + want := fmt.Sprintf("F01-%04d", i) + if pn != want { + t.Errorf("Generate #%d: got %q, want %q", i, pn, want) + } + } +} + +func TestGenerateWithFormat(t *testing.T) { + s := &schema.Schema{ + Name: "formatted", + Version: 1, + Format: "{prefix}/{category}-{serial}", + Segments: []schema.Segment{ + {Name: "prefix", Type: "constant", Value: "KS"}, + {Name: "category", Type: "enum", Required: true, Values: map[string]string{"A": "Alpha"}}, + {Name: "serial", Type: "serial", Length: 3}, + }, + } + gen := NewGenerator(map[string]*schema.Schema{"formatted": s}, &mockSeqStore{}) + + pn, err := gen.Generate(context.Background(), Input{ + SchemaName: "formatted", + Values: map[string]string{"category": "A"}, + }) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } + if pn != "KS/A-001" { + t.Errorf("got %q, want %q", pn, "KS/A-001") + } +} + +func TestGenerateUnknownSchema(t *testing.T) { + gen := NewGenerator(map[string]*schema.Schema{}, &mockSeqStore{}) + + _, err := gen.Generate(context.Background(), Input{ + SchemaName: "nonexistent", + Values: map[string]string{}, + }) + if err == nil { + t.Fatal("expected error for unknown schema, got nil") + } +} + +func TestGenerateMissingRequiredEnum(t *testing.T) { + s := testSchema() + gen := NewGenerator(map[string]*schema.Schema{"test": s}, &mockSeqStore{}) + + _, err := gen.Generate(context.Background(), Input{ + SchemaName: "test", + Values: map[string]string{}, // missing required "category" + }) + if err == nil { + t.Fatal("expected error for missing required enum, got nil") + } +} + +func TestGenerateInvalidEnumValue(t *testing.T) { + s := testSchema() + gen := NewGenerator(map[string]*schema.Schema{"test": s}, &mockSeqStore{}) + + _, err := gen.Generate(context.Background(), Input{ + SchemaName: "test", + Values: map[string]string{"category": "INVALID"}, + }) + if err == nil { + t.Fatal("expected error for invalid enum value, got nil") + } +} + +func TestGenerateConstantSegment(t *testing.T) { + s := &schema.Schema{ + Name: "const-test", + Version: 1, + Separator: "-", + Segments: []schema.Segment{ + {Name: "prefix", Type: "constant", Value: "KS"}, + {Name: "serial", Type: "serial", Length: 4}, + }, + } + gen := NewGenerator(map[string]*schema.Schema{"const-test": s}, &mockSeqStore{}) + + pn, err := gen.Generate(context.Background(), Input{ + SchemaName: "const-test", + Values: map[string]string{}, + }) + if err != nil { + t.Fatalf("Generate returned error: %v", err) + } + if pn != "KS-0001" { + t.Errorf("got %q, want %q", pn, "KS-0001") + } +} diff --git a/internal/schema/schema_test.go b/internal/schema/schema_test.go new file mode 100644 index 0000000..3c3a0b7 --- /dev/null +++ b/internal/schema/schema_test.go @@ -0,0 +1,251 @@ +package schema + +import ( + "os" + "path/filepath" + "testing" +) + +// findSchemasDir walks upward to find the project root and returns +// the path to the schemas/ directory. +func findSchemasDir(t *testing.T) string { + t.Helper() + dir, err := os.Getwd() + if err != nil { + t.Fatalf("getting working directory: %v", err) + } + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return filepath.Join(dir, "schemas") + } + parent := filepath.Dir(dir) + if parent == dir { + t.Fatal("could not find project root") + } + dir = parent + } +} + +func TestLoadSchema(t *testing.T) { + schemasDir := findSchemasDir(t) + path := filepath.Join(schemasDir, "kindred-rd.yaml") + + s, err := Load(path) + if err != nil { + t.Fatalf("Load returned error: %v", err) + } + + if s.Name != "kindred-rd" { + t.Errorf("schema name: got %q, want %q", s.Name, "kindred-rd") + } + if s.Separator != "-" { + t.Errorf("separator: got %q, want %q", s.Separator, "-") + } + if len(s.Segments) == 0 { + t.Fatal("expected at least one segment") + } + + // First segment should be the category enum + cat := s.Segments[0] + if cat.Name != "category" { + t.Errorf("first segment name: got %q, want %q", cat.Name, "category") + } + if cat.Type != "enum" { + t.Errorf("first segment type: got %q, want %q", cat.Type, "enum") + } + if len(cat.Values) == 0 { + t.Error("category segment has no values") + } +} + +func TestLoadSchemaDir(t *testing.T) { + schemasDir := findSchemasDir(t) + + schemas, err := LoadAll(schemasDir) + if err != nil { + t.Fatalf("LoadAll returned error: %v", err) + } + + if len(schemas) == 0 { + t.Fatal("expected at least one schema") + } + + if _, ok := schemas["kindred-rd"]; !ok { + t.Error("kindred-rd schema not found in loaded schemas") + } +} + +func TestLoadSchemaValidation(t *testing.T) { + schemasDir := findSchemasDir(t) + + schemas, err := LoadAll(schemasDir) + if err != nil { + t.Fatalf("LoadAll returned error: %v", err) + } + + for name, s := range schemas { + if s.Name == "" { + continue // non-part-numbering schemas (e.g., location_schema) + } + if err := s.Validate(); err != nil { + t.Errorf("schema %q failed validation: %v", name, err) + } + } +} + +func TestGetPropertiesForCategory(t *testing.T) { + ps := &PropertySchemas{ + Version: 1, + Defaults: map[string]PropertyDefinition{ + "weight": {Type: "number", Unit: "kg"}, + "color": {Type: "string"}, + }, + Categories: map[string]map[string]PropertyDefinition{ + "F": { + "thread_size": {Type: "string", Required: true}, + "weight": {Type: "number", Unit: "g"}, // override default + }, + }, + } + + props := ps.GetPropertiesForCategory("F01") + + // Should have all three: weight (overridden), color (default), thread_size (category) + if len(props) != 3 { + t.Errorf("expected 3 properties, got %d", len(props)) + } + if props["weight"].Unit != "g" { + t.Errorf("weight unit: got %q, want %q (should be overridden by category)", props["weight"].Unit, "g") + } + if props["color"].Type != "string" { + t.Errorf("color type: got %q, want %q", props["color"].Type, "string") + } + if !props["thread_size"].Required { + t.Error("thread_size should be required") + } +} + +func TestGetPropertiesForUnknownCategory(t *testing.T) { + ps := &PropertySchemas{ + Version: 1, + Defaults: map[string]PropertyDefinition{ + "weight": {Type: "number"}, + }, + Categories: map[string]map[string]PropertyDefinition{ + "F": {"thread_size": {Type: "string"}}, + }, + } + + props := ps.GetPropertiesForCategory("Z99") + + // Only defaults, no category-specific properties + if len(props) != 1 { + t.Errorf("expected 1 property (defaults only), got %d", len(props)) + } + if _, ok := props["weight"]; !ok { + t.Error("default property 'weight' should be present") + } +} + +func TestApplyDefaults(t *testing.T) { + ps := &PropertySchemas{ + Version: 1, + Defaults: map[string]PropertyDefinition{ + "status": {Type: "string", Default: "draft"}, + "weight": {Type: "number"}, + }, + } + + props := map[string]any{"custom": "value"} + result := ps.ApplyDefaults(props, "X") + + if result["status"] != "draft" { + t.Errorf("status: got %v, want %q", result["status"], "draft") + } + if result["custom"] != "value" { + t.Errorf("custom: got %v, want %q", result["custom"], "value") + } + // weight has no default, should not be added + if _, ok := result["weight"]; ok { + t.Error("weight should not be added (no default value)") + } +} + +func TestSchemaValidate(t *testing.T) { + tests := []struct { + name string + schema Schema + wantErr bool + }{ + { + name: "empty name", + schema: Schema{Name: "", Segments: []Segment{{Name: "s", Type: "constant", Value: "X"}}}, + wantErr: true, + }, + { + name: "no segments", + schema: Schema{Name: "test", Segments: nil}, + wantErr: true, + }, + { + name: "valid minimal", + schema: Schema{ + Name: "test", + Segments: []Segment{ + {Name: "prefix", Type: "constant", Value: "X"}, + }, + }, + wantErr: false, + }, + { + name: "enum without values", + schema: Schema{ + Name: "test", + Segments: []Segment{ + {Name: "cat", Type: "enum"}, + }, + }, + wantErr: true, + }, + { + name: "serial without length", + schema: Schema{ + Name: "test", + Segments: []Segment{ + {Name: "seq", Type: "serial", Length: 0}, + }, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.schema.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestGetSegment(t *testing.T) { + s := &Schema{ + Segments: []Segment{ + {Name: "prefix", Type: "constant", Value: "KS"}, + {Name: "serial", Type: "serial", Length: 4}, + }, + } + + seg := s.GetSegment("serial") + if seg == nil { + t.Fatal("GetSegment returned nil for existing segment") + } + if seg.Type != "serial" { + t.Errorf("segment type: got %q, want %q", seg.Type, "serial") + } + + if s.GetSegment("nonexistent") != nil { + t.Error("GetSegment should return nil for nonexistent segment") + } +} diff --git a/internal/testutil/testutil.go b/internal/testutil/testutil.go new file mode 100644 index 0000000..d7c62c9 --- /dev/null +++ b/internal/testutil/testutil.go @@ -0,0 +1,114 @@ +// Package testutil provides shared helpers for Silo tests. +package testutil + +import ( + "context" + "fmt" + "os" + "path/filepath" + "sort" + "testing" + + "github.com/jackc/pgx/v5/pgxpool" +) + +// MustConnectTestPool connects to a test database using TEST_DATABASE_URL. +// If the env var is unset the test is skipped. Migrations are applied and +// all data tables are truncated before returning. The pool is closed +// automatically when the test finishes. +func MustConnectTestPool(t *testing.T) *pgxpool.Pool { + t.Helper() + + dsn := os.Getenv("TEST_DATABASE_URL") + if dsn == "" { + t.Skip("TEST_DATABASE_URL not set, skipping integration test") + } + + pool, err := pgxpool.New(context.Background(), dsn) + if err != nil { + t.Fatalf("connecting to test database: %v", err) + } + if err := pool.Ping(context.Background()); err != nil { + pool.Close() + t.Fatalf("pinging test database: %v", err) + } + t.Cleanup(func() { pool.Close() }) + + RunMigrations(t, pool) + TruncateAll(t, pool) + + return pool +} + +// RunMigrations applies all SQL migration files from the migrations/ +// directory. It walks upward from the current working directory to find +// the project root (containing go.mod). +func RunMigrations(t *testing.T, pool *pgxpool.Pool) { + t.Helper() + + root := findProjectRoot(t) + migDir := filepath.Join(root, "migrations") + entries, err := os.ReadDir(migDir) + if err != nil { + t.Fatalf("reading migrations directory: %v", err) + } + + // Sort by filename to apply in order. + sort.Slice(entries, func(i, j int) bool { + return entries[i].Name() < entries[j].Name() + }) + + for _, e := range entries { + if e.IsDir() || filepath.Ext(e.Name()) != ".sql" { + continue + } + sql, err := os.ReadFile(filepath.Join(migDir, e.Name())) + if err != nil { + t.Fatalf("reading migration %s: %v", e.Name(), err) + } + if _, err := pool.Exec(context.Background(), string(sql)); err != nil { + // Migrations may contain IF NOT EXISTS / CREATE OR REPLACE, + // so most "already exists" errors are fine. Log and continue. + t.Logf("migration %s: %v (may be OK if already applied)", e.Name(), err) + } + } +} + +// TruncateAll removes all data from tables, leaving schema intact. +func TruncateAll(t *testing.T, pool *pgxpool.Pool) { + t.Helper() + + _, err := pool.Exec(context.Background(), ` + TRUNCATE + audit_log, sync_log, api_tokens, sessions, item_files, + item_projects, relationships, revisions, inventory, items, + projects, sequences_by_name, users, property_migrations + CASCADE + `) + if err != nil { + t.Fatalf("truncating tables: %v", err) + } +} + +// findProjectRoot walks upward from cwd to find the directory containing go.mod. +func findProjectRoot(t *testing.T) string { + t.Helper() + + dir, err := os.Getwd() + if err != nil { + t.Fatalf("getting working directory: %v", err) + } + + for { + if _, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil { + return dir + } + parent := filepath.Dir(dir) + if parent == dir { + t.Fatalf("could not find project root (go.mod)") + } + dir = parent + } + + panic(fmt.Sprintf("unreachable")) +}