handle-count.go raw

   1  package app
   2  
   3  import (
   4  	"context"
   5  	"errors"
   6  	"fmt"
   7  	"time"
   8  
   9  	"next.orly.dev/pkg/lol/chk"
  10  	"next.orly.dev/pkg/lol/log"
  11  	"next.orly.dev/pkg/acl"
  12  	"next.orly.dev/pkg/nostr/crypto/ec/schnorr"
  13  	"next.orly.dev/pkg/nostr/encoders/envelopes/authenvelope"
  14  	"next.orly.dev/pkg/nostr/encoders/envelopes/countenvelope"
  15  	"next.orly.dev/pkg/nostr/utils/normalize"
  16  )
  17  
  18  // HandleCount processes a COUNT envelope by parsing the request, verifying
  19  // permissions, invoking the database CountEvents for each provided filter, and
  20  // responding with a COUNT response containing the aggregate count.
  21  func (l *Listener) HandleCount(msg []byte) (err error) {
  22  	log.D.F("HandleCount: START processing from %s", l.remote)
  23  
  24  	// Parse the COUNT request
  25  	env := countenvelope.New()
  26  	if _, err = env.Unmarshal(msg); chk.E(err) {
  27  		return normalize.Error.Errorf(err.Error())
  28  	}
  29  	log.D.C(func() string { return fmt.Sprintf("COUNT sub=%s filters=%d", env.Subscription, len(env.Filters)) })
  30  
  31  	// If ACL is active, auth is required, or AuthToWrite is enabled, send a challenge (same as REQ path)
  32  	if len(l.authedPubkey.Load()) != schnorr.PubKeyBytesLen && (acl.Registry.GetMode() != "none" || l.Config.AuthRequired || l.Config.AuthToWrite) {
  33  		if err = authenvelope.NewChallengeWith(l.challenge.Load()).Write(l); chk.E(err) {
  34  			return
  35  		}
  36  	}
  37  
  38  	// Check read permissions
  39  	accessLevel := acl.Registry.GetAccessLevel(l.authedPubkey.Load(), l.remote)
  40  
  41  	// If auth is required but user is not authenticated, deny access
  42  	if l.Config.AuthRequired && len(l.authedPubkey.Load()) == 0 {
  43  		return errors.New("authentication required")
  44  	}
  45  
  46  	// If AuthToWrite is enabled, allow COUNT without auth (but still check ACL)
  47  	if l.Config.AuthToWrite && len(l.authedPubkey.Load()) == 0 {
  48  		// Allow unauthenticated COUNT when AuthToWrite is enabled
  49  		// but still respect ACL access levels if ACL is active
  50  		if acl.Registry.GetMode() != "none" {
  51  			switch accessLevel {
  52  			case "none", "blocked", "banned":
  53  				return errors.New("auth required: user not authed or has no read access")
  54  			}
  55  		}
  56  		// Allow the request to proceed without authentication
  57  	} else {
  58  		// Only check ACL access level if not already handled by AuthToWrite
  59  		switch accessLevel {
  60  		case "none":
  61  			return errors.New("auth required: user not authed or has no read access")
  62  		default:
  63  			// allowed to read
  64  		}
  65  	}
  66  
  67  	// Use a bounded context for counting, isolated from the connection context
  68  	// to prevent count timeouts from affecting the long-lived websocket connection
  69  	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
  70  	defer cancel()
  71  
  72  	// Aggregate count across all provided filters
  73  	var total int
  74  	var approx bool // database returns false per implementation
  75  	for _, f := range env.Filters {
  76  		if f == nil {
  77  			continue
  78  		}
  79  		var cnt int
  80  		var a bool
  81  		cnt, a, err = l.DB.CountEvents(ctx, f)
  82  		if chk.E(err) {
  83  			return
  84  		}
  85  		total += cnt
  86  		approx = approx || a
  87  	}
  88  
  89  	// Build and send COUNT response
  90  	var res *countenvelope.Response
  91  	if res, err = countenvelope.NewResponseFrom(env.Subscription, total, approx); chk.E(err) {
  92  		return
  93  	}
  94  	if err = res.Write(l); chk.E(err) {
  95  		return
  96  	}
  97  
  98  	log.D.F("HandleCount: COMPLETED processing from %s count=%d approx=%v", l.remote, total, approx)
  99  	return nil
 100  }
 101