sticky_linux.go raw

   1  /* SPDX-License-Identifier: MIT
   2   *
   3   * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
   4   *
   5   * This implements userspace semantics of "sticky sockets", modeled after
   6   * WireGuard's kernelspace implementation. This is more or less a straight port
   7   * of the sticky-sockets.c example code:
   8   * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
   9   *
  10   * Currently there is no way to achieve this within the net package:
  11   * See e.g. https://github.com/golang/go/issues/17930
  12   * So this code remains platform dependent.
  13   */
  14  
  15  package device
  16  
  17  import (
  18  	"sync"
  19  	"unsafe"
  20  
  21  	"golang.org/x/sys/unix"
  22  
  23  	"golang.zx2c4.com/wireguard/conn"
  24  	"golang.zx2c4.com/wireguard/rwcancel"
  25  )
  26  
  27  func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
  28  	if !conn.StdNetSupportsStickySockets {
  29  		return nil, nil
  30  	}
  31  	if _, ok := bind.(*conn.StdNetBind); !ok {
  32  		return nil, nil
  33  	}
  34  
  35  	netlinkSock, err := createNetlinkRouteSocket()
  36  	if err != nil {
  37  		return nil, err
  38  	}
  39  	netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
  40  	if err != nil {
  41  		unix.Close(netlinkSock)
  42  		return nil, err
  43  	}
  44  
  45  	go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
  46  
  47  	return netlinkCancel, nil
  48  }
  49  
  50  func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
  51  	type peerEndpointPtr struct {
  52  		peer     *Peer
  53  		endpoint *conn.Endpoint
  54  	}
  55  	var reqPeer map[uint32]peerEndpointPtr
  56  	var reqPeerLock sync.Mutex
  57  
  58  	defer netlinkCancel.Close()
  59  	defer unix.Close(netlinkSock)
  60  
  61  	for msg := make([]byte, 1<<16); ; {
  62  		var err error
  63  		var msgn int
  64  		for {
  65  			msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
  66  			if err == nil || !rwcancel.RetryAfterError(err) {
  67  				break
  68  			}
  69  			if !netlinkCancel.ReadyRead() {
  70  				return
  71  			}
  72  		}
  73  		if err != nil {
  74  			return
  75  		}
  76  
  77  		for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
  78  
  79  			hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
  80  
  81  			if uint(hdr.Len) > uint(len(remain)) {
  82  				break
  83  			}
  84  
  85  			switch hdr.Type {
  86  			case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
  87  				if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
  88  					if uint(len(remain)) < uint(hdr.Len) {
  89  						break
  90  					}
  91  					if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
  92  						attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
  93  						for {
  94  							if uint(len(attr)) < uint(unix.SizeofRtAttr) {
  95  								break
  96  							}
  97  							attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
  98  							if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
  99  								break
 100  							}
 101  							if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
 102  								ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
 103  								reqPeerLock.Lock()
 104  								if reqPeer == nil {
 105  									reqPeerLock.Unlock()
 106  									break
 107  								}
 108  								pePtr, ok := reqPeer[hdr.Seq]
 109  								reqPeerLock.Unlock()
 110  								if !ok {
 111  									break
 112  								}
 113  								pePtr.peer.endpoint.Lock()
 114  								if &pePtr.peer.endpoint.val != pePtr.endpoint {
 115  									pePtr.peer.endpoint.Unlock()
 116  									break
 117  								}
 118  								if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
 119  									pePtr.peer.endpoint.Unlock()
 120  									break
 121  								}
 122  								pePtr.peer.endpoint.clearSrcOnTx = true
 123  								pePtr.peer.endpoint.Unlock()
 124  							}
 125  							attr = attr[attrhdr.Len:]
 126  						}
 127  					}
 128  					break
 129  				}
 130  				reqPeerLock.Lock()
 131  				reqPeer = make(map[uint32]peerEndpointPtr)
 132  				reqPeerLock.Unlock()
 133  				go func() {
 134  					device.peers.RLock()
 135  					i := uint32(1)
 136  					for _, peer := range device.peers.keyMap {
 137  						peer.endpoint.Lock()
 138  						if peer.endpoint.val == nil {
 139  							peer.endpoint.Unlock()
 140  							continue
 141  						}
 142  						nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
 143  						if nativeEP == nil {
 144  							peer.endpoint.Unlock()
 145  							continue
 146  						}
 147  						if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
 148  							peer.endpoint.Unlock()
 149  							break
 150  						}
 151  						nlmsg := struct {
 152  							hdr     unix.NlMsghdr
 153  							msg     unix.RtMsg
 154  							dsthdr  unix.RtAttr
 155  							dst     [4]byte
 156  							srchdr  unix.RtAttr
 157  							src     [4]byte
 158  							markhdr unix.RtAttr
 159  							mark    uint32
 160  						}{
 161  							unix.NlMsghdr{
 162  								Type:  uint16(unix.RTM_GETROUTE),
 163  								Flags: unix.NLM_F_REQUEST,
 164  								Seq:   i,
 165  							},
 166  							unix.RtMsg{
 167  								Family:  unix.AF_INET,
 168  								Dst_len: 32,
 169  								Src_len: 32,
 170  							},
 171  							unix.RtAttr{
 172  								Len:  8,
 173  								Type: unix.RTA_DST,
 174  							},
 175  							nativeEP.DstIP().As4(),
 176  							unix.RtAttr{
 177  								Len:  8,
 178  								Type: unix.RTA_SRC,
 179  							},
 180  							nativeEP.SrcIP().As4(),
 181  							unix.RtAttr{
 182  								Len:  8,
 183  								Type: unix.RTA_MARK,
 184  							},
 185  							device.net.fwmark,
 186  						}
 187  						nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
 188  						reqPeerLock.Lock()
 189  						reqPeer[i] = peerEndpointPtr{
 190  							peer:     peer,
 191  							endpoint: &peer.endpoint.val,
 192  						}
 193  						reqPeerLock.Unlock()
 194  						peer.endpoint.Unlock()
 195  						i++
 196  						_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
 197  						if err != nil {
 198  							break
 199  						}
 200  					}
 201  					device.peers.RUnlock()
 202  				}()
 203  			}
 204  			remain = remain[hdr.Len:]
 205  		}
 206  	}
 207  }
 208  
 209  func createNetlinkRouteSocket() (int, error) {
 210  	sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
 211  	if err != nil {
 212  		return -1, err
 213  	}
 214  	saddr := &unix.SockaddrNetlink{
 215  		Family: unix.AF_NETLINK,
 216  		Groups: unix.RTMGRP_IPV4_ROUTE,
 217  	}
 218  	err = unix.Bind(sock, saddr)
 219  	if err != nil {
 220  		unix.Close(sock)
 221  		return -1, err
 222  	}
 223  	return sock, nil
 224  }
 225