noise-protocol.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package device
7
8 import (
9 "encoding/binary"
10 "errors"
11 "fmt"
12 "sync"
13 "time"
14
15 "golang.org/x/crypto/blake2s"
16 "golang.org/x/crypto/chacha20poly1305"
17 "golang.org/x/crypto/poly1305"
18
19 "golang.zx2c4.com/wireguard/tai64n"
20 )
21
22 type handshakeState int
23
24 const (
25 handshakeZeroed = handshakeState(iota)
26 handshakeInitiationCreated
27 handshakeInitiationConsumed
28 handshakeResponseCreated
29 handshakeResponseConsumed
30 )
31
32 func (hs handshakeState) String() string {
33 switch hs {
34 case handshakeZeroed:
35 return "handshakeZeroed"
36 case handshakeInitiationCreated:
37 return "handshakeInitiationCreated"
38 case handshakeInitiationConsumed:
39 return "handshakeInitiationConsumed"
40 case handshakeResponseCreated:
41 return "handshakeResponseCreated"
42 case handshakeResponseConsumed:
43 return "handshakeResponseConsumed"
44 default:
45 return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
46 }
47 }
48
49 const (
50 NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
51 WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
52 WGLabelMAC1 = "mac1----"
53 WGLabelCookie = "cookie--"
54 )
55
56 const (
57 MessageInitiationType = 1
58 MessageResponseType = 2
59 MessageCookieReplyType = 3
60 MessageTransportType = 4
61 )
62
63 const (
64 MessageInitiationSize = 148 // size of handshake initiation message
65 MessageResponseSize = 92 // size of response message
66 MessageCookieReplySize = 64 // size of cookie reply message
67 MessageTransportHeaderSize = 16 // size of data preceding content in transport message
68 MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
69 MessageKeepaliveSize = MessageTransportSize // size of keepalive
70 MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
71 )
72
73 const (
74 MessageTransportOffsetReceiver = 4
75 MessageTransportOffsetCounter = 8
76 MessageTransportOffsetContent = 16
77 )
78
79 /* Type is an 8-bit field, followed by 3 nul bytes,
80 * by marshalling the messages in little-endian byteorder
81 * we can treat these as a 32-bit unsigned int (for now)
82 *
83 */
84
85 type MessageInitiation struct {
86 Type uint32
87 Sender uint32
88 Ephemeral NoisePublicKey
89 Static [NoisePublicKeySize + poly1305.TagSize]byte
90 Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
91 MAC1 [blake2s.Size128]byte
92 MAC2 [blake2s.Size128]byte
93 }
94
95 type MessageResponse struct {
96 Type uint32
97 Sender uint32
98 Receiver uint32
99 Ephemeral NoisePublicKey
100 Empty [poly1305.TagSize]byte
101 MAC1 [blake2s.Size128]byte
102 MAC2 [blake2s.Size128]byte
103 }
104
105 type MessageTransport struct {
106 Type uint32
107 Receiver uint32
108 Counter uint64
109 Content []byte
110 }
111
112 type MessageCookieReply struct {
113 Type uint32
114 Receiver uint32
115 Nonce [chacha20poly1305.NonceSizeX]byte
116 Cookie [blake2s.Size128 + poly1305.TagSize]byte
117 }
118
119 var errMessageLengthMismatch = errors.New("message length mismatch")
120
121 func (msg *MessageInitiation) unmarshal(b []byte) error {
122 if len(b) != MessageInitiationSize {
123 return errMessageLengthMismatch
124 }
125
126 msg.Type = binary.LittleEndian.Uint32(b)
127 msg.Sender = binary.LittleEndian.Uint32(b[4:])
128 copy(msg.Ephemeral[:], b[8:])
129 copy(msg.Static[:], b[8+len(msg.Ephemeral):])
130 copy(msg.Timestamp[:], b[8+len(msg.Ephemeral)+len(msg.Static):])
131 copy(msg.MAC1[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):])
132 copy(msg.MAC2[:], b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):])
133
134 return nil
135 }
136
137 func (msg *MessageInitiation) marshal(b []byte) error {
138 if len(b) != MessageInitiationSize {
139 return errMessageLengthMismatch
140 }
141
142 binary.LittleEndian.PutUint32(b, msg.Type)
143 binary.LittleEndian.PutUint32(b[4:], msg.Sender)
144 copy(b[8:], msg.Ephemeral[:])
145 copy(b[8+len(msg.Ephemeral):], msg.Static[:])
146 copy(b[8+len(msg.Ephemeral)+len(msg.Static):], msg.Timestamp[:])
147 copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp):], msg.MAC1[:])
148 copy(b[8+len(msg.Ephemeral)+len(msg.Static)+len(msg.Timestamp)+len(msg.MAC1):], msg.MAC2[:])
149
150 return nil
151 }
152
153 func (msg *MessageResponse) unmarshal(b []byte) error {
154 if len(b) != MessageResponseSize {
155 return errMessageLengthMismatch
156 }
157
158 msg.Type = binary.LittleEndian.Uint32(b)
159 msg.Sender = binary.LittleEndian.Uint32(b[4:])
160 msg.Receiver = binary.LittleEndian.Uint32(b[8:])
161 copy(msg.Ephemeral[:], b[12:])
162 copy(msg.Empty[:], b[12+len(msg.Ephemeral):])
163 copy(msg.MAC1[:], b[12+len(msg.Ephemeral)+len(msg.Empty):])
164 copy(msg.MAC2[:], b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):])
165
166 return nil
167 }
168
169 func (msg *MessageResponse) marshal(b []byte) error {
170 if len(b) != MessageResponseSize {
171 return errMessageLengthMismatch
172 }
173
174 binary.LittleEndian.PutUint32(b, msg.Type)
175 binary.LittleEndian.PutUint32(b[4:], msg.Sender)
176 binary.LittleEndian.PutUint32(b[8:], msg.Receiver)
177 copy(b[12:], msg.Ephemeral[:])
178 copy(b[12+len(msg.Ephemeral):], msg.Empty[:])
179 copy(b[12+len(msg.Ephemeral)+len(msg.Empty):], msg.MAC1[:])
180 copy(b[12+len(msg.Ephemeral)+len(msg.Empty)+len(msg.MAC1):], msg.MAC2[:])
181
182 return nil
183 }
184
185 func (msg *MessageCookieReply) unmarshal(b []byte) error {
186 if len(b) != MessageCookieReplySize {
187 return errMessageLengthMismatch
188 }
189
190 msg.Type = binary.LittleEndian.Uint32(b)
191 msg.Receiver = binary.LittleEndian.Uint32(b[4:])
192 copy(msg.Nonce[:], b[8:])
193 copy(msg.Cookie[:], b[8+len(msg.Nonce):])
194
195 return nil
196 }
197
198 func (msg *MessageCookieReply) marshal(b []byte) error {
199 if len(b) != MessageCookieReplySize {
200 return errMessageLengthMismatch
201 }
202
203 binary.LittleEndian.PutUint32(b, msg.Type)
204 binary.LittleEndian.PutUint32(b[4:], msg.Receiver)
205 copy(b[8:], msg.Nonce[:])
206 copy(b[8+len(msg.Nonce):], msg.Cookie[:])
207
208 return nil
209 }
210
211 type Handshake struct {
212 state handshakeState
213 mutex sync.RWMutex
214 hash [blake2s.Size]byte // hash value
215 chainKey [blake2s.Size]byte // chain key
216 presharedKey NoisePresharedKey // psk
217 localEphemeral NoisePrivateKey // ephemeral secret key
218 localIndex uint32 // used to clear hash-table
219 remoteIndex uint32 // index for sending
220 remoteStatic NoisePublicKey // long term key
221 remoteEphemeral NoisePublicKey // ephemeral public key
222 precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
223 lastTimestamp tai64n.Timestamp
224 lastInitiationConsumption time.Time
225 lastSentHandshake time.Time
226 }
227
228 var (
229 InitialChainKey [blake2s.Size]byte
230 InitialHash [blake2s.Size]byte
231 ZeroNonce [chacha20poly1305.NonceSize]byte
232 )
233
234 func mixKey(dst, c *[blake2s.Size]byte, data []byte) {
235 KDF1(dst, c[:], data)
236 }
237
238 func mixHash(dst, h *[blake2s.Size]byte, data []byte) {
239 hash, _ := blake2s.New256(nil)
240 hash.Write(h[:])
241 hash.Write(data)
242 hash.Sum(dst[:0])
243 hash.Reset()
244 }
245
246 func (h *Handshake) Clear() {
247 setZero(h.localEphemeral[:])
248 setZero(h.remoteEphemeral[:])
249 setZero(h.chainKey[:])
250 setZero(h.hash[:])
251 h.localIndex = 0
252 h.state = handshakeZeroed
253 }
254
255 func (h *Handshake) mixHash(data []byte) {
256 mixHash(&h.hash, &h.hash, data)
257 }
258
259 func (h *Handshake) mixKey(data []byte) {
260 mixKey(&h.chainKey, &h.chainKey, data)
261 }
262
263 /* Do basic precomputations
264 */
265 func init() {
266 InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
267 mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
268 }
269
270 func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
271 device.staticIdentity.RLock()
272 defer device.staticIdentity.RUnlock()
273
274 handshake := &peer.handshake
275 handshake.mutex.Lock()
276 defer handshake.mutex.Unlock()
277
278 // create ephemeral key
279 var err error
280 handshake.hash = InitialHash
281 handshake.chainKey = InitialChainKey
282 handshake.localEphemeral, err = newPrivateKey()
283 if err != nil {
284 return nil, err
285 }
286
287 handshake.mixHash(handshake.remoteStatic[:])
288
289 msg := MessageInitiation{
290 Type: MessageInitiationType,
291 Ephemeral: handshake.localEphemeral.publicKey(),
292 }
293
294 handshake.mixKey(msg.Ephemeral[:])
295 handshake.mixHash(msg.Ephemeral[:])
296
297 // encrypt static key
298 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
299 if err != nil {
300 return nil, err
301 }
302 var key [chacha20poly1305.KeySize]byte
303 KDF2(
304 &handshake.chainKey,
305 &key,
306 handshake.chainKey[:],
307 ss[:],
308 )
309 aead, _ := chacha20poly1305.New(key[:])
310 aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
311 handshake.mixHash(msg.Static[:])
312
313 // encrypt timestamp
314 if isZero(handshake.precomputedStaticStatic[:]) {
315 return nil, errInvalidPublicKey
316 }
317 KDF2(
318 &handshake.chainKey,
319 &key,
320 handshake.chainKey[:],
321 handshake.precomputedStaticStatic[:],
322 )
323 timestamp := tai64n.Now()
324 aead, _ = chacha20poly1305.New(key[:])
325 aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
326
327 // assign index
328 device.indexTable.Delete(handshake.localIndex)
329 msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
330 if err != nil {
331 return nil, err
332 }
333 handshake.localIndex = msg.Sender
334
335 handshake.mixHash(msg.Timestamp[:])
336 handshake.state = handshakeInitiationCreated
337 return &msg, nil
338 }
339
340 func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
341 var (
342 hash [blake2s.Size]byte
343 chainKey [blake2s.Size]byte
344 )
345
346 if msg.Type != MessageInitiationType {
347 return nil
348 }
349
350 device.staticIdentity.RLock()
351 defer device.staticIdentity.RUnlock()
352
353 mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
354 mixHash(&hash, &hash, msg.Ephemeral[:])
355 mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
356
357 // decrypt static key
358 var peerPK NoisePublicKey
359 var key [chacha20poly1305.KeySize]byte
360 ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
361 if err != nil {
362 return nil
363 }
364 KDF2(&chainKey, &key, chainKey[:], ss[:])
365 aead, _ := chacha20poly1305.New(key[:])
366 _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
367 if err != nil {
368 return nil
369 }
370 mixHash(&hash, &hash, msg.Static[:])
371
372 // lookup peer
373
374 peer := device.LookupPeer(peerPK)
375 if peer == nil || !peer.isRunning.Load() {
376 return nil
377 }
378
379 handshake := &peer.handshake
380
381 // verify identity
382
383 var timestamp tai64n.Timestamp
384
385 handshake.mutex.RLock()
386
387 if isZero(handshake.precomputedStaticStatic[:]) {
388 handshake.mutex.RUnlock()
389 return nil
390 }
391 KDF2(
392 &chainKey,
393 &key,
394 chainKey[:],
395 handshake.precomputedStaticStatic[:],
396 )
397 aead, _ = chacha20poly1305.New(key[:])
398 _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
399 if err != nil {
400 handshake.mutex.RUnlock()
401 return nil
402 }
403 mixHash(&hash, &hash, msg.Timestamp[:])
404
405 // protect against replay & flood
406
407 replay := !timestamp.After(handshake.lastTimestamp)
408 flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
409 handshake.mutex.RUnlock()
410 if replay {
411 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake replay @ %v", peer, timestamp)
412 return nil
413 }
414 if flood {
415 device.log.Verbosef("%v - ConsumeMessageInitiation: handshake flood", peer)
416 return nil
417 }
418
419 // update handshake state
420
421 handshake.mutex.Lock()
422
423 handshake.hash = hash
424 handshake.chainKey = chainKey
425 handshake.remoteIndex = msg.Sender
426 handshake.remoteEphemeral = msg.Ephemeral
427 if timestamp.After(handshake.lastTimestamp) {
428 handshake.lastTimestamp = timestamp
429 }
430 now := time.Now()
431 if now.After(handshake.lastInitiationConsumption) {
432 handshake.lastInitiationConsumption = now
433 }
434 handshake.state = handshakeInitiationConsumed
435
436 handshake.mutex.Unlock()
437
438 setZero(hash[:])
439 setZero(chainKey[:])
440
441 return peer
442 }
443
444 func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
445 handshake := &peer.handshake
446 handshake.mutex.Lock()
447 defer handshake.mutex.Unlock()
448
449 if handshake.state != handshakeInitiationConsumed {
450 return nil, errors.New("handshake initiation must be consumed first")
451 }
452
453 // assign index
454
455 var err error
456 device.indexTable.Delete(handshake.localIndex)
457 handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
458 if err != nil {
459 return nil, err
460 }
461
462 var msg MessageResponse
463 msg.Type = MessageResponseType
464 msg.Sender = handshake.localIndex
465 msg.Receiver = handshake.remoteIndex
466
467 // create ephemeral key
468
469 handshake.localEphemeral, err = newPrivateKey()
470 if err != nil {
471 return nil, err
472 }
473 msg.Ephemeral = handshake.localEphemeral.publicKey()
474 handshake.mixHash(msg.Ephemeral[:])
475 handshake.mixKey(msg.Ephemeral[:])
476
477 ss, err := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
478 if err != nil {
479 return nil, err
480 }
481 handshake.mixKey(ss[:])
482 ss, err = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
483 if err != nil {
484 return nil, err
485 }
486 handshake.mixKey(ss[:])
487
488 // add preshared key
489
490 var tau [blake2s.Size]byte
491 var key [chacha20poly1305.KeySize]byte
492
493 KDF3(
494 &handshake.chainKey,
495 &tau,
496 &key,
497 handshake.chainKey[:],
498 handshake.presharedKey[:],
499 )
500
501 handshake.mixHash(tau[:])
502
503 aead, _ := chacha20poly1305.New(key[:])
504 aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
505 handshake.mixHash(msg.Empty[:])
506
507 handshake.state = handshakeResponseCreated
508
509 return &msg, nil
510 }
511
512 func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
513 if msg.Type != MessageResponseType {
514 return nil
515 }
516
517 // lookup handshake by receiver
518
519 lookup := device.indexTable.Lookup(msg.Receiver)
520 handshake := lookup.handshake
521 if handshake == nil {
522 return nil
523 }
524
525 var (
526 hash [blake2s.Size]byte
527 chainKey [blake2s.Size]byte
528 )
529
530 ok := func() bool {
531 // lock handshake state
532
533 handshake.mutex.RLock()
534 defer handshake.mutex.RUnlock()
535
536 if handshake.state != handshakeInitiationCreated {
537 return false
538 }
539
540 // lock private key for reading
541
542 device.staticIdentity.RLock()
543 defer device.staticIdentity.RUnlock()
544
545 // finish 3-way DH
546
547 mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
548 mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
549
550 ss, err := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
551 if err != nil {
552 return false
553 }
554 mixKey(&chainKey, &chainKey, ss[:])
555 setZero(ss[:])
556
557 ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
558 if err != nil {
559 return false
560 }
561 mixKey(&chainKey, &chainKey, ss[:])
562 setZero(ss[:])
563
564 // add preshared key (psk)
565
566 var tau [blake2s.Size]byte
567 var key [chacha20poly1305.KeySize]byte
568 KDF3(
569 &chainKey,
570 &tau,
571 &key,
572 chainKey[:],
573 handshake.presharedKey[:],
574 )
575 mixHash(&hash, &hash, tau[:])
576
577 // authenticate transcript
578
579 aead, _ := chacha20poly1305.New(key[:])
580 _, err = aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
581 if err != nil {
582 return false
583 }
584 mixHash(&hash, &hash, msg.Empty[:])
585 return true
586 }()
587
588 if !ok {
589 return nil
590 }
591
592 // update handshake state
593
594 handshake.mutex.Lock()
595
596 handshake.hash = hash
597 handshake.chainKey = chainKey
598 handshake.remoteIndex = msg.Sender
599 handshake.state = handshakeResponseConsumed
600
601 handshake.mutex.Unlock()
602
603 setZero(hash[:])
604 setZero(chainKey[:])
605
606 return lookup.peer
607 }
608
609 /* Derives a new keypair from the current handshake state
610 *
611 */
612 func (peer *Peer) BeginSymmetricSession() error {
613 device := peer.device
614 handshake := &peer.handshake
615 handshake.mutex.Lock()
616 defer handshake.mutex.Unlock()
617
618 // derive keys
619
620 var isInitiator bool
621 var sendKey [chacha20poly1305.KeySize]byte
622 var recvKey [chacha20poly1305.KeySize]byte
623
624 if handshake.state == handshakeResponseConsumed {
625 KDF2(
626 &sendKey,
627 &recvKey,
628 handshake.chainKey[:],
629 nil,
630 )
631 isInitiator = true
632 } else if handshake.state == handshakeResponseCreated {
633 KDF2(
634 &recvKey,
635 &sendKey,
636 handshake.chainKey[:],
637 nil,
638 )
639 isInitiator = false
640 } else {
641 return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
642 }
643
644 // zero handshake
645
646 setZero(handshake.chainKey[:])
647 setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
648 setZero(handshake.localEphemeral[:])
649 peer.handshake.state = handshakeZeroed
650
651 // create AEAD instances
652
653 keypair := new(Keypair)
654 keypair.send, _ = chacha20poly1305.New(sendKey[:])
655 keypair.receive, _ = chacha20poly1305.New(recvKey[:])
656
657 setZero(sendKey[:])
658 setZero(recvKey[:])
659
660 keypair.created = time.Now()
661 keypair.replayFilter.Reset()
662 keypair.isInitiator = isInitiator
663 keypair.localIndex = peer.handshake.localIndex
664 keypair.remoteIndex = peer.handshake.remoteIndex
665
666 // remap index
667
668 device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
669 handshake.localIndex = 0
670
671 // rotate key pairs
672
673 keypairs := &peer.keypairs
674 keypairs.Lock()
675 defer keypairs.Unlock()
676
677 previous := keypairs.previous
678 next := keypairs.next.Load()
679 current := keypairs.current
680
681 if isInitiator {
682 if next != nil {
683 keypairs.next.Store(nil)
684 keypairs.previous = next
685 device.DeleteKeypair(current)
686 } else {
687 keypairs.previous = current
688 }
689 device.DeleteKeypair(previous)
690 keypairs.current = keypair
691 } else {
692 keypairs.next.Store(keypair)
693 device.DeleteKeypair(next)
694 keypairs.previous = nil
695 device.DeleteKeypair(previous)
696 }
697
698 return nil
699 }
700
701 func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
702 keypairs := &peer.keypairs
703
704 if keypairs.next.Load() != receivedKeypair {
705 return false
706 }
707 keypairs.Lock()
708 defer keypairs.Unlock()
709 if keypairs.next.Load() != receivedKeypair {
710 return false
711 }
712 old := keypairs.previous
713 keypairs.previous = keypairs.current
714 peer.device.DeleteKeypair(old)
715 keypairs.current = keypairs.next.Load()
716 keypairs.next.Store(nil)
717 return true
718 }
719