receive.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package device
7
8 import (
9 "encoding/binary"
10 "errors"
11 "net"
12 "sync"
13 "time"
14
15 "golang.org/x/crypto/chacha20poly1305"
16 "golang.org/x/net/ipv4"
17 "golang.org/x/net/ipv6"
18 "golang.zx2c4.com/wireguard/conn"
19 )
20
21 type QueueHandshakeElement struct {
22 msgType uint32
23 packet []byte
24 endpoint conn.Endpoint
25 buffer *[MaxMessageSize]byte
26 }
27
28 type QueueInboundElement struct {
29 buffer *[MaxMessageSize]byte
30 packet []byte
31 counter uint64
32 keypair *Keypair
33 endpoint conn.Endpoint
34 }
35
36 type QueueInboundElementsContainer struct {
37 sync.Mutex
38 elems []*QueueInboundElement
39 }
40
41 // clearPointers clears elem fields that contain pointers.
42 // This makes the garbage collector's life easier and
43 // avoids accidentally keeping other objects around unnecessarily.
44 // It also reduces the possible collateral damage from use-after-free bugs.
45 func (elem *QueueInboundElement) clearPointers() {
46 elem.buffer = nil
47 elem.packet = nil
48 elem.keypair = nil
49 elem.endpoint = nil
50 }
51
52 /* Called when a new authenticated message has been received
53 *
54 * NOTE: Not thread safe, but called by sequential receiver!
55 */
56 func (peer *Peer) keepKeyFreshReceiving() {
57 if peer.timers.sentLastMinuteHandshake.Load() {
58 return
59 }
60 keypair := peer.keypairs.Current()
61 if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
62 peer.timers.sentLastMinuteHandshake.Store(true)
63 peer.SendHandshakeInitiation(false)
64 }
65 }
66
67 /* Receives incoming datagrams for the device
68 *
69 * Every time the bind is updated a new routine is started for
70 * IPv4 and IPv6 (separately)
71 */
72 func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.ReceiveFunc) {
73 recvName := recv.PrettyName()
74 defer func() {
75 device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
76 device.queue.decryption.wg.Done()
77 device.queue.handshake.wg.Done()
78 device.net.stopping.Done()
79 }()
80
81 device.log.Verbosef("Routine: receive incoming %s - started", recvName)
82
83 // receive datagrams until conn is closed
84
85 var (
86 bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize)
87 bufs = make([][]byte, maxBatchSize)
88 err error
89 sizes = make([]int, maxBatchSize)
90 count int
91 endpoints = make([]conn.Endpoint, maxBatchSize)
92 deathSpiral int
93 elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize)
94 )
95
96 for i := range bufsArrs {
97 bufsArrs[i] = device.GetMessageBuffer()
98 bufs[i] = bufsArrs[i][:]
99 }
100
101 defer func() {
102 for i := 0; i < maxBatchSize; i++ {
103 if bufsArrs[i] != nil {
104 device.PutMessageBuffer(bufsArrs[i])
105 }
106 }
107 }()
108
109 for {
110 count, err = recv(bufs, sizes, endpoints)
111 if err != nil {
112 if errors.Is(err, net.ErrClosed) {
113 return
114 }
115 device.log.Verbosef("Failed to receive %s packet: %v", recvName, err)
116 if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
117 return
118 }
119 if deathSpiral < 10 {
120 deathSpiral++
121 time.Sleep(time.Second / 3)
122 continue
123 }
124 return
125 }
126 deathSpiral = 0
127
128 // handle each packet in the batch
129 for i, size := range sizes[:count] {
130 if size < MinMessageSize {
131 continue
132 }
133
134 // check size of packet
135
136 packet := bufsArrs[i][:size]
137 msgType := binary.LittleEndian.Uint32(packet[:4])
138
139 switch msgType {
140
141 // check if transport
142
143 case MessageTransportType:
144
145 // check size
146
147 if len(packet) < MessageTransportSize {
148 continue
149 }
150
151 // lookup key pair
152
153 receiver := binary.LittleEndian.Uint32(
154 packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
155 )
156 value := device.indexTable.Lookup(receiver)
157 keypair := value.keypair
158 if keypair == nil {
159 continue
160 }
161
162 // check keypair expiry
163
164 if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
165 continue
166 }
167
168 // create work element
169 peer := value.peer
170 elem := device.GetInboundElement()
171 elem.packet = packet
172 elem.buffer = bufsArrs[i]
173 elem.keypair = keypair
174 elem.endpoint = endpoints[i]
175 elem.counter = 0
176
177 elemsForPeer, ok := elemsByPeer[peer]
178 if !ok {
179 elemsForPeer = device.GetInboundElementsContainer()
180 elemsForPeer.Lock()
181 elemsByPeer[peer] = elemsForPeer
182 }
183 elemsForPeer.elems = append(elemsForPeer.elems, elem)
184 bufsArrs[i] = device.GetMessageBuffer()
185 bufs[i] = bufsArrs[i][:]
186 continue
187
188 // otherwise it is a fixed size & handshake related packet
189
190 case MessageInitiationType:
191 if len(packet) != MessageInitiationSize {
192 continue
193 }
194
195 case MessageResponseType:
196 if len(packet) != MessageResponseSize {
197 continue
198 }
199
200 case MessageCookieReplyType:
201 if len(packet) != MessageCookieReplySize {
202 continue
203 }
204
205 default:
206 device.log.Verbosef("Received message with unknown type")
207 continue
208 }
209
210 select {
211 case device.queue.handshake.c <- QueueHandshakeElement{
212 msgType: msgType,
213 buffer: bufsArrs[i],
214 packet: packet,
215 endpoint: endpoints[i],
216 }:
217 bufsArrs[i] = device.GetMessageBuffer()
218 bufs[i] = bufsArrs[i][:]
219 default:
220 }
221 }
222 for peer, elemsContainer := range elemsByPeer {
223 if peer.isRunning.Load() {
224 peer.queue.inbound.c <- elemsContainer
225 device.queue.decryption.c <- elemsContainer
226 } else {
227 for _, elem := range elemsContainer.elems {
228 device.PutMessageBuffer(elem.buffer)
229 device.PutInboundElement(elem)
230 }
231 device.PutInboundElementsContainer(elemsContainer)
232 }
233 delete(elemsByPeer, peer)
234 }
235 }
236 }
237
238 func (device *Device) RoutineDecryption(id int) {
239 var nonce [chacha20poly1305.NonceSize]byte
240
241 defer device.log.Verbosef("Routine: decryption worker %d - stopped", id)
242 device.log.Verbosef("Routine: decryption worker %d - started", id)
243
244 for elemsContainer := range device.queue.decryption.c {
245 for _, elem := range elemsContainer.elems {
246 // split message into fields
247 counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
248 content := elem.packet[MessageTransportOffsetContent:]
249
250 // decrypt and release to consumer
251 var err error
252 elem.counter = binary.LittleEndian.Uint64(counter)
253 // copy counter to nonce
254 binary.LittleEndian.PutUint64(nonce[0x4:0xc], elem.counter)
255 elem.packet, err = elem.keypair.receive.Open(
256 content[:0],
257 nonce[:],
258 content,
259 nil,
260 )
261 if err != nil {
262 elem.packet = nil
263 }
264 }
265 elemsContainer.Unlock()
266 }
267 }
268
269 /* Handles incoming packets related to handshake
270 */
271 func (device *Device) RoutineHandshake(id int) {
272 defer func() {
273 device.log.Verbosef("Routine: handshake worker %d - stopped", id)
274 device.queue.encryption.wg.Done()
275 }()
276 device.log.Verbosef("Routine: handshake worker %d - started", id)
277
278 for elem := range device.queue.handshake.c {
279
280 // handle cookie fields and ratelimiting
281
282 switch elem.msgType {
283
284 case MessageCookieReplyType:
285
286 // unmarshal packet
287
288 var reply MessageCookieReply
289 err := reply.unmarshal(elem.packet)
290 if err != nil {
291 device.log.Verbosef("Failed to decode cookie reply")
292 goto skip
293 }
294
295 // lookup peer from index
296
297 entry := device.indexTable.Lookup(reply.Receiver)
298
299 if entry.peer == nil {
300 goto skip
301 }
302
303 // consume reply
304
305 if peer := entry.peer; peer.isRunning.Load() {
306 device.log.Verbosef("Receiving cookie response from %s", elem.endpoint.DstToString())
307 if !peer.cookieGenerator.ConsumeReply(&reply) {
308 device.log.Verbosef("Could not decrypt invalid cookie response")
309 }
310 }
311
312 goto skip
313
314 case MessageInitiationType, MessageResponseType:
315
316 // check mac fields and maybe ratelimit
317
318 if !device.cookieChecker.CheckMAC1(elem.packet) {
319 device.log.Verbosef("Received packet with invalid mac1")
320 goto skip
321 }
322
323 // endpoints destination address is the source of the datagram
324
325 if device.IsUnderLoad() {
326
327 // verify MAC2 field
328
329 if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
330 device.SendHandshakeCookie(&elem)
331 goto skip
332 }
333
334 // check ratelimiter
335
336 if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
337 goto skip
338 }
339 }
340
341 default:
342 device.log.Errorf("Invalid packet ended up in the handshake queue")
343 goto skip
344 }
345
346 // handle handshake initiation/response content
347
348 switch elem.msgType {
349 case MessageInitiationType:
350
351 // unmarshal
352
353 var msg MessageInitiation
354 err := msg.unmarshal(elem.packet)
355 if err != nil {
356 device.log.Errorf("Failed to decode initiation message")
357 goto skip
358 }
359
360 // consume initiation
361
362 peer := device.ConsumeMessageInitiation(&msg)
363 if peer == nil {
364 device.log.Verbosef("Received invalid initiation message from %s", elem.endpoint.DstToString())
365 goto skip
366 }
367
368 // update timers
369
370 peer.timersAnyAuthenticatedPacketTraversal()
371 peer.timersAnyAuthenticatedPacketReceived()
372
373 // update endpoint
374 peer.SetEndpointFromPacket(elem.endpoint)
375
376 device.log.Verbosef("%v - Received handshake initiation", peer)
377 peer.rxBytes.Add(uint64(len(elem.packet)))
378
379 peer.SendHandshakeResponse()
380
381 case MessageResponseType:
382
383 // unmarshal
384
385 var msg MessageResponse
386 err := msg.unmarshal(elem.packet)
387 if err != nil {
388 device.log.Errorf("Failed to decode response message")
389 goto skip
390 }
391
392 // consume response
393
394 peer := device.ConsumeMessageResponse(&msg)
395 if peer == nil {
396 device.log.Verbosef("Received invalid response message from %s", elem.endpoint.DstToString())
397 goto skip
398 }
399
400 // update endpoint
401 peer.SetEndpointFromPacket(elem.endpoint)
402
403 device.log.Verbosef("%v - Received handshake response", peer)
404 peer.rxBytes.Add(uint64(len(elem.packet)))
405
406 // update timers
407
408 peer.timersAnyAuthenticatedPacketTraversal()
409 peer.timersAnyAuthenticatedPacketReceived()
410
411 // derive keypair
412
413 err = peer.BeginSymmetricSession()
414
415 if err != nil {
416 device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
417 goto skip
418 }
419
420 peer.timersSessionDerived()
421 peer.timersHandshakeComplete()
422 peer.SendKeepalive()
423 }
424 skip:
425 device.PutMessageBuffer(elem.buffer)
426 }
427 }
428
429 func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) {
430 device := peer.device
431 defer func() {
432 device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
433 peer.stopping.Done()
434 }()
435 device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
436
437 bufs := make([][]byte, 0, maxBatchSize)
438
439 for elemsContainer := range peer.queue.inbound.c {
440 if elemsContainer == nil {
441 return
442 }
443 elemsContainer.Lock()
444 validTailPacket := -1
445 dataPacketReceived := false
446 rxBytesLen := uint64(0)
447 for i, elem := range elemsContainer.elems {
448 if elem.packet == nil {
449 // decryption failed
450 continue
451 }
452
453 if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
454 continue
455 }
456
457 validTailPacket = i
458 if peer.ReceivedWithKeypair(elem.keypair) {
459 peer.SetEndpointFromPacket(elem.endpoint)
460 peer.timersHandshakeComplete()
461 peer.SendStagedPackets()
462 }
463 rxBytesLen += uint64(len(elem.packet) + MinMessageSize)
464
465 if len(elem.packet) == 0 {
466 device.log.Verbosef("%v - Receiving keepalive packet", peer)
467 continue
468 }
469 dataPacketReceived = true
470
471 switch elem.packet[0] >> 4 {
472 case 4:
473 if len(elem.packet) < ipv4.HeaderLen {
474 continue
475 }
476 field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
477 length := binary.BigEndian.Uint16(field)
478 if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
479 continue
480 }
481 elem.packet = elem.packet[:length]
482 src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
483 if device.allowedips.Lookup(src) != peer {
484 device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
485 continue
486 }
487
488 case 6:
489 if len(elem.packet) < ipv6.HeaderLen {
490 continue
491 }
492 field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
493 length := binary.BigEndian.Uint16(field)
494 length += ipv6.HeaderLen
495 if int(length) > len(elem.packet) {
496 continue
497 }
498 elem.packet = elem.packet[:length]
499 src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
500 if device.allowedips.Lookup(src) != peer {
501 device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
502 continue
503 }
504
505 default:
506 device.log.Verbosef("Packet with invalid IP version from %v", peer)
507 continue
508 }
509
510 bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)])
511 }
512
513 peer.rxBytes.Add(rxBytesLen)
514 if validTailPacket >= 0 {
515 peer.SetEndpointFromPacket(elemsContainer.elems[validTailPacket].endpoint)
516 peer.keepKeyFreshReceiving()
517 peer.timersAnyAuthenticatedPacketTraversal()
518 peer.timersAnyAuthenticatedPacketReceived()
519 }
520 if dataPacketReceived {
521 peer.timersDataReceived()
522 }
523 if len(bufs) > 0 {
524 _, err := device.tun.device.Write(bufs, MessageTransportOffsetContent)
525 if err != nil && !device.isClosed() {
526 device.log.Errorf("Failed to write packets to TUN device: %v", err)
527 }
528 }
529 for _, elem := range elemsContainer.elems {
530 device.PutMessageBuffer(elem.buffer)
531 device.PutInboundElement(elem)
532 }
533 bufs = bufs[:0]
534 device.PutInboundElementsContainer(elemsContainer)
535 }
536 }
537