bind_windows.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package conn
   7  
   8  import (
   9  	"encoding/binary"
  10  	"io"
  11  	"net"
  12  	"net/netip"
  13  	"strconv"
  14  	"sync"
  15  	"sync/atomic"
  16  	"unsafe"
  17  
  18  	"golang.org/x/sys/windows"
  19  
  20  	"golang.zx2c4.com/wireguard/conn/winrio"
  21  )
  22  
  23  const (
  24  	packetsPerRing = 1024
  25  	bytesPerPacket = 2048 - 32
  26  	receiveSpins   = 15
  27  )
  28  
  29  type ringPacket struct {
  30  	addr WinRingEndpoint
  31  	data [bytesPerPacket]byte
  32  }
  33  
  34  type ringBuffer struct {
  35  	packets    uintptr
  36  	head, tail uint32
  37  	id         winrio.BufferId
  38  	iocp       windows.Handle
  39  	isFull     bool
  40  	cq         winrio.Cq
  41  	mu         sync.Mutex
  42  	overlapped windows.Overlapped
  43  }
  44  
  45  func (rb *ringBuffer) Push() *ringPacket {
  46  	for rb.isFull {
  47  		panic("ring is full")
  48  	}
  49  	ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
  50  	rb.tail += 1
  51  	if rb.tail%packetsPerRing == rb.head%packetsPerRing {
  52  		rb.isFull = true
  53  	}
  54  	return ret
  55  }
  56  
  57  func (rb *ringBuffer) Return(count uint32) {
  58  	if rb.head%packetsPerRing == rb.tail%packetsPerRing && !rb.isFull {
  59  		return
  60  	}
  61  	rb.head += count
  62  	rb.isFull = false
  63  }
  64  
  65  type afWinRingBind struct {
  66  	sock      windows.Handle
  67  	rx, tx    ringBuffer
  68  	rq        winrio.Rq
  69  	mu        sync.Mutex
  70  	blackhole bool
  71  }
  72  
  73  // WinRingBind uses Windows registered I/O for fast ring buffered networking.
  74  type WinRingBind struct {
  75  	v4, v6 afWinRingBind
  76  	mu     sync.RWMutex
  77  	isOpen atomic.Uint32 // 0, 1, or 2
  78  }
  79  
  80  func NewDefaultBind() Bind { return NewWinRingBind() }
  81  
  82  func NewWinRingBind() Bind {
  83  	if !winrio.Initialize() {
  84  		return NewStdNetBind()
  85  	}
  86  	return new(WinRingBind)
  87  }
  88  
  89  type WinRingEndpoint struct {
  90  	family uint16
  91  	data   [30]byte
  92  }
  93  
  94  var (
  95  	_ Bind     = (*WinRingBind)(nil)
  96  	_ Endpoint = (*WinRingEndpoint)(nil)
  97  )
  98  
  99  func (*WinRingBind) ParseEndpoint(s string) (Endpoint, error) {
 100  	host, port, err := net.SplitHostPort(s)
 101  	if err != nil {
 102  		return nil, err
 103  	}
 104  	host16, err := windows.UTF16PtrFromString(host)
 105  	if err != nil {
 106  		return nil, err
 107  	}
 108  	port16, err := windows.UTF16PtrFromString(port)
 109  	if err != nil {
 110  		return nil, err
 111  	}
 112  	hints := windows.AddrinfoW{
 113  		Flags:    windows.AI_NUMERICHOST,
 114  		Family:   windows.AF_UNSPEC,
 115  		Socktype: windows.SOCK_DGRAM,
 116  		Protocol: windows.IPPROTO_UDP,
 117  	}
 118  	var addrinfo *windows.AddrinfoW
 119  	err = windows.GetAddrInfoW(host16, port16, &hints, &addrinfo)
 120  	if err != nil {
 121  		return nil, err
 122  	}
 123  	defer windows.FreeAddrInfoW(addrinfo)
 124  	if (addrinfo.Family != windows.AF_INET && addrinfo.Family != windows.AF_INET6) || addrinfo.Addrlen > unsafe.Sizeof(WinRingEndpoint{}) {
 125  		return nil, windows.ERROR_INVALID_ADDRESS
 126  	}
 127  	var dst [unsafe.Sizeof(WinRingEndpoint{})]byte
 128  	copy(dst[:], unsafe.Slice((*byte)(unsafe.Pointer(addrinfo.Addr)), addrinfo.Addrlen))
 129  	return (*WinRingEndpoint)(unsafe.Pointer(&dst[0])), nil
 130  }
 131  
 132  func (*WinRingEndpoint) ClearSrc() {}
 133  
 134  func (e *WinRingEndpoint) DstIP() netip.Addr {
 135  	switch e.family {
 136  	case windows.AF_INET:
 137  		return netip.AddrFrom4(*(*[4]byte)(e.data[2:6]))
 138  	case windows.AF_INET6:
 139  		return netip.AddrFrom16(*(*[16]byte)(e.data[6:22]))
 140  	}
 141  	return netip.Addr{}
 142  }
 143  
 144  func (e *WinRingEndpoint) SrcIP() netip.Addr {
 145  	return netip.Addr{} // not supported
 146  }
 147  
 148  func (e *WinRingEndpoint) DstToBytes() []byte {
 149  	switch e.family {
 150  	case windows.AF_INET:
 151  		b := make([]byte, 0, 6)
 152  		b = append(b, e.data[2:6]...)
 153  		b = append(b, e.data[1], e.data[0])
 154  		return b
 155  	case windows.AF_INET6:
 156  		b := make([]byte, 0, 18)
 157  		b = append(b, e.data[6:22]...)
 158  		b = append(b, e.data[1], e.data[0])
 159  		return b
 160  	}
 161  	return nil
 162  }
 163  
 164  func (e *WinRingEndpoint) DstToString() string {
 165  	switch e.family {
 166  	case windows.AF_INET:
 167  		return netip.AddrPortFrom(netip.AddrFrom4(*(*[4]byte)(e.data[2:6])), binary.BigEndian.Uint16(e.data[0:2])).String()
 168  	case windows.AF_INET6:
 169  		var zone string
 170  		if scope := *(*uint32)(unsafe.Pointer(&e.data[22])); scope > 0 {
 171  			zone = strconv.FormatUint(uint64(scope), 10)
 172  		}
 173  		return netip.AddrPortFrom(netip.AddrFrom16(*(*[16]byte)(e.data[6:22])).WithZone(zone), binary.BigEndian.Uint16(e.data[0:2])).String()
 174  	}
 175  	return ""
 176  }
 177  
 178  func (e *WinRingEndpoint) SrcToString() string {
 179  	return ""
 180  }
 181  
 182  func (ring *ringBuffer) CloseAndZero() {
 183  	if ring.cq != 0 {
 184  		winrio.CloseCompletionQueue(ring.cq)
 185  		ring.cq = 0
 186  	}
 187  	if ring.iocp != 0 {
 188  		windows.CloseHandle(ring.iocp)
 189  		ring.iocp = 0
 190  	}
 191  	if ring.id != 0 {
 192  		winrio.DeregisterBuffer(ring.id)
 193  		ring.id = 0
 194  	}
 195  	if ring.packets != 0 {
 196  		windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
 197  		ring.packets = 0
 198  	}
 199  	ring.head = 0
 200  	ring.tail = 0
 201  	ring.isFull = false
 202  }
 203  
 204  func (bind *afWinRingBind) CloseAndZero() {
 205  	bind.rx.CloseAndZero()
 206  	bind.tx.CloseAndZero()
 207  	if bind.sock != 0 {
 208  		windows.CloseHandle(bind.sock)
 209  		bind.sock = 0
 210  	}
 211  	bind.blackhole = false
 212  }
 213  
 214  func (bind *WinRingBind) closeAndZero() {
 215  	bind.isOpen.Store(0)
 216  	bind.v4.CloseAndZero()
 217  	bind.v6.CloseAndZero()
 218  }
 219  
 220  func (ring *ringBuffer) Open() error {
 221  	var err error
 222  	packetsLen := unsafe.Sizeof(ringPacket{}) * packetsPerRing
 223  	ring.packets, err = windows.VirtualAlloc(0, packetsLen, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
 224  	if err != nil {
 225  		return err
 226  	}
 227  	ring.id, err = winrio.RegisterPointer(unsafe.Pointer(ring.packets), uint32(packetsLen))
 228  	if err != nil {
 229  		return err
 230  	}
 231  	ring.iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
 232  	if err != nil {
 233  		return err
 234  	}
 235  	ring.cq, err = winrio.CreateIOCPCompletionQueue(packetsPerRing, ring.iocp, 0, &ring.overlapped)
 236  	if err != nil {
 237  		return err
 238  	}
 239  	return nil
 240  }
 241  
 242  func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sockaddr, error) {
 243  	var err error
 244  	bind.sock, err = winrio.Socket(family, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
 245  	if err != nil {
 246  		return nil, err
 247  	}
 248  	err = bind.rx.Open()
 249  	if err != nil {
 250  		return nil, err
 251  	}
 252  	err = bind.tx.Open()
 253  	if err != nil {
 254  		return nil, err
 255  	}
 256  	bind.rq, err = winrio.CreateRequestQueue(bind.sock, packetsPerRing, 1, packetsPerRing, 1, bind.rx.cq, bind.tx.cq, 0)
 257  	if err != nil {
 258  		return nil, err
 259  	}
 260  	err = windows.Bind(bind.sock, sa)
 261  	if err != nil {
 262  		return nil, err
 263  	}
 264  	sa, err = windows.Getsockname(bind.sock)
 265  	if err != nil {
 266  		return nil, err
 267  	}
 268  	return sa, nil
 269  }
 270  
 271  func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
 272  	bind.mu.Lock()
 273  	defer bind.mu.Unlock()
 274  	defer func() {
 275  		if err != nil {
 276  			bind.closeAndZero()
 277  		}
 278  	}()
 279  	if bind.isOpen.Load() != 0 {
 280  		return nil, 0, ErrBindAlreadyOpen
 281  	}
 282  	var sa windows.Sockaddr
 283  	sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
 284  	if err != nil {
 285  		return nil, 0, err
 286  	}
 287  	sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
 288  	if err != nil {
 289  		return nil, 0, err
 290  	}
 291  	selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
 292  	for i := 0; i < packetsPerRing; i++ {
 293  		err = bind.v4.InsertReceiveRequest()
 294  		if err != nil {
 295  			return nil, 0, err
 296  		}
 297  		err = bind.v6.InsertReceiveRequest()
 298  		if err != nil {
 299  			return nil, 0, err
 300  		}
 301  	}
 302  	bind.isOpen.Store(1)
 303  	return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
 304  }
 305  
 306  func (bind *WinRingBind) Close() error {
 307  	bind.mu.RLock()
 308  	if bind.isOpen.Load() != 1 {
 309  		bind.mu.RUnlock()
 310  		return nil
 311  	}
 312  	bind.isOpen.Store(2)
 313  	windows.PostQueuedCompletionStatus(bind.v4.rx.iocp, 0, 0, nil)
 314  	windows.PostQueuedCompletionStatus(bind.v4.tx.iocp, 0, 0, nil)
 315  	windows.PostQueuedCompletionStatus(bind.v6.rx.iocp, 0, 0, nil)
 316  	windows.PostQueuedCompletionStatus(bind.v6.tx.iocp, 0, 0, nil)
 317  	bind.mu.RUnlock()
 318  	bind.mu.Lock()
 319  	defer bind.mu.Unlock()
 320  	bind.closeAndZero()
 321  	return nil
 322  }
 323  
 324  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
 325  // rename the IdealBatchSize constant to BatchSize.
 326  func (bind *WinRingBind) BatchSize() int {
 327  	// TODO: implement batching in and out of the ring
 328  	return 1
 329  }
 330  
 331  func (bind *WinRingBind) SetMark(mark uint32) error {
 332  	return nil
 333  }
 334  
 335  func (bind *afWinRingBind) InsertReceiveRequest() error {
 336  	packet := bind.rx.Push()
 337  	dataBuffer := &winrio.Buffer{
 338  		Id:     bind.rx.id,
 339  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.rx.packets),
 340  		Length: uint32(len(packet.data)),
 341  	}
 342  	addressBuffer := &winrio.Buffer{
 343  		Id:     bind.rx.id,
 344  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.rx.packets),
 345  		Length: uint32(unsafe.Sizeof(packet.addr)),
 346  	}
 347  	bind.mu.Lock()
 348  	defer bind.mu.Unlock()
 349  	return winrio.ReceiveEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, uintptr(unsafe.Pointer(packet)))
 350  }
 351  
 352  //go:linkname procyield runtime.procyield
 353  func procyield(cycles uint32)
 354  
 355  func (bind *afWinRingBind) Receive(buf []byte, isOpen *atomic.Uint32) (int, Endpoint, error) {
 356  	if isOpen.Load() != 1 {
 357  		return 0, nil, net.ErrClosed
 358  	}
 359  	bind.rx.mu.Lock()
 360  	defer bind.rx.mu.Unlock()
 361  
 362  	var err error
 363  	var count uint32
 364  	var results [1]winrio.Result
 365  retry:
 366  	count = 0
 367  	for tries := 0; count == 0 && tries < receiveSpins; tries++ {
 368  		if tries > 0 {
 369  			if isOpen.Load() != 1 {
 370  				return 0, nil, net.ErrClosed
 371  			}
 372  			procyield(1)
 373  		}
 374  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
 375  	}
 376  	if count == 0 {
 377  		err = winrio.Notify(bind.rx.cq)
 378  		if err != nil {
 379  			return 0, nil, err
 380  		}
 381  		var bytes uint32
 382  		var key uintptr
 383  		var overlapped *windows.Overlapped
 384  		err = windows.GetQueuedCompletionStatus(bind.rx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
 385  		if err != nil {
 386  			return 0, nil, err
 387  		}
 388  		if isOpen.Load() != 1 {
 389  			return 0, nil, net.ErrClosed
 390  		}
 391  		count = winrio.DequeueCompletion(bind.rx.cq, results[:])
 392  		if count == 0 {
 393  			return 0, nil, io.ErrNoProgress
 394  		}
 395  	}
 396  	bind.rx.Return(1)
 397  	err = bind.InsertReceiveRequest()
 398  	if err != nil {
 399  		return 0, nil, err
 400  	}
 401  	// We limit the MTU well below the 65k max for practicality, but this means a remote host can still send us
 402  	// huge packets. Just try again when this happens. The infinite loop this could cause is still limited to
 403  	// attacker bandwidth, just like the rest of the receive path.
 404  	if windows.Errno(results[0].Status) == windows.WSAEMSGSIZE {
 405  		if isOpen.Load() != 1 {
 406  			return 0, nil, net.ErrClosed
 407  		}
 408  		goto retry
 409  	}
 410  	if results[0].Status != 0 {
 411  		return 0, nil, windows.Errno(results[0].Status)
 412  	}
 413  	packet := (*ringPacket)(unsafe.Pointer(uintptr(results[0].RequestContext)))
 414  	ep := packet.addr
 415  	n := copy(buf, packet.data[:results[0].BytesTransferred])
 416  	return n, &ep, nil
 417  }
 418  
 419  func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
 420  	bind.mu.RLock()
 421  	defer bind.mu.RUnlock()
 422  	n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
 423  	sizes[0] = n
 424  	eps[0] = ep
 425  	return 1, err
 426  }
 427  
 428  func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
 429  	bind.mu.RLock()
 430  	defer bind.mu.RUnlock()
 431  	n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
 432  	sizes[0] = n
 433  	eps[0] = ep
 434  	return 1, err
 435  }
 436  
 437  func (bind *afWinRingBind) Send(buf []byte, nend *WinRingEndpoint, isOpen *atomic.Uint32) error {
 438  	if isOpen.Load() != 1 {
 439  		return net.ErrClosed
 440  	}
 441  	if len(buf) > bytesPerPacket {
 442  		return io.ErrShortBuffer
 443  	}
 444  	bind.tx.mu.Lock()
 445  	defer bind.tx.mu.Unlock()
 446  	var results [packetsPerRing]winrio.Result
 447  	count := winrio.DequeueCompletion(bind.tx.cq, results[:])
 448  	if count == 0 && bind.tx.isFull {
 449  		err := winrio.Notify(bind.tx.cq)
 450  		if err != nil {
 451  			return err
 452  		}
 453  		var bytes uint32
 454  		var key uintptr
 455  		var overlapped *windows.Overlapped
 456  		err = windows.GetQueuedCompletionStatus(bind.tx.iocp, &bytes, &key, &overlapped, windows.INFINITE)
 457  		if err != nil {
 458  			return err
 459  		}
 460  		if isOpen.Load() != 1 {
 461  			return net.ErrClosed
 462  		}
 463  		count = winrio.DequeueCompletion(bind.tx.cq, results[:])
 464  		if count == 0 {
 465  			return io.ErrNoProgress
 466  		}
 467  	}
 468  	if count > 0 {
 469  		bind.tx.Return(count)
 470  	}
 471  	packet := bind.tx.Push()
 472  	packet.addr = *nend
 473  	copy(packet.data[:], buf)
 474  	dataBuffer := &winrio.Buffer{
 475  		Id:     bind.tx.id,
 476  		Offset: uint32(uintptr(unsafe.Pointer(&packet.data[0])) - bind.tx.packets),
 477  		Length: uint32(len(buf)),
 478  	}
 479  	addressBuffer := &winrio.Buffer{
 480  		Id:     bind.tx.id,
 481  		Offset: uint32(uintptr(unsafe.Pointer(&packet.addr)) - bind.tx.packets),
 482  		Length: uint32(unsafe.Sizeof(packet.addr)),
 483  	}
 484  	bind.mu.Lock()
 485  	defer bind.mu.Unlock()
 486  	return winrio.SendEx(bind.rq, dataBuffer, 1, nil, addressBuffer, nil, nil, 0, 0)
 487  }
 488  
 489  func (bind *WinRingBind) Send(bufs [][]byte, endpoint Endpoint) error {
 490  	nend, ok := endpoint.(*WinRingEndpoint)
 491  	if !ok {
 492  		return ErrWrongEndpointType
 493  	}
 494  	bind.mu.RLock()
 495  	defer bind.mu.RUnlock()
 496  	for _, buf := range bufs {
 497  		switch nend.family {
 498  		case windows.AF_INET:
 499  			if bind.v4.blackhole {
 500  				continue
 501  			}
 502  			if err := bind.v4.Send(buf, nend, &bind.isOpen); err != nil {
 503  				return err
 504  			}
 505  		case windows.AF_INET6:
 506  			if bind.v6.blackhole {
 507  				continue
 508  			}
 509  			if err := bind.v6.Send(buf, nend, &bind.isOpen); err != nil {
 510  				return err
 511  			}
 512  		}
 513  	}
 514  	return nil
 515  }
 516  
 517  func (s *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
 518  	s.mu.Lock()
 519  	defer s.mu.Unlock()
 520  	sysconn, err := s.ipv4.SyscallConn()
 521  	if err != nil {
 522  		return err
 523  	}
 524  	err2 := sysconn.Control(func(fd uintptr) {
 525  		err = bindSocketToInterface4(windows.Handle(fd), interfaceIndex)
 526  	})
 527  	if err2 != nil {
 528  		return err2
 529  	}
 530  	if err != nil {
 531  		return err
 532  	}
 533  	s.blackhole4 = blackhole
 534  	return nil
 535  }
 536  
 537  func (s *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
 538  	s.mu.Lock()
 539  	defer s.mu.Unlock()
 540  	sysconn, err := s.ipv6.SyscallConn()
 541  	if err != nil {
 542  		return err
 543  	}
 544  	err2 := sysconn.Control(func(fd uintptr) {
 545  		err = bindSocketToInterface6(windows.Handle(fd), interfaceIndex)
 546  	})
 547  	if err2 != nil {
 548  		return err2
 549  	}
 550  	if err != nil {
 551  		return err
 552  	}
 553  	s.blackhole6 = blackhole
 554  	return nil
 555  }
 556  
 557  func (bind *WinRingBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
 558  	bind.mu.RLock()
 559  	defer bind.mu.RUnlock()
 560  	if bind.isOpen.Load() != 1 {
 561  		return net.ErrClosed
 562  	}
 563  	err := bindSocketToInterface4(bind.v4.sock, interfaceIndex)
 564  	if err != nil {
 565  		return err
 566  	}
 567  	bind.v4.blackhole = blackhole
 568  	return nil
 569  }
 570  
 571  func (bind *WinRingBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
 572  	bind.mu.RLock()
 573  	defer bind.mu.RUnlock()
 574  	if bind.isOpen.Load() != 1 {
 575  		return net.ErrClosed
 576  	}
 577  	err := bindSocketToInterface6(bind.v6.sock, interfaceIndex)
 578  	if err != nil {
 579  		return err
 580  	}
 581  	bind.v6.blackhole = blackhole
 582  	return nil
 583  }
 584  
 585  func bindSocketToInterface4(handle windows.Handle, interfaceIndex uint32) error {
 586  	const IP_UNICAST_IF = 31
 587  	/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
 588  	var bytes [4]byte
 589  	binary.BigEndian.PutUint32(bytes[:], interfaceIndex)
 590  	interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
 591  	err := windows.SetsockoptInt(handle, windows.IPPROTO_IP, IP_UNICAST_IF, int(interfaceIndex))
 592  	if err != nil {
 593  		return err
 594  	}
 595  	return nil
 596  }
 597  
 598  func bindSocketToInterface6(handle windows.Handle, interfaceIndex uint32) error {
 599  	const IPV6_UNICAST_IF = 31
 600  	return windows.SetsockoptInt(handle, windows.IPPROTO_IPV6, IPV6_UNICAST_IF, int(interfaceIndex))
 601  }
 602