send.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package device
7
8 import (
9 "encoding/binary"
10 "errors"
11 "net"
12 "os"
13 "sync"
14 "time"
15
16 "golang.org/x/crypto/chacha20poly1305"
17 "golang.org/x/net/ipv4"
18 "golang.org/x/net/ipv6"
19 "golang.zx2c4.com/wireguard/conn"
20 "golang.zx2c4.com/wireguard/tun"
21 )
22
23 /* Outbound flow
24 *
25 * 1. TUN queue
26 * 2. Routing (sequential)
27 * 3. Nonce assignment (sequential)
28 * 4. Encryption (parallel)
29 * 5. Transmission (sequential)
30 *
31 * The functions in this file occur (roughly) in the order in
32 * which the packets are processed.
33 *
34 * Locking, Producers and Consumers
35 *
36 * The order of packets (per peer) must be maintained,
37 * but encryption of packets happen out-of-order:
38 *
39 * The sequential consumers will attempt to take the lock,
40 * workers release lock when they have completed work (encryption) on the packet.
41 *
42 * If the element is inserted into the "encryption queue",
43 * the content is preceded by enough "junk" to contain the transport header
44 * (to allow the construction of transport messages in-place)
45 */
46
47 type QueueOutboundElement struct {
48 buffer *[MaxMessageSize]byte // slice holding the packet data
49 packet []byte // slice of "buffer" (always!)
50 nonce uint64 // nonce for encryption
51 keypair *Keypair // keypair for encryption
52 peer *Peer // related peer
53 }
54
55 type QueueOutboundElementsContainer struct {
56 sync.Mutex
57 elems []*QueueOutboundElement
58 }
59
60 func (device *Device) NewOutboundElement() *QueueOutboundElement {
61 elem := device.GetOutboundElement()
62 elem.buffer = device.GetMessageBuffer()
63 elem.nonce = 0
64 // keypair and peer were cleared (if necessary) by clearPointers.
65 return elem
66 }
67
68 // clearPointers clears elem fields that contain pointers.
69 // This makes the garbage collector's life easier and
70 // avoids accidentally keeping other objects around unnecessarily.
71 // It also reduces the possible collateral damage from use-after-free bugs.
72 func (elem *QueueOutboundElement) clearPointers() {
73 elem.buffer = nil
74 elem.packet = nil
75 elem.keypair = nil
76 elem.peer = nil
77 }
78
79 /* Queues a keepalive if no packets are queued for peer
80 */
81 func (peer *Peer) SendKeepalive() {
82 if len(peer.queue.staged) == 0 && peer.isRunning.Load() {
83 elem := peer.device.NewOutboundElement()
84 elemsContainer := peer.device.GetOutboundElementsContainer()
85 elemsContainer.elems = append(elemsContainer.elems, elem)
86 select {
87 case peer.queue.staged <- elemsContainer:
88 peer.device.log.Verbosef("%v - Sending keepalive packet", peer)
89 default:
90 peer.device.PutMessageBuffer(elem.buffer)
91 peer.device.PutOutboundElement(elem)
92 peer.device.PutOutboundElementsContainer(elemsContainer)
93 }
94 }
95 peer.SendStagedPackets()
96 }
97
98 func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
99 if !isRetry {
100 peer.timers.handshakeAttempts.Store(0)
101 }
102
103 peer.handshake.mutex.RLock()
104 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
105 peer.handshake.mutex.RUnlock()
106 return nil
107 }
108 peer.handshake.mutex.RUnlock()
109
110 peer.handshake.mutex.Lock()
111 if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
112 peer.handshake.mutex.Unlock()
113 return nil
114 }
115 peer.handshake.lastSentHandshake = time.Now()
116 peer.handshake.mutex.Unlock()
117
118 peer.device.log.Verbosef("%v - Sending handshake initiation", peer)
119
120 msg, err := peer.device.CreateMessageInitiation(peer)
121 if err != nil {
122 peer.device.log.Errorf("%v - Failed to create initiation message: %v", peer, err)
123 return err
124 }
125
126 packet := make([]byte, MessageInitiationSize)
127 _ = msg.marshal(packet)
128 peer.cookieGenerator.AddMacs(packet)
129
130 peer.timersAnyAuthenticatedPacketTraversal()
131 peer.timersAnyAuthenticatedPacketSent()
132
133 err = peer.SendBuffers([][]byte{packet})
134 if err != nil {
135 peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err)
136 }
137 peer.timersHandshakeInitiated()
138
139 return err
140 }
141
142 func (peer *Peer) SendHandshakeResponse() error {
143 peer.handshake.mutex.Lock()
144 peer.handshake.lastSentHandshake = time.Now()
145 peer.handshake.mutex.Unlock()
146
147 peer.device.log.Verbosef("%v - Sending handshake response", peer)
148
149 response, err := peer.device.CreateMessageResponse(peer)
150 if err != nil {
151 peer.device.log.Errorf("%v - Failed to create response message: %v", peer, err)
152 return err
153 }
154
155 packet := make([]byte, MessageResponseSize)
156 _ = response.marshal(packet)
157 peer.cookieGenerator.AddMacs(packet)
158
159 err = peer.BeginSymmetricSession()
160 if err != nil {
161 peer.device.log.Errorf("%v - Failed to derive keypair: %v", peer, err)
162 return err
163 }
164
165 peer.timersSessionDerived()
166 peer.timersAnyAuthenticatedPacketTraversal()
167 peer.timersAnyAuthenticatedPacketSent()
168
169 // TODO: allocation could be avoided
170 err = peer.SendBuffers([][]byte{packet})
171 if err != nil {
172 peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err)
173 }
174 return err
175 }
176
177 func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
178 device.log.Verbosef("Sending cookie response for denied handshake message for %v", initiatingElem.endpoint.DstToString())
179
180 sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
181 reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
182 if err != nil {
183 device.log.Errorf("Failed to create cookie reply: %v", err)
184 return err
185 }
186
187 packet := make([]byte, MessageCookieReplySize)
188 _ = reply.marshal(packet)
189 // TODO: allocation could be avoided
190 device.net.bind.Send([][]byte{packet}, initiatingElem.endpoint)
191
192 return nil
193 }
194
195 func (peer *Peer) keepKeyFreshSending() {
196 keypair := peer.keypairs.Current()
197 if keypair == nil {
198 return
199 }
200 nonce := keypair.sendNonce.Load()
201 if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
202 peer.SendHandshakeInitiation(false)
203 }
204 }
205
206 func (device *Device) RoutineReadFromTUN() {
207 defer func() {
208 device.log.Verbosef("Routine: TUN reader - stopped")
209 device.state.stopping.Done()
210 device.queue.encryption.wg.Done()
211 }()
212
213 device.log.Verbosef("Routine: TUN reader - started")
214
215 var (
216 batchSize = device.BatchSize()
217 readErr error
218 elems = make([]*QueueOutboundElement, batchSize)
219 bufs = make([][]byte, batchSize)
220 elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize)
221 count = 0
222 sizes = make([]int, batchSize)
223 offset = MessageTransportHeaderSize
224 )
225
226 for i := range elems {
227 elems[i] = device.NewOutboundElement()
228 bufs[i] = elems[i].buffer[:]
229 }
230
231 defer func() {
232 for _, elem := range elems {
233 if elem != nil {
234 device.PutMessageBuffer(elem.buffer)
235 device.PutOutboundElement(elem)
236 }
237 }
238 }()
239
240 for {
241 // read packets
242 count, readErr = device.tun.device.Read(bufs, sizes, offset)
243 for i := 0; i < count; i++ {
244 if sizes[i] < 1 {
245 continue
246 }
247
248 elem := elems[i]
249 elem.packet = bufs[i][offset : offset+sizes[i]]
250
251 // lookup peer
252 var peer *Peer
253 switch elem.packet[0] >> 4 {
254 case 4:
255 if len(elem.packet) < ipv4.HeaderLen {
256 continue
257 }
258 dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
259 peer = device.allowedips.Lookup(dst)
260
261 case 6:
262 if len(elem.packet) < ipv6.HeaderLen {
263 continue
264 }
265 dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
266 peer = device.allowedips.Lookup(dst)
267
268 default:
269 device.log.Verbosef("Received packet with unknown IP version")
270 }
271
272 if peer == nil {
273 continue
274 }
275 elemsForPeer, ok := elemsByPeer[peer]
276 if !ok {
277 elemsForPeer = device.GetOutboundElementsContainer()
278 elemsByPeer[peer] = elemsForPeer
279 }
280 elemsForPeer.elems = append(elemsForPeer.elems, elem)
281 elems[i] = device.NewOutboundElement()
282 bufs[i] = elems[i].buffer[:]
283 }
284
285 for peer, elemsForPeer := range elemsByPeer {
286 if peer.isRunning.Load() {
287 peer.StagePackets(elemsForPeer)
288 peer.SendStagedPackets()
289 } else {
290 for _, elem := range elemsForPeer.elems {
291 device.PutMessageBuffer(elem.buffer)
292 device.PutOutboundElement(elem)
293 }
294 device.PutOutboundElementsContainer(elemsForPeer)
295 }
296 delete(elemsByPeer, peer)
297 }
298
299 if readErr != nil {
300 if errors.Is(readErr, tun.ErrTooManySegments) {
301 // TODO: record stat for this
302 // This will happen if MSS is surprisingly small (< 576)
303 // coincident with reasonably high throughput.
304 device.log.Verbosef("Dropped some packets from multi-segment read: %v", readErr)
305 continue
306 }
307 if !device.isClosed() {
308 if !errors.Is(readErr, os.ErrClosed) {
309 device.log.Errorf("Failed to read packet from TUN device: %v", readErr)
310 }
311 go device.Close()
312 }
313 return
314 }
315 }
316 }
317
318 func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) {
319 for {
320 select {
321 case peer.queue.staged <- elems:
322 return
323 default:
324 }
325 select {
326 case tooOld := <-peer.queue.staged:
327 for _, elem := range tooOld.elems {
328 peer.device.PutMessageBuffer(elem.buffer)
329 peer.device.PutOutboundElement(elem)
330 }
331 peer.device.PutOutboundElementsContainer(tooOld)
332 default:
333 }
334 }
335 }
336
337 func (peer *Peer) SendStagedPackets() {
338 top:
339 if len(peer.queue.staged) == 0 || !peer.device.isUp() {
340 return
341 }
342
343 keypair := peer.keypairs.Current()
344 if keypair == nil || keypair.sendNonce.Load() >= RejectAfterMessages || time.Since(keypair.created) >= RejectAfterTime {
345 peer.SendHandshakeInitiation(false)
346 return
347 }
348
349 for {
350 var elemsContainerOOO *QueueOutboundElementsContainer
351 select {
352 case elemsContainer := <-peer.queue.staged:
353 i := 0
354 for _, elem := range elemsContainer.elems {
355 elem.peer = peer
356 elem.nonce = keypair.sendNonce.Add(1) - 1
357 if elem.nonce >= RejectAfterMessages {
358 keypair.sendNonce.Store(RejectAfterMessages)
359 if elemsContainerOOO == nil {
360 elemsContainerOOO = peer.device.GetOutboundElementsContainer()
361 }
362 elemsContainerOOO.elems = append(elemsContainerOOO.elems, elem)
363 continue
364 } else {
365 elemsContainer.elems[i] = elem
366 i++
367 }
368
369 elem.keypair = keypair
370 }
371 elemsContainer.Lock()
372 elemsContainer.elems = elemsContainer.elems[:i]
373
374 if elemsContainerOOO != nil {
375 peer.StagePackets(elemsContainerOOO) // XXX: Out of order, but we can't front-load go chans
376 }
377
378 if len(elemsContainer.elems) == 0 {
379 peer.device.PutOutboundElementsContainer(elemsContainer)
380 goto top
381 }
382
383 // add to parallel and sequential queue
384 if peer.isRunning.Load() {
385 peer.queue.outbound.c <- elemsContainer
386 peer.device.queue.encryption.c <- elemsContainer
387 } else {
388 for _, elem := range elemsContainer.elems {
389 peer.device.PutMessageBuffer(elem.buffer)
390 peer.device.PutOutboundElement(elem)
391 }
392 peer.device.PutOutboundElementsContainer(elemsContainer)
393 }
394
395 if elemsContainerOOO != nil {
396 goto top
397 }
398 default:
399 return
400 }
401 }
402 }
403
404 func (peer *Peer) FlushStagedPackets() {
405 for {
406 select {
407 case elemsContainer := <-peer.queue.staged:
408 for _, elem := range elemsContainer.elems {
409 peer.device.PutMessageBuffer(elem.buffer)
410 peer.device.PutOutboundElement(elem)
411 }
412 peer.device.PutOutboundElementsContainer(elemsContainer)
413 default:
414 return
415 }
416 }
417 }
418
419 func calculatePaddingSize(packetSize, mtu int) int {
420 lastUnit := packetSize
421 if mtu == 0 {
422 return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
423 }
424 if lastUnit > mtu {
425 lastUnit %= mtu
426 }
427 paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
428 if paddedSize > mtu {
429 paddedSize = mtu
430 }
431 return paddedSize - lastUnit
432 }
433
434 /* Encrypts the elements in the queue
435 * and marks them for sequential consumption (by releasing the mutex)
436 *
437 * Obs. One instance per core
438 */
439 func (device *Device) RoutineEncryption(id int) {
440 var paddingZeros [PaddingMultiple]byte
441 var nonce [chacha20poly1305.NonceSize]byte
442
443 defer device.log.Verbosef("Routine: encryption worker %d - stopped", id)
444 device.log.Verbosef("Routine: encryption worker %d - started", id)
445
446 for elemsContainer := range device.queue.encryption.c {
447 for _, elem := range elemsContainer.elems {
448 // populate header fields
449 header := elem.buffer[:MessageTransportHeaderSize]
450
451 fieldType := header[0:4]
452 fieldReceiver := header[4:8]
453 fieldNonce := header[8:16]
454
455 binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
456 binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
457 binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
458
459 // pad content to multiple of 16
460 paddingSize := calculatePaddingSize(len(elem.packet), int(device.tun.mtu.Load()))
461 elem.packet = append(elem.packet, paddingZeros[:paddingSize]...)
462
463 // encrypt content and release to consumer
464
465 binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
466 elem.packet = elem.keypair.send.Seal(
467 header,
468 nonce[:],
469 elem.packet,
470 nil,
471 )
472 }
473 elemsContainer.Unlock()
474 }
475 }
476
477 func (peer *Peer) RoutineSequentialSender(maxBatchSize int) {
478 device := peer.device
479 defer func() {
480 defer device.log.Verbosef("%v - Routine: sequential sender - stopped", peer)
481 peer.stopping.Done()
482 }()
483 device.log.Verbosef("%v - Routine: sequential sender - started", peer)
484
485 bufs := make([][]byte, 0, maxBatchSize)
486
487 for elemsContainer := range peer.queue.outbound.c {
488 bufs = bufs[:0]
489 if elemsContainer == nil {
490 return
491 }
492 if !peer.isRunning.Load() {
493 // peer has been stopped; return re-usable elems to the shared pool.
494 // This is an optimization only. It is possible for the peer to be stopped
495 // immediately after this check, in which case, elem will get processed.
496 // The timers and SendBuffers code are resilient to a few stragglers.
497 // TODO: rework peer shutdown order to ensure
498 // that we never accidentally keep timers alive longer than necessary.
499 elemsContainer.Lock()
500 for _, elem := range elemsContainer.elems {
501 device.PutMessageBuffer(elem.buffer)
502 device.PutOutboundElement(elem)
503 }
504 device.PutOutboundElementsContainer(elemsContainer)
505 continue
506 }
507 dataSent := false
508 elemsContainer.Lock()
509 for _, elem := range elemsContainer.elems {
510 if len(elem.packet) != MessageKeepaliveSize {
511 dataSent = true
512 }
513 bufs = append(bufs, elem.packet)
514 }
515
516 peer.timersAnyAuthenticatedPacketTraversal()
517 peer.timersAnyAuthenticatedPacketSent()
518
519 err := peer.SendBuffers(bufs)
520 if dataSent {
521 peer.timersDataSent()
522 }
523 for _, elem := range elemsContainer.elems {
524 device.PutMessageBuffer(elem.buffer)
525 device.PutOutboundElement(elem)
526 }
527 device.PutOutboundElementsContainer(elemsContainer)
528 if err != nil {
529 var errGSO conn.ErrUDPGSODisabled
530 if errors.As(err, &errGSO) {
531 device.log.Verbosef(err.Error())
532 err = errGSO.RetryErr
533 }
534 }
535 if err != nil {
536 device.log.Errorf("%v - Failed to send data packets: %v", peer, err)
537 continue
538 }
539
540 peer.keepKeyFreshSending()
541 }
542 }
543