device.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package device
7
8 import (
9 "runtime"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "golang.zx2c4.com/wireguard/conn"
15 "golang.zx2c4.com/wireguard/ratelimiter"
16 "golang.zx2c4.com/wireguard/rwcancel"
17 "golang.zx2c4.com/wireguard/tun"
18 )
19
20 type Device struct {
21 state struct {
22 // state holds the device's state. It is accessed atomically.
23 // Use the device.deviceState method to read it.
24 // device.deviceState does not acquire the mutex, so it captures only a snapshot.
25 // During state transitions, the state variable is updated before the device itself.
26 // The state is thus either the current state of the device or
27 // the intended future state of the device.
28 // For example, while executing a call to Up, state will be deviceStateUp.
29 // There is no guarantee that that intended future state of the device
30 // will become the actual state; Up can fail.
31 // The device can also change state multiple times between time of check and time of use.
32 // Unsynchronized uses of state must therefore be advisory/best-effort only.
33 state atomic.Uint32 // actually a deviceState, but typed uint32 for convenience
34 // stopping blocks until all inputs to Device have been closed.
35 stopping sync.WaitGroup
36 // mu protects state changes.
37 sync.Mutex
38 }
39
40 net struct {
41 stopping sync.WaitGroup
42 sync.RWMutex
43 bind conn.Bind // bind interface
44 netlinkCancel *rwcancel.RWCancel
45 port uint16 // listening port
46 fwmark uint32 // mark value (0 = disabled)
47 brokenRoaming bool
48 }
49
50 staticIdentity struct {
51 sync.RWMutex
52 privateKey NoisePrivateKey
53 publicKey NoisePublicKey
54 }
55
56 peers struct {
57 sync.RWMutex // protects keyMap
58 keyMap map[NoisePublicKey]*Peer
59 }
60
61 rate struct {
62 underLoadUntil atomic.Int64
63 limiter ratelimiter.Ratelimiter
64 }
65
66 allowedips AllowedIPs
67 indexTable IndexTable
68 cookieChecker CookieChecker
69
70 pool struct {
71 inboundElementsContainer *WaitPool
72 outboundElementsContainer *WaitPool
73 messageBuffers *WaitPool
74 inboundElements *WaitPool
75 outboundElements *WaitPool
76 }
77
78 queue struct {
79 encryption *outboundQueue
80 decryption *inboundQueue
81 handshake *handshakeQueue
82 }
83
84 tun struct {
85 device tun.Device
86 mtu atomic.Int32
87 }
88
89 ipcMutex sync.RWMutex
90 closed chan struct{}
91 log *Logger
92 }
93
94 // deviceState represents the state of a Device.
95 // There are three states: down, up, closed.
96 // Transitions:
97 //
98 // down -----+
99 // ↑↓ ↓
100 // up -> closed
101 type deviceState uint32
102
103 //go:generate go run golang.org/x/tools/cmd/stringer -type deviceState -trimprefix=deviceState
104 const (
105 deviceStateDown deviceState = iota
106 deviceStateUp
107 deviceStateClosed
108 )
109
110 // deviceState returns device.state.state as a deviceState
111 // See those docs for how to interpret this value.
112 func (device *Device) deviceState() deviceState {
113 return deviceState(device.state.state.Load())
114 }
115
116 // isClosed reports whether the device is closed (or is closing).
117 // See device.state.state comments for how to interpret this value.
118 func (device *Device) isClosed() bool {
119 return device.deviceState() == deviceStateClosed
120 }
121
122 // isUp reports whether the device is up (or is attempting to come up).
123 // See device.state.state comments for how to interpret this value.
124 func (device *Device) isUp() bool {
125 return device.deviceState() == deviceStateUp
126 }
127
128 // Must hold device.peers.Lock()
129 func removePeerLocked(device *Device, peer *Peer, key NoisePublicKey) {
130 // stop routing and processing of packets
131 device.allowedips.RemoveByPeer(peer)
132 peer.Stop()
133
134 // remove from peer map
135 delete(device.peers.keyMap, key)
136 }
137
138 // changeState attempts to change the device state to match want.
139 func (device *Device) changeState(want deviceState) (err error) {
140 device.state.Lock()
141 defer device.state.Unlock()
142 old := device.deviceState()
143 if old == deviceStateClosed {
144 // once closed, always closed
145 device.log.Verbosef("Interface closed, ignored requested state %s", want)
146 return nil
147 }
148 switch want {
149 case old:
150 return nil
151 case deviceStateUp:
152 device.state.state.Store(uint32(deviceStateUp))
153 err = device.upLocked()
154 if err == nil {
155 break
156 }
157 fallthrough // up failed; bring the device all the way back down
158 case deviceStateDown:
159 device.state.state.Store(uint32(deviceStateDown))
160 errDown := device.downLocked()
161 if err == nil {
162 err = errDown
163 }
164 }
165 device.log.Verbosef("Interface state was %s, requested %s, now %s", old, want, device.deviceState())
166 return
167 }
168
169 // upLocked attempts to bring the device up and reports whether it succeeded.
170 // The caller must hold device.state.mu and is responsible for updating device.state.state.
171 func (device *Device) upLocked() error {
172 if err := device.BindUpdate(); err != nil {
173 device.log.Errorf("Unable to update bind: %v", err)
174 return err
175 }
176
177 // The IPC set operation waits for peers to be created before calling Start() on them,
178 // so if there's a concurrent IPC set request happening, we should wait for it to complete.
179 device.ipcMutex.Lock()
180 defer device.ipcMutex.Unlock()
181
182 device.peers.RLock()
183 for _, peer := range device.peers.keyMap {
184 peer.Start()
185 if peer.persistentKeepaliveInterval.Load() > 0 {
186 peer.SendKeepalive()
187 }
188 }
189 device.peers.RUnlock()
190 return nil
191 }
192
193 // downLocked attempts to bring the device down.
194 // The caller must hold device.state.mu and is responsible for updating device.state.state.
195 func (device *Device) downLocked() error {
196 err := device.BindClose()
197 if err != nil {
198 device.log.Errorf("Bind close failed: %v", err)
199 }
200
201 device.peers.RLock()
202 for _, peer := range device.peers.keyMap {
203 peer.Stop()
204 }
205 device.peers.RUnlock()
206 return err
207 }
208
209 func (device *Device) Up() error {
210 return device.changeState(deviceStateUp)
211 }
212
213 func (device *Device) Down() error {
214 return device.changeState(deviceStateDown)
215 }
216
217 func (device *Device) IsUnderLoad() bool {
218 // check if currently under load
219 now := time.Now()
220 underLoad := len(device.queue.handshake.c) >= QueueHandshakeSize/8
221 if underLoad {
222 device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime).UnixNano())
223 return true
224 }
225 // check if recently under load
226 return device.rate.underLoadUntil.Load() > now.UnixNano()
227 }
228
229 func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
230 // lock required resources
231
232 device.staticIdentity.Lock()
233 defer device.staticIdentity.Unlock()
234
235 if sk.Equals(device.staticIdentity.privateKey) {
236 return nil
237 }
238
239 device.peers.Lock()
240 defer device.peers.Unlock()
241
242 lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
243 for _, peer := range device.peers.keyMap {
244 peer.handshake.mutex.RLock()
245 lockedPeers = append(lockedPeers, peer)
246 }
247
248 // remove peers with matching public keys
249
250 publicKey := sk.publicKey()
251 for key, peer := range device.peers.keyMap {
252 if peer.handshake.remoteStatic.Equals(publicKey) {
253 peer.handshake.mutex.RUnlock()
254 removePeerLocked(device, peer, key)
255 peer.handshake.mutex.RLock()
256 }
257 }
258
259 // update key material
260
261 device.staticIdentity.privateKey = sk
262 device.staticIdentity.publicKey = publicKey
263 device.cookieChecker.Init(publicKey)
264
265 // do static-static DH pre-computations
266
267 expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
268 for _, peer := range device.peers.keyMap {
269 handshake := &peer.handshake
270 handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
271 expiredPeers = append(expiredPeers, peer)
272 }
273
274 for _, peer := range lockedPeers {
275 peer.handshake.mutex.RUnlock()
276 }
277 for _, peer := range expiredPeers {
278 peer.ExpireCurrentKeypairs()
279 }
280
281 return nil
282 }
283
284 func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
285 device := new(Device)
286 device.state.state.Store(uint32(deviceStateDown))
287 device.closed = make(chan struct{})
288 device.log = logger
289 device.net.bind = bind
290 device.tun.device = tunDevice
291 mtu, err := device.tun.device.MTU()
292 if err != nil {
293 device.log.Errorf("Trouble determining MTU, assuming default: %v", err)
294 mtu = DefaultMTU
295 }
296 device.tun.mtu.Store(int32(mtu))
297 device.peers.keyMap = make(map[NoisePublicKey]*Peer)
298 device.rate.limiter.Init()
299 device.indexTable.Init()
300
301 device.PopulatePools()
302
303 // create queues
304
305 device.queue.handshake = newHandshakeQueue()
306 device.queue.encryption = newOutboundQueue()
307 device.queue.decryption = newInboundQueue()
308
309 // start workers
310
311 cpus := runtime.NumCPU()
312 device.state.stopping.Wait()
313 device.queue.encryption.wg.Add(cpus) // One for each RoutineHandshake
314 for i := 0; i < cpus; i++ {
315 go device.RoutineEncryption(i + 1)
316 go device.RoutineDecryption(i + 1)
317 go device.RoutineHandshake(i + 1)
318 }
319
320 device.state.stopping.Add(1) // RoutineReadFromTUN
321 device.queue.encryption.wg.Add(1) // RoutineReadFromTUN
322 go device.RoutineReadFromTUN()
323 go device.RoutineTUNEventReader()
324
325 return device
326 }
327
328 // BatchSize returns the BatchSize for the device as a whole which is the max of
329 // the bind batch size and the tun batch size. The batch size reported by device
330 // is the size used to construct memory pools, and is the allowed batch size for
331 // the lifetime of the device.
332 func (device *Device) BatchSize() int {
333 size := device.net.bind.BatchSize()
334 dSize := device.tun.device.BatchSize()
335 if size < dSize {
336 size = dSize
337 }
338 return size
339 }
340
341 func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
342 device.peers.RLock()
343 defer device.peers.RUnlock()
344
345 return device.peers.keyMap[pk]
346 }
347
348 func (device *Device) RemovePeer(key NoisePublicKey) {
349 device.peers.Lock()
350 defer device.peers.Unlock()
351 // stop peer and remove from routing
352
353 peer, ok := device.peers.keyMap[key]
354 if ok {
355 removePeerLocked(device, peer, key)
356 }
357 }
358
359 func (device *Device) RemoveAllPeers() {
360 device.peers.Lock()
361 defer device.peers.Unlock()
362
363 for key, peer := range device.peers.keyMap {
364 removePeerLocked(device, peer, key)
365 }
366
367 device.peers.keyMap = make(map[NoisePublicKey]*Peer)
368 }
369
370 func (device *Device) Close() {
371 device.state.Lock()
372 defer device.state.Unlock()
373 device.ipcMutex.Lock()
374 defer device.ipcMutex.Unlock()
375 if device.isClosed() {
376 return
377 }
378 device.state.state.Store(uint32(deviceStateClosed))
379 device.log.Verbosef("Device closing")
380
381 device.tun.device.Close()
382 device.downLocked()
383
384 // Remove peers before closing queues,
385 // because peers assume that queues are active.
386 device.RemoveAllPeers()
387
388 // We kept a reference to the encryption and decryption queues,
389 // in case we started any new peers that might write to them.
390 // No new peers are coming; we are done with these queues.
391 device.queue.encryption.wg.Done()
392 device.queue.decryption.wg.Done()
393 device.queue.handshake.wg.Done()
394 device.state.stopping.Wait()
395
396 device.rate.limiter.Close()
397
398 device.log.Verbosef("Device closed")
399 close(device.closed)
400 }
401
402 func (device *Device) Wait() chan struct{} {
403 return device.closed
404 }
405
406 func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
407 if !device.isUp() {
408 return
409 }
410
411 device.peers.RLock()
412 for _, peer := range device.peers.keyMap {
413 peer.keypairs.RLock()
414 sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
415 peer.keypairs.RUnlock()
416 if sendKeepalive {
417 peer.SendKeepalive()
418 }
419 }
420 device.peers.RUnlock()
421 }
422
423 // closeBindLocked closes the device's net.bind.
424 // The caller must hold the net mutex.
425 func closeBindLocked(device *Device) error {
426 var err error
427 netc := &device.net
428 if netc.netlinkCancel != nil {
429 netc.netlinkCancel.Cancel()
430 }
431 if netc.bind != nil {
432 err = netc.bind.Close()
433 }
434 netc.stopping.Wait()
435 return err
436 }
437
438 func (device *Device) Bind() conn.Bind {
439 device.net.Lock()
440 defer device.net.Unlock()
441 return device.net.bind
442 }
443
444 func (device *Device) BindSetMark(mark uint32) error {
445 device.net.Lock()
446 defer device.net.Unlock()
447
448 // check if modified
449 if device.net.fwmark == mark {
450 return nil
451 }
452
453 // update fwmark on existing bind
454 device.net.fwmark = mark
455 if device.isUp() && device.net.bind != nil {
456 if err := device.net.bind.SetMark(mark); err != nil {
457 return err
458 }
459 }
460
461 // clear cached source addresses
462 device.peers.RLock()
463 for _, peer := range device.peers.keyMap {
464 peer.markEndpointSrcForClearing()
465 }
466 device.peers.RUnlock()
467
468 return nil
469 }
470
471 func (device *Device) BindUpdate() error {
472 device.net.Lock()
473 defer device.net.Unlock()
474
475 // close existing sockets
476 if err := closeBindLocked(device); err != nil {
477 return err
478 }
479
480 // open new sockets
481 if !device.isUp() {
482 return nil
483 }
484
485 // bind to new port
486 var err error
487 var recvFns []conn.ReceiveFunc
488 netc := &device.net
489
490 recvFns, netc.port, err = netc.bind.Open(netc.port)
491 if err != nil {
492 netc.port = 0
493 return err
494 }
495
496 netc.netlinkCancel, err = device.startRouteListener(netc.bind)
497 if err != nil {
498 netc.bind.Close()
499 netc.port = 0
500 return err
501 }
502
503 // set fwmark
504 if netc.fwmark != 0 {
505 err = netc.bind.SetMark(netc.fwmark)
506 if err != nil {
507 return err
508 }
509 }
510
511 // clear cached source addresses
512 device.peers.RLock()
513 for _, peer := range device.peers.keyMap {
514 peer.markEndpointSrcForClearing()
515 }
516 device.peers.RUnlock()
517
518 // start receiving routines
519 device.net.stopping.Add(len(recvFns))
520 device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
521 device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
522 batchSize := netc.bind.BatchSize()
523 for _, fn := range recvFns {
524 go device.RoutineReceiveIncoming(batchSize, fn)
525 }
526
527 device.log.Verbosef("UDP bind has been updated")
528 return nil
529 }
530
531 func (device *Device) BindClose() error {
532 device.net.Lock()
533 err := closeBindLocked(device)
534 device.net.Unlock()
535 return err
536 }
537