handle-websocket.go raw
1 package app
2
3 import (
4 "context"
5 "crypto/rand"
6 "fmt"
7 "net/http"
8 "strings"
9 "time"
10
11 "github.com/gorilla/websocket"
12 "next.orly.dev/pkg/lol/chk"
13 "next.orly.dev/pkg/lol/log"
14 "next.orly.dev/pkg/nostr/encoders/envelopes/authenvelope"
15 "next.orly.dev/pkg/nostr/encoders/hex"
16 "next.orly.dev/pkg/protocol/publish"
17 "next.orly.dev/pkg/nostr/utils/units"
18 )
19
20 const (
21 DefaultWriteWait = 10 * time.Second
22 DefaultPongWait = 60 * time.Second
23 DefaultPingWait = DefaultPongWait / 2
24 DefaultWriteTimeout = 3 * time.Second
25 // DefaultMaxMessageSize is the maximum message size for WebSocket connections
26 // Increased from 512KB to 10MB to support large kind 3 follow lists (10k+ follows)
27 // and other large events without truncation
28 DefaultMaxMessageSize = 10 * 1024 * 1024 // 10MB
29 // ClientMessageSizeLimit is the maximum message size that clients can handle
30 // This is set to 100MB to allow large messages
31 ClientMessageSizeLimit = 100 * 1024 * 1024 // 100MB
32 )
33
34 var upgrader = websocket.Upgrader{
35 ReadBufferSize: 1024,
36 WriteBufferSize: 1024,
37 CheckOrigin: func(r *http.Request) bool {
38 return true // Allow all origins for proxy compatibility
39 },
40 }
41
42 func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
43 remote := GetRemoteFromReq(r)
44
45 // Log comprehensive proxy information for debugging
46 LogProxyInfo(r, "WebSocket connection from "+remote)
47 if len(s.Config.IPWhitelist) > 0 {
48 for _, ip := range s.Config.IPWhitelist {
49 log.T.F("checking IP whitelist: %s", ip)
50 if strings.HasPrefix(remote, ip) {
51 log.T.F("IP whitelisted %s", remote)
52 goto whitelist
53 }
54 }
55 log.T.F("IP not whitelisted: %s", remote)
56 return
57 }
58 whitelist:
59 // Extract IP from remote (strip port)
60 ip := remote
61 if idx := strings.LastIndex(remote, ":"); idx != -1 {
62 ip = remote[:idx]
63 }
64
65 // Check per-IP connection limit (hard limit 10)
66 maxConnPerIP := s.Config.MaxConnectionsPerIP
67 if maxConnPerIP <= 0 {
68 maxConnPerIP = 10
69 }
70 if maxConnPerIP > 10 {
71 maxConnPerIP = 10 // Hard limit
72 }
73
74 s.connPerIPMu.Lock()
75 currentConns := s.connPerIP[ip]
76 if currentConns >= maxConnPerIP {
77 s.connPerIPMu.Unlock()
78 log.W.F("connection limit exceeded for IP %s: %d/%d connections", ip, currentConns, maxConnPerIP)
79 http.Error(w, "too many connections from your IP", http.StatusTooManyRequests)
80 return
81 }
82 s.connPerIP[ip]++
83 s.connPerIPMu.Unlock()
84
85 // Track global connection count
86 s.activeConnCount.Add(1)
87
88 // Decrement connection counts when this function returns
89 defer func() {
90 s.activeConnCount.Add(-1)
91 s.connPerIPMu.Lock()
92 s.connPerIP[ip]--
93 if s.connPerIP[ip] <= 0 {
94 delete(s.connPerIP, ip)
95 }
96 s.connPerIPMu.Unlock()
97 }()
98
99 // Localhost connections are exempt from rate limiting — split-IPC internal
100 // connections must never be refused, even during emergency mode.
101 isLocalhost := ip == "127.0.0.1" || ip == "::1" || ip == "localhost"
102
103 // Global adaptive load check — refuse or delay connections under load
104 if !isLocalhost && s.rateLimiter != nil && s.rateLimiter.IsEnabled() {
105 s.rateLimiter.SetActiveConnections(s.activeConnCount.Load())
106
107 if !s.rateLimiter.ShouldAcceptConnection() {
108 log.W.F("refusing connection from %s: system overloaded", ip)
109 http.Error(w, "server overloaded, try later", http.StatusServiceUnavailable)
110 return
111 }
112
113 if delay := s.rateLimiter.ConnectionDelay(); delay > 0 {
114 log.D.F("delaying connection from %s by %v (load mitigation)", ip, delay)
115 time.Sleep(delay)
116 }
117 }
118
119 // Progressive per-IP delay: each additional connection from the same IP adds delay
120 if currentConns > 0 {
121 perIPDelay := time.Duration(currentConns) * 200 * time.Millisecond
122 if perIPDelay > 2*time.Second {
123 perIPDelay = 2 * time.Second
124 }
125 log.D.F("per-IP delay for %s: %v (%d connections)", ip, perIPDelay, currentConns+1)
126 time.Sleep(perIPDelay)
127 }
128
129 // Create an independent context for this connection
130 // This context will be cancelled when the connection closes or server shuts down
131 ctx, cancel := context.WithCancel(s.Ctx)
132 defer cancel()
133 var err error
134 var conn *websocket.Conn
135
136 // Create a per-connection upgrader to avoid racing on global state.
137 // Use 64KB buffers instead of max message size (10MB) to limit memory.
138 connUpgrader := websocket.Upgrader{
139 ReadBufferSize: 64 * 1024,
140 WriteBufferSize: 64 * 1024,
141 CheckOrigin: func(r *http.Request) bool {
142 return true
143 },
144 }
145
146 if conn, err = connUpgrader.Upgrade(w, r, nil); chk.E(err) {
147 log.E.F("websocket accept failed from %s: %v", remote, err)
148 return
149 }
150 log.T.F("websocket accepted from %s path=%s", remote, r.URL.String())
151
152 // Set read limit immediately after connection is established
153 conn.SetReadLimit(DefaultMaxMessageSize)
154 log.D.F("set read limit to %d bytes (%d MB) for %s", DefaultMaxMessageSize, DefaultMaxMessageSize/units.Mb, remote)
155
156 // Set initial read deadline - pong handler will extend it when pongs are received
157 conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
158
159 // Add pong handler to extend read deadline when client responds to pings
160 conn.SetPongHandler(func(string) error {
161 log.T.F("received PONG from %s, extending read deadline", remote)
162 return conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
163 })
164
165 defer conn.Close()
166 // Determine handler semaphore size from config
167 handlerSemSize := s.Config.MaxHandlersPerConnection
168 if handlerSemSize <= 0 {
169 handlerSemSize = 100 // Default if not configured
170 }
171
172 now := time.Now()
173 listener := &Listener{
174 ctx: ctx,
175 cancel: cancel,
176 Server: s,
177 conn: conn,
178 remote: remote,
179 connectionID: fmt.Sprintf("%s-%d", remote, now.UnixNano()), // Unique connection ID for access tracking
180 req: r,
181 startTime: now,
182 writeChan: make(chan publish.WriteRequest, 100), // Buffered channel for writes
183 writeDone: make(chan struct{}),
184 messageQueue: make(chan messageRequest, 100), // Buffered channel for message processing
185 processingDone: make(chan struct{}),
186 handlerSem: make(chan struct{}, handlerSemSize), // Limits concurrent handlers
187 subscriptions: make(map[string]context.CancelFunc),
188 }
189
190 // Start write worker goroutine
191 go listener.writeWorker()
192
193 // Start message processor goroutine
194 go listener.messageProcessor()
195
196 // Register write channel with publisher
197 if socketPub := listener.publishers.GetSocketPublisher(); socketPub != nil {
198 socketPub.SetWriteChan(conn, listener.writeChan)
199 }
200
201 // Check for blacklisted IPs
202 listener.isBlacklisted = s.isIPBlacklisted(remote)
203 if listener.isBlacklisted {
204 log.W.F("detected blacklisted IP %s, marking connection for timeout", remote)
205 listener.blacklistTimeout = time.Now().Add(time.Minute) // Timeout after 1 minute
206 }
207 chal := make([]byte, 32)
208 if _, err = rand.Read(chal); err != nil {
209 log.E.F("failed to generate auth challenge: %v", err)
210 return
211 }
212 listener.challenge.Store([]byte(hex.Enc(chal)))
213 // Always send AUTH challenge - channel kinds (40-44) require authentication
214 // regardless of ACL mode, and NIP-42 AUTH is harmless for clients that don't need it
215 {
216 log.D.F("sending AUTH challenge to %s", remote)
217 if err = authenvelope.NewChallengeWith(listener.challenge.Load()).
218 Write(listener); chk.E(err) {
219 log.E.F("failed to send AUTH challenge to %s: %v", remote, err)
220 return
221 }
222 log.D.F("AUTH challenge sent successfully to %s", remote)
223 }
224 ticker := time.NewTicker(DefaultPingWait)
225 // Don't pass cancel to Pinger - it should not be able to cancel the connection context
226 go s.Pinger(ctx, listener, ticker)
227 defer func() {
228 log.D.F("closing websocket connection from %s", remote)
229
230 // Cancel all active subscriptions first
231 listener.subscriptionsMu.Lock()
232 for subID, cancelFunc := range listener.subscriptions {
233 log.D.F("cancelling subscription %s for %s", subID, remote)
234 cancelFunc()
235 }
236 listener.subscriptions = nil
237 listener.subscriptionsMu.Unlock()
238
239 // Cancel context and stop pinger
240 cancel()
241 ticker.Stop()
242
243 // Cancel all subscriptions for this connection at publisher level
244 log.D.F("removing subscriptions from publisher for %s", remote)
245 listener.publishers.Receive(&W{
246 Cancel: true,
247 Conn: listener.conn,
248 remote: listener.remote,
249 })
250
251 // Log detailed connection statistics
252 dur := time.Since(listener.startTime)
253 log.D.F(
254 "ws connection closed %s: msgs=%d, REQs=%d, EVENTs=%d, dropped=%d, duration=%v",
255 remote, listener.msgCount, listener.reqCount, listener.eventCount,
256 listener.DroppedMessages(), dur,
257 )
258
259 // Log any remaining connection state
260 if listener.authedPubkey.Load() != nil {
261 log.D.F("ws connection %s was authenticated", remote)
262 } else {
263 log.D.F("ws connection %s was not authenticated", remote)
264 }
265
266 // Close message queue to signal processor to exit
267 close(listener.messageQueue)
268 // Wait for message processor to finish
269 <-listener.processingDone
270
271 // Wait for all spawned message handlers to complete
272 // This is critical to prevent "send on closed channel" panics
273 log.D.F("ws->%s waiting for message handlers to complete", remote)
274 listener.handlerWg.Wait()
275 log.D.F("ws->%s all message handlers completed", remote)
276
277 // Close write channel to signal worker to exit
278 close(listener.writeChan)
279 // Wait for write worker to finish
280 <-listener.writeDone
281 }()
282 for {
283 select {
284 case <-ctx.Done():
285 return
286 default:
287 }
288
289 // Check if blacklisted connection has timed out
290 if listener.isBlacklisted && time.Now().After(listener.blacklistTimeout) {
291 log.W.F("blacklisted IP %s timeout reached, closing connection", remote)
292 return
293 }
294
295 var typ int
296 var msg []byte
297 log.T.F("waiting for message from %s", remote)
298
299 // Don't set read deadline here - it's set initially and extended by pong handler
300 // This prevents premature timeouts on idle connections with active subscriptions
301 if ctx.Err() != nil {
302 return
303 }
304
305 // Block waiting for message; rely on pings and context cancellation to detect dead peers
306 // The read deadline is managed by the pong handler which extends it when pongs are received
307 typ, msg, err = conn.ReadMessage()
308
309 if err != nil {
310 if websocket.IsUnexpectedCloseError(
311 err,
312 websocket.CloseNormalClosure, // 1000
313 websocket.CloseGoingAway, // 1001
314 websocket.CloseNoStatusReceived, // 1005
315 websocket.CloseAbnormalClosure, // 1006
316 4537, // some client seems to send many of these
317 ) {
318 log.D.F("websocket connection closed from %s: %v", remote, err)
319 }
320 cancel() // Cancel context like khatru does
321 return
322 }
323 if typ == websocket.PingMessage {
324 log.D.F("received PING from %s, sending PONG", remote)
325 // Send pong directly (like khatru does)
326 if err = conn.WriteMessage(websocket.PongMessage, nil); err != nil {
327 log.E.F("failed to send PONG to %s: %v", remote, err)
328 return
329 }
330 continue
331 }
332 // Log message size for debugging
333 if len(msg) > 1000 { // Only log for larger messages
334 log.D.F("received large message from %s: %d bytes", remote, len(msg))
335 }
336 // log.T.F("received message from %s: %s", remote, string(msg))
337
338 // Queue message for asynchronous processing
339 if !listener.QueueMessage(msg, remote) {
340 log.D.F("ws->%s message queue full, dropping message (capacity=%d)", remote, cap(listener.messageQueue))
341 }
342 }
343 }
344
345 func (s *Server) Pinger(
346 ctx context.Context, listener *Listener, ticker *time.Ticker,
347 ) {
348 defer func() {
349 log.D.F("pinger shutting down")
350 ticker.Stop()
351 // Recover from panic if channel is closed
352 if r := recover(); r != nil {
353 log.D.F("pinger recovered from panic (channel likely closed): %v", r)
354 }
355 }()
356 pingCount := 0
357 for {
358 select {
359 case <-ctx.Done():
360 log.T.F("pinger context cancelled after %d pings", pingCount)
361 return
362 case <-ticker.C:
363 pingCount++
364 // Send ping request through write channel - this allows pings to interrupt other writes
365 select {
366 case <-ctx.Done():
367 return
368 case listener.writeChan <- publish.WriteRequest{IsPing: true, MsgType: pingCount}:
369 // Ping request queued successfully
370 case <-time.After(DefaultWriteTimeout):
371 log.E.F("ping #%d channel timeout - connection may be overloaded", pingCount)
372 return
373 }
374 }
375 }
376 }
377
378