tun_windows.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package tun
   7  
   8  import (
   9  	"errors"
  10  	"fmt"
  11  	"os"
  12  	"sync"
  13  	"sync/atomic"
  14  	"time"
  15  	_ "unsafe"
  16  
  17  	"golang.org/x/sys/windows"
  18  	"golang.zx2c4.com/wintun"
  19  )
  20  
  21  const (
  22  	rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
  23  	spinloopRateThreshold      = 800000000 / 8                                   // 800mbps
  24  	spinloopDuration           = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
  25  )
  26  
  27  type rateJuggler struct {
  28  	current       atomic.Uint64
  29  	nextByteCount atomic.Uint64
  30  	nextStartTime atomic.Int64
  31  	changing      atomic.Bool
  32  }
  33  
  34  type NativeTun struct {
  35  	wt        *wintun.Adapter
  36  	name      string
  37  	handle    windows.Handle
  38  	rate      rateJuggler
  39  	session   wintun.Session
  40  	readWait  windows.Handle
  41  	events    chan Event
  42  	running   sync.WaitGroup
  43  	closeOnce sync.Once
  44  	close     atomic.Bool
  45  	forcedMTU int
  46  	outSizes  []int
  47  }
  48  
  49  var (
  50  	WintunTunnelType          = "WireGuard"
  51  	WintunStaticRequestedGUID *windows.GUID
  52  )
  53  
  54  //go:linkname procyield runtime.procyield
  55  func procyield(cycles uint32)
  56  
  57  //go:linkname nanotime runtime.nanotime
  58  func nanotime() int64
  59  
  60  // CreateTUN creates a Wintun interface with the given name. Should a Wintun
  61  // interface with the same name exist, it is reused.
  62  func CreateTUN(ifname string, mtu int) (Device, error) {
  63  	return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
  64  }
  65  
  66  // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
  67  // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
  68  func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
  69  	wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
  70  	if err != nil {
  71  		return nil, fmt.Errorf("Error creating interface: %w", err)
  72  	}
  73  
  74  	forcedMTU := 1420
  75  	if mtu > 0 {
  76  		forcedMTU = mtu
  77  	}
  78  
  79  	tun := &NativeTun{
  80  		wt:        wt,
  81  		name:      ifname,
  82  		handle:    windows.InvalidHandle,
  83  		events:    make(chan Event, 10),
  84  		forcedMTU: forcedMTU,
  85  	}
  86  
  87  	tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
  88  	if err != nil {
  89  		tun.wt.Close()
  90  		close(tun.events)
  91  		return nil, fmt.Errorf("Error starting session: %w", err)
  92  	}
  93  	tun.readWait = tun.session.ReadWaitEvent()
  94  	return tun, nil
  95  }
  96  
  97  func (tun *NativeTun) Name() (string, error) {
  98  	return tun.name, nil
  99  }
 100  
 101  func (tun *NativeTun) File() *os.File {
 102  	return nil
 103  }
 104  
 105  func (tun *NativeTun) Events() <-chan Event {
 106  	return tun.events
 107  }
 108  
 109  func (tun *NativeTun) Close() error {
 110  	var err error
 111  	tun.closeOnce.Do(func() {
 112  		tun.close.Store(true)
 113  		windows.SetEvent(tun.readWait)
 114  		tun.running.Wait()
 115  		tun.session.End()
 116  		if tun.wt != nil {
 117  			tun.wt.Close()
 118  		}
 119  		close(tun.events)
 120  	})
 121  	return err
 122  }
 123  
 124  func (tun *NativeTun) MTU() (int, error) {
 125  	return tun.forcedMTU, nil
 126  }
 127  
 128  // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
 129  func (tun *NativeTun) ForceMTU(mtu int) {
 130  	if tun.close.Load() {
 131  		return
 132  	}
 133  	update := tun.forcedMTU != mtu
 134  	tun.forcedMTU = mtu
 135  	if update {
 136  		tun.events <- EventMTUUpdate
 137  	}
 138  }
 139  
 140  func (tun *NativeTun) BatchSize() int {
 141  	// TODO: implement batching with wintun
 142  	return 1
 143  }
 144  
 145  // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
 146  
 147  func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
 148  	tun.running.Add(1)
 149  	defer tun.running.Done()
 150  retry:
 151  	if tun.close.Load() {
 152  		return 0, os.ErrClosed
 153  	}
 154  	start := nanotime()
 155  	shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
 156  	for {
 157  		if tun.close.Load() {
 158  			return 0, os.ErrClosed
 159  		}
 160  		packet, err := tun.session.ReceivePacket()
 161  		switch err {
 162  		case nil:
 163  			n := copy(bufs[0][offset:], packet)
 164  			sizes[0] = n
 165  			tun.session.ReleaseReceivePacket(packet)
 166  			tun.rate.update(uint64(n))
 167  			return 1, nil
 168  		case windows.ERROR_NO_MORE_ITEMS:
 169  			if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
 170  				windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
 171  				goto retry
 172  			}
 173  			procyield(1)
 174  			continue
 175  		case windows.ERROR_HANDLE_EOF:
 176  			return 0, os.ErrClosed
 177  		case windows.ERROR_INVALID_DATA:
 178  			return 0, errors.New("Send ring corrupt")
 179  		}
 180  		return 0, fmt.Errorf("Read failed: %w", err)
 181  	}
 182  }
 183  
 184  func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
 185  	tun.running.Add(1)
 186  	defer tun.running.Done()
 187  	if tun.close.Load() {
 188  		return 0, os.ErrClosed
 189  	}
 190  
 191  	for i, buf := range bufs {
 192  		packetSize := len(buf) - offset
 193  		tun.rate.update(uint64(packetSize))
 194  
 195  		packet, err := tun.session.AllocateSendPacket(packetSize)
 196  		switch err {
 197  		case nil:
 198  			// TODO: Explore options to eliminate this copy.
 199  			copy(packet, buf[offset:])
 200  			tun.session.SendPacket(packet)
 201  			continue
 202  		case windows.ERROR_HANDLE_EOF:
 203  			return i, os.ErrClosed
 204  		case windows.ERROR_BUFFER_OVERFLOW:
 205  			continue // Dropping when ring is full.
 206  		default:
 207  			return i, fmt.Errorf("Write failed: %w", err)
 208  		}
 209  	}
 210  	return len(bufs), nil
 211  }
 212  
 213  // LUID returns Windows interface instance ID.
 214  func (tun *NativeTun) LUID() uint64 {
 215  	tun.running.Add(1)
 216  	defer tun.running.Done()
 217  	if tun.close.Load() {
 218  		return 0
 219  	}
 220  	return tun.wt.LUID()
 221  }
 222  
 223  // RunningVersion returns the running version of the Wintun driver.
 224  func (tun *NativeTun) RunningVersion() (version uint32, err error) {
 225  	return wintun.RunningVersion()
 226  }
 227  
 228  func (rate *rateJuggler) update(packetLen uint64) {
 229  	now := nanotime()
 230  	total := rate.nextByteCount.Add(packetLen)
 231  	period := uint64(now - rate.nextStartTime.Load())
 232  	if period >= rateMeasurementGranularity {
 233  		if !rate.changing.CompareAndSwap(false, true) {
 234  			return
 235  		}
 236  		rate.nextStartTime.Store(now)
 237  		rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
 238  		rate.nextByteCount.Store(0)
 239  		rate.changing.Store(false)
 240  	}
 241  }
 242