ratelimit_test.go raw

   1  package graph
   2  
   3  import (
   4  	"context"
   5  	"testing"
   6  	"time"
   7  )
   8  
   9  func TestRateLimiterQueryCost(t *testing.T) {
  10  	rl := NewRateLimiter(DefaultRateLimiterConfig())
  11  
  12  	tests := []struct {
  13  		name    string
  14  		query   *Query
  15  		minCost float64
  16  		maxCost float64
  17  	}{
  18  		{
  19  			name:    "nil query",
  20  			query:   nil,
  21  			minCost: 1.0,
  22  			maxCost: 1.0,
  23  		},
  24  		{
  25  			name:    "depth 1 unidirectional",
  26  			query:   &Query{Pubkey: "abc", Depth: 1, Edge: "pp", Direction: "out"},
  27  			minCost: 1.5, // depthFactor^1 = 2
  28  			maxCost: 2.5,
  29  		},
  30  		{
  31  			name:    "depth 2 unidirectional",
  32  			query:   &Query{Pubkey: "abc", Depth: 2, Edge: "pp", Direction: "out"},
  33  			minCost: 3.5, // depthFactor^2 = 4
  34  			maxCost: 4.5,
  35  		},
  36  		{
  37  			name:    "depth 3 unidirectional",
  38  			query:   &Query{Pubkey: "abc", Depth: 3, Edge: "pp", Direction: "out"},
  39  			minCost: 7.5, // depthFactor^3 = 8
  40  			maxCost: 8.5,
  41  		},
  42  		{
  43  			name:    "depth 2 bidirectional (1.5x cost)",
  44  			query:   &Query{Pubkey: "abc", Depth: 2, Edge: "pp", Direction: "both"},
  45  			minCost: 5.5, // depthFactor^2 * 1.5 = 6
  46  			maxCost: 6.5,
  47  		},
  48  	}
  49  
  50  	for _, tt := range tests {
  51  		t.Run(tt.name, func(t *testing.T) {
  52  			cost := rl.QueryCost(tt.query)
  53  			if cost < tt.minCost || cost > tt.maxCost {
  54  				t.Errorf("QueryCost() = %v, want between %v and %v", cost, tt.minCost, tt.maxCost)
  55  			}
  56  		})
  57  	}
  58  }
  59  
  60  func TestRateLimiterOperationCost(t *testing.T) {
  61  	rl := NewRateLimiter(DefaultRateLimiterConfig())
  62  
  63  	// Depth 0, 1 node
  64  	cost0 := rl.OperationCost(0, 1)
  65  	if cost0 < 1.0 || cost0 > 1.1 {
  66  		t.Errorf("OperationCost(0, 1) = %v, want ~1.01", cost0)
  67  	}
  68  
  69  	// Depth 1, 1 node
  70  	cost1 := rl.OperationCost(1, 1)
  71  	if cost1 < 2.0 || cost1 > 2.1 {
  72  		t.Errorf("OperationCost(1, 1) = %v, want ~2.02", cost1)
  73  	}
  74  
  75  	// Depth 2, 100 nodes
  76  	cost2 := rl.OperationCost(2, 100)
  77  	if cost2 < 8.0 {
  78  		t.Errorf("OperationCost(2, 100) = %v, want > 8", cost2)
  79  	}
  80  }
  81  
  82  func TestRateLimiterAcquire(t *testing.T) {
  83  	cfg := DefaultRateLimiterConfig()
  84  	cfg.MaxTokens = 10
  85  	cfg.RefillRate = 100 // Fast refill for testing
  86  	rl := NewRateLimiter(cfg)
  87  
  88  	ctx := context.Background()
  89  
  90  	// Should acquire immediately when tokens available
  91  	delay, err := rl.Acquire(ctx, 5)
  92  	if err != nil {
  93  		t.Fatalf("unexpected error: %v", err)
  94  	}
  95  	if delay > time.Millisecond*10 {
  96  		t.Errorf("expected minimal delay, got %v", delay)
  97  	}
  98  
  99  	// Check remaining tokens
 100  	remaining := rl.AvailableTokens()
 101  	if remaining > 6 {
 102  		t.Errorf("expected ~5 tokens remaining, got %v", remaining)
 103  	}
 104  }
 105  
 106  func TestRateLimiterTryAcquire(t *testing.T) {
 107  	cfg := DefaultRateLimiterConfig()
 108  	cfg.MaxTokens = 10
 109  	rl := NewRateLimiter(cfg)
 110  
 111  	// Should succeed with enough tokens
 112  	if !rl.TryAcquire(5) {
 113  		t.Error("TryAcquire(5) should succeed with 10 tokens")
 114  	}
 115  
 116  	// Should succeed again
 117  	if !rl.TryAcquire(5) {
 118  		t.Error("TryAcquire(5) should succeed with 5 tokens")
 119  	}
 120  
 121  	// Should fail with insufficient tokens
 122  	if rl.TryAcquire(1) {
 123  		t.Error("TryAcquire(1) should fail with 0 tokens")
 124  	}
 125  }
 126  
 127  func TestRateLimiterContextCancellation(t *testing.T) {
 128  	cfg := DefaultRateLimiterConfig()
 129  	cfg.MaxTokens = 1
 130  	cfg.RefillRate = 0.1 // Very slow refill
 131  	rl := NewRateLimiter(cfg)
 132  
 133  	// Drain tokens
 134  	rl.TryAcquire(1)
 135  
 136  	// Create cancellable context
 137  	ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
 138  	defer cancel()
 139  
 140  	// Try to acquire - should be cancelled
 141  	_, err := rl.Acquire(ctx, 10)
 142  	if err != context.DeadlineExceeded {
 143  		t.Errorf("expected DeadlineExceeded, got %v", err)
 144  	}
 145  }
 146  
 147  func TestRateLimiterRefill(t *testing.T) {
 148  	cfg := DefaultRateLimiterConfig()
 149  	cfg.MaxTokens = 10
 150  	cfg.RefillRate = 1000 // 1000 tokens per second
 151  	rl := NewRateLimiter(cfg)
 152  
 153  	// Drain tokens
 154  	rl.TryAcquire(10)
 155  
 156  	// Wait for refill
 157  	time.Sleep(15 * time.Millisecond)
 158  
 159  	// Should have some tokens now
 160  	available := rl.AvailableTokens()
 161  	if available < 5 {
 162  		t.Errorf("expected >= 5 tokens after 15ms at 1000/s, got %v", available)
 163  	}
 164  	if available > 10 {
 165  		t.Errorf("expected <= 10 tokens (max), got %v", available)
 166  	}
 167  }
 168  
 169  func TestRateLimiterPause(t *testing.T) {
 170  	rl := NewRateLimiter(DefaultRateLimiterConfig())
 171  	ctx := context.Background()
 172  
 173  	start := time.Now()
 174  	err := rl.Pause(ctx, 1, 0)
 175  	elapsed := time.Since(start)
 176  
 177  	if err != nil {
 178  		t.Fatalf("unexpected error: %v", err)
 179  	}
 180  
 181  	// Should have paused for at least baseDelay
 182  	if elapsed < rl.baseDelay {
 183  		t.Errorf("pause duration %v < baseDelay %v", elapsed, rl.baseDelay)
 184  	}
 185  }
 186  
 187  func TestThrottler(t *testing.T) {
 188  	cfg := DefaultRateLimiterConfig()
 189  	cfg.BaseDelay = 100 * time.Microsecond // Short for testing
 190  	rl := NewRateLimiter(cfg)
 191  
 192  	throttler := NewThrottler(rl, 1)
 193  	ctx := context.Background()
 194  
 195  	// Process items
 196  	for i := 0; i < 100; i++ {
 197  		if err := throttler.Tick(ctx); err != nil {
 198  			t.Fatalf("unexpected error at tick %d: %v", i, err)
 199  		}
 200  	}
 201  
 202  	processed := throttler.Complete()
 203  	if processed != 100 {
 204  		t.Errorf("expected 100 items processed, got %d", processed)
 205  	}
 206  }
 207  
 208  func TestThrottlerContextCancellation(t *testing.T) {
 209  	cfg := DefaultRateLimiterConfig()
 210  	rl := NewRateLimiter(cfg)
 211  
 212  	throttler := NewThrottler(rl, 2) // depth 2 = more frequent pauses
 213  	ctx, cancel := context.WithCancel(context.Background())
 214  
 215  	// Process some items
 216  	for i := 0; i < 20; i++ {
 217  		throttler.Tick(ctx)
 218  	}
 219  
 220  	// Cancel context
 221  	cancel()
 222  
 223  	// Next tick that would pause should return error
 224  	for i := 0; i < 100; i++ {
 225  		if err := throttler.Tick(ctx); err != nil {
 226  			// Expected - context was cancelled
 227  			return
 228  		}
 229  	}
 230  	// If we get here without error, the throttler didn't check context
 231  	// This is acceptable if no pause was needed
 232  }
 233