handle-websocket.go raw

   1  package app
   2  
   3  import (
   4  	"context"
   5  	"crypto/rand"
   6  	"fmt"
   7  	"net/http"
   8  	"strings"
   9  	"time"
  10  
  11  	"github.com/gorilla/websocket"
  12  	"next.orly.dev/pkg/lol/chk"
  13  	"next.orly.dev/pkg/lol/log"
  14  	"next.orly.dev/pkg/nostr/encoders/envelopes/authenvelope"
  15  	"next.orly.dev/pkg/nostr/encoders/hex"
  16  	"next.orly.dev/pkg/protocol/publish"
  17  	"next.orly.dev/pkg/nostr/utils/units"
  18  )
  19  
  20  const (
  21  	DefaultWriteWait      = 10 * time.Second
  22  	DefaultPongWait       = 60 * time.Second
  23  	DefaultPingWait       = DefaultPongWait / 2
  24  	DefaultWriteTimeout   = 3 * time.Second
  25  	// DefaultMaxMessageSize is the maximum message size for WebSocket connections
  26  	// Increased from 512KB to 10MB to support large kind 3 follow lists (10k+ follows)
  27  	// and other large events without truncation
  28  	DefaultMaxMessageSize = 10 * 1024 * 1024 // 10MB
  29  	// ClientMessageSizeLimit is the maximum message size that clients can handle
  30  	// This is set to 100MB to allow large messages
  31  	ClientMessageSizeLimit = 100 * 1024 * 1024 // 100MB
  32  )
  33  
  34  var upgrader = websocket.Upgrader{
  35  	ReadBufferSize:  1024,
  36  	WriteBufferSize: 1024,
  37  	CheckOrigin: func(r *http.Request) bool {
  38  		return true // Allow all origins for proxy compatibility
  39  	},
  40  }
  41  
  42  func (s *Server) HandleWebsocket(w http.ResponseWriter, r *http.Request) {
  43  	remote := GetRemoteFromReq(r)
  44  
  45  	// Log comprehensive proxy information for debugging
  46  	LogProxyInfo(r, "WebSocket connection from "+remote)
  47  	if len(s.Config.IPWhitelist) > 0 {
  48  		for _, ip := range s.Config.IPWhitelist {
  49  			log.T.F("checking IP whitelist: %s", ip)
  50  			if strings.HasPrefix(remote, ip) {
  51  				log.T.F("IP whitelisted %s", remote)
  52  				goto whitelist
  53  			}
  54  		}
  55  		log.T.F("IP not whitelisted: %s", remote)
  56  		return
  57  	}
  58  whitelist:
  59  	// Extract IP from remote (strip port)
  60  	ip := remote
  61  	if idx := strings.LastIndex(remote, ":"); idx != -1 {
  62  		ip = remote[:idx]
  63  	}
  64  
  65  	// Check per-IP connection limit (hard limit 10)
  66  	maxConnPerIP := s.Config.MaxConnectionsPerIP
  67  	if maxConnPerIP <= 0 {
  68  		maxConnPerIP = 10
  69  	}
  70  	if maxConnPerIP > 10 {
  71  		maxConnPerIP = 10 // Hard limit
  72  	}
  73  
  74  	s.connPerIPMu.Lock()
  75  	currentConns := s.connPerIP[ip]
  76  	if currentConns >= maxConnPerIP {
  77  		s.connPerIPMu.Unlock()
  78  		log.W.F("connection limit exceeded for IP %s: %d/%d connections", ip, currentConns, maxConnPerIP)
  79  		http.Error(w, "too many connections from your IP", http.StatusTooManyRequests)
  80  		return
  81  	}
  82  	s.connPerIP[ip]++
  83  	s.connPerIPMu.Unlock()
  84  
  85  	// Track global connection count
  86  	s.activeConnCount.Add(1)
  87  
  88  	// Decrement connection counts when this function returns
  89  	defer func() {
  90  		s.activeConnCount.Add(-1)
  91  		s.connPerIPMu.Lock()
  92  		s.connPerIP[ip]--
  93  		if s.connPerIP[ip] <= 0 {
  94  			delete(s.connPerIP, ip)
  95  		}
  96  		s.connPerIPMu.Unlock()
  97  	}()
  98  
  99  	// Localhost connections are exempt from rate limiting — split-IPC internal
 100  	// connections must never be refused, even during emergency mode.
 101  	isLocalhost := ip == "127.0.0.1" || ip == "::1" || ip == "localhost"
 102  
 103  	// Global adaptive load check — refuse or delay connections under load
 104  	if !isLocalhost && s.rateLimiter != nil && s.rateLimiter.IsEnabled() {
 105  		s.rateLimiter.SetActiveConnections(s.activeConnCount.Load())
 106  
 107  		if !s.rateLimiter.ShouldAcceptConnection() {
 108  			log.W.F("refusing connection from %s: system overloaded", ip)
 109  			http.Error(w, "server overloaded, try later", http.StatusServiceUnavailable)
 110  			return
 111  		}
 112  
 113  		if delay := s.rateLimiter.ConnectionDelay(); delay > 0 {
 114  			log.D.F("delaying connection from %s by %v (load mitigation)", ip, delay)
 115  			time.Sleep(delay)
 116  		}
 117  	}
 118  
 119  	// Progressive per-IP delay: each additional connection from the same IP adds delay
 120  	if currentConns > 0 {
 121  		perIPDelay := time.Duration(currentConns) * 200 * time.Millisecond
 122  		if perIPDelay > 2*time.Second {
 123  			perIPDelay = 2 * time.Second
 124  		}
 125  		log.D.F("per-IP delay for %s: %v (%d connections)", ip, perIPDelay, currentConns+1)
 126  		time.Sleep(perIPDelay)
 127  	}
 128  
 129  	// Create an independent context for this connection
 130  	// This context will be cancelled when the connection closes or server shuts down
 131  	ctx, cancel := context.WithCancel(s.Ctx)
 132  	defer cancel()
 133  	var err error
 134  	var conn *websocket.Conn
 135  
 136  	// Create a per-connection upgrader to avoid racing on global state.
 137  	// Use 64KB buffers instead of max message size (10MB) to limit memory.
 138  	connUpgrader := websocket.Upgrader{
 139  		ReadBufferSize:  64 * 1024,
 140  		WriteBufferSize: 64 * 1024,
 141  		CheckOrigin: func(r *http.Request) bool {
 142  			return true
 143  		},
 144  	}
 145  
 146  	if conn, err = connUpgrader.Upgrade(w, r, nil); chk.E(err) {
 147  		log.E.F("websocket accept failed from %s: %v", remote, err)
 148  		return
 149  	}
 150  	log.T.F("websocket accepted from %s path=%s", remote, r.URL.String())
 151  
 152  	// Set read limit immediately after connection is established
 153  	conn.SetReadLimit(DefaultMaxMessageSize)
 154  	log.D.F("set read limit to %d bytes (%d MB) for %s", DefaultMaxMessageSize, DefaultMaxMessageSize/units.Mb, remote)
 155  
 156  	// Set initial read deadline - pong handler will extend it when pongs are received
 157  	conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
 158  
 159  	// Add pong handler to extend read deadline when client responds to pings
 160  	conn.SetPongHandler(func(string) error {
 161  		log.T.F("received PONG from %s, extending read deadline", remote)
 162  		return conn.SetReadDeadline(time.Now().Add(DefaultPongWait))
 163  	})
 164  
 165  	defer conn.Close()
 166  	// Determine handler semaphore size from config
 167  	handlerSemSize := s.Config.MaxHandlersPerConnection
 168  	if handlerSemSize <= 0 {
 169  		handlerSemSize = 100 // Default if not configured
 170  	}
 171  
 172  	now := time.Now()
 173  	listener := &Listener{
 174  		ctx:            ctx,
 175  		cancel:         cancel,
 176  		Server:         s,
 177  		conn:           conn,
 178  		remote:         remote,
 179  		connectionID:   fmt.Sprintf("%s-%d", remote, now.UnixNano()), // Unique connection ID for access tracking
 180  		req:            r,
 181  		startTime:      now,
 182  		writeChan:      make(chan publish.WriteRequest, 100), // Buffered channel for writes
 183  		writeDone:      make(chan struct{}),
 184  		messageQueue:   make(chan messageRequest, 100), // Buffered channel for message processing
 185  		processingDone: make(chan struct{}),
 186  		handlerSem:     make(chan struct{}, handlerSemSize), // Limits concurrent handlers
 187  		subscriptions:  make(map[string]context.CancelFunc),
 188  	}
 189  
 190  	// Start write worker goroutine
 191  	go listener.writeWorker()
 192  
 193  	// Start message processor goroutine
 194  	go listener.messageProcessor()
 195  
 196  	// Register write channel with publisher
 197  	if socketPub := listener.publishers.GetSocketPublisher(); socketPub != nil {
 198  		socketPub.SetWriteChan(conn, listener.writeChan)
 199  	}
 200  
 201  	// Check for blacklisted IPs
 202  	listener.isBlacklisted = s.isIPBlacklisted(remote)
 203  	if listener.isBlacklisted {
 204  		log.W.F("detected blacklisted IP %s, marking connection for timeout", remote)
 205  		listener.blacklistTimeout = time.Now().Add(time.Minute) // Timeout after 1 minute
 206  	}
 207  	chal := make([]byte, 32)
 208  	if _, err = rand.Read(chal); err != nil {
 209  		log.E.F("failed to generate auth challenge: %v", err)
 210  		return
 211  	}
 212  	listener.challenge.Store([]byte(hex.Enc(chal)))
 213  	// Always send AUTH challenge - channel kinds (40-44) require authentication
 214  	// regardless of ACL mode, and NIP-42 AUTH is harmless for clients that don't need it
 215  	{
 216  		log.D.F("sending AUTH challenge to %s", remote)
 217  		if err = authenvelope.NewChallengeWith(listener.challenge.Load()).
 218  			Write(listener); chk.E(err) {
 219  			log.E.F("failed to send AUTH challenge to %s: %v", remote, err)
 220  			return
 221  		}
 222  		log.D.F("AUTH challenge sent successfully to %s", remote)
 223  	}
 224  	ticker := time.NewTicker(DefaultPingWait)
 225  	// Don't pass cancel to Pinger - it should not be able to cancel the connection context
 226  	go s.Pinger(ctx, listener, ticker)
 227  	defer func() {
 228  		log.D.F("closing websocket connection from %s", remote)
 229  
 230  		// Cancel all active subscriptions first
 231  		listener.subscriptionsMu.Lock()
 232  		for subID, cancelFunc := range listener.subscriptions {
 233  			log.D.F("cancelling subscription %s for %s", subID, remote)
 234  			cancelFunc()
 235  		}
 236  		listener.subscriptions = nil
 237  		listener.subscriptionsMu.Unlock()
 238  
 239  		// Cancel context and stop pinger
 240  		cancel()
 241  		ticker.Stop()
 242  
 243  		// Cancel all subscriptions for this connection at publisher level
 244  		log.D.F("removing subscriptions from publisher for %s", remote)
 245  		listener.publishers.Receive(&W{
 246  			Cancel: true,
 247  			Conn:   listener.conn,
 248  			remote: listener.remote,
 249  		})
 250  
 251  		// Log detailed connection statistics
 252  		dur := time.Since(listener.startTime)
 253  		log.D.F(
 254  			"ws connection closed %s: msgs=%d, REQs=%d, EVENTs=%d, dropped=%d, duration=%v",
 255  			remote, listener.msgCount, listener.reqCount, listener.eventCount,
 256  			listener.DroppedMessages(), dur,
 257  		)
 258  
 259  		// Log any remaining connection state
 260  		if listener.authedPubkey.Load() != nil {
 261  			log.D.F("ws connection %s was authenticated", remote)
 262  		} else {
 263  			log.D.F("ws connection %s was not authenticated", remote)
 264  		}
 265  
 266  		// Close message queue to signal processor to exit
 267  		close(listener.messageQueue)
 268  		// Wait for message processor to finish
 269  		<-listener.processingDone
 270  
 271  		// Wait for all spawned message handlers to complete
 272  		// This is critical to prevent "send on closed channel" panics
 273  		log.D.F("ws->%s waiting for message handlers to complete", remote)
 274  		listener.handlerWg.Wait()
 275  		log.D.F("ws->%s all message handlers completed", remote)
 276  
 277  		// Close write channel to signal worker to exit
 278  		close(listener.writeChan)
 279  		// Wait for write worker to finish
 280  		<-listener.writeDone
 281  	}()
 282  	for {
 283  		select {
 284  		case <-ctx.Done():
 285  			return
 286  		default:
 287  		}
 288  
 289  		// Check if blacklisted connection has timed out
 290  		if listener.isBlacklisted && time.Now().After(listener.blacklistTimeout) {
 291  			log.W.F("blacklisted IP %s timeout reached, closing connection", remote)
 292  			return
 293  		}
 294  
 295  		var typ int
 296  		var msg []byte
 297  		log.T.F("waiting for message from %s", remote)
 298  
 299  		// Don't set read deadline here - it's set initially and extended by pong handler
 300  		// This prevents premature timeouts on idle connections with active subscriptions
 301  		if ctx.Err() != nil {
 302  			return
 303  		}
 304  
 305  		// Block waiting for message; rely on pings and context cancellation to detect dead peers
 306  		// The read deadline is managed by the pong handler which extends it when pongs are received
 307  		typ, msg, err = conn.ReadMessage()
 308  
 309  		if err != nil {
 310  			if websocket.IsUnexpectedCloseError(
 311  				err,
 312  				websocket.CloseNormalClosure,    // 1000
 313  				websocket.CloseGoingAway,        // 1001
 314  				websocket.CloseNoStatusReceived, // 1005
 315  				websocket.CloseAbnormalClosure,  // 1006
 316  				4537,                            // some client seems to send many of these
 317  			) {
 318  				log.D.F("websocket connection closed from %s: %v", remote, err)
 319  			}
 320  			cancel() // Cancel context like khatru does
 321  			return
 322  		}
 323  		if typ == websocket.PingMessage {
 324  			log.D.F("received PING from %s, sending PONG", remote)
 325  			// Send pong directly (like khatru does)
 326  			if err = conn.WriteMessage(websocket.PongMessage, nil); err != nil {
 327  				log.E.F("failed to send PONG to %s: %v", remote, err)
 328  				return
 329  			}
 330  			continue
 331  		}
 332  		// Log message size for debugging
 333  		if len(msg) > 1000 { // Only log for larger messages
 334  			log.D.F("received large message from %s: %d bytes", remote, len(msg))
 335  		}
 336  		// log.T.F("received message from %s: %s", remote, string(msg))
 337  
 338  		// Queue message for asynchronous processing
 339  		if !listener.QueueMessage(msg, remote) {
 340  			log.D.F("ws->%s message queue full, dropping message (capacity=%d)", remote, cap(listener.messageQueue))
 341  		}
 342  	}
 343  }
 344  
 345  func (s *Server) Pinger(
 346  	ctx context.Context, listener *Listener, ticker *time.Ticker,
 347  ) {
 348  	defer func() {
 349  		log.D.F("pinger shutting down")
 350  		ticker.Stop()
 351  		// Recover from panic if channel is closed
 352  		if r := recover(); r != nil {
 353  			log.D.F("pinger recovered from panic (channel likely closed): %v", r)
 354  		}
 355  	}()
 356  	pingCount := 0
 357  	for {
 358  		select {
 359  		case <-ctx.Done():
 360  			log.T.F("pinger context cancelled after %d pings", pingCount)
 361  			return
 362  		case <-ticker.C:
 363  			pingCount++
 364  			// Send ping request through write channel - this allows pings to interrupt other writes
 365  			select {
 366  			case <-ctx.Done():
 367  				return
 368  			case listener.writeChan <- publish.WriteRequest{IsPing: true, MsgType: pingCount}:
 369  				// Ping request queued successfully
 370  			case <-time.After(DefaultWriteTimeout):
 371  				log.E.F("ping #%d channel timeout - connection may be overloaded", pingCount)
 372  				return
 373  			}
 374  		}
 375  	}
 376  }
 377  
 378