ratelimiter.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package ratelimiter
   7  
   8  import (
   9  	"net/netip"
  10  	"sync"
  11  	"time"
  12  )
  13  
  14  const (
  15  	packetsPerSecond   = 20
  16  	packetsBurstable   = 5
  17  	garbageCollectTime = time.Second
  18  	packetCost         = 1000000000 / packetsPerSecond
  19  	maxTokens          = packetCost * packetsBurstable
  20  )
  21  
  22  type RatelimiterEntry struct {
  23  	mu       sync.Mutex
  24  	lastTime time.Time
  25  	tokens   int64
  26  }
  27  
  28  type Ratelimiter struct {
  29  	mu      sync.RWMutex
  30  	timeNow func() time.Time
  31  
  32  	stopReset chan struct{} // send to reset, close to stop
  33  	table     map[netip.Addr]*RatelimiterEntry
  34  }
  35  
  36  func (rate *Ratelimiter) Close() {
  37  	rate.mu.Lock()
  38  	defer rate.mu.Unlock()
  39  
  40  	if rate.stopReset != nil {
  41  		close(rate.stopReset)
  42  	}
  43  }
  44  
  45  func (rate *Ratelimiter) Init() {
  46  	rate.mu.Lock()
  47  	defer rate.mu.Unlock()
  48  
  49  	if rate.timeNow == nil {
  50  		rate.timeNow = time.Now
  51  	}
  52  
  53  	// stop any ongoing garbage collection routine
  54  	if rate.stopReset != nil {
  55  		close(rate.stopReset)
  56  	}
  57  
  58  	rate.stopReset = make(chan struct{})
  59  	rate.table = make(map[netip.Addr]*RatelimiterEntry)
  60  
  61  	stopReset := rate.stopReset // store in case Init is called again.
  62  
  63  	// Start garbage collection routine.
  64  	go func() {
  65  		ticker := time.NewTicker(time.Second)
  66  		ticker.Stop()
  67  		for {
  68  			select {
  69  			case _, ok := <-stopReset:
  70  				ticker.Stop()
  71  				if !ok {
  72  					return
  73  				}
  74  				ticker = time.NewTicker(time.Second)
  75  			case <-ticker.C:
  76  				if rate.cleanup() {
  77  					ticker.Stop()
  78  				}
  79  			}
  80  		}
  81  	}()
  82  }
  83  
  84  func (rate *Ratelimiter) cleanup() (empty bool) {
  85  	rate.mu.Lock()
  86  	defer rate.mu.Unlock()
  87  
  88  	for key, entry := range rate.table {
  89  		entry.mu.Lock()
  90  		if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
  91  			delete(rate.table, key)
  92  		}
  93  		entry.mu.Unlock()
  94  	}
  95  
  96  	return len(rate.table) == 0
  97  }
  98  
  99  func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
 100  	var entry *RatelimiterEntry
 101  	// lookup entry
 102  	rate.mu.RLock()
 103  	entry = rate.table[ip]
 104  	rate.mu.RUnlock()
 105  
 106  	// make new entry if not found
 107  	if entry == nil {
 108  		entry = new(RatelimiterEntry)
 109  		entry.tokens = maxTokens - packetCost
 110  		entry.lastTime = rate.timeNow()
 111  		rate.mu.Lock()
 112  		rate.table[ip] = entry
 113  		if len(rate.table) == 1 {
 114  			rate.stopReset <- struct{}{}
 115  		}
 116  		rate.mu.Unlock()
 117  		return true
 118  	}
 119  
 120  	// add tokens to entry
 121  	entry.mu.Lock()
 122  	now := rate.timeNow()
 123  	entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
 124  	entry.lastTime = now
 125  	if entry.tokens > maxTokens {
 126  		entry.tokens = maxTokens
 127  	}
 128  
 129  	// subtract cost of packet
 130  	if entry.tokens > packetCost {
 131  		entry.tokens -= packetCost
 132  		entry.mu.Unlock()
 133  		return true
 134  	}
 135  	entry.mu.Unlock()
 136  	return false
 137  }
 138