bind_std.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   */
   5  
   6  package conn
   7  
   8  import (
   9  	"context"
  10  	"errors"
  11  	"fmt"
  12  	"net"
  13  	"net/netip"
  14  	"runtime"
  15  	"strconv"
  16  	"sync"
  17  	"syscall"
  18  
  19  	"golang.org/x/net/ipv4"
  20  	"golang.org/x/net/ipv6"
  21  )
  22  
  23  var (
  24  	_ Bind = (*StdNetBind)(nil)
  25  )
  26  
  27  // StdNetBind implements Bind for all platforms. While Windows has its own Bind
  28  // (see bind_windows.go), it may fall back to StdNetBind.
  29  // TODO: Remove usage of ipv{4,6}.PacketConn when net.UDPConn has comparable
  30  // methods for sending and receiving multiple datagrams per-syscall. See the
  31  // proposal in https://github.com/golang/go/issues/45886#issuecomment-1218301564.
  32  type StdNetBind struct {
  33  	mu            sync.Mutex // protects all fields except as specified
  34  	ipv4          *net.UDPConn
  35  	ipv6          *net.UDPConn
  36  	ipv4PC        *ipv4.PacketConn // will be nil on non-Linux
  37  	ipv6PC        *ipv6.PacketConn // will be nil on non-Linux
  38  	ipv4TxOffload bool
  39  	ipv4RxOffload bool
  40  	ipv6TxOffload bool
  41  	ipv6RxOffload bool
  42  
  43  	// these two fields are not guarded by mu
  44  	udpAddrPool sync.Pool
  45  	msgsPool    sync.Pool
  46  
  47  	blackhole4 bool
  48  	blackhole6 bool
  49  }
  50  
  51  func NewStdNetBind() Bind {
  52  	return &StdNetBind{
  53  		udpAddrPool: sync.Pool{
  54  			New: func() any {
  55  				return &net.UDPAddr{
  56  					IP: make([]byte, 16),
  57  				}
  58  			},
  59  		},
  60  
  61  		msgsPool: sync.Pool{
  62  			New: func() any {
  63  				// ipv6.Message and ipv4.Message are interchangeable as they are
  64  				// both aliases for x/net/internal/socket.Message.
  65  				msgs := make([]ipv6.Message, IdealBatchSize)
  66  				for i := range msgs {
  67  					msgs[i].Buffers = make(net.Buffers, 1)
  68  					msgs[i].OOB = make([]byte, 0, stickyControlSize+gsoControlSize)
  69  				}
  70  				return &msgs
  71  			},
  72  		},
  73  	}
  74  }
  75  
  76  type StdNetEndpoint struct {
  77  	// AddrPort is the endpoint destination.
  78  	netip.AddrPort
  79  	// src is the current sticky source address and interface index, if
  80  	// supported. Typically this is a PKTINFO structure from/for control
  81  	// messages, see unix.PKTINFO for an example.
  82  	src []byte
  83  }
  84  
  85  var (
  86  	_ Bind     = (*StdNetBind)(nil)
  87  	_ Endpoint = &StdNetEndpoint{}
  88  )
  89  
  90  func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
  91  	e, err := netip.ParseAddrPort(s)
  92  	if err != nil {
  93  		return nil, err
  94  	}
  95  	return &StdNetEndpoint{
  96  		AddrPort: e,
  97  	}, nil
  98  }
  99  
 100  func (e *StdNetEndpoint) ClearSrc() {
 101  	if e.src != nil {
 102  		// Truncate src, no need to reallocate.
 103  		e.src = e.src[:0]
 104  	}
 105  }
 106  
 107  func (e *StdNetEndpoint) DstIP() netip.Addr {
 108  	return e.AddrPort.Addr()
 109  }
 110  
 111  // See control_default,linux, etc for implementations of SrcIP and SrcIfidx.
 112  
 113  func (e *StdNetEndpoint) DstToBytes() []byte {
 114  	b, _ := e.AddrPort.MarshalBinary()
 115  	return b
 116  }
 117  
 118  func (e *StdNetEndpoint) DstToString() string {
 119  	return e.AddrPort.String()
 120  }
 121  
 122  func listenNet(network string, port int) (*net.UDPConn, int, error) {
 123  	conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
 124  	if err != nil {
 125  		return nil, 0, err
 126  	}
 127  
 128  	// Retrieve port.
 129  	laddr := conn.LocalAddr()
 130  	uaddr, err := net.ResolveUDPAddr(
 131  		laddr.Network(),
 132  		laddr.String(),
 133  	)
 134  	if err != nil {
 135  		return nil, 0, err
 136  	}
 137  	return conn.(*net.UDPConn), uaddr.Port, nil
 138  }
 139  
 140  func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
 141  	s.mu.Lock()
 142  	defer s.mu.Unlock()
 143  
 144  	var err error
 145  	var tries int
 146  
 147  	if s.ipv4 != nil || s.ipv6 != nil {
 148  		return nil, 0, ErrBindAlreadyOpen
 149  	}
 150  
 151  	// Attempt to open ipv4 and ipv6 listeners on the same port.
 152  	// If uport is 0, we can retry on failure.
 153  again:
 154  	port := int(uport)
 155  	var v4conn, v6conn *net.UDPConn
 156  	var v4pc *ipv4.PacketConn
 157  	var v6pc *ipv6.PacketConn
 158  
 159  	v4conn, port, err = listenNet("udp4", port)
 160  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
 161  		return nil, 0, err
 162  	}
 163  
 164  	// Listen on the same port as we're using for ipv4.
 165  	v6conn, port, err = listenNet("udp6", port)
 166  	if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
 167  		v4conn.Close()
 168  		tries++
 169  		goto again
 170  	}
 171  	if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
 172  		v4conn.Close()
 173  		return nil, 0, err
 174  	}
 175  	var fns []ReceiveFunc
 176  	if v4conn != nil {
 177  		s.ipv4TxOffload, s.ipv4RxOffload = supportsUDPOffload(v4conn)
 178  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
 179  			v4pc = ipv4.NewPacketConn(v4conn)
 180  			s.ipv4PC = v4pc
 181  		}
 182  		fns = append(fns, s.makeReceiveIPv4(v4pc, v4conn, s.ipv4RxOffload))
 183  		s.ipv4 = v4conn
 184  	}
 185  	if v6conn != nil {
 186  		s.ipv6TxOffload, s.ipv6RxOffload = supportsUDPOffload(v6conn)
 187  		if runtime.GOOS == "linux" || runtime.GOOS == "android" {
 188  			v6pc = ipv6.NewPacketConn(v6conn)
 189  			s.ipv6PC = v6pc
 190  		}
 191  		fns = append(fns, s.makeReceiveIPv6(v6pc, v6conn, s.ipv6RxOffload))
 192  		s.ipv6 = v6conn
 193  	}
 194  	if len(fns) == 0 {
 195  		return nil, 0, syscall.EAFNOSUPPORT
 196  	}
 197  
 198  	return fns, uint16(port), nil
 199  }
 200  
 201  func (s *StdNetBind) putMessages(msgs *[]ipv6.Message) {
 202  	for i := range *msgs {
 203  		(*msgs)[i].OOB = (*msgs)[i].OOB[:0]
 204  		(*msgs)[i] = ipv6.Message{Buffers: (*msgs)[i].Buffers, OOB: (*msgs)[i].OOB}
 205  	}
 206  	s.msgsPool.Put(msgs)
 207  }
 208  
 209  func (s *StdNetBind) getMessages() *[]ipv6.Message {
 210  	return s.msgsPool.Get().(*[]ipv6.Message)
 211  }
 212  
 213  var (
 214  	// If compilation fails here these are no longer the same underlying type.
 215  	_ ipv6.Message = ipv4.Message{}
 216  )
 217  
 218  type batchReader interface {
 219  	ReadBatch([]ipv6.Message, int) (int, error)
 220  }
 221  
 222  type batchWriter interface {
 223  	WriteBatch([]ipv6.Message, int) (int, error)
 224  }
 225  
 226  func (s *StdNetBind) receiveIP(
 227  	br batchReader,
 228  	conn *net.UDPConn,
 229  	rxOffload bool,
 230  	bufs [][]byte,
 231  	sizes []int,
 232  	eps []Endpoint,
 233  ) (n int, err error) {
 234  	msgs := s.getMessages()
 235  	for i := range bufs {
 236  		(*msgs)[i].Buffers[0] = bufs[i]
 237  		(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
 238  	}
 239  	defer s.putMessages(msgs)
 240  	var numMsgs int
 241  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
 242  		if rxOffload {
 243  			readAt := len(*msgs) - (IdealBatchSize / udpSegmentMaxDatagrams)
 244  			numMsgs, err = br.ReadBatch((*msgs)[readAt:], 0)
 245  			if err != nil {
 246  				return 0, err
 247  			}
 248  			numMsgs, err = splitCoalescedMessages(*msgs, readAt, getGSOSize)
 249  			if err != nil {
 250  				return 0, err
 251  			}
 252  		} else {
 253  			numMsgs, err = br.ReadBatch(*msgs, 0)
 254  			if err != nil {
 255  				return 0, err
 256  			}
 257  		}
 258  	} else {
 259  		msg := &(*msgs)[0]
 260  		msg.N, msg.NN, _, msg.Addr, err = conn.ReadMsgUDP(msg.Buffers[0], msg.OOB)
 261  		if err != nil {
 262  			return 0, err
 263  		}
 264  		numMsgs = 1
 265  	}
 266  	for i := 0; i < numMsgs; i++ {
 267  		msg := &(*msgs)[i]
 268  		sizes[i] = msg.N
 269  		if sizes[i] == 0 {
 270  			continue
 271  		}
 272  		addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
 273  		ep := &StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
 274  		getSrcFromControl(msg.OOB[:msg.NN], ep)
 275  		eps[i] = ep
 276  	}
 277  	return numMsgs, nil
 278  }
 279  
 280  func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
 281  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
 282  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
 283  	}
 284  }
 285  
 286  func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
 287  	return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
 288  		return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
 289  	}
 290  }
 291  
 292  // TODO: When all Binds handle IdealBatchSize, remove this dynamic function and
 293  // rename the IdealBatchSize constant to BatchSize.
 294  func (s *StdNetBind) BatchSize() int {
 295  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
 296  		return IdealBatchSize
 297  	}
 298  	return 1
 299  }
 300  
 301  func (s *StdNetBind) Close() error {
 302  	s.mu.Lock()
 303  	defer s.mu.Unlock()
 304  
 305  	var err1, err2 error
 306  	if s.ipv4 != nil {
 307  		err1 = s.ipv4.Close()
 308  		s.ipv4 = nil
 309  		s.ipv4PC = nil
 310  	}
 311  	if s.ipv6 != nil {
 312  		err2 = s.ipv6.Close()
 313  		s.ipv6 = nil
 314  		s.ipv6PC = nil
 315  	}
 316  	s.blackhole4 = false
 317  	s.blackhole6 = false
 318  	s.ipv4TxOffload = false
 319  	s.ipv4RxOffload = false
 320  	s.ipv6TxOffload = false
 321  	s.ipv6RxOffload = false
 322  	if err1 != nil {
 323  		return err1
 324  	}
 325  	return err2
 326  }
 327  
 328  type ErrUDPGSODisabled struct {
 329  	onLaddr  string
 330  	RetryErr error
 331  }
 332  
 333  func (e ErrUDPGSODisabled) Error() string {
 334  	return fmt.Sprintf("disabled UDP GSO on %s, NIC(s) may not support checksum offload", e.onLaddr)
 335  }
 336  
 337  func (e ErrUDPGSODisabled) Unwrap() error {
 338  	return e.RetryErr
 339  }
 340  
 341  func (s *StdNetBind) Send(bufs [][]byte, endpoint Endpoint) error {
 342  	s.mu.Lock()
 343  	blackhole := s.blackhole4
 344  	conn := s.ipv4
 345  	offload := s.ipv4TxOffload
 346  	br := batchWriter(s.ipv4PC)
 347  	is6 := false
 348  	if endpoint.DstIP().Is6() {
 349  		blackhole = s.blackhole6
 350  		conn = s.ipv6
 351  		br = s.ipv6PC
 352  		is6 = true
 353  		offload = s.ipv6TxOffload
 354  	}
 355  	s.mu.Unlock()
 356  
 357  	if blackhole {
 358  		return nil
 359  	}
 360  	if conn == nil {
 361  		return syscall.EAFNOSUPPORT
 362  	}
 363  
 364  	msgs := s.getMessages()
 365  	defer s.putMessages(msgs)
 366  	ua := s.udpAddrPool.Get().(*net.UDPAddr)
 367  	defer s.udpAddrPool.Put(ua)
 368  	if is6 {
 369  		as16 := endpoint.DstIP().As16()
 370  		copy(ua.IP, as16[:])
 371  		ua.IP = ua.IP[:16]
 372  	} else {
 373  		as4 := endpoint.DstIP().As4()
 374  		copy(ua.IP, as4[:])
 375  		ua.IP = ua.IP[:4]
 376  	}
 377  	ua.Port = int(endpoint.(*StdNetEndpoint).Port())
 378  	var (
 379  		retried bool
 380  		err     error
 381  	)
 382  retry:
 383  	if offload {
 384  		n := coalesceMessages(ua, endpoint.(*StdNetEndpoint), bufs, *msgs, setGSOSize)
 385  		err = s.send(conn, br, (*msgs)[:n])
 386  		if err != nil && offload && errShouldDisableUDPGSO(err) {
 387  			offload = false
 388  			s.mu.Lock()
 389  			if is6 {
 390  				s.ipv6TxOffload = false
 391  			} else {
 392  				s.ipv4TxOffload = false
 393  			}
 394  			s.mu.Unlock()
 395  			retried = true
 396  			goto retry
 397  		}
 398  	} else {
 399  		for i := range bufs {
 400  			(*msgs)[i].Addr = ua
 401  			(*msgs)[i].Buffers[0] = bufs[i]
 402  			setSrcControl(&(*msgs)[i].OOB, endpoint.(*StdNetEndpoint))
 403  		}
 404  		err = s.send(conn, br, (*msgs)[:len(bufs)])
 405  	}
 406  	if retried {
 407  		return ErrUDPGSODisabled{onLaddr: conn.LocalAddr().String(), RetryErr: err}
 408  	}
 409  	return err
 410  }
 411  
 412  func (s *StdNetBind) send(conn *net.UDPConn, pc batchWriter, msgs []ipv6.Message) error {
 413  	var (
 414  		n     int
 415  		err   error
 416  		start int
 417  	)
 418  	if runtime.GOOS == "linux" || runtime.GOOS == "android" {
 419  		for {
 420  			n, err = pc.WriteBatch(msgs[start:], 0)
 421  			if err != nil || n == len(msgs[start:]) {
 422  				break
 423  			}
 424  			start += n
 425  		}
 426  	} else {
 427  		for _, msg := range msgs {
 428  			_, _, err = conn.WriteMsgUDP(msg.Buffers[0], msg.OOB, msg.Addr.(*net.UDPAddr))
 429  			if err != nil {
 430  				break
 431  			}
 432  		}
 433  	}
 434  	return err
 435  }
 436  
 437  const (
 438  	// Exceeding these values results in EMSGSIZE. They account for layer3 and
 439  	// layer4 headers. IPv6 does not need to account for itself as the payload
 440  	// length field is self excluding.
 441  	maxIPv4PayloadLen = 1<<16 - 1 - 20 - 8
 442  	maxIPv6PayloadLen = 1<<16 - 1 - 8
 443  
 444  	// This is a hard limit imposed by the kernel.
 445  	udpSegmentMaxDatagrams = 64
 446  )
 447  
 448  type setGSOFunc func(control *[]byte, gsoSize uint16)
 449  
 450  func coalesceMessages(addr *net.UDPAddr, ep *StdNetEndpoint, bufs [][]byte, msgs []ipv6.Message, setGSO setGSOFunc) int {
 451  	var (
 452  		base     = -1 // index of msg we are currently coalescing into
 453  		gsoSize  int  // segmentation size of msgs[base]
 454  		dgramCnt int  // number of dgrams coalesced into msgs[base]
 455  		endBatch bool // tracking flag to start a new batch on next iteration of bufs
 456  	)
 457  	maxPayloadLen := maxIPv4PayloadLen
 458  	if ep.DstIP().Is6() {
 459  		maxPayloadLen = maxIPv6PayloadLen
 460  	}
 461  	for i, buf := range bufs {
 462  		if i > 0 {
 463  			msgLen := len(buf)
 464  			baseLenBefore := len(msgs[base].Buffers[0])
 465  			freeBaseCap := cap(msgs[base].Buffers[0]) - baseLenBefore
 466  			if msgLen+baseLenBefore <= maxPayloadLen &&
 467  				msgLen <= gsoSize &&
 468  				msgLen <= freeBaseCap &&
 469  				dgramCnt < udpSegmentMaxDatagrams &&
 470  				!endBatch {
 471  				msgs[base].Buffers[0] = append(msgs[base].Buffers[0], buf...)
 472  				if i == len(bufs)-1 {
 473  					setGSO(&msgs[base].OOB, uint16(gsoSize))
 474  				}
 475  				dgramCnt++
 476  				if msgLen < gsoSize {
 477  					// A smaller than gsoSize packet on the tail is legal, but
 478  					// it must end the batch.
 479  					endBatch = true
 480  				}
 481  				continue
 482  			}
 483  		}
 484  		if dgramCnt > 1 {
 485  			setGSO(&msgs[base].OOB, uint16(gsoSize))
 486  		}
 487  		// Reset prior to incrementing base since we are preparing to start a
 488  		// new potential batch.
 489  		endBatch = false
 490  		base++
 491  		gsoSize = len(buf)
 492  		setSrcControl(&msgs[base].OOB, ep)
 493  		msgs[base].Buffers[0] = buf
 494  		msgs[base].Addr = addr
 495  		dgramCnt = 1
 496  	}
 497  	return base + 1
 498  }
 499  
 500  type getGSOFunc func(control []byte) (int, error)
 501  
 502  func splitCoalescedMessages(msgs []ipv6.Message, firstMsgAt int, getGSO getGSOFunc) (n int, err error) {
 503  	for i := firstMsgAt; i < len(msgs); i++ {
 504  		msg := &msgs[i]
 505  		if msg.N == 0 {
 506  			return n, err
 507  		}
 508  		var (
 509  			gsoSize    int
 510  			start      int
 511  			end        = msg.N
 512  			numToSplit = 1
 513  		)
 514  		gsoSize, err = getGSO(msg.OOB[:msg.NN])
 515  		if err != nil {
 516  			return n, err
 517  		}
 518  		if gsoSize > 0 {
 519  			numToSplit = (msg.N + gsoSize - 1) / gsoSize
 520  			end = gsoSize
 521  		}
 522  		for j := 0; j < numToSplit; j++ {
 523  			if n > i {
 524  				return n, errors.New("splitting coalesced packet resulted in overflow")
 525  			}
 526  			copied := copy(msgs[n].Buffers[0], msg.Buffers[0][start:end])
 527  			msgs[n].N = copied
 528  			msgs[n].Addr = msg.Addr
 529  			start = end
 530  			end += gsoSize
 531  			if end > msg.N {
 532  				end = msg.N
 533  			}
 534  			n++
 535  		}
 536  		if i != n-1 {
 537  			// It is legal for bytes to move within msg.Buffers[0] as a result
 538  			// of splitting, so we only zero the source msg len when it is not
 539  			// the destination of the last split operation above.
 540  			msg.N = 0
 541  		}
 542  	}
 543  	return n, nil
 544  }
 545