send.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package device
   7  
   8  import (
   9  	"encoding/binary"
  10  	"errors"
  11  	"net"
  12  	"os"
  13  	"sync"
  14  	"time"
  15  
  16  	"golang.org/x/crypto/chacha20poly1305"
  17  	"golang.org/x/net/ipv4"
  18  	"golang.org/x/net/ipv6"
  19  	"golang.zx2c4.com/wireguard/conn"
  20  	"golang.zx2c4.com/wireguard/tun"
  21  )
  22  
  23  /* Outbound flow
  24   *
  25   * 1. TUN queue
  26   * 2. Routing (sequential)
  27   * 3. Nonce assignment (sequential)
  28   * 4. Encryption (parallel)
  29   * 5. Transmission (sequential)
  30   *
  31   * The functions in this file occur (roughly) in the order in
  32   * which the packets are processed.
  33   *
  34   * Locking, Producers and Consumers
  35   *
  36   * The order of packets (per peer) must be maintained,
  37   * but encryption of packets happen out-of-order:
  38   *
  39   * The sequential consumers will attempt to take the lock,
  40   * workers release lock when they have completed work (encryption) on the packet.
  41   *
  42   * If the element is inserted into the "encryption queue",
  43   * the content is preceded by enough "junk" to contain the transport header
  44   * (to allow the construction of transport messages in-place)
  45   */
  46  
  47  type QueueOutboundElement struct {
  48  	buffer  *[MaxMessageSize]byte // slice holding the packet data
  49  	packet  []byte                // slice of "buffer" (always!)
  50  	nonce   uint64                // nonce for encryption
  51  	keypair *Keypair              // keypair for encryption
  52  	peer    *Peer                 // related peer
  53  }
  54  
  55  type QueueOutboundElementsContainer struct {
  56  	sync.Mutex
  57  	elems []*QueueOutboundElement
  58  }
  59  
  60  func (device *Device) NewOutboundElement() *QueueOutboundElement {
  61  	elem := device.GetOutboundElement()
  62  	elem.buffer = device.GetMessageBuffer()
  63  	elem.nonce = 0
  64  	// keypair and peer were cleared (if necessary) by clearPointers.
  65  	return elem
  66  }
  67  
  68  // clearPointers clears elem fields that contain pointers.
  69  // This makes the garbage collector's life easier and
  70  // avoids accidentally keeping other objects around unnecessarily.
  71  // It also reduces the possible collateral damage from use-after-free bugs.
  72  func (elem *QueueOutboundElement) clearPointers() {
  73  	elem.buffer = nil
  74  	elem.packet = nil
  75  	elem.keypair = nil
  76  	elem.peer = nil
  77  }
  78  
  79  /* Queues a keepalive if no packets are queued for peer
  80   */
  81  func (peer *Peer) SendKeepalive() {
  82  	if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
  83  		elem := peer.device.NewOutboundElement()
  84  		elemsContainer := peer.device.GetOutboundElementsContainer()
  85  		elemsContainer.elems = append(elemsContainer.elems, elem)
  86  		select {
  87  		case peer.queue.staged <- elemsContainer:
  88  			peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
  89  		default:
  90  			peer.device.PutMessageBuffer(elem.buffer)
  91  			peer.device.PutOutboundElement(elem)
  92  			peer.device.PutOutboundElementsContainer(elemsContainer)
  93  		}
  94  	}
  95  	peer.SendStagedPackets()
  96  }
  97  
  98  func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
  99  	if !isRetry {
 100  		peer.timers.handshakeAttempts.Store(0)
 101  	}
 102  
 103  	peer.handshake.mutex.RLock()
 104  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
 105  		peer.handshake.mutex.RUnlock()
 106  		return nil
 107  	}
 108  	peer.handshake.mutex.RUnlock()
 109  
 110  	peer.handshake.mutex.Lock()
 111  	if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
 112  		peer.handshake.mutex.Unlock()
 113  		return nil
 114  	}
 115  	peer.handshake.lastSentHandshake = time.Now()
 116  	peer.handshake.mutex.Unlock()
 117  
 118  	peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
 119  
 120  	msg, err := peer.device.CreateMessageInitiation(peer)
 121  	if err != nil {
 122  		peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
 123  		return err
 124  	}
 125  
 126  	packet := make([]byte, MessageInitiationSize)
 127  	_ = msg.marshal(packet)
 128  	peer.cookieGenerator.AddMacs(packet)
 129  
 130  	peer.timersAnyAuthenticatedPacketTraversal()
 131  	peer.timersAnyAuthenticatedPacketSent()
 132  
 133  	err = peer.SendBuffers([][]byte{packet})
 134  	if err != nil {
 135  		peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
 136  	}
 137  	peer.timersHandshakeInitiated()
 138  
 139  	return err
 140  }
 141  
 142  func (peer *Peer) SendHandshakeResponse() error {
 143  	peer.handshake.mutex.Lock()
 144  	peer.handshake.lastSentHandshake = time.Now()
 145  	peer.handshake.mutex.Unlock()
 146  
 147  	peer.device.log.Verbosef("%v - Sending handshake response", peer)
 148  
 149  	response, err := peer.device.CreateMessageResponse(peer)
 150  	if err != nil {
 151  		peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
 152  		return err
 153  	}
 154  
 155  	packet := make([]byte, MessageResponseSize)
 156  	_ = response.marshal(packet)
 157  	peer.cookieGenerator.AddMacs(packet)
 158  
 159  	err = peer.BeginSymmetricSession()
 160  	if err != nil {
 161  		peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
 162  		return err
 163  	}
 164  
 165  	peer.timersSessionDerived()
 166  	peer.timersAnyAuthenticatedPacketTraversal()
 167  	peer.timersAnyAuthenticatedPacketSent()
 168  
 169  	// TODO: allocation could be avoided
 170  	err = peer.SendBuffers([][]byte{packet})
 171  	if err != nil {
 172  		peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
 173  	}
 174  	return err
 175  }
 176  
 177  func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
 178  	device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
 179  
 180  	sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
 181  	reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
 182  	if err != nil {
 183  		device.log.Errorf("Failed to create cookie reply: %v", err)
 184  		return err
 185  	}
 186  
 187  	packet := make([]byte, MessageCookieReplySize)
 188  	_ = reply.marshal(packet)
 189  	// TODO: allocation could be avoided
 190  	device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
 191  
 192  	return nil
 193  }
 194  
 195  func (peer *Peer) keepKeyFreshSending() {
 196  	keypair := peer.keypairs.Current()
 197  	if keypair == nil {
 198  		return
 199  	}
 200  	nonce := keypair.sendNonce.Load()
 201  	if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
 202  		peer.SendHandshakeInitiation(false)
 203  	}
 204  }
 205  
 206  func (device *Device) RoutineReadFromTUN() {
 207  	defer func() {
 208  		device.log.Verbosef("Routine: TUN reader - stopped")
 209  		device.state.stopping.Done()
 210  		device.queue.encryption.wg.Done()
 211  	}()
 212  
 213  	device.log.Verbosef("Routine: TUN reader - started")
 214  
 215  	var (
 216  		batchSize   = device.BatchSize()
 217  		readErr     error
 218  		elems       = make([]*QueueOutboundElement, batchSize)
 219  		bufs        = make([][]byte, batchSize)
 220  		elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
 221  		count       = 0
 222  		sizes       = make([]int, batchSize)
 223  		offset      = MessageTransportHeaderSize
 224  	)
 225  
 226  	for i := range elems {
 227  		elems[i] = device.NewOutboundElement()
 228  		bufs[i] = elems[i].buffer[:]
 229  	}
 230  
 231  	defer func() {
 232  		for _, elem := range elems {
 233  			if elem != nil {
 234  				device.PutMessageBuffer(elem.buffer)
 235  				device.PutOutboundElement(elem)
 236  			}
 237  		}
 238  	}()
 239  
 240  	for {
 241  		// read packets
 242  		count, readErr = device.tun.device.Read(bufs, sizes, offset)
 243  		for i := 0; i < count; i++ {
 244  			if sizes[i] < 1 {
 245  				continue
 246  			}
 247  
 248  			elem := elems[i]
 249  			elem.packet = bufs[i][offset : offset+sizes[i]]
 250  
 251  			// lookup peer
 252  			var peer *Peer
 253  			switch elem.packet[0] >> 4 {
 254  			case 4:
 255  				if len(elem.packet) < ipv4.HeaderLen {
 256  					continue
 257  				}
 258  				dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
 259  				peer = device.allowedips.Lookup(dst)
 260  
 261  			case 6:
 262  				if len(elem.packet) < ipv6.HeaderLen {
 263  					continue
 264  				}
 265  				dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
 266  				peer = device.allowedips.Lookup(dst)
 267  
 268  			default:
 269  				device.log.Verbosef("Received packet with unknown IP version")
 270  			}
 271  
 272  			if peer == nil {
 273  				continue
 274  			}
 275  			elemsForPeer, ok := elemsByPeer[peer]
 276  			if !ok {
 277  				elemsForPeer = device.GetOutboundElementsContainer()
 278  				elemsByPeer[peer] = elemsForPeer
 279  			}
 280  			elemsForPeer.elems = append(elemsForPeer.elems, elem)
 281  			elems[i] = device.NewOutboundElement()
 282  			bufs[i] = elems[i].buffer[:]
 283  		}
 284  
 285  		for peer, elemsForPeer := range elemsByPeer {
 286  			if peer.isRunning.Load() {
 287  				peer.StagePackets(elemsForPeer)
 288  				peer.SendStagedPackets()
 289  			} else {
 290  				for _, elem := range elemsForPeer.elems {
 291  					device.PutMessageBuffer(elem.buffer)
 292  					device.PutOutboundElement(elem)
 293  				}
 294  				device.PutOutboundElementsContainer(elemsForPeer)
 295  			}
 296  			delete(elemsByPeer, peer)
 297  		}
 298  
 299  		if readErr != nil {
 300  			if errors.Is(readErr, tun.ErrTooManySegments) {
 301  				// TODO: record stat for this
 302  				// This will happen if MSS is surprisingly small (< 576)
 303  				// coincident with reasonably high throughput.
 304  				device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
 305  				continue
 306  			}
 307  			if !device.isClosed() {
 308  				if !errors.Is(readErr, os.ErrClosed) {
 309  					device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
 310  				}
 311  				go device.Close()
 312  			}
 313  			return
 314  		}
 315  	}
 316  }
 317  
 318  func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
 319  	for {
 320  		select {
 321  		case peer.queue.staged <- elems:
 322  			return
 323  		default:
 324  		}
 325  		select {
 326  		case tooOld := <-peer.queue.staged:
 327  			for _, elem := range tooOld.elems {
 328  				peer.device.PutMessageBuffer(elem.buffer)
 329  				peer.device.PutOutboundElement(elem)
 330  			}
 331  			peer.device.PutOutboundElementsContainer(tooOld)
 332  		default:
 333  		}
 334  	}
 335  }
 336  
 337  func (peer *Peer) SendStagedPackets() {
 338  top:
 339  	if len(peer.queue.staged) == 0 || !peer.device.isUp() {
 340  		return
 341  	}
 342  
 343  	keypair := peer.keypairs.Current()
 344  	if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
 345  		peer.SendHandshakeInitiation(false)
 346  		return
 347  	}
 348  
 349  	for {
 350  		var elemsContainerOOO *QueueOutboundElementsContainer
 351  		select {
 352  		case elemsContainer := <-peer.queue.staged:
 353  			i := 0
 354  			for _, elem := range elemsContainer.elems {
 355  				elem.peer = peer
 356  				elem.nonce = keypair.sendNonce.Add(1) - 1
 357  				if elem.nonce >= RejectAfterMessages {
 358  					keypair.sendNonce.Store(RejectAfterMessages)
 359  					if elemsContainerOOO == nil {
 360  						elemsContainerOOO = peer.device.GetOutboundElementsContainer()
 361  					}
 362  					elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
 363  					continue
 364  				} else {
 365  					elemsContainer.elems[i] = elem
 366  					i++
 367  				}
 368  
 369  				elem.keypair = keypair
 370  			}
 371  			elemsContainer.Lock()
 372  			elemsContainer.elems = elemsContainer.elems[:i]
 373  
 374  			if elemsContainerOOO != nil {
 375  				peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
 376  			}
 377  
 378  			if len(elemsContainer.elems) == 0 {
 379  				peer.device.PutOutboundElementsContainer(elemsContainer)
 380  				goto top
 381  			}
 382  
 383  			// add to parallel and sequential queue
 384  			if peer.isRunning.Load() {
 385  				peer.queue.outbound.c <- elemsContainer
 386  				peer.device.queue.encryption.c <- elemsContainer
 387  			} else {
 388  				for _, elem := range elemsContainer.elems {
 389  					peer.device.PutMessageBuffer(elem.buffer)
 390  					peer.device.PutOutboundElement(elem)
 391  				}
 392  				peer.device.PutOutboundElementsContainer(elemsContainer)
 393  			}
 394  
 395  			if elemsContainerOOO != nil {
 396  				goto top
 397  			}
 398  		default:
 399  			return
 400  		}
 401  	}
 402  }
 403  
 404  func (peer *Peer) FlushStagedPackets() {
 405  	for {
 406  		select {
 407  		case elemsContainer := <-peer.queue.staged:
 408  			for _, elem := range elemsContainer.elems {
 409  				peer.device.PutMessageBuffer(elem.buffer)
 410  				peer.device.PutOutboundElement(elem)
 411  			}
 412  			peer.device.PutOutboundElementsContainer(elemsContainer)
 413  		default:
 414  			return
 415  		}
 416  	}
 417  }
 418  
 419  func calculatePaddingSize(packetSize, mtu int) int {
 420  	lastUnit := packetSize
 421  	if mtu == 0 {
 422  		return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
 423  	}
 424  	if lastUnit > mtu {
 425  		lastUnit %= mtu
 426  	}
 427  	paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
 428  	if paddedSize > mtu {
 429  		paddedSize = mtu
 430  	}
 431  	return paddedSize - lastUnit
 432  }
 433  
 434  /* Encrypts the elements in the queue
 435   * and marks them for sequential consumption (by releasing the mutex)
 436   *
 437   * Obs. One instance per core
 438   */
 439  func (device *Device) RoutineEncryption(id int) {
 440  	var paddingZeros [PaddingMultiple]byte
 441  	var nonce [chacha20poly1305.NonceSize]byte
 442  
 443  	defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
 444  	device.log.Verbosef("Routine: encryption worker %d - started", id)
 445  
 446  	for elemsContainer := range device.queue.encryption.c {
 447  		for _, elem := range elemsContainer.elems {
 448  			// populate header fields
 449  			header := elem.buffer[:MessageTransportHeaderSize]
 450  
 451  			fieldType := header[0:4]
 452  			fieldReceiver := header[4:8]
 453  			fieldNonce := header[8:16]
 454  
 455  			binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
 456  			binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
 457  			binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
 458  
 459  			// pad content to multiple of 16
 460  			paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
 461  			elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
 462  
 463  			// encrypt content and release to consumer
 464  
 465  			binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
 466  			elem.packet = elem.keypair.send.Seal(
 467  				header,
 468  				nonce[:],
 469  				elem.packet,
 470  				nil,
 471  			)
 472  		}
 473  		elemsContainer.Unlock()
 474  	}
 475  }
 476  
 477  func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
 478  	device := peer.device
 479  	defer func() {
 480  		defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
 481  		peer.stopping.Done()
 482  	}()
 483  	device.log.Verbosef("%v - Routine: sequential sender - started", peer)
 484  
 485  	bufs := make([][]byte, 0, maxBatchSize)
 486  
 487  	for elemsContainer := range peer.queue.outbound.c {
 488  		bufs = bufs[:0]
 489  		if elemsContainer == nil {
 490  			return
 491  		}
 492  		if !peer.isRunning.Load() {
 493  			// peer has been stopped; return re-usable elems to the shared pool.
 494  			// This is an optimization only. It is possible for the peer to be stopped
 495  			// immediately after this check, in which case, elem will get processed.
 496  			// The timers and SendBuffers code are resilient to a few stragglers.
 497  			// TODO: rework peer shutdown order to ensure
 498  			// that we never accidentally keep timers alive longer than necessary.
 499  			elemsContainer.Lock()
 500  			for _, elem := range elemsContainer.elems {
 501  				device.PutMessageBuffer(elem.buffer)
 502  				device.PutOutboundElement(elem)
 503  			}
 504  			device.PutOutboundElementsContainer(elemsContainer)
 505  			continue
 506  		}
 507  		dataSent := false
 508  		elemsContainer.Lock()
 509  		for _, elem := range elemsContainer.elems {
 510  			if len(elem.packet) != MessageKeepaliveSize {
 511  				dataSent = true
 512  			}
 513  			bufs = append(bufs, elem.packet)
 514  		}
 515  
 516  		peer.timersAnyAuthenticatedPacketTraversal()
 517  		peer.timersAnyAuthenticatedPacketSent()
 518  
 519  		err := peer.SendBuffers(bufs)
 520  		if dataSent {
 521  			peer.timersDataSent()
 522  		}
 523  		for _, elem := range elemsContainer.elems {
 524  			device.PutMessageBuffer(elem.buffer)
 525  			device.PutOutboundElement(elem)
 526  		}
 527  		device.PutOutboundElementsContainer(elemsContainer)
 528  		if err != nil {
 529  			var errGSO conn.ErrUDPGSODisabled
 530  			if errors.As(err, &errGSO) {
 531  				device.log.Verbosef(err.Error())
 532  				err = errGSO.RetryErr
 533  			}
 534  		}
 535  		if err != nil {
 536  			device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
 537  			continue
 538  		}
 539  
 540  		peer.keepKeyFreshSending()
 541  	}
 542  }
 543