Merge pull request 'feat(sse): per-connection filtering with user and workstation context' (#171) from feat/sse-per-connection-filtering into main

Reviewed-on: #171
This commit was merged in pull request #171.
This commit is contained in:
2026-03-01 16:05:34 +00:00
4 changed files with 223 additions and 14 deletions

View File

@@ -18,8 +18,34 @@ type Event struct {
// sseClient represents a single connected SSE consumer.
type sseClient struct {
ch chan Event
closed chan struct{}
ch chan Event
closed chan struct{}
userID string
workstationID string
mu sync.RWMutex
itemFilters map[string]struct{}
}
// WatchItem adds an item ID to this client's filter set.
func (c *sseClient) WatchItem(itemID string) {
c.mu.Lock()
c.itemFilters[itemID] = struct{}{}
c.mu.Unlock()
}
// UnwatchItem removes an item ID from this client's filter set.
func (c *sseClient) UnwatchItem(itemID string) {
c.mu.Lock()
delete(c.itemFilters, itemID)
c.mu.Unlock()
}
// IsWatchingItem returns whether this client is watching a specific item.
func (c *sseClient) IsWatchingItem(itemID string) bool {
c.mu.RLock()
_, ok := c.itemFilters[itemID]
c.mu.RUnlock()
return ok
}
const (
@@ -52,10 +78,13 @@ func NewBroker(logger zerolog.Logger) *Broker {
}
// Subscribe adds a new client and returns it. The caller must call Unsubscribe when done.
func (b *Broker) Subscribe() *sseClient {
func (b *Broker) Subscribe(userID, workstationID string) *sseClient {
c := &sseClient{
ch: make(chan Event, clientChanSize),
closed: make(chan struct{}),
ch: make(chan Event, clientChanSize),
closed: make(chan struct{}),
userID: userID,
workstationID: workstationID,
itemFilters: make(map[string]struct{}),
}
b.mu.Lock()
b.clients[c] = struct{}{}
@@ -106,6 +135,49 @@ func (b *Broker) Publish(eventType string, data string) {
b.mu.RUnlock()
}
// publishTargeted sends an event only to clients matching the predicate.
// Targeted events get an ID but are not stored in the history ring buffer.
func (b *Broker) publishTargeted(eventType, data string, match func(*sseClient) bool) {
ev := Event{
ID: b.eventID.Add(1),
Type: eventType,
Data: data,
}
b.mu.RLock()
for c := range b.clients {
if match(c) {
select {
case c.ch <- ev:
default:
b.logger.Warn().Uint64("event_id", ev.ID).Str("type", eventType).Msg("dropped targeted event for slow client")
}
}
}
b.mu.RUnlock()
}
// PublishToItem sends an event only to clients watching a specific item.
func (b *Broker) PublishToItem(itemID, eventType, data string) {
b.publishTargeted(eventType, data, func(c *sseClient) bool {
return c.IsWatchingItem(itemID)
})
}
// PublishToWorkstation sends an event only to the specified workstation.
func (b *Broker) PublishToWorkstation(workstationID, eventType, data string) {
b.publishTargeted(eventType, data, func(c *sseClient) bool {
return c.workstationID == workstationID
})
}
// PublishToUser sends an event to all connections for a specific user.
func (b *Broker) PublishToUser(userID, eventType, data string) {
b.publishTargeted(eventType, data, func(c *sseClient) bool {
return c.userID == userID
})
}
// ClientCount returns the number of connected SSE clients.
func (b *Broker) ClientCount() int {
b.mu.RLock()

View File

@@ -10,7 +10,7 @@ import (
func TestBrokerSubscribeUnsubscribe(t *testing.T) {
b := NewBroker(zerolog.Nop())
c := b.Subscribe()
c := b.Subscribe("", "")
if b.ClientCount() != 1 {
t.Fatalf("expected 1 client, got %d", b.ClientCount())
}
@@ -23,7 +23,7 @@ func TestBrokerSubscribeUnsubscribe(t *testing.T) {
func TestBrokerPublish(t *testing.T) {
b := NewBroker(zerolog.Nop())
c := b.Subscribe()
c := b.Subscribe("", "")
defer b.Unsubscribe(c)
b.Publish("item.created", `{"part_number":"F01-0001"}`)
@@ -46,7 +46,7 @@ func TestBrokerPublish(t *testing.T) {
func TestBrokerPublishDropsSlow(t *testing.T) {
b := NewBroker(zerolog.Nop())
c := b.Subscribe()
c := b.Subscribe("", "")
defer b.Unsubscribe(c)
// Fill the client's channel
@@ -89,9 +89,9 @@ func TestBrokerEventsSince(t *testing.T) {
func TestBrokerClientCount(t *testing.T) {
b := NewBroker(zerolog.Nop())
c1 := b.Subscribe()
c2 := b.Subscribe()
c3 := b.Subscribe()
c1 := b.Subscribe("", "")
c2 := b.Subscribe("", "")
c3 := b.Subscribe("", "")
if b.ClientCount() != 3 {
t.Fatalf("expected 3 clients, got %d", b.ClientCount())
@@ -111,7 +111,7 @@ func TestBrokerClientCount(t *testing.T) {
func TestBrokerShutdown(t *testing.T) {
b := NewBroker(zerolog.Nop())
c := b.Subscribe()
c := b.Subscribe("", "")
b.Shutdown()
@@ -145,3 +145,128 @@ func TestBrokerMonotonicIDs(t *testing.T) {
}
}
}
func TestWatchUnwatchItem(t *testing.T) {
b := NewBroker(zerolog.Nop())
c := b.Subscribe("user1", "ws1")
defer b.Unsubscribe(c)
if c.IsWatchingItem("item-abc") {
t.Fatal("should not be watching item-abc before WatchItem")
}
c.WatchItem("item-abc")
if !c.IsWatchingItem("item-abc") {
t.Fatal("should be watching item-abc after WatchItem")
}
c.UnwatchItem("item-abc")
if c.IsWatchingItem("item-abc") {
t.Fatal("should not be watching item-abc after UnwatchItem")
}
}
func TestPublishToItem(t *testing.T) {
b := NewBroker(zerolog.Nop())
watcher := b.Subscribe("user1", "ws1")
defer b.Unsubscribe(watcher)
bystander := b.Subscribe("user2", "ws2")
defer b.Unsubscribe(bystander)
watcher.WatchItem("item-abc")
b.PublishToItem("item-abc", "edit.started", `{"item_id":"item-abc"}`)
// Watcher should receive the event.
select {
case ev := <-watcher.ch:
if ev.Type != "edit.started" {
t.Fatalf("expected edit.started, got %s", ev.Type)
}
case <-time.After(time.Second):
t.Fatal("watcher did not receive targeted event")
}
// Bystander should not.
select {
case ev := <-bystander.ch:
t.Fatalf("bystander should not receive targeted event, got %s", ev.Type)
case <-time.After(50 * time.Millisecond):
// expected
}
}
func TestPublishToWorkstation(t *testing.T) {
b := NewBroker(zerolog.Nop())
target := b.Subscribe("user1", "ws-target")
defer b.Unsubscribe(target)
other := b.Subscribe("user1", "ws-other")
defer b.Unsubscribe(other)
b.PublishToWorkstation("ws-target", "sync.update", `{"data":"x"}`)
select {
case ev := <-target.ch:
if ev.Type != "sync.update" {
t.Fatalf("expected sync.update, got %s", ev.Type)
}
case <-time.After(time.Second):
t.Fatal("target workstation did not receive event")
}
select {
case ev := <-other.ch:
t.Fatalf("other workstation should not receive event, got %s", ev.Type)
case <-time.After(50 * time.Millisecond):
// expected
}
}
func TestPublishToUser(t *testing.T) {
b := NewBroker(zerolog.Nop())
c1 := b.Subscribe("user1", "ws1")
defer b.Unsubscribe(c1)
c2 := b.Subscribe("user1", "ws2")
defer b.Unsubscribe(c2)
c3 := b.Subscribe("user2", "ws3")
defer b.Unsubscribe(c3)
b.PublishToUser("user1", "user.notify", `{"msg":"hello"}`)
// Both user1 connections should receive.
for _, c := range []*sseClient{c1, c2} {
select {
case ev := <-c.ch:
if ev.Type != "user.notify" {
t.Fatalf("expected user.notify, got %s", ev.Type)
}
case <-time.After(time.Second):
t.Fatal("user1 client did not receive event")
}
}
// user2 should not.
select {
case ev := <-c3.ch:
t.Fatalf("user2 should not receive event, got %s", ev.Type)
case <-time.After(50 * time.Millisecond):
// expected
}
}
func TestTargetedEventsNotInHistory(t *testing.T) {
b := NewBroker(zerolog.Nop())
c := b.Subscribe("user1", "ws1")
defer b.Unsubscribe(c)
c.WatchItem("item-abc")
b.Publish("broadcast", `{}`)
b.PublishToItem("item-abc", "targeted", `{}`)
events := b.EventsSince(0)
if len(events) != 1 {
t.Fatalf("expected 1 event in history (broadcast only), got %d", len(events))
}
if events[0].Type != "broadcast" {
t.Fatalf("expected broadcast event in history, got %s", events[0].Type)
}
}

View File

@@ -76,7 +76,7 @@ func TestServerStateToggleReadOnly(t *testing.T) {
func TestServerStateBroadcastsOnTransition(t *testing.T) {
b := NewBroker(zerolog.Nop())
c := b.Subscribe()
c := b.Subscribe("", "")
defer b.Unsubscribe(c)
ss := NewServerState(zerolog.Nop(), nil, b)

View File

@@ -5,6 +5,8 @@ import (
"net/http"
"strconv"
"time"
"github.com/kindredsystems/silo/internal/auth"
)
// HandleEvents serves the SSE event stream.
@@ -31,9 +33,19 @@ func (s *Server) HandleEvents(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Connection", "keep-alive")
w.Header().Set("X-Accel-Buffering", "no") // nginx: disable proxy buffering
client := s.broker.Subscribe()
userID := ""
if user := auth.UserFromContext(r.Context()); user != nil {
userID = user.ID
}
wsID := r.URL.Query().Get("workstation_id")
client := s.broker.Subscribe(userID, wsID)
defer s.broker.Unsubscribe(client)
if wsID != "" {
s.workstations.Touch(r.Context(), wsID)
}
// Replay missed events if Last-Event-ID is present.
if lastIDStr := r.Header.Get("Last-Event-ID"); lastIDStr != "" {
if lastID, err := strconv.ParseUint(lastIDStr, 10, 64); err == nil {