ratelimit.go raw
1 package bridge
2
3 import (
4 "fmt"
5 "sync"
6 "time"
7 )
8
9 // RateLimitConfig holds rate limit configuration values.
10 type RateLimitConfig struct {
11 PerUserPerHour int // Max emails per user per hour (default: 10)
12 PerUserPerDay int // Max emails per user per day (default: 50)
13 GlobalPerHour int // Max emails globally per hour (default: 100)
14 GlobalPerDay int // Max emails globally per day (default: 500)
15 MinInterval time.Duration // Min time between sends per user (default: 30s)
16 }
17
18 // DefaultRateLimitConfig returns sensible defaults per the spec.
19 func DefaultRateLimitConfig() RateLimitConfig {
20 return RateLimitConfig{
21 PerUserPerHour: 10,
22 PerUserPerDay: 50,
23 GlobalPerHour: 100,
24 GlobalPerDay: 500,
25 MinInterval: 30 * time.Second,
26 }
27 }
28
29 // RateLimiter tracks outbound email sending rates using sliding windows.
30 type RateLimiter struct {
31 cfg RateLimitConfig
32 mu sync.Mutex
33 users map[string]*userWindow
34 global *window
35 }
36
37 // NewRateLimiter creates a rate limiter with the given config.
38 func NewRateLimiter(cfg RateLimitConfig) *RateLimiter {
39 return &RateLimiter{
40 cfg: cfg,
41 users: make(map[string]*userWindow),
42 global: newWindow(),
43 }
44 }
45
46 // Check returns nil if the user is allowed to send, or an error describing
47 // when they can retry.
48 func (rl *RateLimiter) Check(pubkeyHex string) error {
49 rl.mu.Lock()
50 defer rl.mu.Unlock()
51
52 now := time.Now()
53
54 uw := rl.getUser(pubkeyHex)
55
56 // Check minimum interval
57 if rl.cfg.MinInterval > 0 && !uw.lastSend.IsZero() {
58 elapsed := now.Sub(uw.lastSend)
59 if elapsed < rl.cfg.MinInterval {
60 wait := rl.cfg.MinInterval - elapsed
61 return fmt.Errorf("rate limited: wait %v between sends", wait.Round(time.Second))
62 }
63 }
64
65 // Check per-user per-hour
66 if rl.cfg.PerUserPerHour > 0 {
67 count := uw.hourly.countSince(now.Add(-time.Hour))
68 if count >= rl.cfg.PerUserPerHour {
69 return fmt.Errorf("rate limited: %d emails per hour limit reached", rl.cfg.PerUserPerHour)
70 }
71 }
72
73 // Check per-user per-day
74 if rl.cfg.PerUserPerDay > 0 {
75 count := uw.daily.countSince(now.Add(-24 * time.Hour))
76 if count >= rl.cfg.PerUserPerDay {
77 return fmt.Errorf("rate limited: %d emails per day limit reached", rl.cfg.PerUserPerDay)
78 }
79 }
80
81 // Check global per-hour
82 if rl.cfg.GlobalPerHour > 0 {
83 count := rl.global.countSince(now.Add(-time.Hour))
84 if count >= rl.cfg.GlobalPerHour {
85 return fmt.Errorf("rate limited: global hourly limit (%d) reached", rl.cfg.GlobalPerHour)
86 }
87 }
88
89 // Check global per-day
90 if rl.cfg.GlobalPerDay > 0 {
91 count := rl.global.countSince(now.Add(-24 * time.Hour))
92 if count >= rl.cfg.GlobalPerDay {
93 return fmt.Errorf("rate limited: global daily limit (%d) reached", rl.cfg.GlobalPerDay)
94 }
95 }
96
97 return nil
98 }
99
100 // Record records a send event for rate limiting purposes.
101 // Call this after a successful send.
102 func (rl *RateLimiter) Record(pubkeyHex string) {
103 rl.mu.Lock()
104 defer rl.mu.Unlock()
105
106 now := time.Now()
107
108 uw := rl.getUser(pubkeyHex)
109 uw.lastSend = now
110 uw.hourly.add(now)
111 uw.daily.add(now)
112
113 rl.global.add(now)
114 }
115
116 func (rl *RateLimiter) getUser(pubkeyHex string) *userWindow {
117 uw, ok := rl.users[pubkeyHex]
118 if !ok {
119 uw = &userWindow{
120 hourly: newWindow(),
121 daily: newWindow(),
122 }
123 rl.users[pubkeyHex] = uw
124 }
125 return uw
126 }
127
128 // userWindow tracks per-user rate limiting state.
129 type userWindow struct {
130 lastSend time.Time
131 hourly *window
132 daily *window
133 }
134
135 // window is a sliding window of timestamps for counting events.
136 type window struct {
137 times []time.Time
138 }
139
140 func newWindow() *window {
141 return &window{}
142 }
143
144 // add records a new event timestamp.
145 func (w *window) add(t time.Time) {
146 w.times = append(w.times, t)
147 }
148
149 // countSince returns the number of events since the given time,
150 // pruning old entries as a side effect.
151 func (w *window) countSince(since time.Time) int {
152 // Prune old entries
153 n := 0
154 for _, t := range w.times {
155 if !t.Before(since) {
156 w.times[n] = t
157 n++
158 }
159 }
160 w.times = w.times[:n]
161 return n
162 }
163