ratelimit.go raw

   1  package bridge
   2  
   3  import (
   4  	"fmt"
   5  	"sync"
   6  	"time"
   7  )
   8  
   9  // RateLimitConfig holds rate limit configuration values.
  10  type RateLimitConfig struct {
  11  	PerUserPerHour  int           // Max emails per user per hour (default: 10)
  12  	PerUserPerDay   int           // Max emails per user per day (default: 50)
  13  	GlobalPerHour   int           // Max emails globally per hour (default: 100)
  14  	GlobalPerDay    int           // Max emails globally per day (default: 500)
  15  	MinInterval     time.Duration // Min time between sends per user (default: 30s)
  16  }
  17  
  18  // DefaultRateLimitConfig returns sensible defaults per the spec.
  19  func DefaultRateLimitConfig() RateLimitConfig {
  20  	return RateLimitConfig{
  21  		PerUserPerHour: 10,
  22  		PerUserPerDay:  50,
  23  		GlobalPerHour:  100,
  24  		GlobalPerDay:   500,
  25  		MinInterval:    30 * time.Second,
  26  	}
  27  }
  28  
  29  // RateLimiter tracks outbound email sending rates using sliding windows.
  30  type RateLimiter struct {
  31  	cfg    RateLimitConfig
  32  	mu     sync.Mutex
  33  	users  map[string]*userWindow
  34  	global *window
  35  }
  36  
  37  // NewRateLimiter creates a rate limiter with the given config.
  38  func NewRateLimiter(cfg RateLimitConfig) *RateLimiter {
  39  	return &RateLimiter{
  40  		cfg:    cfg,
  41  		users:  make(map[string]*userWindow),
  42  		global: newWindow(),
  43  	}
  44  }
  45  
  46  // Check returns nil if the user is allowed to send, or an error describing
  47  // when they can retry.
  48  func (rl *RateLimiter) Check(pubkeyHex string) error {
  49  	rl.mu.Lock()
  50  	defer rl.mu.Unlock()
  51  
  52  	now := time.Now()
  53  
  54  	uw := rl.getUser(pubkeyHex)
  55  
  56  	// Check minimum interval
  57  	if rl.cfg.MinInterval > 0 && !uw.lastSend.IsZero() {
  58  		elapsed := now.Sub(uw.lastSend)
  59  		if elapsed < rl.cfg.MinInterval {
  60  			wait := rl.cfg.MinInterval - elapsed
  61  			return fmt.Errorf("rate limited: wait %v between sends", wait.Round(time.Second))
  62  		}
  63  	}
  64  
  65  	// Check per-user per-hour
  66  	if rl.cfg.PerUserPerHour > 0 {
  67  		count := uw.hourly.countSince(now.Add(-time.Hour))
  68  		if count >= rl.cfg.PerUserPerHour {
  69  			return fmt.Errorf("rate limited: %d emails per hour limit reached", rl.cfg.PerUserPerHour)
  70  		}
  71  	}
  72  
  73  	// Check per-user per-day
  74  	if rl.cfg.PerUserPerDay > 0 {
  75  		count := uw.daily.countSince(now.Add(-24 * time.Hour))
  76  		if count >= rl.cfg.PerUserPerDay {
  77  			return fmt.Errorf("rate limited: %d emails per day limit reached", rl.cfg.PerUserPerDay)
  78  		}
  79  	}
  80  
  81  	// Check global per-hour
  82  	if rl.cfg.GlobalPerHour > 0 {
  83  		count := rl.global.countSince(now.Add(-time.Hour))
  84  		if count >= rl.cfg.GlobalPerHour {
  85  			return fmt.Errorf("rate limited: global hourly limit (%d) reached", rl.cfg.GlobalPerHour)
  86  		}
  87  	}
  88  
  89  	// Check global per-day
  90  	if rl.cfg.GlobalPerDay > 0 {
  91  		count := rl.global.countSince(now.Add(-24 * time.Hour))
  92  		if count >= rl.cfg.GlobalPerDay {
  93  			return fmt.Errorf("rate limited: global daily limit (%d) reached", rl.cfg.GlobalPerDay)
  94  		}
  95  	}
  96  
  97  	return nil
  98  }
  99  
 100  // Record records a send event for rate limiting purposes.
 101  // Call this after a successful send.
 102  func (rl *RateLimiter) Record(pubkeyHex string) {
 103  	rl.mu.Lock()
 104  	defer rl.mu.Unlock()
 105  
 106  	now := time.Now()
 107  
 108  	uw := rl.getUser(pubkeyHex)
 109  	uw.lastSend = now
 110  	uw.hourly.add(now)
 111  	uw.daily.add(now)
 112  
 113  	rl.global.add(now)
 114  }
 115  
 116  func (rl *RateLimiter) getUser(pubkeyHex string) *userWindow {
 117  	uw, ok := rl.users[pubkeyHex]
 118  	if !ok {
 119  		uw = &userWindow{
 120  			hourly: newWindow(),
 121  			daily:  newWindow(),
 122  		}
 123  		rl.users[pubkeyHex] = uw
 124  	}
 125  	return uw
 126  }
 127  
 128  // userWindow tracks per-user rate limiting state.
 129  type userWindow struct {
 130  	lastSend time.Time
 131  	hourly   *window
 132  	daily    *window
 133  }
 134  
 135  // window is a sliding window of timestamps for counting events.
 136  type window struct {
 137  	times []time.Time
 138  }
 139  
 140  func newWindow() *window {
 141  	return &window{}
 142  }
 143  
 144  // add records a new event timestamp.
 145  func (w *window) add(t time.Time) {
 146  	w.times = append(w.times, t)
 147  }
 148  
 149  // countSince returns the number of events since the given time,
 150  // pruning old entries as a side effect.
 151  func (w *window) countSince(since time.Time) int {
 152  	// Prune old entries
 153  	n := 0
 154  	for _, t := range w.times {
 155  		if !t.Before(since) {
 156  			w.times[n] = t
 157  			n++
 158  		}
 159  	}
 160  	w.times = w.times[:n]
 161  	return n
 162  }
 163