From e7da3ee94d77b8d6f47101d0113e1aab36d78bbc Mon Sep 17 00:00:00 2001 From: Forbes Date: Sun, 1 Mar 2026 10:04:01 -0600 Subject: [PATCH] feat(sse): per-connection filtering with user and workstation context - Extend sseClient with userID, workstationID, and item filter set - Update Subscribe() to accept userID and workstationID params - Add WatchItem/UnwatchItem/IsWatchingItem methods on sseClient - Add PublishToItem, PublishToWorkstation, PublishToUser targeted delivery - Targeted events get IDs but skip history ring buffer (real-time only) - Update HandleEvents to pass auth user ID and workstation_id query param - Touch workstation last_seen on SSE connect - Existing Publish() broadcast unchanged; all current callers unaffected - Add 5 new tests for targeted delivery and item watch lifecycle Closes #162 --- internal/api/broker.go | 82 +++++++++++++++++-- internal/api/broker_test.go | 139 ++++++++++++++++++++++++++++++-- internal/api/servermode_test.go | 2 +- internal/api/sse_handler.go | 14 +++- 4 files changed, 223 insertions(+), 14 deletions(-) diff --git a/internal/api/broker.go b/internal/api/broker.go index bf084aa..382bc48 100644 --- a/internal/api/broker.go +++ b/internal/api/broker.go @@ -18,8 +18,34 @@ type Event struct { // sseClient represents a single connected SSE consumer. type sseClient struct { - ch chan Event - closed chan struct{} + ch chan Event + closed chan struct{} + userID string + workstationID string + mu sync.RWMutex + itemFilters map[string]struct{} +} + +// WatchItem adds an item ID to this client's filter set. +func (c *sseClient) WatchItem(itemID string) { + c.mu.Lock() + c.itemFilters[itemID] = struct{}{} + c.mu.Unlock() +} + +// UnwatchItem removes an item ID from this client's filter set. +func (c *sseClient) UnwatchItem(itemID string) { + c.mu.Lock() + delete(c.itemFilters, itemID) + c.mu.Unlock() +} + +// IsWatchingItem returns whether this client is watching a specific item. +func (c *sseClient) IsWatchingItem(itemID string) bool { + c.mu.RLock() + _, ok := c.itemFilters[itemID] + c.mu.RUnlock() + return ok } const ( @@ -52,10 +78,13 @@ func NewBroker(logger zerolog.Logger) *Broker { } // Subscribe adds a new client and returns it. The caller must call Unsubscribe when done. -func (b *Broker) Subscribe() *sseClient { +func (b *Broker) Subscribe(userID, workstationID string) *sseClient { c := &sseClient{ - ch: make(chan Event, clientChanSize), - closed: make(chan struct{}), + ch: make(chan Event, clientChanSize), + closed: make(chan struct{}), + userID: userID, + workstationID: workstationID, + itemFilters: make(map[string]struct{}), } b.mu.Lock() b.clients[c] = struct{}{} @@ -106,6 +135,49 @@ func (b *Broker) Publish(eventType string, data string) { b.mu.RUnlock() } +// publishTargeted sends an event only to clients matching the predicate. +// Targeted events get an ID but are not stored in the history ring buffer. +func (b *Broker) publishTargeted(eventType, data string, match func(*sseClient) bool) { + ev := Event{ + ID: b.eventID.Add(1), + Type: eventType, + Data: data, + } + + b.mu.RLock() + for c := range b.clients { + if match(c) { + select { + case c.ch <- ev: + default: + b.logger.Warn().Uint64("event_id", ev.ID).Str("type", eventType).Msg("dropped targeted event for slow client") + } + } + } + b.mu.RUnlock() +} + +// PublishToItem sends an event only to clients watching a specific item. +func (b *Broker) PublishToItem(itemID, eventType, data string) { + b.publishTargeted(eventType, data, func(c *sseClient) bool { + return c.IsWatchingItem(itemID) + }) +} + +// PublishToWorkstation sends an event only to the specified workstation. +func (b *Broker) PublishToWorkstation(workstationID, eventType, data string) { + b.publishTargeted(eventType, data, func(c *sseClient) bool { + return c.workstationID == workstationID + }) +} + +// PublishToUser sends an event to all connections for a specific user. +func (b *Broker) PublishToUser(userID, eventType, data string) { + b.publishTargeted(eventType, data, func(c *sseClient) bool { + return c.userID == userID + }) +} + // ClientCount returns the number of connected SSE clients. func (b *Broker) ClientCount() int { b.mu.RLock() diff --git a/internal/api/broker_test.go b/internal/api/broker_test.go index 36874ce..105efa9 100644 --- a/internal/api/broker_test.go +++ b/internal/api/broker_test.go @@ -10,7 +10,7 @@ import ( func TestBrokerSubscribeUnsubscribe(t *testing.T) { b := NewBroker(zerolog.Nop()) - c := b.Subscribe() + c := b.Subscribe("", "") if b.ClientCount() != 1 { t.Fatalf("expected 1 client, got %d", b.ClientCount()) } @@ -23,7 +23,7 @@ func TestBrokerSubscribeUnsubscribe(t *testing.T) { func TestBrokerPublish(t *testing.T) { b := NewBroker(zerolog.Nop()) - c := b.Subscribe() + c := b.Subscribe("", "") defer b.Unsubscribe(c) b.Publish("item.created", `{"part_number":"F01-0001"}`) @@ -46,7 +46,7 @@ func TestBrokerPublish(t *testing.T) { func TestBrokerPublishDropsSlow(t *testing.T) { b := NewBroker(zerolog.Nop()) - c := b.Subscribe() + c := b.Subscribe("", "") defer b.Unsubscribe(c) // Fill the client's channel @@ -89,9 +89,9 @@ func TestBrokerEventsSince(t *testing.T) { func TestBrokerClientCount(t *testing.T) { b := NewBroker(zerolog.Nop()) - c1 := b.Subscribe() - c2 := b.Subscribe() - c3 := b.Subscribe() + c1 := b.Subscribe("", "") + c2 := b.Subscribe("", "") + c3 := b.Subscribe("", "") if b.ClientCount() != 3 { t.Fatalf("expected 3 clients, got %d", b.ClientCount()) @@ -111,7 +111,7 @@ func TestBrokerClientCount(t *testing.T) { func TestBrokerShutdown(t *testing.T) { b := NewBroker(zerolog.Nop()) - c := b.Subscribe() + c := b.Subscribe("", "") b.Shutdown() @@ -145,3 +145,128 @@ func TestBrokerMonotonicIDs(t *testing.T) { } } } + +func TestWatchUnwatchItem(t *testing.T) { + b := NewBroker(zerolog.Nop()) + c := b.Subscribe("user1", "ws1") + defer b.Unsubscribe(c) + + if c.IsWatchingItem("item-abc") { + t.Fatal("should not be watching item-abc before WatchItem") + } + + c.WatchItem("item-abc") + if !c.IsWatchingItem("item-abc") { + t.Fatal("should be watching item-abc after WatchItem") + } + + c.UnwatchItem("item-abc") + if c.IsWatchingItem("item-abc") { + t.Fatal("should not be watching item-abc after UnwatchItem") + } +} + +func TestPublishToItem(t *testing.T) { + b := NewBroker(zerolog.Nop()) + watcher := b.Subscribe("user1", "ws1") + defer b.Unsubscribe(watcher) + bystander := b.Subscribe("user2", "ws2") + defer b.Unsubscribe(bystander) + + watcher.WatchItem("item-abc") + b.PublishToItem("item-abc", "edit.started", `{"item_id":"item-abc"}`) + + // Watcher should receive the event. + select { + case ev := <-watcher.ch: + if ev.Type != "edit.started" { + t.Fatalf("expected edit.started, got %s", ev.Type) + } + case <-time.After(time.Second): + t.Fatal("watcher did not receive targeted event") + } + + // Bystander should not. + select { + case ev := <-bystander.ch: + t.Fatalf("bystander should not receive targeted event, got %s", ev.Type) + case <-time.After(50 * time.Millisecond): + // expected + } +} + +func TestPublishToWorkstation(t *testing.T) { + b := NewBroker(zerolog.Nop()) + target := b.Subscribe("user1", "ws-target") + defer b.Unsubscribe(target) + other := b.Subscribe("user1", "ws-other") + defer b.Unsubscribe(other) + + b.PublishToWorkstation("ws-target", "sync.update", `{"data":"x"}`) + + select { + case ev := <-target.ch: + if ev.Type != "sync.update" { + t.Fatalf("expected sync.update, got %s", ev.Type) + } + case <-time.After(time.Second): + t.Fatal("target workstation did not receive event") + } + + select { + case ev := <-other.ch: + t.Fatalf("other workstation should not receive event, got %s", ev.Type) + case <-time.After(50 * time.Millisecond): + // expected + } +} + +func TestPublishToUser(t *testing.T) { + b := NewBroker(zerolog.Nop()) + c1 := b.Subscribe("user1", "ws1") + defer b.Unsubscribe(c1) + c2 := b.Subscribe("user1", "ws2") + defer b.Unsubscribe(c2) + c3 := b.Subscribe("user2", "ws3") + defer b.Unsubscribe(c3) + + b.PublishToUser("user1", "user.notify", `{"msg":"hello"}`) + + // Both user1 connections should receive. + for _, c := range []*sseClient{c1, c2} { + select { + case ev := <-c.ch: + if ev.Type != "user.notify" { + t.Fatalf("expected user.notify, got %s", ev.Type) + } + case <-time.After(time.Second): + t.Fatal("user1 client did not receive event") + } + } + + // user2 should not. + select { + case ev := <-c3.ch: + t.Fatalf("user2 should not receive event, got %s", ev.Type) + case <-time.After(50 * time.Millisecond): + // expected + } +} + +func TestTargetedEventsNotInHistory(t *testing.T) { + b := NewBroker(zerolog.Nop()) + c := b.Subscribe("user1", "ws1") + defer b.Unsubscribe(c) + c.WatchItem("item-abc") + + b.Publish("broadcast", `{}`) + b.PublishToItem("item-abc", "targeted", `{}`) + + events := b.EventsSince(0) + if len(events) != 1 { + t.Fatalf("expected 1 event in history (broadcast only), got %d", len(events)) + } + if events[0].Type != "broadcast" { + t.Fatalf("expected broadcast event in history, got %s", events[0].Type) + } +} diff --git a/internal/api/servermode_test.go b/internal/api/servermode_test.go index 489bb39..77eaf20 100644 --- a/internal/api/servermode_test.go +++ b/internal/api/servermode_test.go @@ -76,7 +76,7 @@ func TestServerStateToggleReadOnly(t *testing.T) { func TestServerStateBroadcastsOnTransition(t *testing.T) { b := NewBroker(zerolog.Nop()) - c := b.Subscribe() + c := b.Subscribe("", "") defer b.Unsubscribe(c) ss := NewServerState(zerolog.Nop(), nil, b) diff --git a/internal/api/sse_handler.go b/internal/api/sse_handler.go index 3734e85..611f75e 100644 --- a/internal/api/sse_handler.go +++ b/internal/api/sse_handler.go @@ -5,6 +5,8 @@ import ( "net/http" "strconv" "time" + + "github.com/kindredsystems/silo/internal/auth" ) // HandleEvents serves the SSE event stream. @@ -31,9 +33,19 @@ func (s *Server) HandleEvents(w http.ResponseWriter, r *http.Request) { w.Header().Set("Connection", "keep-alive") w.Header().Set("X-Accel-Buffering", "no") // nginx: disable proxy buffering - client := s.broker.Subscribe() + userID := "" + if user := auth.UserFromContext(r.Context()); user != nil { + userID = user.ID + } + wsID := r.URL.Query().Get("workstation_id") + + client := s.broker.Subscribe(userID, wsID) defer s.broker.Unsubscribe(client) + if wsID != "" { + s.workstations.Touch(r.Context(), wsID) + } + // Replay missed events if Last-Event-ID is present. if lastIDStr := r.Header.Get("Last-Event-ID"); lastIDStr != "" { if lastID, err := strconv.ParseUint(lastIDStr, 10, 64); err == nil { -- 2.49.1