device.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package device
   7  
   8  import (
   9  	"runtime"
  10  	"sync"
  11  	"sync/atomic"
  12  	"time"
  13  
  14  	"golang.zx2c4.com/wireguard/conn"
  15  	"golang.zx2c4.com/wireguard/ratelimiter"
  16  	"golang.zx2c4.com/wireguard/rwcancel"
  17  	"golang.zx2c4.com/wireguard/tun"
  18  )
  19  
  20  type Device struct {
  21  	state struct {
  22  		// state holds the device's state. It is accessed atomically.
  23  		// Use the device.deviceState method to read it.
  24  		// device.deviceState does not acquire the mutex, so it captures only a snapshot.
  25  		// During state transitions, the state variable is updated before the device itself.
  26  		// The state is thus either the current state of the device or
  27  		// the intended future state of the device.
  28  		// For example, while executing a call to Up, state will be deviceStateUp.
  29  		// There is no guarantee that that intended future state of the device
  30  		// will become the actual state; Up can fail.
  31  		// The device can also change state multiple times between time of check and time of use.
  32  		// Unsynchronized uses of state must therefore be advisory/best-effort only.
  33  		state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
  34  		// stopping blocks until all inputs to Device have been closed.
  35  		stopping sync.WaitGroup
  36  		// mu protects state changes.
  37  		sync.Mutex
  38  	}
  39  
  40  	net struct {
  41  		stopping sync.WaitGroup
  42  		sync.RWMutex
  43  		bind          conn.Bind // bind interface
  44  		netlinkCancel *rwcancel.RWCancel
  45  		port          uint16 // listening port
  46  		fwmark        uint32 // mark value (0 = disabled)
  47  		brokenRoaming bool
  48  	}
  49  
  50  	staticIdentity struct {
  51  		sync.RWMutex
  52  		privateKey NoisePrivateKey
  53  		publicKey  NoisePublicKey
  54  	}
  55  
  56  	peers struct {
  57  		sync.RWMutex // protects keyMap
  58  		keyMap       map[NoisePublicKey]*Peer
  59  	}
  60  
  61  	rate struct {
  62  		underLoadUntil atomic.Int64
  63  		limiter        ratelimiter.Ratelimiter
  64  	}
  65  
  66  	allowedips    AllowedIPs
  67  	indexTable    IndexTable
  68  	cookieChecker CookieChecker
  69  
  70  	pool struct {
  71  		inboundElementsContainer  *WaitPool
  72  		outboundElementsContainer *WaitPool
  73  		messageBuffers            *WaitPool
  74  		inboundElements           *WaitPool
  75  		outboundElements          *WaitPool
  76  	}
  77  
  78  	queue struct {
  79  		encryption *outboundQueue
  80  		decryption *inboundQueue
  81  		handshake  *handshakeQueue
  82  	}
  83  
  84  	tun struct {
  85  		device tun.Device
  86  		mtu    atomic.Int32
  87  	}
  88  
  89  	ipcMutex sync.RWMutex
  90  	closed   chan struct{}
  91  	log      *Logger
  92  }
  93  
  94  // deviceState represents the state of a Device.
  95  // There are three states: down, up, closed.
  96  // Transitions:
  97  //
  98  //	down -----+
  99  //	  ↑↓      ↓
 100  //	  up -> closed
 101  type deviceState uint32
 102  
 103  //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
 104  const (
 105  	deviceStateDown deviceState = iota
 106  	deviceStateUp
 107  	deviceStateClosed
 108  )
 109  
 110  // deviceState returns device.state.state as a deviceState
 111  // See those docs for how to interpret this value.
 112  func (device *Device) deviceState() deviceState {
 113  	return deviceState(device.state.state.Load())
 114  }
 115  
 116  // isClosed reports whether the device is closed (or is closing).
 117  // See device.state.state comments for how to interpret this value.
 118  func (device *Device) isClosed() bool {
 119  	return device.deviceState() == deviceStateClosed
 120  }
 121  
 122  // isUp reports whether the device is up (or is attempting to come up).
 123  // See device.state.state comments for how to interpret this value.
 124  func (device *Device) isUp() bool {
 125  	return device.deviceState() == deviceStateUp
 126  }
 127  
 128  // Must hold device.peers.Lock()
 129  func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
 130  	// stop routing and processing of packets
 131  	device.allowedips.RemoveByPeer(peer)
 132  	peer.Stop()
 133  
 134  	// remove from peer map
 135  	delete(device.peers.keyMap, key)
 136  }
 137  
 138  // changeState attempts to change the device state to match want.
 139  func (device *Device) changeState(want deviceState) (err error) {
 140  	device.state.Lock()
 141  	defer device.state.Unlock()
 142  	old := device.deviceState()
 143  	if old == deviceStateClosed {
 144  		// once closed, always closed
 145  		device.log.Verbosef("Interface closed, ignored requested state %s", want)
 146  		return nil
 147  	}
 148  	switch want {
 149  	case old:
 150  		return nil
 151  	case deviceStateUp:
 152  		device.state.state.Store(uint32(deviceStateUp))
 153  		err = device.upLocked()
 154  		if err == nil {
 155  			break
 156  		}
 157  		fallthrough // up failed; bring the device all the way back down
 158  	case deviceStateDown:
 159  		device.state.state.Store(uint32(deviceStateDown))
 160  		errDown := device.downLocked()
 161  		if err == nil {
 162  			err = errDown
 163  		}
 164  	}
 165  	device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
 166  	return
 167  }
 168  
 169  // upLocked attempts to bring the device up and reports whether it succeeded.
 170  // The caller must hold device.state.mu and is responsible for updating device.state.state.
 171  func (device *Device) upLocked() error {
 172  	if err := device.BindUpdate(); err != nil {
 173  		device.log.Errorf("Unable to update bind: %v", err)
 174  		return err
 175  	}
 176  
 177  	// The IPC set operation waits for peers to be created before calling Start() on them,
 178  	// so if there's a concurrent IPC set request happening, we should wait for it to complete.
 179  	device.ipcMutex.Lock()
 180  	defer device.ipcMutex.Unlock()
 181  
 182  	device.peers.RLock()
 183  	for _, peer := range device.peers.keyMap {
 184  		peer.Start()
 185  		if peer.persistentKeepaliveInterval.Load() > 0 {
 186  			peer.SendKeepalive()
 187  		}
 188  	}
 189  	device.peers.RUnlock()
 190  	return nil
 191  }
 192  
 193  // downLocked attempts to bring the device down.
 194  // The caller must hold device.state.mu and is responsible for updating device.state.state.
 195  func (device *Device) downLocked() error {
 196  	err := device.BindClose()
 197  	if err != nil {
 198  		device.log.Errorf("Bind close failed: %v", err)
 199  	}
 200  
 201  	device.peers.RLock()
 202  	for _, peer := range device.peers.keyMap {
 203  		peer.Stop()
 204  	}
 205  	device.peers.RUnlock()
 206  	return err
 207  }
 208  
 209  func (device *Device) Up() error {
 210  	return device.changeState(deviceStateUp)
 211  }
 212  
 213  func (device *Device) Down() error {
 214  	return device.changeState(deviceStateDown)
 215  }
 216  
 217  func (device *Device) IsUnderLoad() bool {
 218  	// check if currently under load
 219  	now := time.Now()
 220  	underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
 221  	if underLoad {
 222  		device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
 223  		return true
 224  	}
 225  	// check if recently under load
 226  	return device.rate.underLoadUntil.Load() > now.UnixNano()
 227  }
 228  
 229  func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
 230  	// lock required resources
 231  
 232  	device.staticIdentity.Lock()
 233  	defer device.staticIdentity.Unlock()
 234  
 235  	if sk.Equals(device.staticIdentity.privateKey) {
 236  		return nil
 237  	}
 238  
 239  	device.peers.Lock()
 240  	defer device.peers.Unlock()
 241  
 242  	lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
 243  	for _, peer := range device.peers.keyMap {
 244  		peer.handshake.mutex.RLock()
 245  		lockedPeers = append(lockedPeers, peer)
 246  	}
 247  
 248  	// remove peers with matching public keys
 249  
 250  	publicKey := sk.publicKey()
 251  	for key, peer := range device.peers.keyMap {
 252  		if peer.handshake.remoteStatic.Equals(publicKey) {
 253  			peer.handshake.mutex.RUnlock()
 254  			removePeerLocked(device, peer, key)
 255  			peer.handshake.mutex.RLock()
 256  		}
 257  	}
 258  
 259  	// update key material
 260  
 261  	device.staticIdentity.privateKey = sk
 262  	device.staticIdentity.publicKey = publicKey
 263  	device.cookieChecker.Init(publicKey)
 264  
 265  	// do static-static DH pre-computations
 266  
 267  	expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
 268  	for _, peer := range device.peers.keyMap {
 269  		handshake := &peer.handshake
 270  		handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
 271  		expiredPeers = append(expiredPeers, peer)
 272  	}
 273  
 274  	for _, peer := range lockedPeers {
 275  		peer.handshake.mutex.RUnlock()
 276  	}
 277  	for _, peer := range expiredPeers {
 278  		peer.ExpireCurrentKeypairs()
 279  	}
 280  
 281  	return nil
 282  }
 283  
 284  func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
 285  	device := new(Device)
 286  	device.state.state.Store(uint32(deviceStateDown))
 287  	device.closed = make(chan struct{})
 288  	device.log = logger
 289  	device.net.bind = bind
 290  	device.tun.device = tunDevice
 291  	mtu, err := device.tun.device.MTU()
 292  	if err != nil {
 293  		device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
 294  		mtu = DefaultMTU
 295  	}
 296  	device.tun.mtu.Store(int32(mtu))
 297  	device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 298  	device.rate.limiter.Init()
 299  	device.indexTable.Init()
 300  
 301  	device.PopulatePools()
 302  
 303  	// create queues
 304  
 305  	device.queue.handshake = newHandshakeQueue()
 306  	device.queue.encryption = newOutboundQueue()
 307  	device.queue.decryption = newInboundQueue()
 308  
 309  	// start workers
 310  
 311  	cpus := runtime.NumCPU()
 312  	device.state.stopping.Wait()
 313  	device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
 314  	for i := 0; i < cpus; i++ {
 315  		go device.RoutineEncryption(i + 1)
 316  		go device.RoutineDecryption(i + 1)
 317  		go device.RoutineHandshake(i + 1)
 318  	}
 319  
 320  	device.state.stopping.Add(1)      // RoutineReadFromTUN
 321  	device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
 322  	go device.RoutineReadFromTUN()
 323  	go device.RoutineTUNEventReader()
 324  
 325  	return device
 326  }
 327  
 328  // BatchSize returns the BatchSize for the device as a whole which is the max of
 329  // the bind batch size and the tun batch size. The batch size reported by device
 330  // is the size used to construct memory pools, and is the allowed batch size for
 331  // the lifetime of the device.
 332  func (device *Device) BatchSize() int {
 333  	size := device.net.bind.BatchSize()
 334  	dSize := device.tun.device.BatchSize()
 335  	if size < dSize {
 336  		size = dSize
 337  	}
 338  	return size
 339  }
 340  
 341  func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
 342  	device.peers.RLock()
 343  	defer device.peers.RUnlock()
 344  
 345  	return device.peers.keyMap[pk]
 346  }
 347  
 348  func (device *Device) RemovePeer(key NoisePublicKey) {
 349  	device.peers.Lock()
 350  	defer device.peers.Unlock()
 351  	// stop peer and remove from routing
 352  
 353  	peer, ok := device.peers.keyMap[key]
 354  	if ok {
 355  		removePeerLocked(device, peer, key)
 356  	}
 357  }
 358  
 359  func (device *Device) RemoveAllPeers() {
 360  	device.peers.Lock()
 361  	defer device.peers.Unlock()
 362  
 363  	for key, peer := range device.peers.keyMap {
 364  		removePeerLocked(device, peer, key)
 365  	}
 366  
 367  	device.peers.keyMap = make(map[NoisePublicKey]*Peer)
 368  }
 369  
 370  func (device *Device) Close() {
 371  	device.state.Lock()
 372  	defer device.state.Unlock()
 373  	device.ipcMutex.Lock()
 374  	defer device.ipcMutex.Unlock()
 375  	if device.isClosed() {
 376  		return
 377  	}
 378  	device.state.state.Store(uint32(deviceStateClosed))
 379  	device.log.Verbosef("Device closing")
 380  
 381  	device.tun.device.Close()
 382  	device.downLocked()
 383  
 384  	// Remove peers before closing queues,
 385  	// because peers assume that queues are active.
 386  	device.RemoveAllPeers()
 387  
 388  	// We kept a reference to the encryption and decryption queues,
 389  	// in case we started any new peers that might write to them.
 390  	// No new peers are coming; we are done with these queues.
 391  	device.queue.encryption.wg.Done()
 392  	device.queue.decryption.wg.Done()
 393  	device.queue.handshake.wg.Done()
 394  	device.state.stopping.Wait()
 395  
 396  	device.rate.limiter.Close()
 397  
 398  	device.log.Verbosef("Device closed")
 399  	close(device.closed)
 400  }
 401  
 402  func (device *Device) Wait() chan struct{} {
 403  	return device.closed
 404  }
 405  
 406  func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
 407  	if !device.isUp() {
 408  		return
 409  	}
 410  
 411  	device.peers.RLock()
 412  	for _, peer := range device.peers.keyMap {
 413  		peer.keypairs.RLock()
 414  		sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
 415  		peer.keypairs.RUnlock()
 416  		if sendKeepalive {
 417  			peer.SendKeepalive()
 418  		}
 419  	}
 420  	device.peers.RUnlock()
 421  }
 422  
 423  // closeBindLocked closes the device's net.bind.
 424  // The caller must hold the net mutex.
 425  func closeBindLocked(device *Device) error {
 426  	var err error
 427  	netc := &device.net
 428  	if netc.netlinkCancel != nil {
 429  		netc.netlinkCancel.Cancel()
 430  	}
 431  	if netc.bind != nil {
 432  		err = netc.bind.Close()
 433  	}
 434  	netc.stopping.Wait()
 435  	return err
 436  }
 437  
 438  func (device *Device) Bind() conn.Bind {
 439  	device.net.Lock()
 440  	defer device.net.Unlock()
 441  	return device.net.bind
 442  }
 443  
 444  func (device *Device) BindSetMark(mark uint32) error {
 445  	device.net.Lock()
 446  	defer device.net.Unlock()
 447  
 448  	// check if modified
 449  	if device.net.fwmark == mark {
 450  		return nil
 451  	}
 452  
 453  	// update fwmark on existing bind
 454  	device.net.fwmark = mark
 455  	if device.isUp() && device.net.bind != nil {
 456  		if err := device.net.bind.SetMark(mark); err != nil {
 457  			return err
 458  		}
 459  	}
 460  
 461  	// clear cached source addresses
 462  	device.peers.RLock()
 463  	for _, peer := range device.peers.keyMap {
 464  		peer.markEndpointSrcForClearing()
 465  	}
 466  	device.peers.RUnlock()
 467  
 468  	return nil
 469  }
 470  
 471  func (device *Device) BindUpdate() error {
 472  	device.net.Lock()
 473  	defer device.net.Unlock()
 474  
 475  	// close existing sockets
 476  	if err := closeBindLocked(device); err != nil {
 477  		return err
 478  	}
 479  
 480  	// open new sockets
 481  	if !device.isUp() {
 482  		return nil
 483  	}
 484  
 485  	// bind to new port
 486  	var err error
 487  	var recvFns []conn.ReceiveFunc
 488  	netc := &device.net
 489  
 490  	recvFns, netc.port, err = netc.bind.Open(netc.port)
 491  	if err != nil {
 492  		netc.port = 0
 493  		return err
 494  	}
 495  
 496  	netc.netlinkCancel, err = device.startRouteListener(netc.bind)
 497  	if err != nil {
 498  		netc.bind.Close()
 499  		netc.port = 0
 500  		return err
 501  	}
 502  
 503  	// set fwmark
 504  	if netc.fwmark != 0 {
 505  		err = netc.bind.SetMark(netc.fwmark)
 506  		if err != nil {
 507  			return err
 508  		}
 509  	}
 510  
 511  	// clear cached source addresses
 512  	device.peers.RLock()
 513  	for _, peer := range device.peers.keyMap {
 514  		peer.markEndpointSrcForClearing()
 515  	}
 516  	device.peers.RUnlock()
 517  
 518  	// start receiving routines
 519  	device.net.stopping.Add(len(recvFns))
 520  	device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
 521  	device.queue.handshake.wg.Add(len(recvFns))  // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
 522  	batchSize := netc.bind.BatchSize()
 523  	for _, fn := range recvFns {
 524  		go device.RoutineReceiveIncoming(batchSize, fn)
 525  	}
 526  
 527  	device.log.Verbosef("UDP bind has been updated")
 528  	return nil
 529  }
 530  
 531  func (device *Device) BindClose() error {
 532  	device.net.Lock()
 533  	err := closeBindLocked(device)
 534  	device.net.Unlock()
 535  	return err
 536  }
 537