ratelimit_test.go raw
1 package graph
2
3 import (
4 "context"
5 "testing"
6 "time"
7 )
8
9 func TestRateLimiterQueryCost(t *testing.T) {
10 rl := NewRateLimiter(DefaultRateLimiterConfig())
11
12 tests := []struct {
13 name string
14 query *Query
15 minCost float64
16 maxCost float64
17 }{
18 {
19 name: "nil query",
20 query: nil,
21 minCost: 1.0,
22 maxCost: 1.0,
23 },
24 {
25 name: "depth 1 unidirectional",
26 query: &Query{Pubkey: "abc", Depth: 1, Edge: "pp", Direction: "out"},
27 minCost: 1.5, // depthFactor^1 = 2
28 maxCost: 2.5,
29 },
30 {
31 name: "depth 2 unidirectional",
32 query: &Query{Pubkey: "abc", Depth: 2, Edge: "pp", Direction: "out"},
33 minCost: 3.5, // depthFactor^2 = 4
34 maxCost: 4.5,
35 },
36 {
37 name: "depth 3 unidirectional",
38 query: &Query{Pubkey: "abc", Depth: 3, Edge: "pp", Direction: "out"},
39 minCost: 7.5, // depthFactor^3 = 8
40 maxCost: 8.5,
41 },
42 {
43 name: "depth 2 bidirectional (1.5x cost)",
44 query: &Query{Pubkey: "abc", Depth: 2, Edge: "pp", Direction: "both"},
45 minCost: 5.5, // depthFactor^2 * 1.5 = 6
46 maxCost: 6.5,
47 },
48 }
49
50 for _, tt := range tests {
51 t.Run(tt.name, func(t *testing.T) {
52 cost := rl.QueryCost(tt.query)
53 if cost < tt.minCost || cost > tt.maxCost {
54 t.Errorf("QueryCost() = %v, want between %v and %v", cost, tt.minCost, tt.maxCost)
55 }
56 })
57 }
58 }
59
60 func TestRateLimiterOperationCost(t *testing.T) {
61 rl := NewRateLimiter(DefaultRateLimiterConfig())
62
63 // Depth 0, 1 node
64 cost0 := rl.OperationCost(0, 1)
65 if cost0 < 1.0 || cost0 > 1.1 {
66 t.Errorf("OperationCost(0, 1) = %v, want ~1.01", cost0)
67 }
68
69 // Depth 1, 1 node
70 cost1 := rl.OperationCost(1, 1)
71 if cost1 < 2.0 || cost1 > 2.1 {
72 t.Errorf("OperationCost(1, 1) = %v, want ~2.02", cost1)
73 }
74
75 // Depth 2, 100 nodes
76 cost2 := rl.OperationCost(2, 100)
77 if cost2 < 8.0 {
78 t.Errorf("OperationCost(2, 100) = %v, want > 8", cost2)
79 }
80 }
81
82 func TestRateLimiterAcquire(t *testing.T) {
83 cfg := DefaultRateLimiterConfig()
84 cfg.MaxTokens = 10
85 cfg.RefillRate = 100 // Fast refill for testing
86 rl := NewRateLimiter(cfg)
87
88 ctx := context.Background()
89
90 // Should acquire immediately when tokens available
91 delay, err := rl.Acquire(ctx, 5)
92 if err != nil {
93 t.Fatalf("unexpected error: %v", err)
94 }
95 if delay > time.Millisecond*10 {
96 t.Errorf("expected minimal delay, got %v", delay)
97 }
98
99 // Check remaining tokens
100 remaining := rl.AvailableTokens()
101 if remaining > 6 {
102 t.Errorf("expected ~5 tokens remaining, got %v", remaining)
103 }
104 }
105
106 func TestRateLimiterTryAcquire(t *testing.T) {
107 cfg := DefaultRateLimiterConfig()
108 cfg.MaxTokens = 10
109 rl := NewRateLimiter(cfg)
110
111 // Should succeed with enough tokens
112 if !rl.TryAcquire(5) {
113 t.Error("TryAcquire(5) should succeed with 10 tokens")
114 }
115
116 // Should succeed again
117 if !rl.TryAcquire(5) {
118 t.Error("TryAcquire(5) should succeed with 5 tokens")
119 }
120
121 // Should fail with insufficient tokens
122 if rl.TryAcquire(1) {
123 t.Error("TryAcquire(1) should fail with 0 tokens")
124 }
125 }
126
127 func TestRateLimiterContextCancellation(t *testing.T) {
128 cfg := DefaultRateLimiterConfig()
129 cfg.MaxTokens = 1
130 cfg.RefillRate = 0.1 // Very slow refill
131 rl := NewRateLimiter(cfg)
132
133 // Drain tokens
134 rl.TryAcquire(1)
135
136 // Create cancellable context
137 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
138 defer cancel()
139
140 // Try to acquire - should be cancelled
141 _, err := rl.Acquire(ctx, 10)
142 if err != context.DeadlineExceeded {
143 t.Errorf("expected DeadlineExceeded, got %v", err)
144 }
145 }
146
147 func TestRateLimiterRefill(t *testing.T) {
148 cfg := DefaultRateLimiterConfig()
149 cfg.MaxTokens = 10
150 cfg.RefillRate = 1000 // 1000 tokens per second
151 rl := NewRateLimiter(cfg)
152
153 // Drain tokens
154 rl.TryAcquire(10)
155
156 // Wait for refill
157 time.Sleep(15 * time.Millisecond)
158
159 // Should have some tokens now
160 available := rl.AvailableTokens()
161 if available < 5 {
162 t.Errorf("expected >= 5 tokens after 15ms at 1000/s, got %v", available)
163 }
164 if available > 10 {
165 t.Errorf("expected <= 10 tokens (max), got %v", available)
166 }
167 }
168
169 func TestRateLimiterPause(t *testing.T) {
170 rl := NewRateLimiter(DefaultRateLimiterConfig())
171 ctx := context.Background()
172
173 start := time.Now()
174 err := rl.Pause(ctx, 1, 0)
175 elapsed := time.Since(start)
176
177 if err != nil {
178 t.Fatalf("unexpected error: %v", err)
179 }
180
181 // Should have paused for at least baseDelay
182 if elapsed < rl.baseDelay {
183 t.Errorf("pause duration %v < baseDelay %v", elapsed, rl.baseDelay)
184 }
185 }
186
187 func TestThrottler(t *testing.T) {
188 cfg := DefaultRateLimiterConfig()
189 cfg.BaseDelay = 100 * time.Microsecond // Short for testing
190 rl := NewRateLimiter(cfg)
191
192 throttler := NewThrottler(rl, 1)
193 ctx := context.Background()
194
195 // Process items
196 for i := 0; i < 100; i++ {
197 if err := throttler.Tick(ctx); err != nil {
198 t.Fatalf("unexpected error at tick %d: %v", i, err)
199 }
200 }
201
202 processed := throttler.Complete()
203 if processed != 100 {
204 t.Errorf("expected 100 items processed, got %d", processed)
205 }
206 }
207
208 func TestThrottlerContextCancellation(t *testing.T) {
209 cfg := DefaultRateLimiterConfig()
210 rl := NewRateLimiter(cfg)
211
212 throttler := NewThrottler(rl, 2) // depth 2 = more frequent pauses
213 ctx, cancel := context.WithCancel(context.Background())
214
215 // Process some items
216 for i := 0; i < 20; i++ {
217 throttler.Tick(ctx)
218 }
219
220 // Cancel context
221 cancel()
222
223 // Next tick that would pause should return error
224 for i := 0; i < 100; i++ {
225 if err := throttler.Tick(ctx); err != nil {
226 // Expected - context was cancelled
227 return
228 }
229 }
230 // If we get here without error, the throttler didn't check context
231 // This is acceptable if no pause was needed
232 }
233