listener.go raw

   1  package app
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"net/http"
   7  	"strings"
   8  	"sync"
   9  	"sync/atomic"
  10  	"time"
  11  
  12  	"github.com/gorilla/websocket"
  13  	"next.orly.dev/pkg/lol/errorf"
  14  	"next.orly.dev/pkg/lol/log"
  15  	"next.orly.dev/app/config"
  16  	"next.orly.dev/pkg/acl"
  17  	"next.orly.dev/pkg/database"
  18  	"next.orly.dev/pkg/nostr/encoders/event"
  19  	"next.orly.dev/pkg/nostr/encoders/filter"
  20  	"next.orly.dev/pkg/protocol/publish"
  21  	"next.orly.dev/pkg/utils"
  22  	atomicutils "next.orly.dev/pkg/utils/atomic"
  23  )
  24  
  25  type Listener struct {
  26  	// Server is the embedded server reference.
  27  	// Deprecated: Prefer using accessor methods (ServerConfig, ServerDatabase, etc.)
  28  	// instead of accessing Server fields directly.
  29  	*Server
  30  	conn *websocket.Conn
  31  	ctx              context.Context
  32  	cancel           context.CancelFunc // Cancel function for this listener's context
  33  	remote           string
  34  	connectionID     string // Unique identifier for this connection (for access tracking)
  35  	req              *http.Request
  36  	challenge        atomicutils.Bytes
  37  	authedPubkey     atomicutils.Bytes
  38  	startTime        time.Time
  39  	isBlacklisted    bool      // Marker to identify blacklisted IPs
  40  	blacklistTimeout time.Time // When to timeout blacklisted connections
  41  	writeChan        chan publish.WriteRequest // Channel for write requests (back to queued approach)
  42  	writeDone        chan struct{}     // Closed when write worker exits
  43  	// Message processing queue for async handling
  44  	messageQueue     chan messageRequest // Buffered channel for message processing
  45  	processingDone   chan struct{}       // Closed when message processor exits
  46  	handlerWg        sync.WaitGroup      // Tracks spawned message handler goroutines
  47  	handlerSem       chan struct{}       // Limits concurrent message handlers per connection
  48  	authProcessing   sync.RWMutex        // Ensures AUTH completes before other messages check authentication
  49  	// Flow control counters (atomic for concurrent access)
  50  	droppedMessages      atomic.Int64 // Messages dropped due to full queue
  51  	queryCostAccumulator atomic.Int64 // Accumulated query cost for this connection (units: multiplier * 100)
  52  	// Diagnostics: per-connection counters
  53  	msgCount   int
  54  	reqCount   int
  55  	eventCount int
  56  	// Subscription tracking for cleanup
  57  	subscriptions    map[string]context.CancelFunc // Map of subscription ID to cancel function
  58  	subscriptionsMu  sync.Mutex                     // Protects subscriptions map
  59  }
  60  
  61  type messageRequest struct {
  62  	data   []byte
  63  	remote string
  64  }
  65  
  66  // Ctx returns the listener's context, but creates a new context for each operation
  67  // to prevent cancellation from affecting subsequent operations
  68  func (l *Listener) Ctx() context.Context {
  69  	return l.ctx
  70  }
  71  
  72  // ServerContext returns the server's context (distinct from the listener's own context).
  73  func (l *Listener) ServerContext() context.Context {
  74  	return l.Server.Context()
  75  }
  76  
  77  // ServerConfig returns the server's configuration.
  78  func (l *Listener) ServerConfig() *config.C {
  79  	return l.Server.GetConfig()
  80  }
  81  
  82  // ServerDatabase returns the server's database instance.
  83  func (l *Listener) ServerDatabase() database.Database {
  84  	return l.Server.Database()
  85  }
  86  
  87  // DroppedMessages returns the total number of messages that were dropped
  88  // because the message processing queue was full.
  89  func (l *Listener) DroppedMessages() int {
  90  	return int(l.droppedMessages.Load())
  91  }
  92  
  93  // RemainingCapacity returns the number of slots available in the message processing queue.
  94  func (l *Listener) RemainingCapacity() int {
  95  	return cap(l.messageQueue) - len(l.messageQueue)
  96  }
  97  
  98  // QueueMessage queues a message for asynchronous processing.
  99  // Returns true if the message was queued, false if the queue was full.
 100  func (l *Listener) QueueMessage(data []byte, remote string) bool {
 101  	req := messageRequest{data: data, remote: remote}
 102  	select {
 103  	case l.messageQueue <- req:
 104  		return true
 105  	default:
 106  		l.droppedMessages.Add(1)
 107  		return false
 108  	}
 109  }
 110  
 111  
 112  func (l *Listener) Write(p []byte) (n int, err error) {
 113  	// Defensive: recover from any panic when sending to closed channel
 114  	defer func() {
 115  		if r := recover(); r != nil {
 116  			log.D.F("ws->%s write panic recovered (channel likely closed): %v", l.remote, r)
 117  			err = errorf.E("write channel closed")
 118  			n = 0
 119  		}
 120  	}()
 121  
 122  	// Send write request to channel - non-blocking with timeout
 123  	select {
 124  	case <-l.ctx.Done():
 125  		return 0, l.ctx.Err()
 126  	case l.writeChan <- publish.WriteRequest{Data: p, MsgType: websocket.TextMessage, IsControl: false}:
 127  		return len(p), nil
 128  	case <-time.After(DefaultWriteTimeout):
 129  		log.E.F("ws->%s write channel timeout", l.remote)
 130  		return 0, errorf.E("write channel timeout")
 131  	}
 132  }
 133  
 134  // SendEvent sends an event to the client. Implements archive.EventDeliveryChannel.
 135  func (l *Listener) SendEvent(ev *event.E) error {
 136  	if ev == nil {
 137  		return nil
 138  	}
 139  	// Serialize the event as an EVENT envelope
 140  	data := ev.Serialize()
 141  	// Use Write to send
 142  	_, err := l.Write(data)
 143  	return err
 144  }
 145  
 146  // IsConnected returns whether the client connection is still active.
 147  // Implements archive.EventDeliveryChannel.
 148  func (l *Listener) IsConnected() bool {
 149  	select {
 150  	case <-l.ctx.Done():
 151  		return false
 152  	default:
 153  		return l.conn != nil
 154  	}
 155  }
 156  
 157  // WriteControl sends a control message through the write channel
 158  func (l *Listener) WriteControl(messageType int, data []byte, deadline time.Time) (err error) {
 159  	// Defensive: recover from any panic when sending to closed channel
 160  	defer func() {
 161  		if r := recover(); r != nil {
 162  			log.D.F("ws->%s writeControl panic recovered (channel likely closed): %v", l.remote, r)
 163  			err = errorf.E("write channel closed")
 164  		}
 165  	}()
 166  
 167  	select {
 168  	case <-l.ctx.Done():
 169  		return l.ctx.Err()
 170  	case l.writeChan <- publish.WriteRequest{Data: data, MsgType: messageType, IsControl: true, Deadline: deadline}:
 171  		return nil
 172  	case <-time.After(DefaultWriteTimeout):
 173  		log.E.F("ws->%s writeControl channel timeout", l.remote)
 174  		return errorf.E("writeControl channel timeout")
 175  	}
 176  }
 177  
 178  // writeWorker is the single goroutine that handles all writes to the websocket connection.
 179  // This serializes all writes to prevent concurrent write panics and allows pings to interrupt writes.
 180  func (l *Listener) writeWorker() {
 181  	defer func() {
 182  		// Only unregister write channel if connection is actually dead/closing
 183  		// Unregister if:
 184  		// 1. Context is cancelled (connection closing)
 185  		// 2. Channel was closed (connection closing)
 186  		// 3. Connection error occurred (already handled inline)
 187  		if l.ctx.Err() != nil {
 188  			// Connection is closing - safe to unregister
 189  			if socketPub := l.publishers.GetSocketPublisher(); socketPub != nil {
 190  				log.D.F("ws->%s write worker: unregistering write channel (connection closing)", l.remote)
 191  				socketPub.SetWriteChan(l.conn, nil)
 192  			}
 193  		} else {
 194  			// Exiting for other reasons (timeout, etc.) but connection may still be valid
 195  			log.D.F("ws->%s write worker exiting unexpectedly", l.remote)
 196  		}
 197  		close(l.writeDone)
 198  	}()
 199  
 200  	for {
 201  		select {
 202  		case <-l.ctx.Done():
 203  			log.D.F("ws->%s write worker context cancelled", l.remote)
 204  			return
 205  		case req, ok := <-l.writeChan:
 206  			if !ok {
 207  				log.D.F("ws->%s write channel closed", l.remote)
 208  				return
 209  			}
 210  
 211  			// Skip writes if no connection (unit tests)
 212  			if l.conn == nil {
 213  				log.T.F("ws->%s skipping write (no connection)", l.remote)
 214  				continue
 215  			}
 216  
 217  			// Handle the write request
 218  			var err error
 219  			if req.IsPing {
 220  				// Special handling for ping messages
 221  				log.D.F("sending PING #%d", req.MsgType)
 222  				deadline := time.Now().Add(DefaultWriteTimeout)
 223  				err = l.conn.WriteControl(websocket.PingMessage, nil, deadline)
 224  				if err != nil {
 225  					if !strings.HasSuffix(err.Error(), "use of closed network connection") {
 226  						log.E.F("error writing ping: %v; closing websocket", err)
 227  					}
 228  					return
 229  				}
 230  			} else if req.IsControl {
 231  				// Control message
 232  				err = l.conn.WriteControl(req.MsgType, req.Data, req.Deadline)
 233  				if err != nil {
 234  					log.E.F("ws->%s control write failed: %v", l.remote, err)
 235  					return
 236  				}
 237  			} else {
 238  				// Regular message
 239  				l.conn.SetWriteDeadline(time.Now().Add(DefaultWriteTimeout))
 240  				err = l.conn.WriteMessage(req.MsgType, req.Data)
 241  				if err != nil {
 242  					log.E.F("ws->%s write failed: %v", l.remote, err)
 243  					return
 244  				}
 245  			}
 246  		}
 247  	}
 248  }
 249  
 250  // messageProcessor is the goroutine that processes messages asynchronously.
 251  // This prevents the websocket read loop from blocking on message processing.
 252  func (l *Listener) messageProcessor() {
 253  	defer func() {
 254  		close(l.processingDone)
 255  	}()
 256  
 257  	for {
 258  		select {
 259  		case <-l.ctx.Done():
 260  			log.D.F("ws->%s message processor context cancelled", l.remote)
 261  			return
 262  		case req, ok := <-l.messageQueue:
 263  			if !ok {
 264  				log.D.F("ws->%s message queue closed", l.remote)
 265  				return
 266  			}
 267  
 268  			// Lock immediately to ensure AUTH is processed before subsequent messages
 269  			// are dequeued. This prevents race conditions where EVENT checks authentication
 270  			// before AUTH completes.
 271  			l.authProcessing.Lock()
 272  
 273  			// Check if this is an AUTH message by looking for the ["AUTH" prefix
 274  			isAuthMessage := len(req.data) > 7 && bytes.HasPrefix(req.data, []byte(`["AUTH"`))
 275  
 276  			if isAuthMessage {
 277  				// Process AUTH message synchronously while holding lock
 278  				// This blocks the messageProcessor from dequeuing the next message
 279  				// until authentication is complete and authedPubkey is set
 280  				log.D.F("ws->%s processing AUTH synchronously with lock", req.remote)
 281  				l.HandleMessage(req.data, req.remote)
 282  				// Unlock after AUTH completes so subsequent messages see updated authedPubkey
 283  				l.authProcessing.Unlock()
 284  			} else {
 285  				// Not AUTH - unlock immediately and process concurrently
 286  				// The next message can now be dequeued (possibly another non-AUTH to process concurrently)
 287  				l.authProcessing.Unlock()
 288  
 289  				// Acquire semaphore to limit concurrent handlers (blocking with context awareness)
 290  				select {
 291  				case l.handlerSem <- struct{}{}:
 292  					// Semaphore acquired
 293  				case <-l.ctx.Done():
 294  					return
 295  				}
 296  				l.handlerWg.Add(1)
 297  				go func(data []byte, remote string) {
 298  					defer func() {
 299  						<-l.handlerSem // Release semaphore
 300  						l.handlerWg.Done()
 301  					}()
 302  					l.HandleMessage(data, remote)
 303  				}(req.data, req.remote)
 304  			}
 305  		}
 306  	}
 307  }
 308  
 309  // getManagedACL returns the managed ACL instance if available
 310  func (l *Listener) getManagedACL() *database.ManagedACL {
 311  	// Get the managed ACL instance from the ACL registry
 312  	for _, aclInstance := range acl.Registry.ACLs() {
 313  		if aclInstance.Type() == "managed" {
 314  			if managed, ok := aclInstance.(*acl.Managed); ok {
 315  				return managed.GetManagedACL()
 316  			}
 317  		}
 318  	}
 319  	return nil
 320  }
 321  
 322  // getFollowsThrottleDelay returns the progressive throttle delay for follows or social ACL mode.
 323  // Returns 0 if not in a throttle-enabled mode, throttle is disabled, or user is exempt.
 324  func (l *Listener) getFollowsThrottleDelay(ev *event.E) time.Duration {
 325  	mode := acl.Registry.GetMode()
 326  	switch mode {
 327  	case "follows":
 328  		for _, aclInstance := range acl.Registry.ACLs() {
 329  			if follows, ok := aclInstance.(*acl.Follows); ok {
 330  				return follows.GetThrottleDelay(ev.Pubkey, l.remote)
 331  			}
 332  		}
 333  	case "social":
 334  		for _, aclInstance := range acl.Registry.ACLs() {
 335  			if social, ok := aclInstance.(*acl.Social); ok {
 336  				return social.GetThrottleDelay(ev.Pubkey, l.remote)
 337  			}
 338  		}
 339  	}
 340  	return 0
 341  }
 342  
 343  // QueryEvents queries events using the database QueryEvents method
 344  func (l *Listener) QueryEvents(ctx context.Context, f *filter.F) (event.S, error) {
 345  	return l.DB.QueryEvents(ctx, f)
 346  }
 347  
 348  // QueryAllVersions queries events using the database QueryAllVersions method
 349  func (l *Listener) QueryAllVersions(ctx context.Context, f *filter.F) (event.S, error) {
 350  	return l.DB.QueryAllVersions(ctx, f)
 351  }
 352  
 353  // canSeePrivateEvent checks if the authenticated user can see an event with a private tag
 354  func (l *Listener) canSeePrivateEvent(authedPubkey, privatePubkey []byte) (canSee bool) {
 355  	// If no authenticated user, deny access
 356  	if len(authedPubkey) == 0 {
 357  		return false
 358  	}
 359  
 360  	// If the authenticated user matches the private tag pubkey, allow access
 361  	if len(privatePubkey) > 0 && utils.FastEqual(authedPubkey, privatePubkey) {
 362  		return true
 363  	}
 364  
 365  	// Check if user is an admin or owner (they can see all private events)
 366  	accessLevel := acl.Registry.GetAccessLevel(authedPubkey, l.remote)
 367  	if accessLevel == "admin" || accessLevel == "owner" {
 368  		return true
 369  	}
 370  
 371  	// Default deny
 372  	return false
 373  }
 374