peer.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package device
7
8 import (
9 "container/list"
10 "errors"
11 "sync"
12 "sync/atomic"
13 "time"
14
15 "golang.zx2c4.com/wireguard/conn"
16 )
17
18 type Peer struct {
19 isRunning atomic.Bool
20 keypairs Keypairs
21 handshake Handshake
22 device *Device
23 stopping sync.WaitGroup // routines pending stop
24 txBytes atomic.Uint64 // bytes send to peer (endpoint)
25 rxBytes atomic.Uint64 // bytes received from peer
26 lastHandshakeNano atomic.Int64 // nano seconds since epoch
27
28 endpoint struct {
29 sync.Mutex
30 val conn.Endpoint
31 clearSrcOnTx bool // signal to val.ClearSrc() prior to next packet transmission
32 disableRoaming bool
33 }
34
35 timers struct {
36 retransmitHandshake *Timer
37 sendKeepalive *Timer
38 newHandshake *Timer
39 zeroKeyMaterial *Timer
40 persistentKeepalive *Timer
41 handshakeAttempts atomic.Uint32
42 needAnotherKeepalive atomic.Bool
43 sentLastMinuteHandshake atomic.Bool
44 }
45
46 state struct {
47 sync.Mutex // protects against concurrent Start/Stop
48 }
49
50 queue struct {
51 staged chan *QueueOutboundElementsContainer // staged packets before a handshake is available
52 outbound *autodrainingOutboundQueue // sequential ordering of udp transmission
53 inbound *autodrainingInboundQueue // sequential ordering of tun writing
54 }
55
56 cookieGenerator CookieGenerator
57 trieEntries list.List
58 persistentKeepaliveInterval atomic.Uint32
59 }
60
61 func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
62 if device.isClosed() {
63 return nil, errors.New("device closed")
64 }
65
66 // lock resources
67 device.staticIdentity.RLock()
68 defer device.staticIdentity.RUnlock()
69
70 device.peers.Lock()
71 defer device.peers.Unlock()
72
73 // check if over limit
74 if len(device.peers.keyMap) >= MaxPeers {
75 return nil, errors.New("too many peers")
76 }
77
78 // create peer
79 peer := new(Peer)
80
81 peer.cookieGenerator.Init(pk)
82 peer.device = device
83 peer.queue.outbound = newAutodrainingOutboundQueue(device)
84 peer.queue.inbound = newAutodrainingInboundQueue(device)
85 peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
86
87 // map public key
88 _, ok := device.peers.keyMap[pk]
89 if ok {
90 return nil, errors.New("adding existing peer")
91 }
92
93 // pre-compute DH
94 handshake := &peer.handshake
95 handshake.mutex.Lock()
96 handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
97 handshake.remoteStatic = pk
98 handshake.mutex.Unlock()
99
100 // reset endpoint
101 peer.endpoint.Lock()
102 peer.endpoint.val = nil
103 peer.endpoint.disableRoaming = false
104 peer.endpoint.clearSrcOnTx = false
105 peer.endpoint.Unlock()
106
107 // init timers
108 peer.timersInit()
109
110 // add
111 device.peers.keyMap[pk] = peer
112
113 return peer, nil
114 }
115
116 func (peer *Peer) SendBuffers(buffers [][]byte) error {
117 peer.device.net.RLock()
118 defer peer.device.net.RUnlock()
119
120 if peer.device.isClosed() {
121 return nil
122 }
123
124 peer.endpoint.Lock()
125 endpoint := peer.endpoint.val
126 if endpoint == nil {
127 peer.endpoint.Unlock()
128 return errors.New("no known endpoint for peer")
129 }
130 if peer.endpoint.clearSrcOnTx {
131 endpoint.ClearSrc()
132 peer.endpoint.clearSrcOnTx = false
133 }
134 peer.endpoint.Unlock()
135
136 err := peer.device.net.bind.Send(buffers, endpoint)
137 if err == nil {
138 var totalLen uint64
139 for _, b := range buffers {
140 totalLen += uint64(len(b))
141 }
142 peer.txBytes.Add(totalLen)
143 }
144 return err
145 }
146
147 func (peer *Peer) String() string {
148 // The awful goo that follows is identical to:
149 //
150 // base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
151 // abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
152 // return fmt.Sprintf("peer(%s)", abbreviatedKey)
153 //
154 // except that it is considerably more efficient.
155 src := peer.handshake.remoteStatic
156 b64 := func(input byte) byte {
157 return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
158 }
159 b := []byte("peer(____…____)")
160 const first = len("peer(")
161 const second = len("peer(____…")
162 b[first+0] = b64((src[0] >> 2) & 63)
163 b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
164 b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
165 b[first+3] = b64(src[2] & 63)
166 b[second+0] = b64(src[29] & 63)
167 b[second+1] = b64((src[30] >> 2) & 63)
168 b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
169 b[second+3] = b64((src[31] << 2) & 63)
170 return string(b)
171 }
172
173 func (peer *Peer) Start() {
174 // should never start a peer on a closed device
175 if peer.device.isClosed() {
176 return
177 }
178
179 // prevent simultaneous start/stop operations
180 peer.state.Lock()
181 defer peer.state.Unlock()
182
183 if peer.isRunning.Load() {
184 return
185 }
186
187 device := peer.device
188 device.log.Verbosef("%v - Starting", peer)
189
190 // reset routine state
191 peer.stopping.Wait()
192 peer.stopping.Add(2)
193
194 peer.handshake.mutex.Lock()
195 peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
196 peer.handshake.mutex.Unlock()
197
198 peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
199
200 peer.timersStart()
201
202 device.flushInboundQueue(peer.queue.inbound)
203 device.flushOutboundQueue(peer.queue.outbound)
204
205 // Use the device batch size, not the bind batch size, as the device size is
206 // the size of the batch pools.
207 batchSize := peer.device.BatchSize()
208 go peer.RoutineSequentialSender(batchSize)
209 go peer.RoutineSequentialReceiver(batchSize)
210
211 peer.isRunning.Store(true)
212 }
213
214 func (peer *Peer) ZeroAndFlushAll() {
215 device := peer.device
216
217 // clear key pairs
218
219 keypairs := &peer.keypairs
220 keypairs.Lock()
221 device.DeleteKeypair(keypairs.previous)
222 device.DeleteKeypair(keypairs.current)
223 device.DeleteKeypair(keypairs.next.Load())
224 keypairs.previous = nil
225 keypairs.current = nil
226 keypairs.next.Store(nil)
227 keypairs.Unlock()
228
229 // clear handshake state
230
231 handshake := &peer.handshake
232 handshake.mutex.Lock()
233 device.indexTable.Delete(handshake.localIndex)
234 handshake.Clear()
235 handshake.mutex.Unlock()
236
237 peer.FlushStagedPackets()
238 }
239
240 func (peer *Peer) ExpireCurrentKeypairs() {
241 handshake := &peer.handshake
242 handshake.mutex.Lock()
243 peer.device.indexTable.Delete(handshake.localIndex)
244 handshake.Clear()
245 peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
246 handshake.mutex.Unlock()
247
248 keypairs := &peer.keypairs
249 keypairs.Lock()
250 if keypairs.current != nil {
251 keypairs.current.sendNonce.Store(RejectAfterMessages)
252 }
253 if next := keypairs.next.Load(); next != nil {
254 next.sendNonce.Store(RejectAfterMessages)
255 }
256 keypairs.Unlock()
257 }
258
259 func (peer *Peer) Stop() {
260 peer.state.Lock()
261 defer peer.state.Unlock()
262
263 if !peer.isRunning.Swap(false) {
264 return
265 }
266
267 peer.device.log.Verbosef("%v - Stopping", peer)
268
269 peer.timersStop()
270 // Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
271 peer.queue.inbound.c <- nil
272 peer.queue.outbound.c <- nil
273 peer.stopping.Wait()
274 peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
275
276 peer.ZeroAndFlushAll()
277 }
278
279 func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
280 peer.endpoint.Lock()
281 defer peer.endpoint.Unlock()
282 if peer.endpoint.disableRoaming {
283 return
284 }
285 peer.endpoint.clearSrcOnTx = false
286 peer.endpoint.val = endpoint
287 }
288
289 func (peer *Peer) markEndpointSrcForClearing() {
290 peer.endpoint.Lock()
291 defer peer.endpoint.Unlock()
292 if peer.endpoint.val == nil {
293 return
294 }
295 peer.endpoint.clearSrcOnTx = true
296 }
297