sticky_linux.go raw
1 /* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2025 WireGuard LLC. All Rights Reserved.
4 *
5 * This implements userspace semantics of "sticky sockets", modeled after
6 * WireGuard's kernelspace implementation. This is more or less a straight port
7 * of the sticky-sockets.c example code:
8 * https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
9 *
10 * Currently there is no way to achieve this within the net package:
11 * See e.g. https://github.com/golang/go/issues/17930
12 * So this code remains platform dependent.
13 */
14
15 package device
16
17 import (
18 "sync"
19 "unsafe"
20
21 "golang.org/x/sys/unix"
22
23 "golang.zx2c4.com/wireguard/conn"
24 "golang.zx2c4.com/wireguard/rwcancel"
25 )
26
27 func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
28 if !conn.StdNetSupportsStickySockets {
29 return nil, nil
30 }
31 if _, ok := bind.(*conn.StdNetBind); !ok {
32 return nil, nil
33 }
34
35 netlinkSock, err := createNetlinkRouteSocket()
36 if err != nil {
37 return nil, err
38 }
39 netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
40 if err != nil {
41 unix.Close(netlinkSock)
42 return nil, err
43 }
44
45 go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
46
47 return netlinkCancel, nil
48 }
49
50 func (device *Device) routineRouteListener(_ conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
51 type peerEndpointPtr struct {
52 peer *Peer
53 endpoint *conn.Endpoint
54 }
55 var reqPeer map[uint32]peerEndpointPtr
56 var reqPeerLock sync.Mutex
57
58 defer netlinkCancel.Close()
59 defer unix.Close(netlinkSock)
60
61 for msg := make([]byte, 1<<16); ; {
62 var err error
63 var msgn int
64 for {
65 msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
66 if err == nil || !rwcancel.RetryAfterError(err) {
67 break
68 }
69 if !netlinkCancel.ReadyRead() {
70 return
71 }
72 }
73 if err != nil {
74 return
75 }
76
77 for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
78
79 hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
80
81 if uint(hdr.Len) > uint(len(remain)) {
82 break
83 }
84
85 switch hdr.Type {
86 case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
87 if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
88 if uint(len(remain)) < uint(hdr.Len) {
89 break
90 }
91 if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
92 attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
93 for {
94 if uint(len(attr)) < uint(unix.SizeofRtAttr) {
95 break
96 }
97 attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
98 if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
99 break
100 }
101 if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
102 ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
103 reqPeerLock.Lock()
104 if reqPeer == nil {
105 reqPeerLock.Unlock()
106 break
107 }
108 pePtr, ok := reqPeer[hdr.Seq]
109 reqPeerLock.Unlock()
110 if !ok {
111 break
112 }
113 pePtr.peer.endpoint.Lock()
114 if &pePtr.peer.endpoint.val != pePtr.endpoint {
115 pePtr.peer.endpoint.Unlock()
116 break
117 }
118 if uint32(pePtr.peer.endpoint.val.(*conn.StdNetEndpoint).SrcIfidx()) == ifidx {
119 pePtr.peer.endpoint.Unlock()
120 break
121 }
122 pePtr.peer.endpoint.clearSrcOnTx = true
123 pePtr.peer.endpoint.Unlock()
124 }
125 attr = attr[attrhdr.Len:]
126 }
127 }
128 break
129 }
130 reqPeerLock.Lock()
131 reqPeer = make(map[uint32]peerEndpointPtr)
132 reqPeerLock.Unlock()
133 go func() {
134 device.peers.RLock()
135 i := uint32(1)
136 for _, peer := range device.peers.keyMap {
137 peer.endpoint.Lock()
138 if peer.endpoint.val == nil {
139 peer.endpoint.Unlock()
140 continue
141 }
142 nativeEP, _ := peer.endpoint.val.(*conn.StdNetEndpoint)
143 if nativeEP == nil {
144 peer.endpoint.Unlock()
145 continue
146 }
147 if nativeEP.DstIP().Is6() || nativeEP.SrcIfidx() == 0 {
148 peer.endpoint.Unlock()
149 break
150 }
151 nlmsg := struct {
152 hdr unix.NlMsghdr
153 msg unix.RtMsg
154 dsthdr unix.RtAttr
155 dst [4]byte
156 srchdr unix.RtAttr
157 src [4]byte
158 markhdr unix.RtAttr
159 mark uint32
160 }{
161 unix.NlMsghdr{
162 Type: uint16(unix.RTM_GETROUTE),
163 Flags: unix.NLM_F_REQUEST,
164 Seq: i,
165 },
166 unix.RtMsg{
167 Family: unix.AF_INET,
168 Dst_len: 32,
169 Src_len: 32,
170 },
171 unix.RtAttr{
172 Len: 8,
173 Type: unix.RTA_DST,
174 },
175 nativeEP.DstIP().As4(),
176 unix.RtAttr{
177 Len: 8,
178 Type: unix.RTA_SRC,
179 },
180 nativeEP.SrcIP().As4(),
181 unix.RtAttr{
182 Len: 8,
183 Type: unix.RTA_MARK,
184 },
185 device.net.fwmark,
186 }
187 nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
188 reqPeerLock.Lock()
189 reqPeer[i] = peerEndpointPtr{
190 peer: peer,
191 endpoint: &peer.endpoint.val,
192 }
193 reqPeerLock.Unlock()
194 peer.endpoint.Unlock()
195 i++
196 _, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
197 if err != nil {
198 break
199 }
200 }
201 device.peers.RUnlock()
202 }()
203 }
204 remain = remain[hdr.Len:]
205 }
206 }
207 }
208
209 func createNetlinkRouteSocket() (int, error) {
210 sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW|unix.SOCK_CLOEXEC, unix.NETLINK_ROUTE)
211 if err != nil {
212 return -1, err
213 }
214 saddr := &unix.SockaddrNetlink{
215 Family: unix.AF_NETLINK,
216 Groups: unix.RTMGRP_IPV4_ROUTE,
217 }
218 err = unix.Bind(sock, saddr)
219 if err != nil {
220 unix.Close(sock)
221 return -1, err
222 }
223 return sock, nil
224 }
225