bind_std.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 */
5
6 package conn
7
8 import (
9 "context"
10 "errors"
11 "fmt"
12 "net"
13 "net/netip"
14 "runtime"
15 "strconv"
16 "sync"
17 "syscall"
18
19 "golang.org/x/net/ipv4"
20 "golang.org/x/net/ipv6"
21 )
22
23 var (
24 _ Bind = (*StdNetBind)(nil)
25 )
26
27 // StdNetBind implements Bind for all platforms. While Windows has its own Bind
28 // (see bind_windows.go), it may fall back to StdNetBind.
29 // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
30 // methods for sending and receiving multiple datagrams per-syscall. See the
31 // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
32 type StdNetBind struct {
33 mu sync.Mutex // protects all fields except as specified
34 ipv4 *net.UDPConn
35 ipv6 *net.UDPConn
36 ipv4PC *ipv4.PacketConn // will be nil on non-Linux
37 ipv6PC *ipv6.PacketConn // will be nil on non-Linux
38 ipv4TxOffload bool
39 ipv4RxOffload bool
40 ipv6TxOffload bool
41 ipv6RxOffload bool
42
43 // these two fields are not guarded by mu
44 udpAddrPool sync.Pool
45 msgsPool sync.Pool
46
47 blackhole4 bool
48 blackhole6 bool
49 }
50
51 func NewStdNetBind() Bind {
52 return &StdNetBind{
53 udpAddrPool: sync.Pool{
54 New: func() any {
55 return &net.UDPAddr{
56 IP: make([]byte, 16),
57 }
58 },
59 },
60
61 msgsPool: sync.Pool{
62 New: func() any {
63 // ipv6.Message and ipv4.Message are interchangeable as they are
64 // both aliases for x/net/internal/socket.Message.
65 msgs := make([]ipv6.Message, IdealBatchSize)
66 for i := range msgs {
67 msgs[i].Buffers = make(net.Buffers, 1)
68 msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
69 }
70 return &msgs
71 },
72 },
73 }
74 }
75
76 type StdNetEndpoint struct {
77 // AddrPort is the endpoint destination.
78 netip.AddrPort
79 // src is the current sticky source address and interface index, if
80 // supported. Typically this is a PKTINFO structure from/for control
81 // messages, see unix.PKTINFO for an example.
82 src []byte
83 }
84
85 var (
86 _ Bind = (*StdNetBind)(nil)
87 _ Endpoint = &StdNetEndpoint{}
88 )
89
90 func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
91 e, err := netip.ParseAddrPort(s)
92 if err != nil {
93 return nil, err
94 }
95 return &StdNetEndpoint{
96 AddrPort: e,
97 }, nil
98 }
99
100 func (e *StdNetEndpoint) ClearSrc() {
101 if e.src != nil {
102 // Truncate src, no need to reallocate.
103 e.src = e.src[:0]
104 }
105 }
106
107 func (e *StdNetEndpoint) DstIP() netip.Addr {
108 return e.AddrPort.Addr()
109 }
110
111 // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
112
113 func (e *StdNetEndpoint) DstToBytes() []byte {
114 b, _ := e.AddrPort.MarshalBinary()
115 return b
116 }
117
118 func (e *StdNetEndpoint) DstToString() string {
119 return e.AddrPort.String()
120 }
121
122 func listenNet(network string, port int) (*net.UDPConn, int, error) {
123 conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
124 if err != nil {
125 return nil, 0, err
126 }
127
128 // Retrieve port.
129 laddr := conn.LocalAddr()
130 uaddr, err := net.ResolveUDPAddr(
131 laddr.Network(),
132 laddr.String(),
133 )
134 if err != nil {
135 return nil, 0, err
136 }
137 return conn.(*net.UDPConn), uaddr.Port, nil
138 }
139
140 func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
141 s.mu.Lock()
142 defer s.mu.Unlock()
143
144 var err error
145 var tries int
146
147 if s.ipv4 != nil || s.ipv6 != nil {
148 return nil, 0, ErrBindAlreadyOpen
149 }
150
151 // Attempt to open ipv4 and ipv6 listeners on the same port.
152 // If uport is 0, we can retry on failure.
153 again:
154 port := int(uport)
155 var v4conn, v6conn *net.UDPConn
156 var v4pc *ipv4.PacketConn
157 var v6pc *ipv6.PacketConn
158
159 v4conn, port, err = listenNet("udp4", port)
160 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
161 return nil, 0, err
162 }
163
164 // Listen on the same port as we're using for ipv4.
165 v6conn, port, err = listenNet("udp6", port)
166 if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
167 v4conn.Close()
168 tries++
169 goto again
170 }
171 if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
172 v4conn.Close()
173 return nil, 0, err
174 }
175 var fns []ReceiveFunc
176 if v4conn != nil {
177 s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
178 if runtime.GOOS == "linux" || runtime.GOOS == "android" {
179 v4pc = ipv4.NewPacketConn(v4conn)
180 s.ipv4PC = v4pc
181 }
182 fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
183 s.ipv4 = v4conn
184 }
185 if v6conn != nil {
186 s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
187 if runtime.GOOS == "linux" || runtime.GOOS == "android" {
188 v6pc = ipv6.NewPacketConn(v6conn)
189 s.ipv6PC = v6pc
190 }
191 fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
192 s.ipv6 = v6conn
193 }
194 if len(fns) == 0 {
195 return nil, 0, syscall.EAFNOSUPPORT
196 }
197
198 return fns, uint16(port), nil
199 }
200
201 func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
202 for i := range *msgs {
203 (*msgs)[i].OOB = (*msgs)[i].OOB[:0]
204 (*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
205 }
206 s.msgsPool.Put(msgs)
207 }
208
209 func (s *StdNetBind) getMessages() *[]ipv6.Message {
210 return s.msgsPool.Get().(*[]ipv6.Message)
211 }
212
213 var (
214 // If compilation fails here these are no longer the same underlying type.
215 _ ipv6.Message = ipv4.Message{}
216 )
217
218 type batchReader interface {
219 ReadBatch([]ipv6.Message, int) (int, error)
220 }
221
222 type batchWriter interface {
223 WriteBatch([]ipv6.Message, int) (int, error)
224 }
225
226 func (s *StdNetBind) receiveIP(
227 br batchReader,
228 conn *net.UDPConn,
229 rxOffload bool,
230 bufs [][]byte,
231 sizes []int,
232 eps []Endpoint,
233 ) (n int, err error) {
234 msgs := s.getMessages()
235 for i := range bufs {
236 (*msgs)[i].Buffers[0] = bufs[i]
237 (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
238 }
239 defer s.putMessages(msgs)
240 var numMsgs int
241 if runtime.GOOS == "linux" || runtime.GOOS == "android" {
242 if rxOffload {
243 readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
244 numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
245 if err != nil {
246 return 0, err
247 }
248 numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
249 if err != nil {
250 return 0, err
251 }
252 } else {
253 numMsgs, err = br.ReadBatch(*msgs, 0)
254 if err != nil {
255 return 0, err
256 }
257 }
258 } else {
259 msg := &(*msgs)[0]
260 msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
261 if err != nil {
262 return 0, err
263 }
264 numMsgs = 1
265 }
266 for i := 0; i < numMsgs; i++ {
267 msg := &(*msgs)[i]
268 sizes[i] = msg.N
269 if sizes[i] == 0 {
270 continue
271 }
272 addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
273 ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
274 getSrcFromControl(msg.OOB[:msg.NN], ep)
275 eps[i] = ep
276 }
277 return numMsgs, nil
278 }
279
280 func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
281 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
282 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
283 }
284 }
285
286 func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
287 return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
288 return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
289 }
290 }
291
292 // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
293 // rename the IdealBatchSize constant to BatchSize.
294 func (s *StdNetBind) BatchSize() int {
295 if runtime.GOOS == "linux" || runtime.GOOS == "android" {
296 return IdealBatchSize
297 }
298 return 1
299 }
300
301 func (s *StdNetBind) Close() error {
302 s.mu.Lock()
303 defer s.mu.Unlock()
304
305 var err1, err2 error
306 if s.ipv4 != nil {
307 err1 = s.ipv4.Close()
308 s.ipv4 = nil
309 s.ipv4PC = nil
310 }
311 if s.ipv6 != nil {
312 err2 = s.ipv6.Close()
313 s.ipv6 = nil
314 s.ipv6PC = nil
315 }
316 s.blackhole4 = false
317 s.blackhole6 = false
318 s.ipv4TxOffload = false
319 s.ipv4RxOffload = false
320 s.ipv6TxOffload = false
321 s.ipv6RxOffload = false
322 if err1 != nil {
323 return err1
324 }
325 return err2
326 }
327
328 type ErrUDPGSODisabled struct {
329 onLaddr string
330 RetryErr error
331 }
332
333 func (e ErrUDPGSODisabled) Error() string {
334 return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
335 }
336
337 func (e ErrUDPGSODisabled) Unwrap() error {
338 return e.RetryErr
339 }
340
341 func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
342 s.mu.Lock()
343 blackhole := s.blackhole4
344 conn := s.ipv4
345 offload := s.ipv4TxOffload
346 br := batchWriter(s.ipv4PC)
347 is6 := false
348 if endpoint.DstIP().Is6() {
349 blackhole = s.blackhole6
350 conn = s.ipv6
351 br = s.ipv6PC
352 is6 = true
353 offload = s.ipv6TxOffload
354 }
355 s.mu.Unlock()
356
357 if blackhole {
358 return nil
359 }
360 if conn == nil {
361 return syscall.EAFNOSUPPORT
362 }
363
364 msgs := s.getMessages()
365 defer s.putMessages(msgs)
366 ua := s.udpAddrPool.Get().(*net.UDPAddr)
367 defer s.udpAddrPool.Put(ua)
368 if is6 {
369 as16 := endpoint.DstIP().As16()
370 copy(ua.IP, as16[:])
371 ua.IP = ua.IP[:16]
372 } else {
373 as4 := endpoint.DstIP().As4()
374 copy(ua.IP, as4[:])
375 ua.IP = ua.IP[:4]
376 }
377 ua.Port = int(endpoint.(*StdNetEndpoint).Port())
378 var (
379 retried bool
380 err error
381 )
382 retry:
383 if offload {
384 n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
385 err = s.send(conn, br, (*msgs)[:n])
386 if err != nil && offload && errShouldDisableUDPGSO(err) {
387 offload = false
388 s.mu.Lock()
389 if is6 {
390 s.ipv6TxOffload = false
391 } else {
392 s.ipv4TxOffload = false
393 }
394 s.mu.Unlock()
395 retried = true
396 goto retry
397 }
398 } else {
399 for i := range bufs {
400 (*msgs)[i].Addr = ua
401 (*msgs)[i].Buffers[0] = bufs[i]
402 setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
403 }
404 err = s.send(conn, br, (*msgs)[:len(bufs)])
405 }
406 if retried {
407 return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
408 }
409 return err
410 }
411
412 func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
413 var (
414 n int
415 err error
416 start int
417 )
418 if runtime.GOOS == "linux" || runtime.GOOS == "android" {
419 for {
420 n, err = pc.WriteBatch(msgs[start:], 0)
421 if err != nil || n == len(msgs[start:]) {
422 break
423 }
424 start += n
425 }
426 } else {
427 for _, msg := range msgs {
428 _, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
429 if err != nil {
430 break
431 }
432 }
433 }
434 return err
435 }
436
437 const (
438 // Exceeding these values results in EMSGSIZE. They account for layer3 and
439 // layer4 headers. IPv6 does not need to account for itself as the payload
440 // length field is self excluding.
441 maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
442 maxIPv6PayloadLen = 1<<16 - 1 - 8
443
444 // This is a hard limit imposed by the kernel.
445 udpSegmentMaxDatagrams = 64
446 )
447
448 type setGSOFunc func(control *[]byte, gsoSize uint16)
449
450 func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
451 var (
452 base = -1 // index of msg we are currently coalescing into
453 gsoSize int // segmentation size of msgs[base]
454 dgramCnt int // number of dgrams coalesced into msgs[base]
455 endBatch bool // tracking flag to start a new batch on next iteration of bufs
456 )
457 maxPayloadLen := maxIPv4PayloadLen
458 if ep.DstIP().Is6() {
459 maxPayloadLen = maxIPv6PayloadLen
460 }
461 for i, buf := range bufs {
462 if i > 0 {
463 msgLen := len(buf)
464 baseLenBefore := len(msgs[base].Buffers[0])
465 freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
466 if msgLen+baseLenBefore <= maxPayloadLen &&
467 msgLen <= gsoSize &&
468 msgLen <= freeBaseCap &&
469 dgramCnt < udpSegmentMaxDatagrams &&
470 !endBatch {
471 msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
472 if i == len(bufs)-1 {
473 setGSO(&msgs[base].OOB, uint16(gsoSize))
474 }
475 dgramCnt++
476 if msgLen < gsoSize {
477 // A smaller than gsoSize packet on the tail is legal, but
478 // it must end the batch.
479 endBatch = true
480 }
481 continue
482 }
483 }
484 if dgramCnt > 1 {
485 setGSO(&msgs[base].OOB, uint16(gsoSize))
486 }
487 // Reset prior to incrementing base since we are preparing to start a
488 // new potential batch.
489 endBatch = false
490 base++
491 gsoSize = len(buf)
492 setSrcControl(&msgs[base].OOB, ep)
493 msgs[base].Buffers[0] = buf
494 msgs[base].Addr = addr
495 dgramCnt = 1
496 }
497 return base + 1
498 }
499
500 type getGSOFunc func(control []byte) (int, error)
501
502 func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
503 for i := firstMsgAt; i < len(msgs); i++ {
504 msg := &msgs[i]
505 if msg.N == 0 {
506 return n, err
507 }
508 var (
509 gsoSize int
510 start int
511 end = msg.N
512 numToSplit = 1
513 )
514 gsoSize, err = getGSO(msg.OOB[:msg.NN])
515 if err != nil {
516 return n, err
517 }
518 if gsoSize > 0 {
519 numToSplit = (msg.N + gsoSize - 1) / gsoSize
520 end = gsoSize
521 }
522 for j := 0; j < numToSplit; j++ {
523 if n > i {
524 return n, errors.New("splitting coalesced packet resulted in overflow")
525 }
526 copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
527 msgs[n].N = copied
528 msgs[n].Addr = msg.Addr
529 start = end
530 end += gsoSize
531 if end > msg.N {
532 end = msg.N
533 }
534 n++
535 }
536 if i != n-1 {
537 // It is legal for bytes to move within msg.Buffers[0] as a result
538 // of splitting, so we only zero the source msg len when it is not
539 // the destination of the last split operation above.
540 msg.N = 0
541 }
542 }
543 return n, nil
544 }
545