ratelimiter.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package ratelimiter
7
8 import (
9 "net/netip"
10 "sync"
11 "time"
12 )
13
14 const (
15 packetsPerSecond = 20
16 packetsBurstable = 5
17 garbageCollectTime = time.Second
18 packetCost = 1000000000 / packetsPerSecond
19 maxTokens = packetCost * packetsBurstable
20 )
21
22 type RatelimiterEntry struct {
23 mu sync.Mutex
24 lastTime time.Time
25 tokens int64
26 }
27
28 type Ratelimiter struct {
29 mu sync.RWMutex
30 timeNow func() time.Time
31
32 stopReset chan struct{} // send to reset, close to stop
33 table map[netip.Addr]*RatelimiterEntry
34 }
35
36 func (rate *Ratelimiter) Close() {
37 rate.mu.Lock()
38 defer rate.mu.Unlock()
39
40 if rate.stopReset != nil {
41 close(rate.stopReset)
42 }
43 }
44
45 func (rate *Ratelimiter) Init() {
46 rate.mu.Lock()
47 defer rate.mu.Unlock()
48
49 if rate.timeNow == nil {
50 rate.timeNow = time.Now
51 }
52
53 // stop any ongoing garbage collection routine
54 if rate.stopReset != nil {
55 close(rate.stopReset)
56 }
57
58 rate.stopReset = make(chan struct{})
59 rate.table = make(map[netip.Addr]*RatelimiterEntry)
60
61 stopReset := rate.stopReset // store in case Init is called again.
62
63 // Start garbage collection routine.
64 go func() {
65 ticker := time.NewTicker(time.Second)
66 ticker.Stop()
67 for {
68 select {
69 case _, ok := <-stopReset:
70 ticker.Stop()
71 if !ok {
72 return
73 }
74 ticker = time.NewTicker(time.Second)
75 case <-ticker.C:
76 if rate.cleanup() {
77 ticker.Stop()
78 }
79 }
80 }
81 }()
82 }
83
84 func (rate *Ratelimiter) cleanup() (empty bool) {
85 rate.mu.Lock()
86 defer rate.mu.Unlock()
87
88 for key, entry := range rate.table {
89 entry.mu.Lock()
90 if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
91 delete(rate.table, key)
92 }
93 entry.mu.Unlock()
94 }
95
96 return len(rate.table) == 0
97 }
98
99 func (rate *Ratelimiter) Allow(ip netip.Addr) bool {
100 var entry *RatelimiterEntry
101 // lookup entry
102 rate.mu.RLock()
103 entry = rate.table[ip]
104 rate.mu.RUnlock()
105
106 // make new entry if not found
107 if entry == nil {
108 entry = new(RatelimiterEntry)
109 entry.tokens = maxTokens - packetCost
110 entry.lastTime = rate.timeNow()
111 rate.mu.Lock()
112 rate.table[ip] = entry
113 if len(rate.table) == 1 {
114 rate.stopReset <- struct{}{}
115 }
116 rate.mu.Unlock()
117 return true
118 }
119
120 // add tokens to entry
121 entry.mu.Lock()
122 now := rate.timeNow()
123 entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
124 entry.lastTime = now
125 if entry.tokens > maxTokens {
126 entry.tokens = maxTokens
127 }
128
129 // subtract cost of packet
130 if entry.tokens > packetCost {
131 entry.tokens -= packetCost
132 entry.mu.Unlock()
133 return true
134 }
135 entry.mu.Unlock()
136 return false
137 }
138