listener.go raw
1 package app
2
3 import (
4 "bytes"
5 "context"
6 "net/http"
7 "strings"
8 "sync"
9 "sync/atomic"
10 "time"
11
12 "github.com/gorilla/websocket"
13 "next.orly.dev/pkg/lol/errorf"
14 "next.orly.dev/pkg/lol/log"
15 "next.orly.dev/app/config"
16 "next.orly.dev/pkg/acl"
17 "next.orly.dev/pkg/database"
18 "next.orly.dev/pkg/nostr/encoders/event"
19 "next.orly.dev/pkg/nostr/encoders/filter"
20 "next.orly.dev/pkg/protocol/publish"
21 "next.orly.dev/pkg/utils"
22 atomicutils "next.orly.dev/pkg/utils/atomic"
23 )
24
25 type Listener struct {
26 // Server is the embedded server reference.
27 // Deprecated: Prefer using accessor methods (ServerConfig, ServerDatabase, etc.)
28 // instead of accessing Server fields directly.
29 *Server
30 conn *websocket.Conn
31 ctx context.Context
32 cancel context.CancelFunc // Cancel function for this listener's context
33 remote string
34 connectionID string // Unique identifier for this connection (for access tracking)
35 req *http.Request
36 challenge atomicutils.Bytes
37 authedPubkey atomicutils.Bytes
38 startTime time.Time
39 isBlacklisted bool // Marker to identify blacklisted IPs
40 blacklistTimeout time.Time // When to timeout blacklisted connections
41 writeChan chan publish.WriteRequest // Channel for write requests (back to queued approach)
42 writeDone chan struct{} // Closed when write worker exits
43 // Message processing queue for async handling
44 messageQueue chan messageRequest // Buffered channel for message processing
45 processingDone chan struct{} // Closed when message processor exits
46 handlerWg sync.WaitGroup // Tracks spawned message handler goroutines
47 handlerSem chan struct{} // Limits concurrent message handlers per connection
48 authProcessing sync.RWMutex // Ensures AUTH completes before other messages check authentication
49 // Flow control counters (atomic for concurrent access)
50 droppedMessages atomic.Int64 // Messages dropped due to full queue
51 queryCostAccumulator atomic.Int64 // Accumulated query cost for this connection (units: multiplier * 100)
52 // Diagnostics: per-connection counters
53 msgCount int
54 reqCount int
55 eventCount int
56 // Subscription tracking for cleanup
57 subscriptions map[string]context.CancelFunc // Map of subscription ID to cancel function
58 subscriptionsMu sync.Mutex // Protects subscriptions map
59 }
60
61 type messageRequest struct {
62 data []byte
63 remote string
64 }
65
66 // Ctx returns the listener's context, but creates a new context for each operation
67 // to prevent cancellation from affecting subsequent operations
68 func (l *Listener) Ctx() context.Context {
69 return l.ctx
70 }
71
72 // ServerContext returns the server's context (distinct from the listener's own context).
73 func (l *Listener) ServerContext() context.Context {
74 return l.Server.Context()
75 }
76
77 // ServerConfig returns the server's configuration.
78 func (l *Listener) ServerConfig() *config.C {
79 return l.Server.GetConfig()
80 }
81
82 // ServerDatabase returns the server's database instance.
83 func (l *Listener) ServerDatabase() database.Database {
84 return l.Server.Database()
85 }
86
87 // DroppedMessages returns the total number of messages that were dropped
88 // because the message processing queue was full.
89 func (l *Listener) DroppedMessages() int {
90 return int(l.droppedMessages.Load())
91 }
92
93 // RemainingCapacity returns the number of slots available in the message processing queue.
94 func (l *Listener) RemainingCapacity() int {
95 return cap(l.messageQueue) - len(l.messageQueue)
96 }
97
98 // QueueMessage queues a message for asynchronous processing.
99 // Returns true if the message was queued, false if the queue was full.
100 func (l *Listener) QueueMessage(data []byte, remote string) bool {
101 req := messageRequest{data: data, remote: remote}
102 select {
103 case l.messageQueue <- req:
104 return true
105 default:
106 l.droppedMessages.Add(1)
107 return false
108 }
109 }
110
111
112 func (l *Listener) Write(p []byte) (n int, err error) {
113 // Defensive: recover from any panic when sending to closed channel
114 defer func() {
115 if r := recover(); r != nil {
116 log.D.F("ws->%s write panic recovered (channel likely closed): %v", l.remote, r)
117 err = errorf.E("write channel closed")
118 n = 0
119 }
120 }()
121
122 // Send write request to channel - non-blocking with timeout
123 select {
124 case <-l.ctx.Done():
125 return 0, l.ctx.Err()
126 case l.writeChan <- publish.WriteRequest{Data: p, MsgType: websocket.TextMessage, IsControl: false}:
127 return len(p), nil
128 case <-time.After(DefaultWriteTimeout):
129 log.E.F("ws->%s write channel timeout", l.remote)
130 return 0, errorf.E("write channel timeout")
131 }
132 }
133
134 // SendEvent sends an event to the client. Implements archive.EventDeliveryChannel.
135 func (l *Listener) SendEvent(ev *event.E) error {
136 if ev == nil {
137 return nil
138 }
139 // Serialize the event as an EVENT envelope
140 data := ev.Serialize()
141 // Use Write to send
142 _, err := l.Write(data)
143 return err
144 }
145
146 // IsConnected returns whether the client connection is still active.
147 // Implements archive.EventDeliveryChannel.
148 func (l *Listener) IsConnected() bool {
149 select {
150 case <-l.ctx.Done():
151 return false
152 default:
153 return l.conn != nil
154 }
155 }
156
157 // WriteControl sends a control message through the write channel
158 func (l *Listener) WriteControl(messageType int, data []byte, deadline time.Time) (err error) {
159 // Defensive: recover from any panic when sending to closed channel
160 defer func() {
161 if r := recover(); r != nil {
162 log.D.F("ws->%s writeControl panic recovered (channel likely closed): %v", l.remote, r)
163 err = errorf.E("write channel closed")
164 }
165 }()
166
167 select {
168 case <-l.ctx.Done():
169 return l.ctx.Err()
170 case l.writeChan <- publish.WriteRequest{Data: data, MsgType: messageType, IsControl: true, Deadline: deadline}:
171 return nil
172 case <-time.After(DefaultWriteTimeout):
173 log.E.F("ws->%s writeControl channel timeout", l.remote)
174 return errorf.E("writeControl channel timeout")
175 }
176 }
177
178 // writeWorker is the single goroutine that handles all writes to the websocket connection.
179 // This serializes all writes to prevent concurrent write panics and allows pings to interrupt writes.
180 func (l *Listener) writeWorker() {
181 defer func() {
182 // Only unregister write channel if connection is actually dead/closing
183 // Unregister if:
184 // 1. Context is cancelled (connection closing)
185 // 2. Channel was closed (connection closing)
186 // 3. Connection error occurred (already handled inline)
187 if l.ctx.Err() != nil {
188 // Connection is closing - safe to unregister
189 if socketPub := l.publishers.GetSocketPublisher(); socketPub != nil {
190 log.D.F("ws->%s write worker: unregistering write channel (connection closing)", l.remote)
191 socketPub.SetWriteChan(l.conn, nil)
192 }
193 } else {
194 // Exiting for other reasons (timeout, etc.) but connection may still be valid
195 log.D.F("ws->%s write worker exiting unexpectedly", l.remote)
196 }
197 close(l.writeDone)
198 }()
199
200 for {
201 select {
202 case <-l.ctx.Done():
203 log.D.F("ws->%s write worker context cancelled", l.remote)
204 return
205 case req, ok := <-l.writeChan:
206 if !ok {
207 log.D.F("ws->%s write channel closed", l.remote)
208 return
209 }
210
211 // Skip writes if no connection (unit tests)
212 if l.conn == nil {
213 log.T.F("ws->%s skipping write (no connection)", l.remote)
214 continue
215 }
216
217 // Handle the write request
218 var err error
219 if req.IsPing {
220 // Special handling for ping messages
221 log.D.F("sending PING #%d", req.MsgType)
222 deadline := time.Now().Add(DefaultWriteTimeout)
223 err = l.conn.WriteControl(websocket.PingMessage, nil, deadline)
224 if err != nil {
225 if !strings.HasSuffix(err.Error(), "use of closed network connection") {
226 log.E.F("error writing ping: %v; closing websocket", err)
227 }
228 return
229 }
230 } else if req.IsControl {
231 // Control message
232 err = l.conn.WriteControl(req.MsgType, req.Data, req.Deadline)
233 if err != nil {
234 log.E.F("ws->%s control write failed: %v", l.remote, err)
235 return
236 }
237 } else {
238 // Regular message
239 l.conn.SetWriteDeadline(time.Now().Add(DefaultWriteTimeout))
240 err = l.conn.WriteMessage(req.MsgType, req.Data)
241 if err != nil {
242 log.E.F("ws->%s write failed: %v", l.remote, err)
243 return
244 }
245 }
246 }
247 }
248 }
249
250 // messageProcessor is the goroutine that processes messages asynchronously.
251 // This prevents the websocket read loop from blocking on message processing.
252 func (l *Listener) messageProcessor() {
253 defer func() {
254 close(l.processingDone)
255 }()
256
257 for {
258 select {
259 case <-l.ctx.Done():
260 log.D.F("ws->%s message processor context cancelled", l.remote)
261 return
262 case req, ok := <-l.messageQueue:
263 if !ok {
264 log.D.F("ws->%s message queue closed", l.remote)
265 return
266 }
267
268 // Lock immediately to ensure AUTH is processed before subsequent messages
269 // are dequeued. This prevents race conditions where EVENT checks authentication
270 // before AUTH completes.
271 l.authProcessing.Lock()
272
273 // Check if this is an AUTH message by looking for the ["AUTH" prefix
274 isAuthMessage := len(req.data) > 7 && bytes.HasPrefix(req.data, []byte(`["AUTH"`))
275
276 if isAuthMessage {
277 // Process AUTH message synchronously while holding lock
278 // This blocks the messageProcessor from dequeuing the next message
279 // until authentication is complete and authedPubkey is set
280 log.D.F("ws->%s processing AUTH synchronously with lock", req.remote)
281 l.HandleMessage(req.data, req.remote)
282 // Unlock after AUTH completes so subsequent messages see updated authedPubkey
283 l.authProcessing.Unlock()
284 } else {
285 // Not AUTH - unlock immediately and process concurrently
286 // The next message can now be dequeued (possibly another non-AUTH to process concurrently)
287 l.authProcessing.Unlock()
288
289 // Acquire semaphore to limit concurrent handlers (blocking with context awareness)
290 select {
291 case l.handlerSem <- struct{}{}:
292 // Semaphore acquired
293 case <-l.ctx.Done():
294 return
295 }
296 l.handlerWg.Add(1)
297 go func(data []byte, remote string) {
298 defer func() {
299 <-l.handlerSem // Release semaphore
300 l.handlerWg.Done()
301 }()
302 l.HandleMessage(data, remote)
303 }(req.data, req.remote)
304 }
305 }
306 }
307 }
308
309 // getManagedACL returns the managed ACL instance if available
310 func (l *Listener) getManagedACL() *database.ManagedACL {
311 // Get the managed ACL instance from the ACL registry
312 for _, aclInstance := range acl.Registry.ACLs() {
313 if aclInstance.Type() == "managed" {
314 if managed, ok := aclInstance.(*acl.Managed); ok {
315 return managed.GetManagedACL()
316 }
317 }
318 }
319 return nil
320 }
321
322 // getFollowsThrottleDelay returns the progressive throttle delay for follows or social ACL mode.
323 // Returns 0 if not in a throttle-enabled mode, throttle is disabled, or user is exempt.
324 func (l *Listener) getFollowsThrottleDelay(ev *event.E) time.Duration {
325 mode := acl.Registry.GetMode()
326 switch mode {
327 case "follows":
328 for _, aclInstance := range acl.Registry.ACLs() {
329 if follows, ok := aclInstance.(*acl.Follows); ok {
330 return follows.GetThrottleDelay(ev.Pubkey, l.remote)
331 }
332 }
333 case "social":
334 for _, aclInstance := range acl.Registry.ACLs() {
335 if social, ok := aclInstance.(*acl.Social); ok {
336 return social.GetThrottleDelay(ev.Pubkey, l.remote)
337 }
338 }
339 }
340 return 0
341 }
342
343 // QueryEvents queries events using the database QueryEvents method
344 func (l *Listener) QueryEvents(ctx context.Context, f *filter.F) (event.S, error) {
345 return l.DB.QueryEvents(ctx, f)
346 }
347
348 // QueryAllVersions queries events using the database QueryAllVersions method
349 func (l *Listener) QueryAllVersions(ctx context.Context, f *filter.F) (event.S, error) {
350 return l.DB.QueryAllVersions(ctx, f)
351 }
352
353 // canSeePrivateEvent checks if the authenticated user can see an event with a private tag
354 func (l *Listener) canSeePrivateEvent(authedPubkey, privatePubkey []byte) (canSee bool) {
355 // If no authenticated user, deny access
356 if len(authedPubkey) == 0 {
357 return false
358 }
359
360 // If the authenticated user matches the private tag pubkey, allow access
361 if len(privatePubkey) > 0 && utils.FastEqual(authedPubkey, privatePubkey) {
362 return true
363 }
364
365 // Check if user is an admin or owner (they can see all private events)
366 accessLevel := acl.Registry.GetAccessLevel(authedPubkey, l.remote)
367 if accessLevel == "admin" || accessLevel == "owner" {
368 return true
369 }
370
371 // Default deny
372 return false
373 }
374