allowedips.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package device
   7  
   8  import (
   9  	"container/list"
  10  	"encoding/binary"
  11  	"errors"
  12  	"math/bits"
  13  	"net"
  14  	"net/netip"
  15  	"sync"
  16  	"unsafe"
  17  )
  18  
  19  type parentIndirection struct {
  20  	parentBit     **trieEntry
  21  	parentBitType uint8
  22  }
  23  
  24  type trieEntry struct {
  25  	peer        *Peer
  26  	child       [2]*trieEntry
  27  	parent      parentIndirection
  28  	cidr        uint8
  29  	bitAtByte   uint8
  30  	bitAtShift  uint8
  31  	bits        []byte
  32  	perPeerElem *list.Element
  33  }
  34  
  35  func commonBits(ip1, ip2 []byte) uint8 {
  36  	size := len(ip1)
  37  	if size == net.IPv4len {
  38  		a := binary.BigEndian.Uint32(ip1)
  39  		b := binary.BigEndian.Uint32(ip2)
  40  		x := a ^ b
  41  		return uint8(bits.LeadingZeros32(x))
  42  	} else if size == net.IPv6len {
  43  		a := binary.BigEndian.Uint64(ip1)
  44  		b := binary.BigEndian.Uint64(ip2)
  45  		x := a ^ b
  46  		if x != 0 {
  47  			return uint8(bits.LeadingZeros64(x))
  48  		}
  49  		a = binary.BigEndian.Uint64(ip1[8:])
  50  		b = binary.BigEndian.Uint64(ip2[8:])
  51  		x = a ^ b
  52  		return 64 + uint8(bits.LeadingZeros64(x))
  53  	} else {
  54  		panic("Wrong size bit string")
  55  	}
  56  }
  57  
  58  func (node *trieEntry) addToPeerEntries() {
  59  	node.perPeerElem = node.peer.trieEntries.PushBack(node)
  60  }
  61  
  62  func (node *trieEntry) removeFromPeerEntries() {
  63  	if node.perPeerElem != nil {
  64  		node.peer.trieEntries.Remove(node.perPeerElem)
  65  		node.perPeerElem = nil
  66  	}
  67  }
  68  
  69  func (node *trieEntry) choose(ip []byte) byte {
  70  	return (ip[node.bitAtByte] >> node.bitAtShift) & 1
  71  }
  72  
  73  func (node *trieEntry) maskSelf() {
  74  	mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
  75  	for i := 0; i < len(mask); i++ {
  76  		node.bits[i] &= mask[i]
  77  	}
  78  }
  79  
  80  func (node *trieEntry) zeroizePointers() {
  81  	// Make the garbage collector's life slightly easier
  82  	node.peer = nil
  83  	node.child[0] = nil
  84  	node.child[1] = nil
  85  	node.parent.parentBit = nil
  86  }
  87  
  88  func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
  89  	for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
  90  		parent = node
  91  		if parent.cidr == cidr {
  92  			exact = true
  93  			return
  94  		}
  95  		bit := node.choose(ip)
  96  		node = node.child[bit]
  97  	}
  98  	return
  99  }
 100  
 101  func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
 102  	if *trie.parentBit == nil {
 103  		node := &trieEntry{
 104  			peer:       peer,
 105  			parent:     trie,
 106  			bits:       ip,
 107  			cidr:       cidr,
 108  			bitAtByte:  cidr / 8,
 109  			bitAtShift: 7 - (cidr % 8),
 110  		}
 111  		node.maskSelf()
 112  		node.addToPeerEntries()
 113  		*trie.parentBit = node
 114  		return
 115  	}
 116  	node, exact := (*trie.parentBit).nodePlacement(ip, cidr)
 117  	if exact {
 118  		node.removeFromPeerEntries()
 119  		node.peer = peer
 120  		node.addToPeerEntries()
 121  		return
 122  	}
 123  
 124  	newNode := &trieEntry{
 125  		peer:       peer,
 126  		bits:       ip,
 127  		cidr:       cidr,
 128  		bitAtByte:  cidr / 8,
 129  		bitAtShift: 7 - (cidr % 8),
 130  	}
 131  	newNode.maskSelf()
 132  	newNode.addToPeerEntries()
 133  
 134  	var down *trieEntry
 135  	if node == nil {
 136  		down = *trie.parentBit
 137  	} else {
 138  		bit := node.choose(ip)
 139  		down = node.child[bit]
 140  		if down == nil {
 141  			newNode.parent = parentIndirection{&node.child[bit], bit}
 142  			node.child[bit] = newNode
 143  			return
 144  		}
 145  	}
 146  	common := commonBits(down.bits, ip)
 147  	if common < cidr {
 148  		cidr = common
 149  	}
 150  	parent := node
 151  
 152  	if newNode.cidr == cidr {
 153  		bit := newNode.choose(down.bits)
 154  		down.parent = parentIndirection{&newNode.child[bit], bit}
 155  		newNode.child[bit] = down
 156  		if parent == nil {
 157  			newNode.parent = trie
 158  			*trie.parentBit = newNode
 159  		} else {
 160  			bit := parent.choose(newNode.bits)
 161  			newNode.parent = parentIndirection{&parent.child[bit], bit}
 162  			parent.child[bit] = newNode
 163  		}
 164  		return
 165  	}
 166  
 167  	node = &trieEntry{
 168  		bits:       append([]byte{}, newNode.bits...),
 169  		cidr:       cidr,
 170  		bitAtByte:  cidr / 8,
 171  		bitAtShift: 7 - (cidr % 8),
 172  	}
 173  	node.maskSelf()
 174  
 175  	bit := node.choose(down.bits)
 176  	down.parent = parentIndirection{&node.child[bit], bit}
 177  	node.child[bit] = down
 178  	bit = node.choose(newNode.bits)
 179  	newNode.parent = parentIndirection{&node.child[bit], bit}
 180  	node.child[bit] = newNode
 181  	if parent == nil {
 182  		node.parent = trie
 183  		*trie.parentBit = node
 184  	} else {
 185  		bit := parent.choose(node.bits)
 186  		node.parent = parentIndirection{&parent.child[bit], bit}
 187  		parent.child[bit] = node
 188  	}
 189  }
 190  
 191  func (node *trieEntry) lookup(ip []byte) *Peer {
 192  	var found *Peer
 193  	size := uint8(len(ip))
 194  	for node != nil && commonBits(node.bits, ip) >= node.cidr {
 195  		if node.peer != nil {
 196  			found = node.peer
 197  		}
 198  		if node.bitAtByte == size {
 199  			break
 200  		}
 201  		bit := node.choose(ip)
 202  		node = node.child[bit]
 203  	}
 204  	return found
 205  }
 206  
 207  type AllowedIPs struct {
 208  	IPv4  *trieEntry
 209  	IPv6  *trieEntry
 210  	mutex sync.RWMutex
 211  }
 212  
 213  func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
 214  	table.mutex.RLock()
 215  	defer table.mutex.RUnlock()
 216  
 217  	for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
 218  		node := elem.Value.(*trieEntry)
 219  		a, _ := netip.AddrFromSlice(node.bits)
 220  		if !cb(netip.PrefixFrom(a, int(node.cidr))) {
 221  			return
 222  		}
 223  	}
 224  }
 225  
 226  func (node *trieEntry) remove() {
 227  	node.removeFromPeerEntries()
 228  	node.peer = nil
 229  	if node.child[0] != nil && node.child[1] != nil {
 230  		return
 231  	}
 232  	bit := 0
 233  	if node.child[0] == nil {
 234  		bit = 1
 235  	}
 236  	child := node.child[bit]
 237  	if child != nil {
 238  		child.parent = node.parent
 239  	}
 240  	*node.parent.parentBit = child
 241  	if node.child[0] != nil || node.child[1] != nil || node.parent.parentBitType > 1 {
 242  		node.zeroizePointers()
 243  		return
 244  	}
 245  	parent := (*trieEntry)(unsafe.Pointer(uintptr(unsafe.Pointer(node.parent.parentBit)) - unsafe.Offsetof(node.child) - unsafe.Sizeof(node.child[0])*uintptr(node.parent.parentBitType)))
 246  	if parent.peer != nil {
 247  		node.zeroizePointers()
 248  		return
 249  	}
 250  	child = parent.child[node.parent.parentBitType^1]
 251  	if child != nil {
 252  		child.parent = parent.parent
 253  	}
 254  	*parent.parent.parentBit = child
 255  	node.zeroizePointers()
 256  	parent.zeroizePointers()
 257  }
 258  
 259  func (table *AllowedIPs) Remove(prefix netip.Prefix, peer *Peer) {
 260  	table.mutex.Lock()
 261  	defer table.mutex.Unlock()
 262  	var node *trieEntry
 263  	var exact bool
 264  
 265  	if prefix.Addr().Is6() {
 266  		ip := prefix.Addr().As16()
 267  		node, exact = table.IPv6.nodePlacement(ip[:], uint8(prefix.Bits()))
 268  	} else if prefix.Addr().Is4() {
 269  		ip := prefix.Addr().As4()
 270  		node, exact = table.IPv4.nodePlacement(ip[:], uint8(prefix.Bits()))
 271  	} else {
 272  		panic(errors.New("removing unknown address type"))
 273  	}
 274  	if !exact || node == nil || peer != node.peer {
 275  		return
 276  	}
 277  	node.remove()
 278  }
 279  
 280  func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
 281  	table.mutex.Lock()
 282  	defer table.mutex.Unlock()
 283  
 284  	var next *list.Element
 285  	for elem := peer.trieEntries.Front(); elem != nil; elem = next {
 286  		next = elem.Next()
 287  		elem.Value.(*trieEntry).remove()
 288  	}
 289  }
 290  
 291  func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
 292  	table.mutex.Lock()
 293  	defer table.mutex.Unlock()
 294  
 295  	if prefix.Addr().Is6() {
 296  		ip := prefix.Addr().As16()
 297  		parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
 298  	} else if prefix.Addr().Is4() {
 299  		ip := prefix.Addr().As4()
 300  		parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
 301  	} else {
 302  		panic(errors.New("inserting unknown address type"))
 303  	}
 304  }
 305  
 306  func (table *AllowedIPs) Lookup(ip []byte) *Peer {
 307  	table.mutex.RLock()
 308  	defer table.mutex.RUnlock()
 309  	switch len(ip) {
 310  	case net.IPv6len:
 311  		return table.IPv6.lookup(ip)
 312  	case net.IPv4len:
 313  		return table.IPv4.lookup(ip)
 314  	default:
 315  		panic(errors.New("looking up unknown address type"))
 316  	}
 317  }
 318