bind_windows.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package conn
7
8 import (
9 "encoding/binary"
10 "io"
11 "net"
12 "net/netip"
13 "strconv"
14 "sync"
15 "sync/atomic"
16 "unsafe"
17
18 "golang.org/x/sys/windows"
19
20 "golang.zx2c4.com/wireguard/conn/winrio"
21 )
22
23 const (
24 packetsPerRing = 1024
25 bytesPerPacket = 2048 - 32
26 receiveSpins = 15
27 )
28
29 type ringPacket struct {
30 addr WinRingEndpoint
31 data [bytesPerPacket]byte
32 }
33
34 type ringBuffer struct {
35 packets uintptr
36 head, tail uint32
37 id winrio.BufferId
38 iocp windows.Handle
39 isFull bool
40 cq winrio.Cq
41 mu sync.Mutex
42 overlapped windows.Overlapped
43 }
44
45 func (rb *ringBuffer) Push() *ringPacket {
46 for rb.isFull {
47 panic("ring is full")
48 }
49 ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
50 rb.tail += 1
51 if rb.tail%packetsPerRing == rb.head%packetsPerRing {
52 rb.isFull = true
53 }
54 return ret
55 }
56
57 func (rb *ringBuffer) Return(count uint32) {
58 if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
59 return
60 }
61 rb.head += count
62 rb.isFull = false
63 }
64
65 type afWinRingBind struct {
66 sock windows.Handle
67 rx, tx ringBuffer
68 rq winrio.Rq
69 mu sync.Mutex
70 blackhole bool
71 }
72
73 // WinRingBind uses Windows registered I/O for fast ring buffered networking.
74 type WinRingBind struct {
75 v4, v6 afWinRingBind
76 mu sync.RWMutex
77 isOpen atomic.Uint32 // 0, 1, or 2
78 }
79
80 func NewDefaultBind() Bind { return NewWinRingBind() }
81
82 func NewWinRingBind() Bind {
83 if !winrio.Initialize() {
84 return NewStdNetBind()
85 }
86 return new(WinRingBind)
87 }
88
89 type WinRingEndpoint struct {
90 family uint16
91 data [30]byte
92 }
93
94 var (
95 _ Bind = (*WinRingBind)(nil)
96 _ Endpoint = (*WinRingEndpoint)(nil)
97 )
98
99 func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
100 host, port, err := net.SplitHostPort(s)
101 if err != nil {
102 return nil, err
103 }
104 host16, err := windows.UTF16PtrFromString(host)
105 if err != nil {
106 return nil, err
107 }
108 port16, err := windows.UTF16PtrFromString(port)
109 if err != nil {
110 return nil, err
111 }
112 hints := windows.AddrinfoW{
113 Flags: windows.AI_NUMERICHOST,
114 Family: windows.AF_UNSPEC,
115 Socktype: windows.SOCK_DGRAM,
116 Protocol: windows.IPPROTO_UDP,
117 }
118 var addrinfo *windows.AddrinfoW
119 err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
120 if err != nil {
121 return nil, err
122 }
123 defer windows.FreeAddrInfoW(addrinfo)
124 if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
125 return nil, windows.ERROR_INVALID_ADDRESS
126 }
127 var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
128 copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
129 return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
130 }
131
132 func (*WinRingEndpoint) ClearSrc() {}
133
134 func (e *WinRingEndpoint) DstIP() netip.Addr {
135 switch e.family {
136 case windows.AF_INET:
137 return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
138 case windows.AF_INET6:
139 return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
140 }
141 return netip.Addr{}
142 }
143
144 func (e *WinRingEndpoint) SrcIP() netip.Addr {
145 return netip.Addr{} // not supported
146 }
147
148 func (e *WinRingEndpoint) DstToBytes() []byte {
149 switch e.family {
150 case windows.AF_INET:
151 b := make([]byte, 0, 6)
152 b = append(b, e.data[2:6]...)
153 b = append(b, e.data[1], e.data[0])
154 return b
155 case windows.AF_INET6:
156 b := make([]byte, 0, 18)
157 b = append(b, e.data[6:22]...)
158 b = append(b, e.data[1], e.data[0])
159 return b
160 }
161 return nil
162 }
163
164 func (e *WinRingEndpoint) DstToString() string {
165 switch e.family {
166 case windows.AF_INET:
167 return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
168 case windows.AF_INET6:
169 var zone string
170 if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
171 zone = strconv.FormatUint(uint64(scope), 10)
172 }
173 return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
174 }
175 return ""
176 }
177
178 func (e *WinRingEndpoint) SrcToString() string {
179 return ""
180 }
181
182 func (ring *ringBuffer) CloseAndZero() {
183 if ring.cq != 0 {
184 winrio.CloseCompletionQueue(ring.cq)
185 ring.cq = 0
186 }
187 if ring.iocp != 0 {
188 windows.CloseHandle(ring.iocp)
189 ring.iocp = 0
190 }
191 if ring.id != 0 {
192 winrio.DeregisterBuffer(ring.id)
193 ring.id = 0
194 }
195 if ring.packets != 0 {
196 windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
197 ring.packets = 0
198 }
199 ring.head = 0
200 ring.tail = 0
201 ring.isFull = false
202 }
203
204 func (bind *afWinRingBind) CloseAndZero() {
205 bind.rx.CloseAndZero()
206 bind.tx.CloseAndZero()
207 if bind.sock != 0 {
208 windows.CloseHandle(bind.sock)
209 bind.sock = 0
210 }
211 bind.blackhole = false
212 }
213
214 func (bind *WinRingBind) closeAndZero() {
215 bind.isOpen.Store(0)
216 bind.v4.CloseAndZero()
217 bind.v6.CloseAndZero()
218 }
219
220 func (ring *ringBuffer) Open() error {
221 var err error
222 packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
223 ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
224 if err != nil {
225 return err
226 }
227 ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
228 if err != nil {
229 return err
230 }
231 ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
232 if err != nil {
233 return err
234 }
235 ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
236 if err != nil {
237 return err
238 }
239 return nil
240 }
241
242 func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
243 var err error
244 bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
245 if err != nil {
246 return nil, err
247 }
248 err = bind.rx.Open()
249 if err != nil {
250 return nil, err
251 }
252 err = bind.tx.Open()
253 if err != nil {
254 return nil, err
255 }
256 bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
257 if err != nil {
258 return nil, err
259 }
260 err = windows.Bind(bind.sock, sa)
261 if err != nil {
262 return nil, err
263 }
264 sa, err = windows.Getsockname(bind.sock)
265 if err != nil {
266 return nil, err
267 }
268 return sa, nil
269 }
270
271 func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
272 bind.mu.Lock()
273 defer bind.mu.Unlock()
274 defer func() {
275 if err != nil {
276 bind.closeAndZero()
277 }
278 }()
279 if bind.isOpen.Load() != 0 {
280 return nil, 0, ErrBindAlreadyOpen
281 }
282 var sa windows.Sockaddr
283 sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
284 if err != nil {
285 return nil, 0, err
286 }
287 sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
288 if err != nil {
289 return nil, 0, err
290 }
291 selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
292 for i := 0; i < packetsPerRing; i++ {
293 err = bind.v4.InsertReceiveRequest()
294 if err != nil {
295 return nil, 0, err
296 }
297 err = bind.v6.InsertReceiveRequest()
298 if err != nil {
299 return nil, 0, err
300 }
301 }
302 bind.isOpen.Store(1)
303 return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
304 }
305
306 func (bind *WinRingBind) Close() error {
307 bind.mu.RLock()
308 if bind.isOpen.Load() != 1 {
309 bind.mu.RUnlock()
310 return nil
311 }
312 bind.isOpen.Store(2)
313 windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
314 windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
315 windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
316 windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
317 bind.mu.RUnlock()
318 bind.mu.Lock()
319 defer bind.mu.Unlock()
320 bind.closeAndZero()
321 return nil
322 }
323
324 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
325 // rename the IdealBatchSize constant to BatchSize.
326 func (bind *WinRingBind) BatchSize() int {
327 // TODO: implement batching in and out of the ring
328 return 1
329 }
330
331 func (bind *WinRingBind) SetMark(mark uint32) error {
332 return nil
333 }
334
335 func (bind *afWinRingBind) InsertReceiveRequest() error {
336 packet := bind.rx.Push()
337 dataBuffer := &winrio.Buffer{
338 Id: bind.rx.id,
339 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
340 Length: uint32(len(packet.data)),
341 }
342 addressBuffer := &winrio.Buffer{
343 Id: bind.rx.id,
344 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
345 Length: uint32(unsafe.Sizeof(packet.addr)),
346 }
347 bind.mu.Lock()
348 defer bind.mu.Unlock()
349 return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
350 }
351
352 //go:linkname procyield runtime.procyield
353 func procyield(cycles uint32)
354
355 func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
356 if isOpen.Load() != 1 {
357 return 0, nil, net.ErrClosed
358 }
359 bind.rx.mu.Lock()
360 defer bind.rx.mu.Unlock()
361
362 var err error
363 var count uint32
364 var results [1]winrio.Result
365 retry:
366 count = 0
367 for tries := 0; count == 0 && tries < receiveSpins; tries++ {
368 if tries > 0 {
369 if isOpen.Load() != 1 {
370 return 0, nil, net.ErrClosed
371 }
372 procyield(1)
373 }
374 count = winrio.DequeueCompletion(bind.rx.cq, results[:])
375 }
376 if count == 0 {
377 err = winrio.Notify(bind.rx.cq)
378 if err != nil {
379 return 0, nil, err
380 }
381 var bytes uint32
382 var key uintptr
383 var overlapped *windows.Overlapped
384 err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
385 if err != nil {
386 return 0, nil, err
387 }
388 if isOpen.Load() != 1 {
389 return 0, nil, net.ErrClosed
390 }
391 count = winrio.DequeueCompletion(bind.rx.cq, results[:])
392 if count == 0 {
393 return 0, nil, io.ErrNoProgress
394 }
395 }
396 bind.rx.Return(1)
397 err = bind.InsertReceiveRequest()
398 if err != nil {
399 return 0, nil, err
400 }
401 // We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
402 // huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
403 // attacker bandwidth, just like the rest of the receive path.
404 if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
405 if isOpen.Load() != 1 {
406 return 0, nil, net.ErrClosed
407 }
408 goto retry
409 }
410 if results[0].Status != 0 {
411 return 0, nil, windows.Errno(results[0].Status)
412 }
413 packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
414 ep := packet.addr
415 n := copy(buf, packet.data[:results[0].BytesTransferred])
416 return n, &ep, nil
417 }
418
419 func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
420 bind.mu.RLock()
421 defer bind.mu.RUnlock()
422 n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
423 sizes[0] = n
424 eps[0] = ep
425 return 1, err
426 }
427
428 func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
429 bind.mu.RLock()
430 defer bind.mu.RUnlock()
431 n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
432 sizes[0] = n
433 eps[0] = ep
434 return 1, err
435 }
436
437 func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
438 if isOpen.Load() != 1 {
439 return net.ErrClosed
440 }
441 if len(buf) > bytesPerPacket {
442 return io.ErrShortBuffer
443 }
444 bind.tx.mu.Lock()
445 defer bind.tx.mu.Unlock()
446 var results [packetsPerRing]winrio.Result
447 count := winrio.DequeueCompletion(bind.tx.cq, results[:])
448 if count == 0 && bind.tx.isFull {
449 err := winrio.Notify(bind.tx.cq)
450 if err != nil {
451 return err
452 }
453 var bytes uint32
454 var key uintptr
455 var overlapped *windows.Overlapped
456 err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
457 if err != nil {
458 return err
459 }
460 if isOpen.Load() != 1 {
461 return net.ErrClosed
462 }
463 count = winrio.DequeueCompletion(bind.tx.cq, results[:])
464 if count == 0 {
465 return io.ErrNoProgress
466 }
467 }
468 if count > 0 {
469 bind.tx.Return(count)
470 }
471 packet := bind.tx.Push()
472 packet.addr = *nend
473 copy(packet.data[:], buf)
474 dataBuffer := &winrio.Buffer{
475 Id: bind.tx.id,
476 Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
477 Length: uint32(len(buf)),
478 }
479 addressBuffer := &winrio.Buffer{
480 Id: bind.tx.id,
481 Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
482 Length: uint32(unsafe.Sizeof(packet.addr)),
483 }
484 bind.mu.Lock()
485 defer bind.mu.Unlock()
486 return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
487 }
488
489 func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
490 nend, ok := endpoint.(*WinRingEndpoint)
491 if !ok {
492 return ErrWrongEndpointType
493 }
494 bind.mu.RLock()
495 defer bind.mu.RUnlock()
496 for _, buf := range bufs {
497 switch nend.family {
498 case windows.AF_INET:
499 if bind.v4.blackhole {
500 continue
501 }
502 if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
503 return err
504 }
505 case windows.AF_INET6:
506 if bind.v6.blackhole {
507 continue
508 }
509 if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
510 return err
511 }
512 }
513 }
514 return nil
515 }
516
517 func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
518 s.mu.Lock()
519 defer s.mu.Unlock()
520 sysconn, err := s.ipv4.SyscallConn()
521 if err != nil {
522 return err
523 }
524 err2 := sysconn.Control(func(fd uintptr) {
525 err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
526 })
527 if err2 != nil {
528 return err2
529 }
530 if err != nil {
531 return err
532 }
533 s.blackhole4 = blackhole
534 return nil
535 }
536
537 func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
538 s.mu.Lock()
539 defer s.mu.Unlock()
540 sysconn, err := s.ipv6.SyscallConn()
541 if err != nil {
542 return err
543 }
544 err2 := sysconn.Control(func(fd uintptr) {
545 err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
546 })
547 if err2 != nil {
548 return err2
549 }
550 if err != nil {
551 return err
552 }
553 s.blackhole6 = blackhole
554 return nil
555 }
556
557 func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
558 bind.mu.RLock()
559 defer bind.mu.RUnlock()
560 if bind.isOpen.Load() != 1 {
561 return net.ErrClosed
562 }
563 err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
564 if err != nil {
565 return err
566 }
567 bind.v4.blackhole = blackhole
568 return nil
569 }
570
571 func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
572 bind.mu.RLock()
573 defer bind.mu.RUnlock()
574 if bind.isOpen.Load() != 1 {
575 return net.ErrClosed
576 }
577 err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
578 if err != nil {
579 return err
580 }
581 bind.v6.blackhole = blackhole
582 return nil
583 }
584
585 func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
586 const IP_UNICAST_IF = 31
587 /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
588 var bytes [4]byte
589 binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
590 interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
591 err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
592 if err != nil {
593 return err
594 }
595 return nil
596 }
597
598 func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
599 const IPV6_UNICAST_IF = 31
600 return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
601 }
602