tun_windows.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package tun
7
8 import (
9 "errors"
10 "fmt"
11 "os"
12 "sync"
13 "sync/atomic"
14 "time"
15 _ "unsafe"
16
17 "golang.org/x/sys/windows"
18 "golang.zx2c4.com/wintun"
19 )
20
21 const (
22 rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
23 spinloopRateThreshold = 800000000 / 8 // 800mbps
24 spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
25 )
26
27 type rateJuggler struct {
28 current atomic.Uint64
29 nextByteCount atomic.Uint64
30 nextStartTime atomic.Int64
31 changing atomic.Bool
32 }
33
34 type NativeTun struct {
35 wt *wintun.Adapter
36 name string
37 handle windows.Handle
38 rate rateJuggler
39 session wintun.Session
40 readWait windows.Handle
41 events chan Event
42 running sync.WaitGroup
43 closeOnce sync.Once
44 close atomic.Bool
45 forcedMTU int
46 outSizes []int
47 }
48
49 var (
50 WintunTunnelType = "WireGuard"
51 WintunStaticRequestedGUID *windows.GUID
52 )
53
54 //go:linkname procyield runtime.procyield
55 func procyield(cycles uint32)
56
57 //go:linkname nanotime runtime.nanotime
58 func nanotime() int64
59
60 // CreateTUN creates a Wintun interface with the given name. Should a Wintun
61 // interface with the same name exist, it is reused.
62 func CreateTUN(ifname string, mtu int) (Device, error) {
63 return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
64 }
65
66 // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
67 // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
68 func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
69 wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
70 if err != nil {
71 return nil, fmt.Errorf("Error creating interface: %w", err)
72 }
73
74 forcedMTU := 1420
75 if mtu > 0 {
76 forcedMTU = mtu
77 }
78
79 tun := &NativeTun{
80 wt: wt,
81 name: ifname,
82 handle: windows.InvalidHandle,
83 events: make(chan Event, 10),
84 forcedMTU: forcedMTU,
85 }
86
87 tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
88 if err != nil {
89 tun.wt.Close()
90 close(tun.events)
91 return nil, fmt.Errorf("Error starting session: %w", err)
92 }
93 tun.readWait = tun.session.ReadWaitEvent()
94 return tun, nil
95 }
96
97 func (tun *NativeTun) Name() (string, error) {
98 return tun.name, nil
99 }
100
101 func (tun *NativeTun) File() *os.File {
102 return nil
103 }
104
105 func (tun *NativeTun) Events() <-chan Event {
106 return tun.events
107 }
108
109 func (tun *NativeTun) Close() error {
110 var err error
111 tun.closeOnce.Do(func() {
112 tun.close.Store(true)
113 windows.SetEvent(tun.readWait)
114 tun.running.Wait()
115 tun.session.End()
116 if tun.wt != nil {
117 tun.wt.Close()
118 }
119 close(tun.events)
120 })
121 return err
122 }
123
124 func (tun *NativeTun) MTU() (int, error) {
125 return tun.forcedMTU, nil
126 }
127
128 // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
129 func (tun *NativeTun) ForceMTU(mtu int) {
130 if tun.close.Load() {
131 return
132 }
133 update := tun.forcedMTU != mtu
134 tun.forcedMTU = mtu
135 if update {
136 tun.events <- EventMTUUpdate
137 }
138 }
139
140 func (tun *NativeTun) BatchSize() int {
141 // TODO: implement batching with wintun
142 return 1
143 }
144
145 // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
146
147 func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
148 tun.running.Add(1)
149 defer tun.running.Done()
150 retry:
151 if tun.close.Load() {
152 return 0, os.ErrClosed
153 }
154 start := nanotime()
155 shouldSpin := tun.rate.current.Load() >= spinloopRateThreshold && uint64(start-tun.rate.nextStartTime.Load()) <= rateMeasurementGranularity*2
156 for {
157 if tun.close.Load() {
158 return 0, os.ErrClosed
159 }
160 packet, err := tun.session.ReceivePacket()
161 switch err {
162 case nil:
163 n := copy(bufs[0][offset:], packet)
164 sizes[0] = n
165 tun.session.ReleaseReceivePacket(packet)
166 tun.rate.update(uint64(n))
167 return 1, nil
168 case windows.ERROR_NO_MORE_ITEMS:
169 if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
170 windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
171 goto retry
172 }
173 procyield(1)
174 continue
175 case windows.ERROR_HANDLE_EOF:
176 return 0, os.ErrClosed
177 case windows.ERROR_INVALID_DATA:
178 return 0, errors.New("Send ring corrupt")
179 }
180 return 0, fmt.Errorf("Read failed: %w", err)
181 }
182 }
183
184 func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) {
185 tun.running.Add(1)
186 defer tun.running.Done()
187 if tun.close.Load() {
188 return 0, os.ErrClosed
189 }
190
191 for i, buf := range bufs {
192 packetSize := len(buf) - offset
193 tun.rate.update(uint64(packetSize))
194
195 packet, err := tun.session.AllocateSendPacket(packetSize)
196 switch err {
197 case nil:
198 // TODO: Explore options to eliminate this copy.
199 copy(packet, buf[offset:])
200 tun.session.SendPacket(packet)
201 continue
202 case windows.ERROR_HANDLE_EOF:
203 return i, os.ErrClosed
204 case windows.ERROR_BUFFER_OVERFLOW:
205 continue // Dropping when ring is full.
206 default:
207 return i, fmt.Errorf("Write failed: %w", err)
208 }
209 }
210 return len(bufs), nil
211 }
212
213 // LUID returns Windows interface instance ID.
214 func (tun *NativeTun) LUID() uint64 {
215 tun.running.Add(1)
216 defer tun.running.Done()
217 if tun.close.Load() {
218 return 0
219 }
220 return tun.wt.LUID()
221 }
222
223 // RunningVersion returns the running version of the Wintun driver.
224 func (tun *NativeTun) RunningVersion() (version uint32, err error) {
225 return wintun.RunningVersion()
226 }
227
228 func (rate *rateJuggler) update(packetLen uint64) {
229 now := nanotime()
230 total := rate.nextByteCount.Add(packetLen)
231 period := uint64(now - rate.nextStartTime.Load())
232 if period >= rateMeasurementGranularity {
233 if !rate.changing.CompareAndSwap(false, true) {
234 return
235 }
236 rate.nextStartTime.Store(now)
237 rate.current.Store(total * uint64(time.Second/time.Nanosecond) / period)
238 rate.nextByteCount.Store(0)
239 rate.changing.Store(false)
240 }
241 }
242