rio_windows.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package winrio
   7  
   8  import (
   9  	"log"
  10  	"sync"
  11  	"syscall"
  12  	"unsafe"
  13  
  14  	"golang.org/x/sys/windows"
  15  )
  16  
  17  const (
  18  	MsgDontNotify = 1
  19  	MsgDefer      = 2
  20  	MsgWaitAll    = 4
  21  	MsgCommitOnly = 8
  22  
  23  	MaxCqSize = 0x8000000
  24  
  25  	invalidBufferId = 0xFFFFFFFF
  26  	invalidCq       = 0
  27  	invalidRq       = 0
  28  	corruptCq       = 0xFFFFFFFF
  29  )
  30  
  31  var extensionFunctionTable struct {
  32  	cbSize                   uint32
  33  	rioReceive               uintptr
  34  	rioReceiveEx             uintptr
  35  	rioSend                  uintptr
  36  	rioSendEx                uintptr
  37  	rioCloseCompletionQueue  uintptr
  38  	rioCreateCompletionQueue uintptr
  39  	rioCreateRequestQueue    uintptr
  40  	rioDequeueCompletion     uintptr
  41  	rioDeregisterBuffer      uintptr
  42  	rioNotify                uintptr
  43  	rioRegisterBuffer        uintptr
  44  	rioResizeCompletionQueue uintptr
  45  	rioResizeRequestQueue    uintptr
  46  }
  47  
  48  type Cq uintptr
  49  
  50  type Rq uintptr
  51  
  52  type BufferId uintptr
  53  
  54  type Buffer struct {
  55  	Id     BufferId
  56  	Offset uint32
  57  	Length uint32
  58  }
  59  
  60  type Result struct {
  61  	Status           int32
  62  	BytesTransferred uint32
  63  	SocketContext    uint64
  64  	RequestContext   uint64
  65  }
  66  
  67  type notificationCompletionType uint32
  68  
  69  const (
  70  	eventCompletion notificationCompletionType = 1
  71  	iocpCompletion  notificationCompletionType = 2
  72  )
  73  
  74  type eventNotificationCompletion struct {
  75  	completionType notificationCompletionType
  76  	event          windows.Handle
  77  	notifyReset    uint32
  78  }
  79  
  80  type iocpNotificationCompletion struct {
  81  	completionType notificationCompletionType
  82  	iocp           windows.Handle
  83  	key            uintptr
  84  	overlapped     *windows.Overlapped
  85  }
  86  
  87  var (
  88  	initialized sync.Once
  89  	available   bool
  90  )
  91  
  92  func Initialize() bool {
  93  	initialized.Do(func() {
  94  		var (
  95  			err    error
  96  			socket windows.Handle
  97  			cq     Cq
  98  		)
  99  		defer func() {
 100  			if err == nil {
 101  				return
 102  			}
 103  			if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
 104  				return
 105  			}
 106  			log.Printf("Registered I/O is unavailable: %v", err)
 107  		}()
 108  		socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
 109  		if err != nil {
 110  			return
 111  		}
 112  		defer windows.CloseHandle(socket)
 113  		WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
 114  		const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
 115  		ob := uint32(0)
 116  		err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
 117  			(*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
 118  			(*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
 119  			&ob, nil, 0)
 120  		if err != nil {
 121  			return
 122  		}
 123  
 124  		// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
 125  		// failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
 126  		var iocp windows.Handle
 127  		iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
 128  		if err != nil {
 129  			return
 130  		}
 131  		defer windows.CloseHandle(iocp)
 132  		var overlapped windows.Overlapped
 133  		cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
 134  		if err != nil {
 135  			return
 136  		}
 137  		defer CloseCompletionQueue(cq)
 138  		_, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
 139  		if err != nil {
 140  			return
 141  		}
 142  		available = true
 143  	})
 144  	return available
 145  }
 146  
 147  func Socket(af, typ, proto int32) (windows.Handle, error) {
 148  	return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
 149  }
 150  
 151  func CloseCompletionQueue(cq Cq) {
 152  	_, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
 153  }
 154  
 155  func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
 156  	notificationCompletion := &eventNotificationCompletion{
 157  		completionType: eventCompletion,
 158  		event:          event,
 159  	}
 160  	if notifyReset {
 161  		notificationCompletion.notifyReset = 1
 162  	}
 163  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
 164  	if ret == invalidCq {
 165  		return 0, err
 166  	}
 167  	return Cq(ret), nil
 168  }
 169  
 170  func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
 171  	notificationCompletion := &iocpNotificationCompletion{
 172  		completionType: iocpCompletion,
 173  		iocp:           iocp,
 174  		key:            key,
 175  		overlapped:     overlapped,
 176  	}
 177  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
 178  	if ret == invalidCq {
 179  		return 0, err
 180  	}
 181  	return Cq(ret), nil
 182  }
 183  
 184  func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
 185  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
 186  	if ret == invalidCq {
 187  		return 0, err
 188  	}
 189  	return Cq(ret), nil
 190  }
 191  
 192  func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
 193  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
 194  	if ret == invalidRq {
 195  		return 0, err
 196  	}
 197  	return Rq(ret), nil
 198  }
 199  
 200  func DequeueCompletion(cq Cq, results []Result) uint32 {
 201  	var array uintptr
 202  	if len(results) > 0 {
 203  		array = uintptr(unsafe.Pointer(&results[0]))
 204  	}
 205  	ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
 206  	if ret == corruptCq {
 207  		panic("cq is corrupt")
 208  	}
 209  	return uint32(ret)
 210  }
 211  
 212  func DeregisterBuffer(id BufferId) {
 213  	_, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
 214  }
 215  
 216  func RegisterBuffer(buffer []byte) (BufferId, error) {
 217  	var buf unsafe.Pointer
 218  	if len(buffer) > 0 {
 219  		buf = unsafe.Pointer(&buffer[0])
 220  	}
 221  	return RegisterPointer(buf, uint32(len(buffer)))
 222  }
 223  
 224  func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
 225  	ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
 226  	if ret == invalidBufferId {
 227  		return 0, err
 228  	}
 229  	return BufferId(ret), nil
 230  }
 231  
 232  func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
 233  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
 234  	if ret == 0 {
 235  		return err
 236  	}
 237  	return nil
 238  }
 239  
 240  func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
 241  	ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
 242  	if ret == 0 {
 243  		return err
 244  	}
 245  	return nil
 246  }
 247  
 248  func Notify(cq Cq) error {
 249  	ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
 250  	if ret != 0 {
 251  		return windows.Errno(ret)
 252  	}
 253  	return nil
 254  }
 255