package app import ( "bytes" "context" "net/http" "strings" "sync" "sync/atomic" "time" "github.com/gorilla/websocket" "next.orly.dev/pkg/lol/errorf" "next.orly.dev/pkg/lol/log" "next.orly.dev/app/config" "next.orly.dev/pkg/acl" "next.orly.dev/pkg/database" "next.orly.dev/pkg/nostr/encoders/event" "next.orly.dev/pkg/nostr/encoders/filter" "next.orly.dev/pkg/protocol/publish" "next.orly.dev/pkg/utils" atomicutils "next.orly.dev/pkg/utils/atomic" ) type Listener struct { // Server is the embedded server reference. // Deprecated: Prefer using accessor methods (ServerConfig, ServerDatabase, etc.) // instead of accessing Server fields directly. *Server conn *websocket.Conn ctx context.Context cancel context.CancelFunc // Cancel function for this listener's context remote string connectionID string // Unique identifier for this connection (for access tracking) req *http.Request challenge atomicutils.Bytes authedPubkey atomicutils.Bytes startTime time.Time isBlacklisted bool // Marker to identify blacklisted IPs blacklistTimeout time.Time // When to timeout blacklisted connections writeChan chan publish.WriteRequest // Channel for write requests (back to queued approach) writeDone chan struct{} // Closed when write worker exits // Message processing queue for async handling messageQueue chan messageRequest // Buffered channel for message processing processingDone chan struct{} // Closed when message processor exits handlerWg sync.WaitGroup // Tracks spawned message handler goroutines handlerSem chan struct{} // Limits concurrent message handlers per connection authProcessing sync.RWMutex // Ensures AUTH completes before other messages check authentication // Flow control counters (atomic for concurrent access) droppedMessages atomic.Int64 // Messages dropped due to full queue queryCostAccumulator atomic.Int64 // Accumulated query cost for this connection (units: multiplier * 100) // Diagnostics: per-connection counters msgCount int reqCount int eventCount int // Subscription tracking for cleanup subscriptions map[string]context.CancelFunc // Map of subscription ID to cancel function subscriptionsMu sync.Mutex // Protects subscriptions map } type messageRequest struct { data []byte remote string } // Ctx returns the listener's context, but creates a new context for each operation // to prevent cancellation from affecting subsequent operations func (l *Listener) Ctx() context.Context { return l.ctx } // ServerContext returns the server's context (distinct from the listener's own context). func (l *Listener) ServerContext() context.Context { return l.Server.Context() } // ServerConfig returns the server's configuration. func (l *Listener) ServerConfig() *config.C { return l.Server.GetConfig() } // ServerDatabase returns the server's database instance. func (l *Listener) ServerDatabase() database.Database { return l.Server.Database() } // DroppedMessages returns the total number of messages that were dropped // because the message processing queue was full. func (l *Listener) DroppedMessages() int { return int(l.droppedMessages.Load()) } // RemainingCapacity returns the number of slots available in the message processing queue. func (l *Listener) RemainingCapacity() int { return cap(l.messageQueue) - len(l.messageQueue) } // QueueMessage queues a message for asynchronous processing. // Returns true if the message was queued, false if the queue was full. func (l *Listener) QueueMessage(data []byte, remote string) bool { req := messageRequest{data: data, remote: remote} select { case l.messageQueue <- req: return true default: l.droppedMessages.Add(1) return false } } func (l *Listener) Write(p []byte) (n int, err error) { // Defensive: recover from any panic when sending to closed channel defer func() { if r := recover(); r != nil { log.D.F("ws->%s write panic recovered (channel likely closed): %v", l.remote, r) err = errorf.E("write channel closed") n = 0 } }() // Send write request to channel - non-blocking with timeout select { case <-l.ctx.Done(): return 0, l.ctx.Err() case l.writeChan <- publish.WriteRequest{Data: p, MsgType: websocket.TextMessage, IsControl: false}: return len(p), nil case <-time.After(DefaultWriteTimeout): log.E.F("ws->%s write channel timeout", l.remote) return 0, errorf.E("write channel timeout") } } // SendEvent sends an event to the client. Implements archive.EventDeliveryChannel. func (l *Listener) SendEvent(ev *event.E) error { if ev == nil { return nil } // Serialize the event as an EVENT envelope data := ev.Serialize() // Use Write to send _, err := l.Write(data) return err } // IsConnected returns whether the client connection is still active. // Implements archive.EventDeliveryChannel. func (l *Listener) IsConnected() bool { select { case <-l.ctx.Done(): return false default: return l.conn != nil } } // WriteControl sends a control message through the write channel func (l *Listener) WriteControl(messageType int, data []byte, deadline time.Time) (err error) { // Defensive: recover from any panic when sending to closed channel defer func() { if r := recover(); r != nil { log.D.F("ws->%s writeControl panic recovered (channel likely closed): %v", l.remote, r) err = errorf.E("write channel closed") } }() select { case <-l.ctx.Done(): return l.ctx.Err() case l.writeChan <- publish.WriteRequest{Data: data, MsgType: messageType, IsControl: true, Deadline: deadline}: return nil case <-time.After(DefaultWriteTimeout): log.E.F("ws->%s writeControl channel timeout", l.remote) return errorf.E("writeControl channel timeout") } } // writeWorker is the single goroutine that handles all writes to the websocket connection. // This serializes all writes to prevent concurrent write panics and allows pings to interrupt writes. func (l *Listener) writeWorker() { defer func() { // Only unregister write channel if connection is actually dead/closing // Unregister if: // 1. Context is cancelled (connection closing) // 2. Channel was closed (connection closing) // 3. Connection error occurred (already handled inline) if l.ctx.Err() != nil { // Connection is closing - safe to unregister if socketPub := l.publishers.GetSocketPublisher(); socketPub != nil { log.D.F("ws->%s write worker: unregistering write channel (connection closing)", l.remote) socketPub.SetWriteChan(l.conn, nil) } } else { // Exiting for other reasons (timeout, etc.) but connection may still be valid log.D.F("ws->%s write worker exiting unexpectedly", l.remote) } close(l.writeDone) }() for { select { case <-l.ctx.Done(): log.D.F("ws->%s write worker context cancelled", l.remote) return case req, ok := <-l.writeChan: if !ok { log.D.F("ws->%s write channel closed", l.remote) return } // Skip writes if no connection (unit tests) if l.conn == nil { log.T.F("ws->%s skipping write (no connection)", l.remote) continue } // Handle the write request var err error if req.IsPing { // Special handling for ping messages log.D.F("sending PING #%d", req.MsgType) deadline := time.Now().Add(DefaultWriteTimeout) err = l.conn.WriteControl(websocket.PingMessage, nil, deadline) if err != nil { if !strings.HasSuffix(err.Error(), "use of closed network connection") { log.E.F("error writing ping: %v; closing websocket", err) } return } } else if req.IsControl { // Control message err = l.conn.WriteControl(req.MsgType, req.Data, req.Deadline) if err != nil { log.E.F("ws->%s control write failed: %v", l.remote, err) return } } else { // Regular message l.conn.SetWriteDeadline(time.Now().Add(DefaultWriteTimeout)) err = l.conn.WriteMessage(req.MsgType, req.Data) if err != nil { log.E.F("ws->%s write failed: %v", l.remote, err) return } } } } } // messageProcessor is the goroutine that processes messages asynchronously. // This prevents the websocket read loop from blocking on message processing. func (l *Listener) messageProcessor() { defer func() { close(l.processingDone) }() for { select { case <-l.ctx.Done(): log.D.F("ws->%s message processor context cancelled", l.remote) return case req, ok := <-l.messageQueue: if !ok { log.D.F("ws->%s message queue closed", l.remote) return } // Lock immediately to ensure AUTH is processed before subsequent messages // are dequeued. This prevents race conditions where EVENT checks authentication // before AUTH completes. l.authProcessing.Lock() // Check if this is an AUTH message by looking for the ["AUTH" prefix isAuthMessage := len(req.data) > 7 && bytes.HasPrefix(req.data, []byte(`["AUTH"`)) if isAuthMessage { // Process AUTH message synchronously while holding lock // This blocks the messageProcessor from dequeuing the next message // until authentication is complete and authedPubkey is set log.D.F("ws->%s processing AUTH synchronously with lock", req.remote) l.HandleMessage(req.data, req.remote) // Unlock after AUTH completes so subsequent messages see updated authedPubkey l.authProcessing.Unlock() } else { // Not AUTH - unlock immediately and process concurrently // The next message can now be dequeued (possibly another non-AUTH to process concurrently) l.authProcessing.Unlock() // Acquire semaphore to limit concurrent handlers (blocking with context awareness) select { case l.handlerSem <- struct{}{}: // Semaphore acquired case <-l.ctx.Done(): return } l.handlerWg.Add(1) go func(data []byte, remote string) { defer func() { <-l.handlerSem // Release semaphore l.handlerWg.Done() }() l.HandleMessage(data, remote) }(req.data, req.remote) } } } } // getManagedACL returns the managed ACL instance if available func (l *Listener) getManagedACL() *database.ManagedACL { // Get the managed ACL instance from the ACL registry for _, aclInstance := range acl.Registry.ACLs() { if aclInstance.Type() == "managed" { if managed, ok := aclInstance.(*acl.Managed); ok { return managed.GetManagedACL() } } } return nil } // getFollowsThrottleDelay returns the progressive throttle delay for follows or social ACL mode. // Returns 0 if not in a throttle-enabled mode, throttle is disabled, or user is exempt. func (l *Listener) getFollowsThrottleDelay(ev *event.E) time.Duration { mode := acl.Registry.GetMode() switch mode { case "follows": for _, aclInstance := range acl.Registry.ACLs() { if follows, ok := aclInstance.(*acl.Follows); ok { return follows.GetThrottleDelay(ev.Pubkey, l.remote) } } case "social": for _, aclInstance := range acl.Registry.ACLs() { if social, ok := aclInstance.(*acl.Social); ok { return social.GetThrottleDelay(ev.Pubkey, l.remote) } } } return 0 } // QueryEvents queries events using the database QueryEvents method func (l *Listener) QueryEvents(ctx context.Context, f *filter.F) (event.S, error) { return l.DB.QueryEvents(ctx, f) } // QueryAllVersions queries events using the database QueryAllVersions method func (l *Listener) QueryAllVersions(ctx context.Context, f *filter.F) (event.S, error) { return l.DB.QueryAllVersions(ctx, f) } // canSeePrivateEvent checks if the authenticated user can see an event with a private tag func (l *Listener) canSeePrivateEvent(authedPubkey, privatePubkey []byte) (canSee bool) { // If no authenticated user, deny access if len(authedPubkey) == 0 { return false } // If the authenticated user matches the private tag pubkey, allow access if len(privatePubkey) > 0 && utils.FastEqual(authedPubkey, privatePubkey) { return true } // Check if user is an admin or owner (they can see all private events) accessLevel := acl.Registry.GetAccessLevel(authedPubkey, l.remote) if accessLevel == "admin" || accessLevel == "owner" { return true } // Default deny return false }