Files
silo/internal/api/broker.go
Forbes e7da3ee94d feat(sse): per-connection filtering with user and workstation context
- Extend sseClient with userID, workstationID, and item filter set
- Update Subscribe() to accept userID and workstationID params
- Add WatchItem/UnwatchItem/IsWatchingItem methods on sseClient
- Add PublishToItem, PublishToWorkstation, PublishToUser targeted delivery
- Targeted events get IDs but skip history ring buffer (real-time only)
- Update HandleEvents to pass auth user ID and workstation_id query param
- Touch workstation last_seen on SSE connect
- Existing Publish() broadcast unchanged; all current callers unaffected
- Add 5 new tests for targeted delivery and item watch lifecycle

Closes #162
2026-03-01 10:04:01 -06:00

241 lines
5.7 KiB
Go

package api
import (
"encoding/json"
"sync"
"sync/atomic"
"time"
"github.com/rs/zerolog"
)
// Event represents a server-sent event.
type Event struct {
ID uint64
Type string
Data string
}
// sseClient represents a single connected SSE consumer.
type sseClient struct {
ch chan Event
closed chan struct{}
userID string
workstationID string
mu sync.RWMutex
itemFilters map[string]struct{}
}
// WatchItem adds an item ID to this client's filter set.
func (c *sseClient) WatchItem(itemID string) {
c.mu.Lock()
c.itemFilters[itemID] = struct{}{}
c.mu.Unlock()
}
// UnwatchItem removes an item ID from this client's filter set.
func (c *sseClient) UnwatchItem(itemID string) {
c.mu.Lock()
delete(c.itemFilters, itemID)
c.mu.Unlock()
}
// IsWatchingItem returns whether this client is watching a specific item.
func (c *sseClient) IsWatchingItem(itemID string) bool {
c.mu.RLock()
_, ok := c.itemFilters[itemID]
c.mu.RUnlock()
return ok
}
const (
clientChanSize = 64
historySize = 256
heartbeatInterval = 30 * time.Second
)
// Broker manages SSE client connections and event fan-out.
type Broker struct {
logger zerolog.Logger
mu sync.RWMutex
clients map[*sseClient]struct{}
eventID atomic.Uint64
historyMu sync.RWMutex
history []Event
done chan struct{}
}
// NewBroker creates a new SSE broker.
func NewBroker(logger zerolog.Logger) *Broker {
return &Broker{
logger: logger.With().Str("component", "sse-broker").Logger(),
clients: make(map[*sseClient]struct{}),
history: make([]Event, 0, historySize),
done: make(chan struct{}),
}
}
// Subscribe adds a new client and returns it. The caller must call Unsubscribe when done.
func (b *Broker) Subscribe(userID, workstationID string) *sseClient {
c := &sseClient{
ch: make(chan Event, clientChanSize),
closed: make(chan struct{}),
userID: userID,
workstationID: workstationID,
itemFilters: make(map[string]struct{}),
}
b.mu.Lock()
b.clients[c] = struct{}{}
count := len(b.clients)
b.mu.Unlock()
b.logger.Info().Int("clients", count).Msg("client connected")
return c
}
// Unsubscribe removes a client and closes its channel.
func (b *Broker) Unsubscribe(c *sseClient) {
b.mu.Lock()
if _, ok := b.clients[c]; ok {
delete(b.clients, c)
close(c.closed)
}
count := len(b.clients)
b.mu.Unlock()
b.logger.Info().Int("clients", count).Msg("client disconnected")
}
// Publish sends an event to all connected clients. Non-blocking per client:
// if a client's channel is full, the event is dropped for that client.
func (b *Broker) Publish(eventType string, data string) {
ev := Event{
ID: b.eventID.Add(1),
Type: eventType,
Data: data,
}
// Append to history ring buffer.
b.historyMu.Lock()
if len(b.history) >= historySize {
b.history = b.history[1:]
}
b.history = append(b.history, ev)
b.historyMu.Unlock()
// Fan out to all clients.
b.mu.RLock()
for c := range b.clients {
select {
case c.ch <- ev:
default:
b.logger.Warn().Uint64("event_id", ev.ID).Str("type", eventType).Msg("dropped event for slow client")
}
}
b.mu.RUnlock()
}
// publishTargeted sends an event only to clients matching the predicate.
// Targeted events get an ID but are not stored in the history ring buffer.
func (b *Broker) publishTargeted(eventType, data string, match func(*sseClient) bool) {
ev := Event{
ID: b.eventID.Add(1),
Type: eventType,
Data: data,
}
b.mu.RLock()
for c := range b.clients {
if match(c) {
select {
case c.ch <- ev:
default:
b.logger.Warn().Uint64("event_id", ev.ID).Str("type", eventType).Msg("dropped targeted event for slow client")
}
}
}
b.mu.RUnlock()
}
// PublishToItem sends an event only to clients watching a specific item.
func (b *Broker) PublishToItem(itemID, eventType, data string) {
b.publishTargeted(eventType, data, func(c *sseClient) bool {
return c.IsWatchingItem(itemID)
})
}
// PublishToWorkstation sends an event only to the specified workstation.
func (b *Broker) PublishToWorkstation(workstationID, eventType, data string) {
b.publishTargeted(eventType, data, func(c *sseClient) bool {
return c.workstationID == workstationID
})
}
// PublishToUser sends an event to all connections for a specific user.
func (b *Broker) PublishToUser(userID, eventType, data string) {
b.publishTargeted(eventType, data, func(c *sseClient) bool {
return c.userID == userID
})
}
// ClientCount returns the number of connected SSE clients.
func (b *Broker) ClientCount() int {
b.mu.RLock()
defer b.mu.RUnlock()
return len(b.clients)
}
// EventsSince returns events with IDs greater than lastID, for Last-Event-ID replay.
func (b *Broker) EventsSince(lastID uint64) []Event {
b.historyMu.RLock()
defer b.historyMu.RUnlock()
var result []Event
for _, ev := range b.history {
if ev.ID > lastID {
result = append(result, ev)
}
}
return result
}
// StartHeartbeat launches a goroutine that publishes a heartbeat every 30s.
func (b *Broker) StartHeartbeat() {
go func() {
ticker := time.NewTicker(heartbeatInterval)
defer ticker.Stop()
for {
select {
case <-ticker.C:
b.Publish("heartbeat", "{}")
case <-b.done:
return
}
}
}()
}
// Shutdown closes all client connections and stops the heartbeat.
func (b *Broker) Shutdown() {
close(b.done)
b.mu.Lock()
for c := range b.clients {
delete(b.clients, c)
close(c.closed)
}
b.mu.Unlock()
b.logger.Info().Msg("broker shut down")
}
// mustMarshal serializes v to JSON. Panics on error (should only be used with
// known-good types like structs and maps).
func mustMarshal(v any) string {
data, err := json.Marshal(v)
if err != nil {
panic("api: failed to marshal SSE event data: " + err.Error())
}
return string(data)
}