ratelimit_test.go raw

   1  package bridge
   2  
   3  import (
   4  	"strings"
   5  	"testing"
   6  	"time"
   7  )
   8  
   9  func TestRateLimiter_AllowsInitialSend(t *testing.T) {
  10  	rl := NewRateLimiter(DefaultRateLimitConfig())
  11  
  12  	if err := rl.Check("user1"); err != nil {
  13  		t.Errorf("first send should be allowed: %v", err)
  14  	}
  15  }
  16  
  17  func TestRateLimiter_MinInterval(t *testing.T) {
  18  	cfg := RateLimitConfig{
  19  		MinInterval:    1 * time.Second,
  20  		PerUserPerHour: 100, // high limits so only interval matters
  21  		PerUserPerDay:  1000,
  22  		GlobalPerHour:  10000,
  23  		GlobalPerDay:   100000,
  24  	}
  25  	rl := NewRateLimiter(cfg)
  26  
  27  	// First send
  28  	if err := rl.Check("user1"); err != nil {
  29  		t.Fatalf("first check failed: %v", err)
  30  	}
  31  	rl.Record("user1")
  32  
  33  	// Immediate retry should be rate limited
  34  	if err := rl.Check("user1"); err == nil {
  35  		t.Error("expected rate limit error for immediate retry")
  36  	} else if !strings.Contains(err.Error(), "wait") {
  37  		t.Errorf("error should mention wait time: %v", err)
  38  	}
  39  
  40  	// Different user should be fine
  41  	if err := rl.Check("user2"); err != nil {
  42  		t.Errorf("different user should not be limited: %v", err)
  43  	}
  44  }
  45  
  46  func TestRateLimiter_PerUserPerHour(t *testing.T) {
  47  	cfg := RateLimitConfig{
  48  		MinInterval:    0, // disable interval check
  49  		PerUserPerHour: 3,
  50  		PerUserPerDay:  1000,
  51  		GlobalPerHour:  10000,
  52  		GlobalPerDay:   100000,
  53  	}
  54  	rl := NewRateLimiter(cfg)
  55  
  56  	for i := 0; i < 3; i++ {
  57  		if err := rl.Check("user1"); err != nil {
  58  			t.Fatalf("send %d should be allowed: %v", i+1, err)
  59  		}
  60  		rl.Record("user1")
  61  	}
  62  
  63  	// 4th send should be rate limited
  64  	if err := rl.Check("user1"); err == nil {
  65  		t.Error("expected per-user hourly rate limit")
  66  	} else if !strings.Contains(err.Error(), "per hour") {
  67  		t.Errorf("error should mention per hour: %v", err)
  68  	}
  69  
  70  	// Different user should still be fine
  71  	if err := rl.Check("user2"); err != nil {
  72  		t.Errorf("different user should not be limited: %v", err)
  73  	}
  74  }
  75  
  76  func TestRateLimiter_PerUserPerDay(t *testing.T) {
  77  	cfg := RateLimitConfig{
  78  		MinInterval:    0,
  79  		PerUserPerHour: 1000, // high
  80  		PerUserPerDay:  3,
  81  		GlobalPerHour:  10000,
  82  		GlobalPerDay:   100000,
  83  	}
  84  	rl := NewRateLimiter(cfg)
  85  
  86  	for i := 0; i < 3; i++ {
  87  		if err := rl.Check("user1"); err != nil {
  88  			t.Fatalf("send %d should be allowed: %v", i+1, err)
  89  		}
  90  		rl.Record("user1")
  91  	}
  92  
  93  	if err := rl.Check("user1"); err == nil {
  94  		t.Error("expected per-user daily rate limit")
  95  	} else if !strings.Contains(err.Error(), "per day") {
  96  		t.Errorf("error should mention per day: %v", err)
  97  	}
  98  }
  99  
 100  func TestRateLimiter_GlobalPerHour(t *testing.T) {
 101  	cfg := RateLimitConfig{
 102  		MinInterval:    0,
 103  		PerUserPerHour: 1000,
 104  		PerUserPerDay:  10000,
 105  		GlobalPerHour:  3,
 106  		GlobalPerDay:   100000,
 107  	}
 108  	rl := NewRateLimiter(cfg)
 109  
 110  	// 3 different users each send once
 111  	for i := 0; i < 3; i++ {
 112  		user := "user" + string(rune('A'+i))
 113  		if err := rl.Check(user); err != nil {
 114  			t.Fatalf("send from %s should be allowed: %v", user, err)
 115  		}
 116  		rl.Record(user)
 117  	}
 118  
 119  	// 4th user should hit global limit
 120  	if err := rl.Check("userD"); err == nil {
 121  		t.Error("expected global hourly rate limit")
 122  	} else if !strings.Contains(err.Error(), "global hourly") {
 123  		t.Errorf("error should mention global hourly: %v", err)
 124  	}
 125  }
 126  
 127  func TestRateLimiter_GlobalPerDay(t *testing.T) {
 128  	cfg := RateLimitConfig{
 129  		MinInterval:    0,
 130  		PerUserPerHour: 1000,
 131  		PerUserPerDay:  10000,
 132  		GlobalPerHour:  10000,
 133  		GlobalPerDay:   3,
 134  	}
 135  	rl := NewRateLimiter(cfg)
 136  
 137  	for i := 0; i < 3; i++ {
 138  		user := "user" + string(rune('A'+i))
 139  		if err := rl.Check(user); err != nil {
 140  			t.Fatalf("send from %s should be allowed: %v", user, err)
 141  		}
 142  		rl.Record(user)
 143  	}
 144  
 145  	if err := rl.Check("userD"); err == nil {
 146  		t.Error("expected global daily rate limit")
 147  	} else if !strings.Contains(err.Error(), "global daily") {
 148  		t.Errorf("error should mention global daily: %v", err)
 149  	}
 150  }
 151  
 152  func TestRateLimiter_ZeroConfigDisables(t *testing.T) {
 153  	// All zeros means no limits
 154  	cfg := RateLimitConfig{}
 155  	rl := NewRateLimiter(cfg)
 156  
 157  	for i := 0; i < 100; i++ {
 158  		if err := rl.Check("user1"); err != nil {
 159  			t.Fatalf("send %d should be allowed with zero config: %v", i+1, err)
 160  		}
 161  		rl.Record("user1")
 162  	}
 163  }
 164  
 165  func TestRateLimiter_DefaultConfig(t *testing.T) {
 166  	cfg := DefaultRateLimitConfig()
 167  
 168  	if cfg.PerUserPerHour != 10 {
 169  		t.Errorf("PerUserPerHour = %d, want 10", cfg.PerUserPerHour)
 170  	}
 171  	if cfg.PerUserPerDay != 50 {
 172  		t.Errorf("PerUserPerDay = %d, want 50", cfg.PerUserPerDay)
 173  	}
 174  	if cfg.GlobalPerHour != 100 {
 175  		t.Errorf("GlobalPerHour = %d, want 100", cfg.GlobalPerHour)
 176  	}
 177  	if cfg.GlobalPerDay != 500 {
 178  		t.Errorf("GlobalPerDay = %d, want 500", cfg.GlobalPerDay)
 179  	}
 180  	if cfg.MinInterval != 30*time.Second {
 181  		t.Errorf("MinInterval = %v, want 30s", cfg.MinInterval)
 182  	}
 183  }
 184  
 185  func TestWindow_CountSince_Prunes(t *testing.T) {
 186  	w := newWindow()
 187  
 188  	now := time.Now()
 189  	// Add 3 old timestamps and 2 recent ones
 190  	w.add(now.Add(-2 * time.Hour))
 191  	w.add(now.Add(-90 * time.Minute))
 192  	w.add(now.Add(-61 * time.Minute))
 193  	w.add(now.Add(-30 * time.Minute))
 194  	w.add(now.Add(-5 * time.Minute))
 195  
 196  	count := w.countSince(now.Add(-time.Hour))
 197  	if count != 2 {
 198  		t.Errorf("countSince = %d, want 2", count)
 199  	}
 200  
 201  	// Old entries should be pruned
 202  	if len(w.times) != 2 {
 203  		t.Errorf("after prune, len = %d, want 2", len(w.times))
 204  	}
 205  }
 206  
 207  func TestRateLimiter_ConcurrentAccess(t *testing.T) {
 208  	rl := NewRateLimiter(DefaultRateLimitConfig())
 209  
 210  	done := make(chan struct{})
 211  	for i := 0; i < 10; i++ {
 212  		go func(id int) {
 213  			defer func() { done <- struct{}{} }()
 214  			user := "user" + string(rune('A'+id))
 215  			for j := 0; j < 5; j++ {
 216  				rl.Check(user)
 217  				rl.Record(user)
 218  			}
 219  		}(i)
 220  	}
 221  
 222  	for i := 0; i < 10; i++ {
 223  		<-done
 224  	}
 225  	// If we get here without a race condition panic, the test passes.
 226  }
 227