allowedips.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package device
7
8 import (
9 "container/list"
10 "encoding/binary"
11 "errors"
12 "math/bits"
13 "net"
14 "net/netip"
15 "sync"
16 "unsafe"
17 )
18
19 type parentIndirection struct {
20 parentBit **trieEntry
21 parentBitType uint8
22 }
23
24 type trieEntry struct {
25 peer *Peer
26 child [2]*trieEntry
27 parent parentIndirection
28 cidr uint8
29 bitAtByte uint8
30 bitAtShift uint8
31 bits []byte
32 perPeerElem *list.Element
33 }
34
35 func commonBits(ip1, ip2 []byte) uint8 {
36 size := len(ip1)
37 if size == net.IPv4len {
38 a := binary.BigEndian.Uint32(ip1)
39 b := binary.BigEndian.Uint32(ip2)
40 x := a ^ b
41 return uint8(bits.LeadingZeros32(x))
42 } else if size == net.IPv6len {
43 a := binary.BigEndian.Uint64(ip1)
44 b := binary.BigEndian.Uint64(ip2)
45 x := a ^ b
46 if x != 0 {
47 return uint8(bits.LeadingZeros64(x))
48 }
49 a = binary.BigEndian.Uint64(ip1[8:])
50 b = binary.BigEndian.Uint64(ip2[8:])
51 x = a ^ b
52 return 64 + uint8(bits.LeadingZeros64(x))
53 } else {
54 panic("Wrong size bit string")
55 }
56 }
57
58 func (node *trieEntry) addToPeerEntries() {
59 node.perPeerElem = node.peer.trieEntries.PushBack(node)
60 }
61
62 func (node *trieEntry) removeFromPeerEntries() {
63 if node.perPeerElem != nil {
64 node.peer.trieEntries.Remove(node.perPeerElem)
65 node.perPeerElem = nil
66 }
67 }
68
69 func (node *trieEntry) choose(ip []byte) byte {
70 return (ip[node.bitAtByte] >> node.bitAtShift) & 1
71 }
72
73 func (node *trieEntry) maskSelf() {
74 mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
75 for i := 0; i < len(mask); i++ {
76 node.bits[i] &= mask[i]
77 }
78 }
79
80 func (node *trieEntry) zeroizePointers() {
81 // Make the garbage collector's life slightly easier
82 node.peer = nil
83 node.child[0] = nil
84 node.child[1] = nil
85 node.parent.parentBit = nil
86 }
87
88 func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
89 for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
90 parent = node
91 if parent.cidr == cidr {
92 exact = true
93 return
94 }
95 bit := node.choose(ip)
96 node = node.child[bit]
97 }
98 return
99 }
100
101 func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
102 if *trie.parentBit == nil {
103 node := &trieEntry{
104 peer: peer,
105 parent: trie,
106 bits: ip,
107 cidr: cidr,
108 bitAtByte: cidr / 8,
109 bitAtShift: 7 - (cidr % 8),
110 }
111 node.maskSelf()
112 node.addToPeerEntries()
113 *trie.parentBit = node
114 return
115 }
116 node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
117 if exact {
118 node.removeFromPeerEntries()
119 node.peer = peer
120 node.addToPeerEntries()
121 return
122 }
123
124 newNode := &trieEntry{
125 peer: peer,
126 bits: ip,
127 cidr: cidr,
128 bitAtByte: cidr / 8,
129 bitAtShift: 7 - (cidr % 8),
130 }
131 newNode.maskSelf()
132 newNode.addToPeerEntries()
133
134 var down *trieEntry
135 if node == nil {
136 down = *trie.parentBit
137 } else {
138 bit := node.choose(ip)
139 down = node.child[bit]
140 if down == nil {
141 newNode.parent = parentIndirection{&node.child[bit], bit}
142 node.child[bit] = newNode
143 return
144 }
145 }
146 common := commonBits(down.bits, ip)
147 if common < cidr {
148 cidr = common
149 }
150 parent := node
151
152 if newNode.cidr == cidr {
153 bit := newNode.choose(down.bits)
154 down.parent = parentIndirection{&newNode.child[bit], bit}
155 newNode.child[bit] = down
156 if parent == nil {
157 newNode.parent = trie
158 *trie.parentBit = newNode
159 } else {
160 bit := parent.choose(newNode.bits)
161 newNode.parent = parentIndirection{&parent.child[bit], bit}
162 parent.child[bit] = newNode
163 }
164 return
165 }
166
167 node = &trieEntry{
168 bits: append([]byte{}, newNode.bits...),
169 cidr: cidr,
170 bitAtByte: cidr / 8,
171 bitAtShift: 7 - (cidr % 8),
172 }
173 node.maskSelf()
174
175 bit := node.choose(down.bits)
176 down.parent = parentIndirection{&node.child[bit], bit}
177 node.child[bit] = down
178 bit = node.choose(newNode.bits)
179 newNode.parent = parentIndirection{&node.child[bit], bit}
180 node.child[bit] = newNode
181 if parent == nil {
182 node.parent = trie
183 *trie.parentBit = node
184 } else {
185 bit := parent.choose(node.bits)
186 node.parent = parentIndirection{&parent.child[bit], bit}
187 parent.child[bit] = node
188 }
189 }
190
191 func (node *trieEntry) lookup(ip []byte) *Peer {
192 var found *Peer
193 size := uint8(len(ip))
194 for node != nil && commonBits(node.bits, ip) >= node.cidr {
195 if node.peer != nil {
196 found = node.peer
197 }
198 if node.bitAtByte == size {
199 break
200 }
201 bit := node.choose(ip)
202 node = node.child[bit]
203 }
204 return found
205 }
206
207 type AllowedIPs struct {
208 IPv4 *trieEntry
209 IPv6 *trieEntry
210 mutex sync.RWMutex
211 }
212
213 func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
214 table.mutex.RLock()
215 defer table.mutex.RUnlock()
216
217 for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
218 node := elem.Value.(*trieEntry)
219 a, _ := netip.AddrFromSlice(node.bits)
220 if !cb(netip.PrefixFrom(a, int(node.cidr))) {
221 return
222 }
223 }
224 }
225
226 func (node *trieEntry) remove() {
227 node.removeFromPeerEntries()
228 node.peer = nil
229 if node.child[0] != nil && node.child[1] != nil {
230 return
231 }
232 bit := 0
233 if node.child[0] == nil {
234 bit = 1
235 }
236 child := node.child[bit]
237 if child != nil {
238 child.parent = node.parent
239 }
240 *node.parent.parentBit = child
241 if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
242 node.zeroizePointers()
243 return
244 }
245 parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
246 if parent.peer != nil {
247 node.zeroizePointers()
248 return
249 }
250 child = parent.child[node.parent.parentBitType^1]
251 if child != nil {
252 child.parent = parent.parent
253 }
254 *parent.parent.parentBit = child
255 node.zeroizePointers()
256 parent.zeroizePointers()
257 }
258
259 func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
260 table.mutex.Lock()
261 defer table.mutex.Unlock()
262 var node *trieEntry
263 var exact bool
264
265 if prefix.Addr().Is6() {
266 ip := prefix.Addr().As16()
267 node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
268 } else if prefix.Addr().Is4() {
269 ip := prefix.Addr().As4()
270 node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
271 } else {
272 panic(errors.New("removing unknown address type"))
273 }
274 if !exact || node == nil || peer != node.peer {
275 return
276 }
277 node.remove()
278 }
279
280 func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
281 table.mutex.Lock()
282 defer table.mutex.Unlock()
283
284 var next *list.Element
285 for elem := peer.trieEntries.Front(); elem != nil; elem = next {
286 next = elem.Next()
287 elem.Value.(*trieEntry).remove()
288 }
289 }
290
291 func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
292 table.mutex.Lock()
293 defer table.mutex.Unlock()
294
295 if prefix.Addr().Is6() {
296 ip := prefix.Addr().As16()
297 parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
298 } else if prefix.Addr().Is4() {
299 ip := prefix.Addr().As4()
300 parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
301 } else {
302 panic(errors.New("inserting unknown address type"))
303 }
304 }
305
306 func (table *AllowedIPs) Lookup(ip []byte) *Peer {
307 table.mutex.RLock()
308 defer table.mutex.RUnlock()
309 switch len(ip) {
310 case net.IPv6len:
311 return table.IPv6.lookup(ip)
312 case net.IPv4len:
313 return table.IPv4.lookup(ip)
314 default:
315 panic(errors.New("looking up unknown address type"))
316 }
317 }
318