1 // Copyright 2020 The gVisor Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 15 package stack
16 17 import (
18 "encoding/binary"
19 "fmt"
20 "math"
21 "math/rand"
22 "sync"
23 "time"
24 25 "gvisor.dev/gvisor/pkg/atomicbitops"
26 "gvisor.dev/gvisor/pkg/tcpip"
27 "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
28 "gvisor.dev/gvisor/pkg/tcpip/header"
29 "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack"
30 )
31 32 // Connection tracking is used to track and manipulate packets for NAT rules.
33 // The connection is created for a packet if it does not exist. Every
34 // connection contains two tuples (original and reply). The tuples are
35 // manipulated if there is a matching NAT rule. The packet is modified by
36 // looking at the tuples in each hook.
37 //
38 // Currently, only TCP tracking is supported.
39 40 // Our hash table has 16K buckets.
41 const numBuckets = 1 << 14
42 43 const (
44 establishedTimeout time.Duration = 5 * 24 * time.Hour
45 unestablishedTimeout time.Duration = 120 * time.Second
46 )
47 48 // tuple holds a connection's identifying and manipulating data in one
49 // direction. It is immutable.
50 //
51 // +stateify savable
52 type tuple struct {
53 // tupleEntry is used to build an intrusive list of tuples.
54 tupleEntry
55 56 // conn is the connection tracking entry this tuple belongs to.
57 conn *conn
58 59 // reply is true iff the tuple's direction is opposite that of the first
60 // packet seen on the connection.
61 reply bool
62 63 // tupleID is set at initialization and is immutable.
64 tupleID tupleID
65 }
66 67 // tupleID uniquely identifies a trackable connection in one direction.
68 //
69 // +stateify savable
70 type tupleID struct {
71 srcAddr tcpip.Address
72 // The source port of a packet in the original direction is overloaded with
73 // the ident of an Echo Request packet.
74 //
75 // This also matches the behaviour of sending packets on Linux where the
76 // socket's source port value is used for the source port of outgoing packets
77 // for TCP/UDP and the ident field for outgoing Echo Requests on Ping sockets:
78 //
79 // IPv4: https://github.com/torvalds/linux/blob/c5c17547b778975b3d83a73c8d84e8fb5ecf3ba5/net/ipv4/ping.c#L810
80 // IPv6: https://github.com/torvalds/linux/blob/c5c17547b778975b3d83a73c8d84e8fb5ecf3ba5/net/ipv6/ping.c#L133
81 srcPortOrEchoRequestIdent uint16
82 dstAddr tcpip.Address
83 // The opposite of srcPortOrEchoRequestIdent; the destination port of a packet
84 // in the reply direction is overloaded with the ident of an Echo Reply.
85 dstPortOrEchoReplyIdent uint16
86 transProto tcpip.TransportProtocolNumber
87 netProto tcpip.NetworkProtocolNumber
88 }
89 90 // reply creates the reply tupleID.
91 func (ti tupleID) reply() tupleID {
92 return tupleID{
93 srcAddr: ti.dstAddr,
94 srcPortOrEchoRequestIdent: ti.dstPortOrEchoReplyIdent,
95 dstAddr: ti.srcAddr,
96 dstPortOrEchoReplyIdent: ti.srcPortOrEchoRequestIdent,
97 transProto: ti.transProto,
98 netProto: ti.netProto,
99 }
100 }
101 102 type manipType int
103 104 const (
105 // manipNotPerformed indicates that NAT has not been performed.
106 manipNotPerformed manipType = iota
107 108 // manipPerformed indicates that NAT was performed.
109 manipPerformed
110 111 // manipPerformedNoop indicates that NAT was performed but it was a no-op.
112 manipPerformedNoop
113 )
114 115 type finalizeResult uint32
116 117 const (
118 // A finalizeResult must be explicitly set so we don't make use of the zero
119 // value.
120 _ finalizeResult = iota
121 122 finalizeResultSuccess
123 finalizeResultConflict
124 )
125 126 // conn is a tracked connection.
127 //
128 // +stateify savable
129 type conn struct {
130 ct *ConnTrack
131 132 // original is the tuple in original direction. It is immutable.
133 original tuple
134 135 // reply is the tuple in reply direction.
136 reply tuple
137 138 // TODO(b/341946753): Restore when netstack is savable.
139 finalizeOnce sync.Once `state:"nosave"`
140 // Holds a finalizeResult.
141 finalizeResult atomicbitops.Uint32
142 143 mu connRWMutex `state:"nosave"`
144 // sourceManip indicates the source manipulation type.
145 //
146 // +checklocks:mu
147 sourceManip manipType
148 // destinationManip indicates the destination's manipulation type.
149 //
150 // +checklocks:mu
151 destinationManip manipType
152 153 stateMu stateConnRWMutex `state:"nosave"`
154 // tcb is TCB control block. It is used to keep track of states
155 // of tcp connection.
156 //
157 // +checklocks:stateMu
158 tcb tcpconntrack.TCB
159 // lastUsed is the last time the connection saw a relevant packet, and
160 // is updated by each packet on the connection.
161 //
162 // +checklocks:stateMu
163 lastUsed tcpip.MonotonicTime
164 }
165 166 // timedOut returns whether the connection timed out based on its state.
167 func (cn *conn) timedOut(now tcpip.MonotonicTime) bool {
168 cn.stateMu.RLock()
169 defer cn.stateMu.RUnlock()
170 if cn.tcb.State() == tcpconntrack.ResultAlive {
171 // Use the same default as Linux, which doesn't delete
172 // established connections for 5(!) days.
173 return now.Sub(cn.lastUsed) > establishedTimeout
174 }
175 // Use the same default as Linux, which lets connections in most states
176 // other than established remain for <= 120 seconds.
177 return now.Sub(cn.lastUsed) > unestablishedTimeout
178 }
179 180 // update the connection tracking state.
181 func (cn *conn) update(pkt *PacketBuffer, reply bool) {
182 cn.stateMu.Lock()
183 defer cn.stateMu.Unlock()
184 185 // Mark the connection as having been used recently so it isn't reaped.
186 cn.lastUsed = cn.ct.clock.NowMonotonic()
187 188 if pkt.TransportProtocolNumber != header.TCPProtocolNumber {
189 return
190 }
191 192 tcpHeader := header.TCP(pkt.TransportHeader().Slice())
193 194 // Update the state of tcb. tcb assumes it's always initialized on the
195 // client. However, we only need to know whether the connection is
196 // established or not, so the client/server distinction isn't important.
197 if cn.tcb.IsEmpty() {
198 cn.tcb.Init(tcpHeader, pkt.Data().Size())
199 return
200 }
201 202 if reply {
203 cn.tcb.UpdateStateReply(tcpHeader, pkt.Data().Size())
204 } else {
205 cn.tcb.UpdateStateOriginal(tcpHeader, pkt.Data().Size())
206 }
207 }
208 209 // ConnTrack tracks all connections created for NAT rules. Most users are
210 // expected to only call handlePacket, insertRedirectConn, and maybeInsertNoop.
211 //
212 // ConnTrack keeps all connections in a slice of buckets, each of which holds a
213 // linked list of tuples. This gives us some desirable properties:
214 // - Each bucket has its own lock, lessening lock contention.
215 // - The slice is large enough that lists stay short (<10 elements on average).
216 // Thus traversal is fast.
217 // - During linked list traversal we reap expired connections. This amortizes
218 // the cost of reaping them and makes reapUnused faster.
219 //
220 // Locks are ordered by their location in the buckets slice. That is, a
221 // goroutine that locks buckets[i] can only lock buckets[j] s.t. i < j.
222 //
223 // +stateify savable
224 type ConnTrack struct {
225 // seed is a one-time random value initialized at stack startup
226 // and is used in the calculation of hash keys for the list of buckets.
227 // It is immutable.
228 seed uint32
229 230 // clock provides timing used to determine conntrack reapings.
231 clock tcpip.Clock
232 // TODO(b/341946753): Restore when netstack is savable.
233 rand *rand.Rand `state:"nosave"`
234 235 mu connTrackRWMutex `state:"nosave"`
236 // mu protects the buckets slice, but not buckets' contents. Only take
237 // the write lock if you are modifying the slice or saving for S/R.
238 //
239 // +checklocks:mu
240 buckets []bucket
241 }
242 243 // +stateify savable
244 type bucket struct {
245 mu bucketRWMutex `state:"nosave"`
246 // +checklocks:mu
247 tuples tupleList
248 }
249 250 // A netAndTransHeadersFunc returns the network and transport headers found
251 // in an ICMP payload. The transport layer's payload will not be returned.
252 //
253 // May panic if the packet does not hold the transport header.
254 type netAndTransHeadersFunc func(icmpPayload []byte, minTransHdrLen int) (netHdr header.Network, transHdrBytes []byte)
255 256 func v4NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
257 netHdr := header.IPv4(icmpPayload)
258 // Do not use netHdr.Payload() as we might not hold the full packet
259 // in the ICMP error; Payload() panics if the buffer is smaller than
260 // the total length specified in the IPv4 header.
261 transHdr := icmpPayload[netHdr.HeaderLength():]
262 return netHdr, transHdr[:minTransHdrLen]
263 }
264 265 func v6NetAndTransHdr(icmpPayload []byte, minTransHdrLen int) (header.Network, []byte) {
266 netHdr := header.IPv6(icmpPayload)
267 // Do not use netHdr.Payload() as we might not hold the full packet
268 // in the ICMP error; Payload() panics if the IP payload is smaller than
269 // the payload length specified in the IPv6 header.
270 transHdr := icmpPayload[header.IPv6MinimumSize:]
271 return netHdr, transHdr[:minTransHdrLen]
272 }
273 274 func getEmbeddedNetAndTransHeaders(pkt *PacketBuffer, netHdrLength int, getNetAndTransHdr netAndTransHeadersFunc, transProto tcpip.TransportProtocolNumber) (header.Network, header.ChecksummableTransport, bool) {
275 switch transProto {
276 case header.TCPProtocolNumber:
277 if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.TCPMinimumSize); ok {
278 netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.TCPMinimumSize)
279 return netHeader, header.TCP(transHeaderBytes), true
280 }
281 case header.UDPProtocolNumber:
282 if netAndTransHeader, ok := pkt.Data().PullUp(netHdrLength + header.UDPMinimumSize); ok {
283 netHeader, transHeaderBytes := getNetAndTransHdr(netAndTransHeader, header.UDPMinimumSize)
284 return netHeader, header.UDP(transHeaderBytes), true
285 }
286 }
287 return nil, nil, false
288 }
289 290 func getHeaders(pkt *PacketBuffer) (netHdr header.Network, transHdr header.Transport, isICMPError bool, ok bool) {
291 switch pkt.TransportProtocolNumber {
292 case header.TCPProtocolNumber:
293 if tcpHeader := header.TCP(pkt.TransportHeader().Slice()); len(tcpHeader) >= header.TCPMinimumSize {
294 return pkt.Network(), tcpHeader, false, true
295 }
296 return nil, nil, false, false
297 case header.UDPProtocolNumber:
298 if udpHeader := header.UDP(pkt.TransportHeader().Slice()); len(udpHeader) >= header.UDPMinimumSize {
299 return pkt.Network(), udpHeader, false, true
300 }
301 return nil, nil, false, false
302 case header.ICMPv4ProtocolNumber:
303 icmpHeader := header.ICMPv4(pkt.TransportHeader().Slice())
304 if len(icmpHeader) < header.ICMPv4MinimumSize {
305 return nil, nil, false, false
306 }
307 308 switch icmpType := icmpHeader.Type(); icmpType {
309 case header.ICMPv4Echo, header.ICMPv4EchoReply:
310 return pkt.Network(), icmpHeader, false, true
311 case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem:
312 default:
313 panic(fmt.Sprintf("unexpected ICMPv4 type = %d", icmpType))
314 }
315 316 h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
317 if !ok {
318 panic(fmt.Sprintf("should have a valid IPv4 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv4MinimumSize))
319 }
320 321 if header.IPv4(h).HeaderLength() > header.IPv4MinimumSize {
322 // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
323 panic("should have dropped packets with IPv4 options")
324 }
325 326 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv4MinimumSize, v4NetAndTransHdr, pkt.tuple.tupleID.transProto); ok {
327 return netHdr, transHdr, true, true
328 }
329 return nil, nil, false, false
330 case header.ICMPv6ProtocolNumber:
331 icmpHeader := header.ICMPv6(pkt.TransportHeader().Slice())
332 if len(icmpHeader) < header.ICMPv6MinimumSize {
333 return nil, nil, false, false
334 }
335 336 switch icmpType := icmpHeader.Type(); icmpType {
337 case header.ICMPv6EchoRequest, header.ICMPv6EchoReply:
338 return pkt.Network(), icmpHeader, false, true
339 case header.ICMPv6DstUnreachable, header.ICMPv6PacketTooBig, header.ICMPv6TimeExceeded, header.ICMPv6ParamProblem:
340 default:
341 panic(fmt.Sprintf("unexpected ICMPv6 type = %d", icmpType))
342 }
343 344 h, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
345 if !ok {
346 panic(fmt.Sprintf("should have a valid IPv6 packet; only have %d bytes, want at least %d bytes", pkt.Data().Size(), header.IPv6MinimumSize))
347 }
348 349 // We do not support extension headers in ICMP errors so the next header
350 // in the IPv6 packet should be a tracked protocol if we reach this point.
351 //
352 // TODO(https://gvisor.dev/issue/6789): Support extension headers.
353 transProto := pkt.tuple.tupleID.transProto
354 if got := header.IPv6(h).TransportProtocol(); got != transProto {
355 panic(fmt.Sprintf("got TransportProtocol() = %d, want = %d", got, transProto))
356 }
357 358 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, header.IPv6MinimumSize, v6NetAndTransHdr, transProto); ok {
359 return netHdr, transHdr, true, true
360 }
361 return nil, nil, false, false
362 default:
363 panic(fmt.Sprintf("unexpected transport protocol = %d", pkt.TransportProtocolNumber))
364 }
365 }
366 367 func getTupleIDForRegularPacket(netHdr header.Network, netProto tcpip.NetworkProtocolNumber, transHdr header.Transport, transProto tcpip.TransportProtocolNumber) tupleID {
368 return tupleID{
369 srcAddr: netHdr.SourceAddress(),
370 srcPortOrEchoRequestIdent: transHdr.SourcePort(),
371 dstAddr: netHdr.DestinationAddress(),
372 dstPortOrEchoReplyIdent: transHdr.DestinationPort(),
373 transProto: transProto,
374 netProto: netProto,
375 }
376 }
377 378 func getTupleIDForPacketInICMPError(pkt *PacketBuffer, getNetAndTransHdr netAndTransHeadersFunc, netProto tcpip.NetworkProtocolNumber, netLen int, transProto tcpip.TransportProtocolNumber) (tupleID, bool) {
379 if netHdr, transHdr, ok := getEmbeddedNetAndTransHeaders(pkt, netLen, getNetAndTransHdr, transProto); ok {
380 return tupleID{
381 srcAddr: netHdr.DestinationAddress(),
382 srcPortOrEchoRequestIdent: transHdr.DestinationPort(),
383 dstAddr: netHdr.SourceAddress(),
384 dstPortOrEchoReplyIdent: transHdr.SourcePort(),
385 transProto: transProto,
386 netProto: netProto,
387 }, true
388 }
389 390 return tupleID{}, false
391 }
392 393 type getTupleIDDisposition int
394 395 const (
396 getTupleIDNotOK getTupleIDDisposition = iota
397 getTupleIDOKAndAllowNewConn
398 getTupleIDOKAndDontAllowNewConn
399 )
400 401 func getTupleIDForEchoPacket(pkt *PacketBuffer, ident uint16, request bool) tupleID {
402 netHdr := pkt.Network()
403 tid := tupleID{
404 srcAddr: netHdr.SourceAddress(),
405 dstAddr: netHdr.DestinationAddress(),
406 transProto: pkt.TransportProtocolNumber,
407 netProto: pkt.NetworkProtocolNumber,
408 }
409 410 if request {
411 tid.srcPortOrEchoRequestIdent = ident
412 } else {
413 tid.dstPortOrEchoReplyIdent = ident
414 }
415 416 return tid
417 }
418 419 func getTupleID(pkt *PacketBuffer) (tupleID, getTupleIDDisposition) {
420 switch pkt.TransportProtocolNumber {
421 case header.TCPProtocolNumber:
422 if transHeader := header.TCP(pkt.TransportHeader().Slice()); len(transHeader) >= header.TCPMinimumSize {
423 return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), getTupleIDOKAndAllowNewConn
424 }
425 case header.UDPProtocolNumber:
426 if transHeader := header.UDP(pkt.TransportHeader().Slice()); len(transHeader) >= header.UDPMinimumSize {
427 return getTupleIDForRegularPacket(pkt.Network(), pkt.NetworkProtocolNumber, transHeader, pkt.TransportProtocolNumber), getTupleIDOKAndAllowNewConn
428 }
429 case header.ICMPv4ProtocolNumber:
430 icmp := header.ICMPv4(pkt.TransportHeader().Slice())
431 if len(icmp) < header.ICMPv4MinimumSize {
432 return tupleID{}, getTupleIDNotOK
433 }
434 435 switch icmp.Type() {
436 case header.ICMPv4Echo:
437 return getTupleIDForEchoPacket(pkt, icmp.Ident(), true /* request */), getTupleIDOKAndAllowNewConn
438 case header.ICMPv4EchoReply:
439 // Do not create a new connection in response to a reply packet as only
440 // the first packet of a connection should create a conntrack entry but
441 // a reply is never the first packet sent for a connection.
442 return getTupleIDForEchoPacket(pkt, icmp.Ident(), false /* request */), getTupleIDOKAndDontAllowNewConn
443 case header.ICMPv4DstUnreachable, header.ICMPv4TimeExceeded, header.ICMPv4ParamProblem:
444 default:
445 // Unsupported ICMP type for NAT-ing.
446 return tupleID{}, getTupleIDNotOK
447 }
448 449 h, ok := pkt.Data().PullUp(header.IPv4MinimumSize)
450 if !ok {
451 return tupleID{}, getTupleIDNotOK
452 }
453 454 ipv4 := header.IPv4(h)
455 if ipv4.HeaderLength() > header.IPv4MinimumSize {
456 // TODO(https://gvisor.dev/issue/6765): Handle IPv4 options.
457 return tupleID{}, getTupleIDNotOK
458 }
459 460 if tid, ok := getTupleIDForPacketInICMPError(pkt, v4NetAndTransHdr, header.IPv4ProtocolNumber, header.IPv4MinimumSize, ipv4.TransportProtocol()); ok {
461 // Do not create a new connection in response to an ICMP error.
462 return tid, getTupleIDOKAndDontAllowNewConn
463 }
464 case header.ICMPv6ProtocolNumber:
465 icmp := header.ICMPv6(pkt.TransportHeader().Slice())
466 if len(icmp) < header.ICMPv6MinimumSize {
467 return tupleID{}, getTupleIDNotOK
468 }
469 470 switch icmp.Type() {
471 case header.ICMPv6EchoRequest:
472 return getTupleIDForEchoPacket(pkt, icmp.Ident(), true /* request */), getTupleIDOKAndAllowNewConn
473 case header.ICMPv6EchoReply:
474 // Do not create a new connection in response to a reply packet as only
475 // the first packet of a connection should create a conntrack entry but
476 // a reply is never the first packet sent for a connection.
477 return getTupleIDForEchoPacket(pkt, icmp.Ident(), false /* request */), getTupleIDOKAndDontAllowNewConn
478 case header.ICMPv6DstUnreachable, header.ICMPv6PacketTooBig, header.ICMPv6TimeExceeded, header.ICMPv6ParamProblem:
479 default:
480 return tupleID{}, getTupleIDNotOK
481 }
482 483 h, ok := pkt.Data().PullUp(header.IPv6MinimumSize)
484 if !ok {
485 return tupleID{}, getTupleIDNotOK
486 }
487 488 // TODO(https://gvisor.dev/issue/6789): Handle extension headers.
489 if tid, ok := getTupleIDForPacketInICMPError(pkt, v6NetAndTransHdr, header.IPv6ProtocolNumber, header.IPv6MinimumSize, header.IPv6(h).TransportProtocol()); ok {
490 // Do not create a new connection in response to an ICMP error.
491 return tid, getTupleIDOKAndDontAllowNewConn
492 }
493 }
494 495 return tupleID{}, getTupleIDNotOK
496 }
497 498 func (ct *ConnTrack) init() {
499 ct.mu.Lock()
500 defer ct.mu.Unlock()
501 ct.buckets = make([]bucket, numBuckets)
502 }
503 504 // getConnAndUpdate attempts to get a connection or creates one if no
505 // connection exists for the packet and packet's protocol is trackable.
506 //
507 // If the packet's protocol is trackable, the connection's state is updated to
508 // match the contents of the packet.
509 func (ct *ConnTrack) getConnAndUpdate(pkt *PacketBuffer, skipChecksumValidation bool) *tuple {
510 // Get or (maybe) create a connection.
511 t := func() *tuple {
512 var allowNewConn bool
513 tid, res := getTupleID(pkt)
514 switch res {
515 case getTupleIDNotOK:
516 return nil
517 case getTupleIDOKAndAllowNewConn:
518 allowNewConn = true
519 case getTupleIDOKAndDontAllowNewConn:
520 allowNewConn = false
521 default:
522 panic(fmt.Sprintf("unhandled %[1]T = %[1]d", res))
523 }
524 525 // Just skip bad packets. They'll be rejected later by the appropriate
526 // protocol package.
527 switch pkt.TransportProtocolNumber {
528 case header.TCPProtocolNumber:
529 _, csumValid, ok := header.TCPValid(
530 header.TCP(pkt.TransportHeader().Slice()),
531 func() uint16 { return pkt.Data().Checksum() },
532 uint16(pkt.Data().Size()),
533 tid.srcAddr,
534 tid.dstAddr,
535 pkt.RXChecksumValidated || skipChecksumValidation)
536 if !csumValid || !ok {
537 return nil
538 }
539 case header.UDPProtocolNumber:
540 lengthValid, csumValid := header.UDPValid(
541 header.UDP(pkt.TransportHeader().Slice()),
542 func() uint16 { return pkt.Data().Checksum() },
543 uint16(pkt.Data().Size()),
544 pkt.NetworkProtocolNumber,
545 tid.srcAddr,
546 tid.dstAddr,
547 pkt.RXChecksumValidated || skipChecksumValidation)
548 if !lengthValid || !csumValid {
549 return nil
550 }
551 }
552 553 ct.mu.RLock()
554 bkt := &ct.buckets[ct.bucket(tid)]
555 ct.mu.RUnlock()
556 557 now := ct.clock.NowMonotonic()
558 if t := bkt.connForTID(tid, now); t != nil {
559 return t
560 }
561 562 if !allowNewConn {
563 return nil
564 }
565 566 bkt.mu.Lock()
567 defer bkt.mu.Unlock()
568 569 // Make sure a connection wasn't added between when we last checked the
570 // bucket and acquired the bucket's write lock.
571 if t := bkt.connForTIDRLocked(tid, now); t != nil {
572 return t
573 }
574 575 // This is the first packet we're seeing for the connection. Create an entry
576 // for this new connection.
577 conn := &conn{
578 ct: ct,
579 original: tuple{tupleID: tid},
580 reply: tuple{tupleID: tid.reply(), reply: true},
581 lastUsed: now,
582 }
583 conn.original.conn = conn
584 conn.reply.conn = conn
585 586 // For now, we only map an entry for the packet's original tuple as NAT may be
587 // performed on this connection. Until the packet goes through all the hooks
588 // and its final address/port is known, we cannot know what the response
589 // packet's addresses/ports will look like.
590 //
591 // This is okay because the destination cannot send its response until it
592 // receives the packet; the packet will only be received once all the hooks
593 // have been performed.
594 //
595 // See (*conn).finalize.
596 bkt.tuples.PushFront(&conn.original)
597 return &conn.original
598 }()
599 if t != nil {
600 t.conn.update(pkt, t.reply)
601 }
602 return t
603 }
604 605 func (ct *ConnTrack) connForTID(tid tupleID) *tuple {
606 ct.mu.RLock()
607 bkt := &ct.buckets[ct.bucket(tid)]
608 ct.mu.RUnlock()
609 610 return bkt.connForTID(tid, ct.clock.NowMonotonic())
611 }
612 613 func (bkt *bucket) connForTID(tid tupleID, now tcpip.MonotonicTime) *tuple {
614 bkt.mu.RLock()
615 defer bkt.mu.RUnlock()
616 return bkt.connForTIDRLocked(tid, now)
617 }
618 619 // +checklocksread:bkt.mu
620 func (bkt *bucket) connForTIDRLocked(tid tupleID, now tcpip.MonotonicTime) *tuple {
621 for other := bkt.tuples.Front(); other != nil; other = other.Next() {
622 if tid == other.tupleID && !other.conn.timedOut(now) {
623 return other
624 }
625 }
626 return nil
627 }
628 629 func (ct *ConnTrack) finalize(cn *conn) finalizeResult {
630 ct.mu.RLock()
631 buckets := ct.buckets
632 ct.mu.RUnlock()
633 634 {
635 tid := cn.reply.tupleID
636 id := ct.bucketWithTableLength(tid, len(buckets))
637 638 bkt := &buckets[id]
639 bkt.mu.Lock()
640 t := bkt.connForTIDRLocked(tid, ct.clock.NowMonotonic())
641 if t == nil {
642 bkt.tuples.PushFront(&cn.reply)
643 bkt.mu.Unlock()
644 return finalizeResultSuccess
645 }
646 bkt.mu.Unlock()
647 648 if t.conn == cn {
649 // We already have an entry for the reply tuple.
650 //
651 // This can occur when the source address/port is the same as the
652 // destination address/port. In this scenario, tid == tid.reply().
653 return finalizeResultSuccess
654 }
655 }
656 657 // Another connection for the reply already exists. Remove the original and
658 // let the caller know we failed.
659 //
660 // TODO(https://gvisor.dev/issue/6850): Investigate handling this clash
661 // better.
662 663 tid := cn.original.tupleID
664 id := ct.bucketWithTableLength(tid, len(buckets))
665 bkt := &buckets[id]
666 bkt.mu.Lock()
667 defer bkt.mu.Unlock()
668 bkt.tuples.Remove(&cn.original)
669 return finalizeResultConflict
670 }
671 672 func (cn *conn) getFinalizeResult() finalizeResult {
673 return finalizeResult(cn.finalizeResult.Load())
674 }
675 676 // finalize attempts to finalize the connection and returns true iff the
677 // connection was successfully finalized.
678 //
679 // If the connection failed to finalize, the caller should drop the packet
680 // associated with the connection.
681 //
682 // If multiple goroutines attempt to finalize at the same time, only one
683 // goroutine will perform the work to finalize the connection, but all
684 // goroutines will block until the finalizing goroutine finishes finalizing.
685 func (cn *conn) finalize() bool {
686 cn.finalizeOnce.Do(func() {
687 cn.finalizeResult.Store(uint32(cn.ct.finalize(cn)))
688 })
689 690 switch res := cn.getFinalizeResult(); res {
691 case finalizeResultSuccess:
692 return true
693 case finalizeResultConflict:
694 return false
695 default:
696 panic(fmt.Sprintf("unhandled result = %d", res))
697 }
698 }
699 700 // If NAT has not been configured for this connection, either mark the
701 // connection as configured for "no-op NAT", in the case of DNAT, or, in the
702 // case of SNAT, perform source port remapping so that source ports used by
703 // locally-generated traffic do not conflict with ports occupied by existing NAT
704 // bindings.
705 //
706 // Note that in the typical case this is also a no-op, because `snatAction`
707 // will do nothing if the original tuple is already unique.
708 func (cn *conn) maybePerformNoopNAT(pkt *PacketBuffer, hook Hook, r *Route, dnat bool) {
709 cn.mu.Lock()
710 var manip *manipType
711 if dnat {
712 manip = &cn.destinationManip
713 } else {
714 manip = &cn.sourceManip
715 }
716 if *manip != manipNotPerformed {
717 cn.mu.Unlock()
718 _ = cn.handlePacket(pkt, hook, r)
719 return
720 }
721 if dnat {
722 *manip = manipPerformedNoop
723 cn.mu.Unlock()
724 _ = cn.handlePacket(pkt, hook, r)
725 return
726 }
727 cn.mu.Unlock()
728 729 // At this point, we know that NAT has not yet been performed on this
730 // connection, and the DNAT case has been handled with a no-op. For SNAT, we
731 // simply perform source port remapping to ensure that source ports for
732 // locally generated traffic do not clash with ports used by existing NAT
733 // bindings.
734 _, _ = snatAction(pkt, hook, r, 0, tcpip.Address{}, true /* changePort */, false /* changeAddress */)
735 }
736 737 type portOrIdentRange struct {
738 start uint16
739 size uint32
740 }
741 742 // performNAT setups up the connection for the specified NAT and rewrites the
743 // packet.
744 //
745 // If NAT has already been performed on the connection, then the packet will
746 // be rewritten with the NAT performed on the connection, ignoring the passed
747 // address and port range.
748 //
749 // Generally, only the first packet of a connection reaches this method; other
750 // packets will be manipulated without needing to modify the connection.
751 func (cn *conn) performNAT(pkt *PacketBuffer, hook Hook, r *Route, portsOrIdents portOrIdentRange, natAddress tcpip.Address, dnat, changePort, changeAddress bool) {
752 lastPortOrIdent := func() uint16 {
753 lastPortOrIdent := uint32(portsOrIdents.start) + portsOrIdents.size - 1
754 if lastPortOrIdent > math.MaxUint16 {
755 panic(fmt.Sprintf("got lastPortOrIdent = %d, want <= MaxUint16(=%d); portsOrIdents=%#v", lastPortOrIdent, math.MaxUint16, portsOrIdents))
756 }
757 return uint16(lastPortOrIdent)
758 }()
759 760 // Make sure the packet is re-written after performing NAT.
761 defer func() {
762 // handlePacket returns true if the packet may skip the NAT table as the
763 // connection is already NATed, but if we reach this point we must be in the
764 // NAT table, so the return value is useless for us.
765 _ = cn.handlePacket(pkt, hook, r)
766 }()
767 768 cn.mu.Lock()
769 defer cn.mu.Unlock()
770 771 var manip *manipType
772 var address *tcpip.Address
773 var portOrIdent *uint16
774 if dnat {
775 manip = &cn.destinationManip
776 address = &cn.reply.tupleID.srcAddr
777 portOrIdent = &cn.reply.tupleID.srcPortOrEchoRequestIdent
778 } else {
779 manip = &cn.sourceManip
780 address = &cn.reply.tupleID.dstAddr
781 portOrIdent = &cn.reply.tupleID.dstPortOrEchoReplyIdent
782 }
783 784 if *manip != manipNotPerformed {
785 return
786 }
787 *manip = manipPerformed
788 if changeAddress {
789 *address = natAddress
790 }
791 792 // Everything below here is port-fiddling.
793 if !changePort {
794 return
795 }
796 797 // Does the current port/ident fit in the range?
798 if portsOrIdents.start <= *portOrIdent && *portOrIdent <= lastPortOrIdent {
799 // Yes, is the current reply tuple unique?
800 //
801 // Or, does the reply tuple refer to the same connection as the current one that
802 // we are NATing? This would apply, for example, to a self-connected socket,
803 // where the original and reply tuples are identical.
804 other := cn.ct.connForTID(cn.reply.tupleID)
805 if other == nil || other.conn == cn {
806 // Yes! No need to change the port.
807 return
808 }
809 }
810 811 // Try our best to find a port/ident that results in a unique reply tuple.
812 //
813 // We limit the number of attempts to find a unique tuple to not waste a lot
814 // of time looking for a unique tuple.
815 //
816 // Matches linux behaviour introduced in
817 // https://github.com/torvalds/linux/commit/a504b703bb1da526a01593da0e4be2af9d9f5fa8.
818 const maxAttemptsForInitialRound uint32 = 128
819 const minAttemptsToContinue = 16
820 821 allowedInitialAttempts := maxAttemptsForInitialRound
822 if allowedInitialAttempts > portsOrIdents.size {
823 allowedInitialAttempts = portsOrIdents.size
824 }
825 826 for maxAttempts := allowedInitialAttempts; ; maxAttempts /= 2 {
827 // Start reach round with a random initial port/ident offset.
828 randOffset := cn.ct.rand.Uint32()
829 830 for i := uint32(0); i < maxAttempts; i++ {
831 newPortOrIdentU32 := uint32(portsOrIdents.start) + (randOffset+i)%portsOrIdents.size
832 if newPortOrIdentU32 > math.MaxUint16 {
833 panic(fmt.Sprintf("got newPortOrIdentU32 = %d, want <= MaxUint16(=%d); portsOrIdents=%#v, randOffset=%d", newPortOrIdentU32, math.MaxUint16, portsOrIdents, randOffset))
834 }
835 836 *portOrIdent = uint16(newPortOrIdentU32)
837 838 if other := cn.ct.connForTID(cn.reply.tupleID); other == nil {
839 // We found a unique tuple!
840 return
841 }
842 }
843 844 if maxAttempts == portsOrIdents.size {
845 // We already tried all the ports/idents in the range so no need to keep
846 // trying.
847 return
848 }
849 850 if maxAttempts < minAttemptsToContinue {
851 return
852 }
853 }
854 855 // We did not find a unique tuple, use the last used port anyways.
856 // TODO(https://gvisor.dev/issue/6850): Handle not finding a unique tuple
857 // better (e.g. remove the connection and drop the packet).
858 }
859 860 // handlePacket attempts to handle a packet and perform NAT if the connection
861 // has had NAT performed on it.
862 //
863 // Returns true if the packet can skip the NAT table.
864 func (cn *conn) handlePacket(pkt *PacketBuffer, hook Hook, rt *Route) bool {
865 netHdr, transHdr, isICMPError, ok := getHeaders(pkt)
866 if !ok {
867 return false
868 }
869 870 fullChecksum := false
871 updatePseudoHeader := false
872 natDone := &pkt.snatDone
873 dnat := false
874 switch hook {
875 case Prerouting:
876 // Packet came from outside the stack so it must have a checksum set
877 // already.
878 fullChecksum = true
879 updatePseudoHeader = true
880 881 natDone = &pkt.dnatDone
882 dnat = true
883 case Input:
884 case Forward:
885 panic("should not handle packet in the forwarding hook")
886 case Output:
887 natDone = &pkt.dnatDone
888 dnat = true
889 fallthrough
890 case Postrouting:
891 if pkt.TransportProtocolNumber == header.TCPProtocolNumber && pkt.GSOOptions.Type != GSONone && pkt.GSOOptions.NeedsCsum {
892 updatePseudoHeader = true
893 } else if rt.RequiresTXTransportChecksum() {
894 fullChecksum = true
895 updatePseudoHeader = true
896 }
897 default:
898 panic(fmt.Sprintf("unrecognized hook = %d", hook))
899 }
900 901 if *natDone {
902 panic(fmt.Sprintf("packet already had NAT(dnat=%t) performed at hook=%s; pkt=%#v", dnat, hook, pkt))
903 }
904 905 // TODO(gvisor.dev/issue/5748): TCP checksums on inbound packets should be
906 // validated if checksum offloading is off. It may require IP defrag if the
907 // packets are fragmented.
908 909 reply := pkt.tuple.reply
910 911 tid, manip := func() (tupleID, manipType) {
912 cn.mu.RLock()
913 defer cn.mu.RUnlock()
914 915 if reply {
916 tid := cn.original.tupleID
917 918 if dnat {
919 return tid, cn.sourceManip
920 }
921 return tid, cn.destinationManip
922 }
923 924 tid := cn.reply.tupleID
925 if dnat {
926 return tid, cn.destinationManip
927 }
928 return tid, cn.sourceManip
929 }()
930 switch manip {
931 case manipNotPerformed:
932 return false
933 case manipPerformedNoop:
934 *natDone = true
935 return true
936 case manipPerformed:
937 default:
938 panic(fmt.Sprintf("unhandled manip = %d", manip))
939 }
940 941 newPort := tid.dstPortOrEchoReplyIdent
942 newAddr := tid.dstAddr
943 if dnat {
944 newPort = tid.srcPortOrEchoRequestIdent
945 newAddr = tid.srcAddr
946 }
947 948 rewritePacket(
949 netHdr,
950 transHdr,
951 !dnat != isICMPError,
952 fullChecksum,
953 updatePseudoHeader,
954 newPort,
955 newAddr,
956 )
957 958 *natDone = true
959 960 if !isICMPError {
961 return true
962 }
963 964 // We performed NAT on (erroneous) packet that triggered an ICMP response, but
965 // not the ICMP packet itself.
966 switch pkt.TransportProtocolNumber {
967 case header.ICMPv4ProtocolNumber:
968 icmp := header.ICMPv4(pkt.TransportHeader().Slice())
969 // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum.
970 icmp.SetChecksum(0)
971 icmp.SetChecksum(header.ICMPv4Checksum(icmp, pkt.Data().Checksum()))
972 973 network := header.IPv4(pkt.NetworkHeader().Slice())
974 if dnat {
975 network.SetDestinationAddressWithChecksumUpdate(tid.srcAddr)
976 } else {
977 network.SetSourceAddressWithChecksumUpdate(tid.dstAddr)
978 }
979 case header.ICMPv6ProtocolNumber:
980 network := header.IPv6(pkt.NetworkHeader().Slice())
981 srcAddr := network.SourceAddress()
982 dstAddr := network.DestinationAddress()
983 if dnat {
984 dstAddr = tid.srcAddr
985 } else {
986 srcAddr = tid.dstAddr
987 }
988 989 icmp := header.ICMPv6(pkt.TransportHeader().Slice())
990 // TODO(https://gvisor.dev/issue/6788): Incrementally update ICMP checksum.
991 icmp.SetChecksum(0)
992 payload := pkt.Data()
993 icmp.SetChecksum(header.ICMPv6Checksum(header.ICMPv6ChecksumParams{
994 Header: icmp,
995 Src: srcAddr,
996 Dst: dstAddr,
997 PayloadCsum: payload.Checksum(),
998 PayloadLen: payload.Size(),
999 }))
1000 1001 if dnat {
1002 network.SetDestinationAddress(dstAddr)
1003 } else {
1004 network.SetSourceAddress(srcAddr)
1005 }
1006 }
1007 1008 return true
1009 }
1010 1011 // bucket gets the conntrack bucket for a tupleID.
1012 // +checklocksread:ct.mu
1013 func (ct *ConnTrack) bucket(id tupleID) int {
1014 return ct.bucketWithTableLength(id, len(ct.buckets))
1015 }
1016 1017 func (ct *ConnTrack) bucketWithTableLength(id tupleID, tableLength int) int {
1018 h := jenkins.Sum32(ct.seed)
1019 h.Write(id.srcAddr.AsSlice())
1020 h.Write(id.dstAddr.AsSlice())
1021 shortBuf := make([]byte, 2)
1022 binary.LittleEndian.PutUint16(shortBuf, id.srcPortOrEchoRequestIdent)
1023 h.Write([]byte(shortBuf))
1024 binary.LittleEndian.PutUint16(shortBuf, id.dstPortOrEchoReplyIdent)
1025 h.Write([]byte(shortBuf))
1026 binary.LittleEndian.PutUint16(shortBuf, uint16(id.transProto))
1027 h.Write([]byte(shortBuf))
1028 binary.LittleEndian.PutUint16(shortBuf, uint16(id.netProto))
1029 h.Write([]byte(shortBuf))
1030 return int(h.Sum32()) % tableLength
1031 }
1032 1033 // reapUnused deletes timed out entries from the conntrack map. The rules for
1034 // reaping are:
1035 // - Each call to reapUnused traverses a fraction of the conntrack table.
1036 // Specifically, it traverses len(ct.buckets)/fractionPerReaping.
1037 // - After reaping, reapUnused decides when it should next run based on the
1038 // ratio of expired connections to examined connections. If the ratio is
1039 // greater than maxExpiredPct, it schedules the next run quickly. Otherwise it
1040 // slightly increases the interval between runs.
1041 // - maxFullTraversal caps the time it takes to traverse the entire table.
1042 //
1043 // reapUnused returns the next bucket that should be checked and the time after
1044 // which it should be called again.
1045 func (ct *ConnTrack) reapUnused(start int, prevInterval time.Duration) (int, time.Duration) {
1046 const fractionPerReaping = 128
1047 const maxExpiredPct = 50
1048 const maxFullTraversal = 60 * time.Second
1049 const minInterval = 10 * time.Millisecond
1050 const maxInterval = maxFullTraversal / fractionPerReaping
1051 1052 now := ct.clock.NowMonotonic()
1053 checked := 0
1054 expired := 0
1055 var idx int
1056 ct.mu.RLock()
1057 defer ct.mu.RUnlock()
1058 for i := 0; i < len(ct.buckets)/fractionPerReaping; i++ {
1059 idx = (i + start) % len(ct.buckets)
1060 bkt := &ct.buckets[idx]
1061 bkt.mu.Lock()
1062 for tuple := bkt.tuples.Front(); tuple != nil; {
1063 // reapTupleLocked updates tuple's next pointer so we grab it here.
1064 nextTuple := tuple.Next()
1065 1066 checked++
1067 if ct.reapTupleLocked(tuple, idx, bkt, now) {
1068 expired++
1069 }
1070 1071 tuple = nextTuple
1072 }
1073 bkt.mu.Unlock()
1074 }
1075 // We already checked buckets[idx].
1076 idx++
1077 1078 // If half or more of the connections are expired, the table has gotten
1079 // stale. Reschedule quickly.
1080 expiredPct := 0
1081 if checked != 0 {
1082 expiredPct = expired * 100 / checked
1083 }
1084 if expiredPct > maxExpiredPct {
1085 return idx, minInterval
1086 }
1087 if interval := prevInterval + minInterval; interval <= maxInterval {
1088 // Increment the interval between runs.
1089 return idx, interval
1090 }
1091 // We've hit the maximum interval.
1092 return idx, maxInterval
1093 }
1094 1095 // reapTupleLocked tries to remove tuple and its reply from the table. It
1096 // returns whether the tuple's connection has timed out.
1097 //
1098 // Precondition: ct.mu is read locked and bkt.mu is write locked.
1099 // +checklocksread:ct.mu
1100 // +checklocks:bkt.mu
1101 func (ct *ConnTrack) reapTupleLocked(reapingTuple *tuple, bktID int, bkt *bucket, now tcpip.MonotonicTime) bool {
1102 if !reapingTuple.conn.timedOut(now) {
1103 return false
1104 }
1105 1106 var otherTuple *tuple
1107 if reapingTuple.reply {
1108 otherTuple = &reapingTuple.conn.original
1109 } else {
1110 otherTuple = &reapingTuple.conn.reply
1111 }
1112 1113 otherTupleBktID := ct.bucket(otherTuple.tupleID)
1114 replyTupleInserted := reapingTuple.conn.getFinalizeResult() == finalizeResultSuccess
1115 1116 // To maintain lock order, we can only reap both tuples if the tuple for the
1117 // other direction appears later in the table.
1118 if bktID > otherTupleBktID && replyTupleInserted {
1119 return true
1120 }
1121 1122 bkt.tuples.Remove(reapingTuple)
1123 1124 if !replyTupleInserted {
1125 // The other tuple is the reply which has not yet been inserted.
1126 return true
1127 }
1128 1129 // Reap the other connection.
1130 if bktID == otherTupleBktID {
1131 // Don't re-lock if both tuples are in the same bucket.
1132 bkt.tuples.Remove(otherTuple)
1133 } else {
1134 otherTupleBkt := &ct.buckets[otherTupleBktID]
1135 otherTupleBkt.mu.NestedLock(bucketLockOthertuple)
1136 otherTupleBkt.tuples.Remove(otherTuple)
1137 otherTupleBkt.mu.NestedUnlock(bucketLockOthertuple)
1138 }
1139 1140 return true
1141 }
1142 1143 func (ct *ConnTrack) originalDst(epID TransportEndpointID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber) (tcpip.Address, uint16, tcpip.Error) {
1144 // Lookup the connection. The reply's original destination
1145 // describes the original address.
1146 tid := tupleID{
1147 srcAddr: epID.LocalAddress,
1148 srcPortOrEchoRequestIdent: epID.LocalPort,
1149 dstAddr: epID.RemoteAddress,
1150 dstPortOrEchoReplyIdent: epID.RemotePort,
1151 transProto: transProto,
1152 netProto: netProto,
1153 }
1154 t := ct.connForTID(tid)
1155 if t == nil {
1156 // Not a tracked connection.
1157 return tcpip.Address{}, 0, &tcpip.ErrNotConnected{}
1158 }
1159 1160 t.conn.mu.RLock()
1161 defer t.conn.mu.RUnlock()
1162 if t.conn.destinationManip == manipNotPerformed {
1163 // Unmanipulated destination.
1164 return tcpip.Address{}, 0, &tcpip.ErrInvalidOptionValue{}
1165 }
1166 1167 id := t.conn.original.tupleID
1168 return id.dstAddr, id.dstPortOrEchoReplyIdent, nil
1169 }
1170