peer.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package device
   7  
   8  import (
   9  	"container/list"
  10  	"errors"
  11  	"sync"
  12  	"sync/atomic"
  13  	"time"
  14  
  15  	"golang.zx2c4.com/wireguard/conn"
  16  )
  17  
  18  type Peer struct {
  19  	isRunning         atomic.Bool
  20  	keypairs          Keypairs
  21  	handshake         Handshake
  22  	device            *Device
  23  	stopping          sync.WaitGroup // routines pending stop
  24  	txBytes           atomic.Uint64  // bytes send to peer (endpoint)
  25  	rxBytes           atomic.Uint64  // bytes received from peer
  26  	lastHandshakeNano atomic.Int64   // nano seconds since epoch
  27  
  28  	endpoint struct {
  29  		sync.Mutex
  30  		val            conn.Endpoint
  31  		clearSrcOnTx   bool // signal to val.ClearSrc() prior to next packet transmission
  32  		disableRoaming bool
  33  	}
  34  
  35  	timers struct {
  36  		retransmitHandshake     *Timer
  37  		sendKeepalive           *Timer
  38  		newHandshake            *Timer
  39  		zeroKeyMaterial         *Timer
  40  		persistentKeepalive     *Timer
  41  		handshakeAttempts       atomic.Uint32
  42  		needAnotherKeepalive    atomic.Bool
  43  		sentLastMinuteHandshake atomic.Bool
  44  	}
  45  
  46  	state struct {
  47  		sync.Mutex // protects against concurrent Start/Stop
  48  	}
  49  
  50  	queue struct {
  51  		staged   chan *QueueOutboundElementsContainer // staged packets before a handshake is available
  52  		outbound *autodrainingOutboundQueue           // sequential ordering of udp transmission
  53  		inbound  *autodrainingInboundQueue            // sequential ordering of tun writing
  54  	}
  55  
  56  	cookieGenerator             CookieGenerator
  57  	trieEntries                 list.List
  58  	persistentKeepaliveInterval atomic.Uint32
  59  }
  60  
  61  func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
  62  	if device.isClosed() {
  63  		return nil, errors.New("device closed")
  64  	}
  65  
  66  	// lock resources
  67  	device.staticIdentity.RLock()
  68  	defer device.staticIdentity.RUnlock()
  69  
  70  	device.peers.Lock()
  71  	defer device.peers.Unlock()
  72  
  73  	// check if over limit
  74  	if len(device.peers.keyMap) >= MaxPeers {
  75  		return nil, errors.New("too many peers")
  76  	}
  77  
  78  	// create peer
  79  	peer := new(Peer)
  80  
  81  	peer.cookieGenerator.Init(pk)
  82  	peer.device = device
  83  	peer.queue.outbound = newAutodrainingOutboundQueue(device)
  84  	peer.queue.inbound = newAutodrainingInboundQueue(device)
  85  	peer.queue.staged = make(chan *QueueOutboundElementsContainer, QueueStagedSize)
  86  
  87  	// map public key
  88  	_, ok := device.peers.keyMap[pk]
  89  	if ok {
  90  		return nil, errors.New("adding existing peer")
  91  	}
  92  
  93  	// pre-compute DH
  94  	handshake := &peer.handshake
  95  	handshake.mutex.Lock()
  96  	handshake.precomputedStaticStatic, _ = device.staticIdentity.privateKey.sharedSecret(pk)
  97  	handshake.remoteStatic = pk
  98  	handshake.mutex.Unlock()
  99  
 100  	// reset endpoint
 101  	peer.endpoint.Lock()
 102  	peer.endpoint.val = nil
 103  	peer.endpoint.disableRoaming = false
 104  	peer.endpoint.clearSrcOnTx = false
 105  	peer.endpoint.Unlock()
 106  
 107  	// init timers
 108  	peer.timersInit()
 109  
 110  	// add
 111  	device.peers.keyMap[pk] = peer
 112  
 113  	return peer, nil
 114  }
 115  
 116  func (peer *Peer) SendBuffers(buffers [][]byte) error {
 117  	peer.device.net.RLock()
 118  	defer peer.device.net.RUnlock()
 119  
 120  	if peer.device.isClosed() {
 121  		return nil
 122  	}
 123  
 124  	peer.endpoint.Lock()
 125  	endpoint := peer.endpoint.val
 126  	if endpoint == nil {
 127  		peer.endpoint.Unlock()
 128  		return errors.New("no known endpoint for peer")
 129  	}
 130  	if peer.endpoint.clearSrcOnTx {
 131  		endpoint.ClearSrc()
 132  		peer.endpoint.clearSrcOnTx = false
 133  	}
 134  	peer.endpoint.Unlock()
 135  
 136  	err := peer.device.net.bind.Send(buffers, endpoint)
 137  	if err == nil {
 138  		var totalLen uint64
 139  		for _, b := range buffers {
 140  			totalLen += uint64(len(b))
 141  		}
 142  		peer.txBytes.Add(totalLen)
 143  	}
 144  	return err
 145  }
 146  
 147  func (peer *Peer) String() string {
 148  	// The awful goo that follows is identical to:
 149  	//
 150  	//   base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
 151  	//   abbreviatedKey := base64Key[0:4] + "…" + base64Key[39:43]
 152  	//   return fmt.Sprintf("peer(%s)", abbreviatedKey)
 153  	//
 154  	// except that it is considerably more efficient.
 155  	src := peer.handshake.remoteStatic
 156  	b64 := func(input byte) byte {
 157  		return input + 'A' + byte(((25-int(input))>>8)&6) - byte(((51-int(input))>>8)&75) - byte(((61-int(input))>>8)&15) + byte(((62-int(input))>>8)&3)
 158  	}
 159  	b := []byte("peer(____…____)")
 160  	const first = len("peer(")
 161  	const second = len("peer(____…")
 162  	b[first+0] = b64((src[0] >> 2) & 63)
 163  	b[first+1] = b64(((src[0] << 4) | (src[1] >> 4)) & 63)
 164  	b[first+2] = b64(((src[1] << 2) | (src[2] >> 6)) & 63)
 165  	b[first+3] = b64(src[2] & 63)
 166  	b[second+0] = b64(src[29] & 63)
 167  	b[second+1] = b64((src[30] >> 2) & 63)
 168  	b[second+2] = b64(((src[30] << 4) | (src[31] >> 4)) & 63)
 169  	b[second+3] = b64((src[31] << 2) & 63)
 170  	return string(b)
 171  }
 172  
 173  func (peer *Peer) Start() {
 174  	// should never start a peer on a closed device
 175  	if peer.device.isClosed() {
 176  		return
 177  	}
 178  
 179  	// prevent simultaneous start/stop operations
 180  	peer.state.Lock()
 181  	defer peer.state.Unlock()
 182  
 183  	if peer.isRunning.Load() {
 184  		return
 185  	}
 186  
 187  	device := peer.device
 188  	device.log.Verbosef("%v - Starting", peer)
 189  
 190  	// reset routine state
 191  	peer.stopping.Wait()
 192  	peer.stopping.Add(2)
 193  
 194  	peer.handshake.mutex.Lock()
 195  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
 196  	peer.handshake.mutex.Unlock()
 197  
 198  	peer.device.queue.encryption.wg.Add(1) // keep encryption queue open for our writes
 199  
 200  	peer.timersStart()
 201  
 202  	device.flushInboundQueue(peer.queue.inbound)
 203  	device.flushOutboundQueue(peer.queue.outbound)
 204  
 205  	// Use the device batch size, not the bind batch size, as the device size is
 206  	// the size of the batch pools.
 207  	batchSize := peer.device.BatchSize()
 208  	go peer.RoutineSequentialSender(batchSize)
 209  	go peer.RoutineSequentialReceiver(batchSize)
 210  
 211  	peer.isRunning.Store(true)
 212  }
 213  
 214  func (peer *Peer) ZeroAndFlushAll() {
 215  	device := peer.device
 216  
 217  	// clear key pairs
 218  
 219  	keypairs := &peer.keypairs
 220  	keypairs.Lock()
 221  	device.DeleteKeypair(keypairs.previous)
 222  	device.DeleteKeypair(keypairs.current)
 223  	device.DeleteKeypair(keypairs.next.Load())
 224  	keypairs.previous = nil
 225  	keypairs.current = nil
 226  	keypairs.next.Store(nil)
 227  	keypairs.Unlock()
 228  
 229  	// clear handshake state
 230  
 231  	handshake := &peer.handshake
 232  	handshake.mutex.Lock()
 233  	device.indexTable.Delete(handshake.localIndex)
 234  	handshake.Clear()
 235  	handshake.mutex.Unlock()
 236  
 237  	peer.FlushStagedPackets()
 238  }
 239  
 240  func (peer *Peer) ExpireCurrentKeypairs() {
 241  	handshake := &peer.handshake
 242  	handshake.mutex.Lock()
 243  	peer.device.indexTable.Delete(handshake.localIndex)
 244  	handshake.Clear()
 245  	peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
 246  	handshake.mutex.Unlock()
 247  
 248  	keypairs := &peer.keypairs
 249  	keypairs.Lock()
 250  	if keypairs.current != nil {
 251  		keypairs.current.sendNonce.Store(RejectAfterMessages)
 252  	}
 253  	if next := keypairs.next.Load(); next != nil {
 254  		next.sendNonce.Store(RejectAfterMessages)
 255  	}
 256  	keypairs.Unlock()
 257  }
 258  
 259  func (peer *Peer) Stop() {
 260  	peer.state.Lock()
 261  	defer peer.state.Unlock()
 262  
 263  	if !peer.isRunning.Swap(false) {
 264  		return
 265  	}
 266  
 267  	peer.device.log.Verbosef("%v - Stopping", peer)
 268  
 269  	peer.timersStop()
 270  	// Signal that RoutineSequentialSender and RoutineSequentialReceiver should exit.
 271  	peer.queue.inbound.c <- nil
 272  	peer.queue.outbound.c <- nil
 273  	peer.stopping.Wait()
 274  	peer.device.queue.encryption.wg.Done() // no more writes to encryption queue from us
 275  
 276  	peer.ZeroAndFlushAll()
 277  }
 278  
 279  func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
 280  	peer.endpoint.Lock()
 281  	defer peer.endpoint.Unlock()
 282  	if peer.endpoint.disableRoaming {
 283  		return
 284  	}
 285  	peer.endpoint.clearSrcOnTx = false
 286  	peer.endpoint.val = endpoint
 287  }
 288  
 289  func (peer *Peer) markEndpointSrcForClearing() {
 290  	peer.endpoint.Lock()
 291  	defer peer.endpoint.Unlock()
 292  	if peer.endpoint.val == nil {
 293  		return
 294  	}
 295  	peer.endpoint.clearSrcOnTx = true
 296  }
 297