dm_ratelimit.go raw
1 package app
2
3 import (
4 "context"
5 "sync"
6 "time"
7
8 "next.orly.dev/pkg/database"
9 "next.orly.dev/pkg/lol/chk"
10 "next.orly.dev/pkg/lol/log"
11 "next.orly.dev/pkg/nostr/encoders/event"
12 "next.orly.dev/pkg/nostr/encoders/filter"
13 hexenc "next.orly.dev/pkg/nostr/encoders/hex"
14 "next.orly.dev/pkg/nostr/encoders/kind"
15 "next.orly.dev/pkg/nostr/encoders/tag"
16 )
17
18 const (
19 // dmStrangerLimit is the max DMs a sender can send to a recipient who hasn't
20 // replied yet. After this limit, further DMs are rejected until the recipient
21 // sends a DM back (bidirectional = implicit whitelist).
22 dmStrangerLimit = 3
23
24 // dmBidirectionalCacheTTL is how long we cache the "has recipient replied" check.
25 dmBidirectionalCacheTTL = 5 * time.Minute
26 )
27
28 // dmPairKey identifies a sender→recipient DM direction.
29 type dmPairKey struct {
30 sender string // hex pubkey
31 recipient string // hex pubkey
32 }
33
34 // dmPairState tracks the state of a sender→recipient DM pair.
35 type dmPairState struct {
36 count int // messages sent in this direction
37 bidirectional bool // true if recipient has replied (cached)
38 checkedAt time.Time // when bidirectional was last checked
39 }
40
41 // DMRateLimiter enforces a per-pair message limit for DMs to strangers.
42 // Once the recipient replies (bidirectional traffic), the limit is lifted.
43 type DMRateLimiter struct {
44 db database.Database
45 mu sync.Mutex
46 pairs map[dmPairKey]*dmPairState
47 }
48
49 // NewDMRateLimiter creates a new DM rate limiter.
50 func NewDMRateLimiter(db database.Database) *DMRateLimiter {
51 return &DMRateLimiter{
52 db: db,
53 pairs: make(map[dmPairKey]*dmPairState),
54 }
55 }
56
57 // CheckDM checks whether a DM event should be allowed. Returns true if allowed,
58 // false with a reason message if rejected.
59 //
60 // DM kinds: 4 (EncryptedDirectMessage), 1059 (GiftWrap)
61 func (r *DMRateLimiter) CheckDM(ctx context.Context, ev *event.E) (allowed bool, reason string) {
62 // Only apply to DM kinds
63 if ev.Kind != kind.EncryptedDirectMessage.K && ev.Kind != kind.GiftWrap.K {
64 return true, ""
65 }
66
67 senderHex := hexenc.Enc(ev.Pubkey)
68
69 // Extract recipient from #p tag
70 recipientHex := extractRecipient(ev)
71 if recipientHex == "" {
72 // No recipient found — allow (might be malformed)
73 return true, ""
74 }
75
76 // Same person — allow
77 if senderHex == recipientHex {
78 return true, ""
79 }
80
81 key := dmPairKey{sender: senderHex, recipient: recipientHex}
82
83 r.mu.Lock()
84 defer r.mu.Unlock()
85
86 state, exists := r.pairs[key]
87 if !exists {
88 state = &dmPairState{}
89 r.pairs[key] = state
90 }
91
92 // Check if this pair is bidirectional (recipient has replied)
93 if state.bidirectional && time.Since(state.checkedAt) < dmBidirectionalCacheTTL {
94 // Cached bidirectional — allow unlimited
95 return true, ""
96 }
97
98 // Check DB for bidirectional traffic (does recipient→sender exist?)
99 if r.checkBidirectional(ctx, recipientHex, senderHex) {
100 state.bidirectional = true
101 state.checkedAt = time.Now()
102 // Also mark the reverse direction as bidirectional
103 reverseKey := dmPairKey{sender: recipientHex, recipient: senderHex}
104 if reverseState, ok := r.pairs[reverseKey]; ok {
105 reverseState.bidirectional = true
106 reverseState.checkedAt = time.Now()
107 }
108 return true, ""
109 }
110
111 state.checkedAt = time.Now()
112
113 // Enforce stranger limit
114 state.count++
115 if state.count > dmStrangerLimit {
116 log.D.F("DM rate limit: %s → %s rejected (count=%d, limit=%d)",
117 senderHex[:12], recipientHex[:12], state.count, dmStrangerLimit)
118 return false, "restricted: DM limit reached, recipient has not accepted your messages"
119 }
120
121 return true, ""
122 }
123
124 // OnDMIngested should be called after a DM is successfully saved.
125 // It updates the bidirectional cache if the DM was from a tracked recipient.
126 func (r *DMRateLimiter) OnDMIngested(ev *event.E) {
127 if ev.Kind != kind.EncryptedDirectMessage.K && ev.Kind != kind.GiftWrap.K {
128 return
129 }
130
131 senderHex := hexenc.Enc(ev.Pubkey)
132 recipientHex := extractRecipient(ev)
133 if recipientHex == "" || senderHex == recipientHex {
134 return
135 }
136
137 r.mu.Lock()
138 defer r.mu.Unlock()
139
140 // If someone sends a reply, mark the reverse direction as bidirectional
141 reverseKey := dmPairKey{sender: recipientHex, recipient: senderHex}
142 if reverseState, ok := r.pairs[reverseKey]; ok {
143 reverseState.bidirectional = true
144 reverseState.checkedAt = time.Now()
145 }
146
147 // Also mark the forward direction
148 forwardKey := dmPairKey{sender: senderHex, recipient: recipientHex}
149 if forwardState, ok := r.pairs[forwardKey]; ok {
150 forwardState.bidirectional = true
151 forwardState.checkedAt = time.Now()
152 }
153 }
154
155 // checkBidirectional queries the DB to see if the recipient has ever sent a DM
156 // to the sender (i.e., a kind 4 or 1059 event from recipient with #p tag containing sender).
157 func (r *DMRateLimiter) checkBidirectional(ctx context.Context, recipientHex, senderHex string) bool {
158 recipientBytes, err := hexenc.Dec(recipientHex)
159 if err != nil {
160 return false
161 }
162 senderBytes, err := hexenc.Dec(senderHex)
163 if err != nil {
164 return false
165 }
166
167 // Query for kind 4 from recipient to sender
168 for _, k := range []*kind.K{kind.EncryptedDirectMessage, kind.GiftWrap} {
169 f := filter.New()
170 f.Kinds = kind.NewS(k)
171 f.Authors = tag.NewFromBytesSlice(recipientBytes)
172
173 // Filter by #p tag = sender
174 pTag := tag.NewFromBytesSlice([]byte("p"), senderBytes)
175 f.Tags = tag.NewSWithCap(1)
176 *f.Tags = append(*f.Tags, pTag)
177
178 limit := uint(1)
179 f.Limit = &limit
180
181 queryCtx, cancel := context.WithTimeout(ctx, 3*time.Second)
182 events, qErr := r.db.QueryEvents(queryCtx, f)
183 cancel()
184
185 if chk.E(qErr) {
186 continue
187 }
188 if len(events) > 0 {
189 return true
190 }
191 }
192
193 return false
194 }
195
196 // extractRecipient gets the recipient pubkey (hex) from an event's first #p tag.
197 func extractRecipient(ev *event.E) string {
198 if ev.Tags == nil {
199 return ""
200 }
201 pTags := ev.Tags.GetAll([]byte("p"))
202 for _, pt := range pTags {
203 if pt.Len() >= 2 {
204 val := pt.ValueHex()
205 if len(val) > 0 {
206 return string(val)
207 }
208 }
209 }
210 return ""
211 }
212