rio_windows.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package winrio
7
8 import (
9 "log"
10 "sync"
11 "syscall"
12 "unsafe"
13
14 "golang.org/x/sys/windows"
15 )
16
17 const (
18 MsgDontNotify = 1
19 MsgDefer = 2
20 MsgWaitAll = 4
21 MsgCommitOnly = 8
22
23 MaxCqSize = 0x8000000
24
25 invalidBufferId = 0xFFFFFFFF
26 invalidCq = 0
27 invalidRq = 0
28 corruptCq = 0xFFFFFFFF
29 )
30
31 var extensionFunctionTable struct {
32 cbSize uint32
33 rioReceive uintptr
34 rioReceiveEx uintptr
35 rioSend uintptr
36 rioSendEx uintptr
37 rioCloseCompletionQueue uintptr
38 rioCreateCompletionQueue uintptr
39 rioCreateRequestQueue uintptr
40 rioDequeueCompletion uintptr
41 rioDeregisterBuffer uintptr
42 rioNotify uintptr
43 rioRegisterBuffer uintptr
44 rioResizeCompletionQueue uintptr
45 rioResizeRequestQueue uintptr
46 }
47
48 type Cq uintptr
49
50 type Rq uintptr
51
52 type BufferId uintptr
53
54 type Buffer struct {
55 Id BufferId
56 Offset uint32
57 Length uint32
58 }
59
60 type Result struct {
61 Status int32
62 BytesTransferred uint32
63 SocketContext uint64
64 RequestContext uint64
65 }
66
67 type notificationCompletionType uint32
68
69 const (
70 eventCompletion notificationCompletionType = 1
71 iocpCompletion notificationCompletionType = 2
72 )
73
74 type eventNotificationCompletion struct {
75 completionType notificationCompletionType
76 event windows.Handle
77 notifyReset uint32
78 }
79
80 type iocpNotificationCompletion struct {
81 completionType notificationCompletionType
82 iocp windows.Handle
83 key uintptr
84 overlapped *windows.Overlapped
85 }
86
87 var (
88 initialized sync.Once
89 available bool
90 )
91
92 func Initialize() bool {
93 initialized.Do(func() {
94 var (
95 err error
96 socket windows.Handle
97 cq Cq
98 )
99 defer func() {
100 if err == nil {
101 return
102 }
103 if maj, _, _ := windows.RtlGetNtVersionNumbers(); maj <= 7 {
104 return
105 }
106 log.Printf("Registered I/O is unavailable: %v", err)
107 }()
108 socket, err = Socket(windows.AF_INET, windows.SOCK_DGRAM, windows.IPPROTO_UDP)
109 if err != nil {
110 return
111 }
112 defer windows.CloseHandle(socket)
113 WSAID_MULTIPLE_RIO := &windows.GUID{0x8509e081, 0x96dd, 0x4005, [8]byte{0xb1, 0x65, 0x9e, 0x2e, 0xe8, 0xc7, 0x9e, 0x3f}}
114 const SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER = 0xc8000024
115 ob := uint32(0)
116 err = windows.WSAIoctl(socket, SIO_GET_MULTIPLE_EXTENSION_FUNCTION_POINTER,
117 (*byte)(unsafe.Pointer(WSAID_MULTIPLE_RIO)), uint32(unsafe.Sizeof(*WSAID_MULTIPLE_RIO)),
118 (*byte)(unsafe.Pointer(&extensionFunctionTable)), uint32(unsafe.Sizeof(extensionFunctionTable)),
119 &ob, nil, 0)
120 if err != nil {
121 return
122 }
123
124 // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
125 // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
126 var iocp windows.Handle
127 iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
128 if err != nil {
129 return
130 }
131 defer windows.CloseHandle(iocp)
132 var overlapped windows.Overlapped
133 cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
134 if err != nil {
135 return
136 }
137 defer CloseCompletionQueue(cq)
138 _, err = CreateRequestQueue(socket, 1, 1, 1, 1, cq, cq, 0)
139 if err != nil {
140 return
141 }
142 available = true
143 })
144 return available
145 }
146
147 func Socket(af, typ, proto int32) (windows.Handle, error) {
148 return windows.WSASocket(af, typ, proto, nil, 0, windows.WSA_FLAG_REGISTERED_IO)
149 }
150
151 func CloseCompletionQueue(cq Cq) {
152 _, _, _ = syscall.Syscall(extensionFunctionTable.rioCloseCompletionQueue, 1, uintptr(cq), 0, 0)
153 }
154
155 func CreateEventCompletionQueue(queueSize uint32, event windows.Handle, notifyReset bool) (Cq, error) {
156 notificationCompletion := &eventNotificationCompletion{
157 completionType: eventCompletion,
158 event: event,
159 }
160 if notifyReset {
161 notificationCompletion.notifyReset = 1
162 }
163 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
164 if ret == invalidCq {
165 return 0, err
166 }
167 return Cq(ret), nil
168 }
169
170 func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintptr, overlapped *windows.Overlapped) (Cq, error) {
171 notificationCompletion := &iocpNotificationCompletion{
172 completionType: iocpCompletion,
173 iocp: iocp,
174 key: key,
175 overlapped: overlapped,
176 }
177 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)
178 if ret == invalidCq {
179 return 0, err
180 }
181 return Cq(ret), nil
182 }
183
184 func CreatePolledCompletionQueue(queueSize uint32) (Cq, error) {
185 ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), 0, 0)
186 if ret == invalidCq {
187 return 0, err
188 }
189 return Cq(ret), nil
190 }
191
192 func CreateRequestQueue(socket windows.Handle, maxOutstandingReceive, maxReceiveDataBuffers, maxOutstandingSend, maxSendDataBuffers uint32, receiveCq, sendCq Cq, socketContext uintptr) (Rq, error) {
193 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioCreateRequestQueue, 8, uintptr(socket), uintptr(maxOutstandingReceive), uintptr(maxReceiveDataBuffers), uintptr(maxOutstandingSend), uintptr(maxSendDataBuffers), uintptr(receiveCq), uintptr(sendCq), socketContext, 0)
194 if ret == invalidRq {
195 return 0, err
196 }
197 return Rq(ret), nil
198 }
199
200 func DequeueCompletion(cq Cq, results []Result) uint32 {
201 var array uintptr
202 if len(results) > 0 {
203 array = uintptr(unsafe.Pointer(&results[0]))
204 }
205 ret, _, _ := syscall.Syscall(extensionFunctionTable.rioDequeueCompletion, 3, uintptr(cq), array, uintptr(len(results)))
206 if ret == corruptCq {
207 panic("cq is corrupt")
208 }
209 return uint32(ret)
210 }
211
212 func DeregisterBuffer(id BufferId) {
213 _, _, _ = syscall.Syscall(extensionFunctionTable.rioDeregisterBuffer, 1, uintptr(id), 0, 0)
214 }
215
216 func RegisterBuffer(buffer []byte) (BufferId, error) {
217 var buf unsafe.Pointer
218 if len(buffer) > 0 {
219 buf = unsafe.Pointer(&buffer[0])
220 }
221 return RegisterPointer(buf, uint32(len(buffer)))
222 }
223
224 func RegisterPointer(ptr unsafe.Pointer, size uint32) (BufferId, error) {
225 ret, _, err := syscall.Syscall(extensionFunctionTable.rioRegisterBuffer, 2, uintptr(ptr), uintptr(size), 0)
226 if ret == invalidBufferId {
227 return 0, err
228 }
229 return BufferId(ret), nil
230 }
231
232 func SendEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
233 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioSendEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
234 if ret == 0 {
235 return err
236 }
237 return nil
238 }
239
240 func ReceiveEx(rq Rq, buf *Buffer, dataBufferCount uint32, localAddress, remoteAddress, controlContext, flags *Buffer, sflags uint32, requestContext uintptr) error {
241 ret, _, err := syscall.Syscall9(extensionFunctionTable.rioReceiveEx, 9, uintptr(rq), uintptr(unsafe.Pointer(buf)), uintptr(dataBufferCount), uintptr(unsafe.Pointer(localAddress)), uintptr(unsafe.Pointer(remoteAddress)), uintptr(unsafe.Pointer(controlContext)), uintptr(unsafe.Pointer(flags)), uintptr(sflags), requestContext)
242 if ret == 0 {
243 return err
244 }
245 return nil
246 }
247
248 func Notify(cq Cq) error {
249 ret, _, _ := syscall.Syscall(extensionFunctionTable.rioNotify, 1, uintptr(cq), 0, 0)
250 if ret != 0 {
251 return windows.Errno(ret)
252 }
253 return nil
254 }
255