dm_ratelimit.go raw

   1  package app
   2  
   3  import (
   4  	"context"
   5  	"sync"
   6  	"time"
   7  
   8  	"next.orly.dev/pkg/database"
   9  	"next.orly.dev/pkg/lol/chk"
  10  	"next.orly.dev/pkg/lol/log"
  11  	"next.orly.dev/pkg/nostr/encoders/event"
  12  	"next.orly.dev/pkg/nostr/encoders/filter"
  13  	hexenc "next.orly.dev/pkg/nostr/encoders/hex"
  14  	"next.orly.dev/pkg/nostr/encoders/kind"
  15  	"next.orly.dev/pkg/nostr/encoders/tag"
  16  )
  17  
  18  const (
  19  	// dmStrangerLimit is the max DMs a sender can send to a recipient who hasn't
  20  	// replied yet. After this limit, further DMs are rejected until the recipient
  21  	// sends a DM back (bidirectional = implicit whitelist).
  22  	dmStrangerLimit = 3
  23  
  24  	// dmBidirectionalCacheTTL is how long we cache the "has recipient replied" check.
  25  	dmBidirectionalCacheTTL = 5 * time.Minute
  26  )
  27  
  28  // dmPairKey identifies a sender→recipient DM direction.
  29  type dmPairKey struct {
  30  	sender    string // hex pubkey
  31  	recipient string // hex pubkey
  32  }
  33  
  34  // dmPairState tracks the state of a sender→recipient DM pair.
  35  type dmPairState struct {
  36  	count        int       // messages sent in this direction
  37  	bidirectional bool     // true if recipient has replied (cached)
  38  	checkedAt    time.Time // when bidirectional was last checked
  39  }
  40  
  41  // DMRateLimiter enforces a per-pair message limit for DMs to strangers.
  42  // Once the recipient replies (bidirectional traffic), the limit is lifted.
  43  type DMRateLimiter struct {
  44  	db    database.Database
  45  	mu    sync.Mutex
  46  	pairs map[dmPairKey]*dmPairState
  47  }
  48  
  49  // NewDMRateLimiter creates a new DM rate limiter.
  50  func NewDMRateLimiter(db database.Database) *DMRateLimiter {
  51  	return &DMRateLimiter{
  52  		db:    db,
  53  		pairs: make(map[dmPairKey]*dmPairState),
  54  	}
  55  }
  56  
  57  // CheckDM checks whether a DM event should be allowed. Returns true if allowed,
  58  // false with a reason message if rejected.
  59  //
  60  // DM kinds: 4 (EncryptedDirectMessage), 1059 (GiftWrap)
  61  func (r *DMRateLimiter) CheckDM(ctx context.Context, ev *event.E) (allowed bool, reason string) {
  62  	// Only apply to DM kinds
  63  	if ev.Kind != kind.EncryptedDirectMessage.K && ev.Kind != kind.GiftWrap.K {
  64  		return true, ""
  65  	}
  66  
  67  	senderHex := hexenc.Enc(ev.Pubkey)
  68  
  69  	// Extract recipient from #p tag
  70  	recipientHex := extractRecipient(ev)
  71  	if recipientHex == "" {
  72  		// No recipient found — allow (might be malformed)
  73  		return true, ""
  74  	}
  75  
  76  	// Same person — allow
  77  	if senderHex == recipientHex {
  78  		return true, ""
  79  	}
  80  
  81  	key := dmPairKey{sender: senderHex, recipient: recipientHex}
  82  
  83  	r.mu.Lock()
  84  	defer r.mu.Unlock()
  85  
  86  	state, exists := r.pairs[key]
  87  	if !exists {
  88  		state = &dmPairState{}
  89  		r.pairs[key] = state
  90  	}
  91  
  92  	// Check if this pair is bidirectional (recipient has replied)
  93  	if state.bidirectional && time.Since(state.checkedAt) < dmBidirectionalCacheTTL {
  94  		// Cached bidirectional — allow unlimited
  95  		return true, ""
  96  	}
  97  
  98  	// Check DB for bidirectional traffic (does recipient→sender exist?)
  99  	if r.checkBidirectional(ctx, recipientHex, senderHex) {
 100  		state.bidirectional = true
 101  		state.checkedAt = time.Now()
 102  		// Also mark the reverse direction as bidirectional
 103  		reverseKey := dmPairKey{sender: recipientHex, recipient: senderHex}
 104  		if reverseState, ok := r.pairs[reverseKey]; ok {
 105  			reverseState.bidirectional = true
 106  			reverseState.checkedAt = time.Now()
 107  		}
 108  		return true, ""
 109  	}
 110  
 111  	state.checkedAt = time.Now()
 112  
 113  	// Enforce stranger limit
 114  	state.count++
 115  	if state.count > dmStrangerLimit {
 116  		log.D.F("DM rate limit: %s → %s rejected (count=%d, limit=%d)",
 117  			senderHex[:12], recipientHex[:12], state.count, dmStrangerLimit)
 118  		return false, "restricted: DM limit reached, recipient has not accepted your messages"
 119  	}
 120  
 121  	return true, ""
 122  }
 123  
 124  // OnDMIngested should be called after a DM is successfully saved.
 125  // It updates the bidirectional cache if the DM was from a tracked recipient.
 126  func (r *DMRateLimiter) OnDMIngested(ev *event.E) {
 127  	if ev.Kind != kind.EncryptedDirectMessage.K && ev.Kind != kind.GiftWrap.K {
 128  		return
 129  	}
 130  
 131  	senderHex := hexenc.Enc(ev.Pubkey)
 132  	recipientHex := extractRecipient(ev)
 133  	if recipientHex == "" || senderHex == recipientHex {
 134  		return
 135  	}
 136  
 137  	r.mu.Lock()
 138  	defer r.mu.Unlock()
 139  
 140  	// If someone sends a reply, mark the reverse direction as bidirectional
 141  	reverseKey := dmPairKey{sender: recipientHex, recipient: senderHex}
 142  	if reverseState, ok := r.pairs[reverseKey]; ok {
 143  		reverseState.bidirectional = true
 144  		reverseState.checkedAt = time.Now()
 145  	}
 146  
 147  	// Also mark the forward direction
 148  	forwardKey := dmPairKey{sender: senderHex, recipient: recipientHex}
 149  	if forwardState, ok := r.pairs[forwardKey]; ok {
 150  		forwardState.bidirectional = true
 151  		forwardState.checkedAt = time.Now()
 152  	}
 153  }
 154  
 155  // checkBidirectional queries the DB to see if the recipient has ever sent a DM
 156  // to the sender (i.e., a kind 4 or 1059 event from recipient with #p tag containing sender).
 157  func (r *DMRateLimiter) checkBidirectional(ctx context.Context, recipientHex, senderHex string) bool {
 158  	recipientBytes, err := hexenc.Dec(recipientHex)
 159  	if err != nil {
 160  		return false
 161  	}
 162  	senderBytes, err := hexenc.Dec(senderHex)
 163  	if err != nil {
 164  		return false
 165  	}
 166  
 167  	// Query for kind 4 from recipient to sender
 168  	for _, k := range []*kind.K{kind.EncryptedDirectMessage, kind.GiftWrap} {
 169  		f := filter.New()
 170  		f.Kinds = kind.NewS(k)
 171  		f.Authors = tag.NewFromBytesSlice(recipientBytes)
 172  
 173  		// Filter by #p tag = sender
 174  		pTag := tag.NewFromBytesSlice([]byte("p"), senderBytes)
 175  		f.Tags = tag.NewSWithCap(1)
 176  		*f.Tags = append(*f.Tags, pTag)
 177  
 178  		limit := uint(1)
 179  		f.Limit = &limit
 180  
 181  		queryCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
 182  		events, qErr := r.db.QueryEvents(queryCtx, f)
 183  		cancel()
 184  
 185  		if chk.E(qErr) {
 186  			continue
 187  		}
 188  		if len(events) > 0 {
 189  			return true
 190  		}
 191  	}
 192  
 193  	return false
 194  }
 195  
 196  // extractRecipient gets the recipient pubkey (hex) from an event's first #p tag.
 197  func extractRecipient(ev *event.E) string {
 198  	if ev.Tags == nil {
 199  		return ""
 200  	}
 201  	pTags := ev.Tags.GetAll([]byte("p"))
 202  	for _, pt := range pTags {
 203  		if pt.Len() >= 2 {
 204  			val := pt.ValueHex()
 205  			if len(val) > 0 {
 206  				return string(val)
 207  			}
 208  		}
 209  	}
 210  	return ""
 211  }
 212