receive.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package device
   7  
   8  import (
   9  	"encoding/binary"
  10  	"errors"
  11  	"net"
  12  	"sync"
  13  	"time"
  14  
  15  	"golang.org/x/crypto/chacha20poly1305"
  16  	"golang.org/x/net/ipv4"
  17  	"golang.org/x/net/ipv6"
  18  	"golang.zx2c4.com/wireguard/conn"
  19  )
  20  
  21  type QueueHandshakeElement struct {
  22  	msgType  uint32
  23  	packet   []byte
  24  	endpoint conn.Endpoint
  25  	buffer   *[MaxMessageSize]byte
  26  }
  27  
  28  type QueueInboundElement struct {
  29  	buffer   *[MaxMessageSize]byte
  30  	packet   []byte
  31  	counter  uint64
  32  	keypair  *Keypair
  33  	endpoint conn.Endpoint
  34  }
  35  
  36  type QueueInboundElementsContainer struct {
  37  	sync.Mutex
  38  	elems []*QueueInboundElement
  39  }
  40  
  41  // clearPointers clears elem fields that contain pointers.
  42  // This makes the garbage collector's life easier and
  43  // avoids accidentally keeping other objects around unnecessarily.
  44  // It also reduces the possible collateral damage from use-after-free bugs.
  45  func (elem *QueueInboundElement) clearPointers() {
  46  	elem.buffer = nil
  47  	elem.packet = nil
  48  	elem.keypair = nil
  49  	elem.endpoint = nil
  50  }
  51  
  52  /* Called when a new authenticated message has been received
  53   *
  54   * NOTE: Not thread safe, but called by sequential receiver!
  55   */
  56  func (peer *Peer) keepKeyFreshReceiving() {
  57  	if peer.timers.sentLastMinuteHandshake.Load() {
  58  		return
  59  	}
  60  	keypair := peer.keypairs.Current()
  61  	if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
  62  		peer.timers.sentLastMinuteHandshake.Store(true)
  63  		peer.SendHandshakeInitiation(false)
  64  	}
  65  }
  66  
  67  /* Receives incoming datagrams for the device
  68   *
  69   * Every time the bind is updated a new routine is started for
  70   * IPv4 and IPv6 (separately)
  71   */
  72  func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
  73  	recvName := recv.PrettyName()
  74  	defer func() {
  75  		device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
  76  		device.queue.decryption.wg.Done()
  77  		device.queue.handshake.wg.Done()
  78  		device.net.stopping.Done()
  79  	}()
  80  
  81  	device.log.Verbosef("Routine: receive incoming %s - started", recvName)
  82  
  83  	// receive datagrams until conn is closed
  84  
  85  	var (
  86  		bufsArrs    = make([]*[MaxMessageSize]byte, maxBatchSize)
  87  		bufs        = make([][]byte, maxBatchSize)
  88  		err         error
  89  		sizes       = make([]int, maxBatchSize)
  90  		count       int
  91  		endpoints   = make([]conn.Endpoint, maxBatchSize)
  92  		deathSpiral int
  93  		elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
  94  	)
  95  
  96  	for i := range bufsArrs {
  97  		bufsArrs[i] = device.GetMessageBuffer()
  98  		bufs[i] = bufsArrs[i][:]
  99  	}
 100  
 101  	defer func() {
 102  		for i := 0; i < maxBatchSize; i++ {
 103  			if bufsArrs[i] != nil {
 104  				device.PutMessageBuffer(bufsArrs[i])
 105  			}
 106  		}
 107  	}()
 108  
 109  	for {
 110  		count, err = recv(bufs, sizes, endpoints)
 111  		if err != nil {
 112  			if errors.Is(err, net.ErrClosed) {
 113  				return
 114  			}
 115  			device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
 116  			if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
 117  				return
 118  			}
 119  			if deathSpiral < 10 {
 120  				deathSpiral++
 121  				time.Sleep(time.Second / 3)
 122  				continue
 123  			}
 124  			return
 125  		}
 126  		deathSpiral = 0
 127  
 128  		// handle each packet in the batch
 129  		for i, size := range sizes[:count] {
 130  			if size < MinMessageSize {
 131  				continue
 132  			}
 133  
 134  			// check size of packet
 135  
 136  			packet := bufsArrs[i][:size]
 137  			msgType := binary.LittleEndian.Uint32(packet[:4])
 138  
 139  			switch msgType {
 140  
 141  			// check if transport
 142  
 143  			case MessageTransportType:
 144  
 145  				// check size
 146  
 147  				if len(packet) < MessageTransportSize {
 148  					continue
 149  				}
 150  
 151  				// lookup key pair
 152  
 153  				receiver := binary.LittleEndian.Uint32(
 154  					packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
 155  				)
 156  				value := device.indexTable.Lookup(receiver)
 157  				keypair := value.keypair
 158  				if keypair == nil {
 159  					continue
 160  				}
 161  
 162  				// check keypair expiry
 163  
 164  				if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
 165  					continue
 166  				}
 167  
 168  				// create work element
 169  				peer := value.peer
 170  				elem := device.GetInboundElement()
 171  				elem.packet = packet
 172  				elem.buffer = bufsArrs[i]
 173  				elem.keypair = keypair
 174  				elem.endpoint = endpoints[i]
 175  				elem.counter = 0
 176  
 177  				elemsForPeer, ok := elemsByPeer[peer]
 178  				if !ok {
 179  					elemsForPeer = device.GetInboundElementsContainer()
 180  					elemsForPeer.Lock()
 181  					elemsByPeer[peer] = elemsForPeer
 182  				}
 183  				elemsForPeer.elems = append(elemsForPeer.elems, elem)
 184  				bufsArrs[i] = device.GetMessageBuffer()
 185  				bufs[i] = bufsArrs[i][:]
 186  				continue
 187  
 188  			// otherwise it is a fixed size & handshake related packet
 189  
 190  			case MessageInitiationType:
 191  				if len(packet) != MessageInitiationSize {
 192  					continue
 193  				}
 194  
 195  			case MessageResponseType:
 196  				if len(packet) != MessageResponseSize {
 197  					continue
 198  				}
 199  
 200  			case MessageCookieReplyType:
 201  				if len(packet) != MessageCookieReplySize {
 202  					continue
 203  				}
 204  
 205  			default:
 206  				device.log.Verbosef("Received message with unknown type")
 207  				continue
 208  			}
 209  
 210  			select {
 211  			case device.queue.handshake.c <- QueueHandshakeElement{
 212  				msgType:  msgType,
 213  				buffer:   bufsArrs[i],
 214  				packet:   packet,
 215  				endpoint: endpoints[i],
 216  			}:
 217  				bufsArrs[i] = device.GetMessageBuffer()
 218  				bufs[i] = bufsArrs[i][:]
 219  			default:
 220  			}
 221  		}
 222  		for peer, elemsContainer := range elemsByPeer {
 223  			if peer.isRunning.Load() {
 224  				peer.queue.inbound.c <- elemsContainer
 225  				device.queue.decryption.c <- elemsContainer
 226  			} else {
 227  				for _, elem := range elemsContainer.elems {
 228  					device.PutMessageBuffer(elem.buffer)
 229  					device.PutInboundElement(elem)
 230  				}
 231  				device.PutInboundElementsContainer(elemsContainer)
 232  			}
 233  			delete(elemsByPeer, peer)
 234  		}
 235  	}
 236  }
 237  
 238  func (device *Device) RoutineDecryption(id int) {
 239  	var nonce [chacha20poly1305.NonceSize]byte
 240  
 241  	defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
 242  	device.log.Verbosef("Routine: decryption worker %d - started", id)
 243  
 244  	for elemsContainer := range device.queue.decryption.c {
 245  		for _, elem := range elemsContainer.elems {
 246  			// split message into fields
 247  			counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
 248  			content := elem.packet[MessageTransportOffsetContent:]
 249  
 250  			// decrypt and release to consumer
 251  			var err error
 252  			elem.counter = binary.LittleEndian.Uint64(counter)
 253  			// copy counter to nonce
 254  			binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
 255  			elem.packet, err = elem.keypair.receive.Open(
 256  				content[:0],
 257  				nonce[:],
 258  				content,
 259  				nil,
 260  			)
 261  			if err != nil {
 262  				elem.packet = nil
 263  			}
 264  		}
 265  		elemsContainer.Unlock()
 266  	}
 267  }
 268  
 269  /* Handles incoming packets related to handshake
 270   */
 271  func (device *Device) RoutineHandshake(id int) {
 272  	defer func() {
 273  		device.log.Verbosef("Routine: handshake worker %d - stopped", id)
 274  		device.queue.encryption.wg.Done()
 275  	}()
 276  	device.log.Verbosef("Routine: handshake worker %d - started", id)
 277  
 278  	for elem := range device.queue.handshake.c {
 279  
 280  		// handle cookie fields and ratelimiting
 281  
 282  		switch elem.msgType {
 283  
 284  		case MessageCookieReplyType:
 285  
 286  			// unmarshal packet
 287  
 288  			var reply MessageCookieReply
 289  			err := reply.unmarshal(elem.packet)
 290  			if err != nil {
 291  				device.log.Verbosef("Failed to decode cookie reply")
 292  				goto skip
 293  			}
 294  
 295  			// lookup peer from index
 296  
 297  			entry := device.indexTable.Lookup(reply.Receiver)
 298  
 299  			if entry.peer == nil {
 300  				goto skip
 301  			}
 302  
 303  			// consume reply
 304  
 305  			if peer := entry.peer; peer.isRunning.Load() {
 306  				device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
 307  				if !peer.cookieGenerator.ConsumeReply(&reply) {
 308  					device.log.Verbosef("Could not decrypt invalid cookie response")
 309  				}
 310  			}
 311  
 312  			goto skip
 313  
 314  		case MessageInitiationType, MessageResponseType:
 315  
 316  			// check mac fields and maybe ratelimit
 317  
 318  			if !device.cookieChecker.CheckMAC1(elem.packet) {
 319  				device.log.Verbosef("Received packet with invalid mac1")
 320  				goto skip
 321  			}
 322  
 323  			// endpoints destination address is the source of the datagram
 324  
 325  			if device.IsUnderLoad() {
 326  
 327  				// verify MAC2 field
 328  
 329  				if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
 330  					device.SendHandshakeCookie(&elem)
 331  					goto skip
 332  				}
 333  
 334  				// check ratelimiter
 335  
 336  				if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
 337  					goto skip
 338  				}
 339  			}
 340  
 341  		default:
 342  			device.log.Errorf("Invalid packet ended up in the handshake queue")
 343  			goto skip
 344  		}
 345  
 346  		// handle handshake initiation/response content
 347  
 348  		switch elem.msgType {
 349  		case MessageInitiationType:
 350  
 351  			// unmarshal
 352  
 353  			var msg MessageInitiation
 354  			err := msg.unmarshal(elem.packet)
 355  			if err != nil {
 356  				device.log.Errorf("Failed to decode initiation message")
 357  				goto skip
 358  			}
 359  
 360  			// consume initiation
 361  
 362  			peer := device.ConsumeMessageInitiation(&msg)
 363  			if peer == nil {
 364  				device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
 365  				goto skip
 366  			}
 367  
 368  			// update timers
 369  
 370  			peer.timersAnyAuthenticatedPacketTraversal()
 371  			peer.timersAnyAuthenticatedPacketReceived()
 372  
 373  			// update endpoint
 374  			peer.SetEndpointFromPacket(elem.endpoint)
 375  
 376  			device.log.Verbosef("%v - Received handshake initiation", peer)
 377  			peer.rxBytes.Add(uint64(len(elem.packet)))
 378  
 379  			peer.SendHandshakeResponse()
 380  
 381  		case MessageResponseType:
 382  
 383  			// unmarshal
 384  
 385  			var msg MessageResponse
 386  			err := msg.unmarshal(elem.packet)
 387  			if err != nil {
 388  				device.log.Errorf("Failed to decode response message")
 389  				goto skip
 390  			}
 391  
 392  			// consume response
 393  
 394  			peer := device.ConsumeMessageResponse(&msg)
 395  			if peer == nil {
 396  				device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
 397  				goto skip
 398  			}
 399  
 400  			// update endpoint
 401  			peer.SetEndpointFromPacket(elem.endpoint)
 402  
 403  			device.log.Verbosef("%v - Received handshake response", peer)
 404  			peer.rxBytes.Add(uint64(len(elem.packet)))
 405  
 406  			// update timers
 407  
 408  			peer.timersAnyAuthenticatedPacketTraversal()
 409  			peer.timersAnyAuthenticatedPacketReceived()
 410  
 411  			// derive keypair
 412  
 413  			err = peer.BeginSymmetricSession()
 414  
 415  			if err != nil {
 416  				device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
 417  				goto skip
 418  			}
 419  
 420  			peer.timersSessionDerived()
 421  			peer.timersHandshakeComplete()
 422  			peer.SendKeepalive()
 423  		}
 424  	skip:
 425  		device.PutMessageBuffer(elem.buffer)
 426  	}
 427  }
 428  
 429  func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
 430  	device := peer.device
 431  	defer func() {
 432  		device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
 433  		peer.stopping.Done()
 434  	}()
 435  	device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
 436  
 437  	bufs := make([][]byte, 0, maxBatchSize)
 438  
 439  	for elemsContainer := range peer.queue.inbound.c {
 440  		if elemsContainer == nil {
 441  			return
 442  		}
 443  		elemsContainer.Lock()
 444  		validTailPacket := -1
 445  		dataPacketReceived := false
 446  		rxBytesLen := uint64(0)
 447  		for i, elem := range elemsContainer.elems {
 448  			if elem.packet == nil {
 449  				// decryption failed
 450  				continue
 451  			}
 452  
 453  			if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
 454  				continue
 455  			}
 456  
 457  			validTailPacket = i
 458  			if peer.ReceivedWithKeypair(elem.keypair) {
 459  				peer.SetEndpointFromPacket(elem.endpoint)
 460  				peer.timersHandshakeComplete()
 461  				peer.SendStagedPackets()
 462  			}
 463  			rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
 464  
 465  			if len(elem.packet) == 0 {
 466  				device.log.Verbosef("%v - Receiving keepalive packet", peer)
 467  				continue
 468  			}
 469  			dataPacketReceived = true
 470  
 471  			switch elem.packet[0] >> 4 {
 472  			case 4:
 473  				if len(elem.packet) < ipv4.HeaderLen {
 474  					continue
 475  				}
 476  				field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
 477  				length := binary.BigEndian.Uint16(field)
 478  				if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
 479  					continue
 480  				}
 481  				elem.packet = elem.packet[:length]
 482  				src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
 483  				if device.allowedips.Lookup(src) != peer {
 484  					device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
 485  					continue
 486  				}
 487  
 488  			case 6:
 489  				if len(elem.packet) < ipv6.HeaderLen {
 490  					continue
 491  				}
 492  				field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
 493  				length := binary.BigEndian.Uint16(field)
 494  				length += ipv6.HeaderLen
 495  				if int(length) > len(elem.packet) {
 496  					continue
 497  				}
 498  				elem.packet = elem.packet[:length]
 499  				src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
 500  				if device.allowedips.Lookup(src) != peer {
 501  					device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
 502  					continue
 503  				}
 504  
 505  			default:
 506  				device.log.Verbosef("Packet with invalid IP version from %v", peer)
 507  				continue
 508  			}
 509  
 510  			bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
 511  		}
 512  
 513  		peer.rxBytes.Add(rxBytesLen)
 514  		if validTailPacket >= 0 {
 515  			peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
 516  			peer.keepKeyFreshReceiving()
 517  			peer.timersAnyAuthenticatedPacketTraversal()
 518  			peer.timersAnyAuthenticatedPacketReceived()
 519  		}
 520  		if dataPacketReceived {
 521  			peer.timersDataReceived()
 522  		}
 523  		if len(bufs) > 0 {
 524  			_, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
 525  			if err != nil && !device.isClosed() {
 526  				device.log.Errorf("Failed to write packets to TUN device: %v", err)
 527  			}
 528  		}
 529  		for _, elem := range elemsContainer.elems {
 530  			device.PutMessageBuffer(elem.buffer)
 531  			device.PutInboundElement(elem)
 532  		}
 533  		bufs = bufs[:0]
 534  		device.PutInboundElementsContainer(elemsContainer)
 535  	}
 536  }
 537