gonet.go raw
1 // Copyright 2018 The gVisor Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // Package gonet provides a Go net package compatible wrapper for a tcpip stack.
16 package gonet
17
18 import (
19 "bytes"
20 "context"
21 "errors"
22 "fmt"
23 "io"
24 "net"
25 "time"
26
27 "gvisor.dev/gvisor/pkg/sync"
28 "gvisor.dev/gvisor/pkg/tcpip"
29 "gvisor.dev/gvisor/pkg/tcpip/stack"
30 "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
31 "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
32 "gvisor.dev/gvisor/pkg/waiter"
33 )
34
35 var (
36 errCanceled = errors.New("operation canceled")
37 errWouldBlock = errors.New("operation would block")
38 )
39
40 // timeoutError is how the net package reports timeouts.
41 type timeoutError struct{}
42
43 func (e *timeoutError) Error() string { return "i/o timeout" }
44 func (e *timeoutError) Timeout() bool { return true }
45 func (e *timeoutError) Temporary() bool { return true }
46
47 // A TCPListener is a wrapper around a TCP tcpip.Endpoint that implements
48 // net.Listener.
49 type TCPListener struct {
50 stack *stack.Stack
51 ep tcpip.Endpoint
52 wq *waiter.Queue
53 cancelOnce sync.Once
54 cancel chan struct{}
55 }
56
57 // NewTCPListener creates a new TCPListener from a listening tcpip.Endpoint.
58 func NewTCPListener(s *stack.Stack, wq *waiter.Queue, ep tcpip.Endpoint) *TCPListener {
59 return &TCPListener{
60 stack: s,
61 ep: ep,
62 wq: wq,
63 cancel: make(chan struct{}),
64 }
65 }
66
67 // maxListenBacklog is set to be reasonably high for most uses of gonet. Go net
68 // package uses the value in /proc/sys/net/core/somaxconn file in Linux as the
69 // default listen backlog. The value below matches the default in common linux
70 // distros.
71 //
72 // See: https://cs.opensource.google/go/go/+/refs/tags/go1.18.1:src/net/sock_linux.go;drc=refs%2Ftags%2Fgo1.18.1;l=66
73 const maxListenBacklog = 4096
74
75 // ListenTCP creates a new TCPListener.
76 func ListenTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPListener, error) {
77 // Create a TCP endpoint, bind it, then start listening.
78 var wq waiter.Queue
79 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
80 if err != nil {
81 return nil, errors.New(err.String())
82 }
83
84 if err := ep.Bind(addr); err != nil {
85 ep.Close()
86 return nil, &net.OpError{
87 Op: "bind",
88 Net: "tcp",
89 Addr: fullToTCPAddr(addr),
90 Err: errors.New(err.String()),
91 }
92 }
93
94 if err := ep.Listen(maxListenBacklog); err != nil {
95 ep.Close()
96 return nil, &net.OpError{
97 Op: "listen",
98 Net: "tcp",
99 Addr: fullToTCPAddr(addr),
100 Err: errors.New(err.String()),
101 }
102 }
103
104 return NewTCPListener(s, &wq, ep), nil
105 }
106
107 // Close implements net.Listener.Close.
108 func (l *TCPListener) Close() error {
109 l.ep.Close()
110 return nil
111 }
112
113 // Shutdown stops the HTTP server.
114 func (l *TCPListener) Shutdown() {
115 l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
116 l.cancelOnce.Do(func() {
117 close(l.cancel) // broadcast cancellation
118 })
119 }
120
121 // Addr implements net.Listener.Addr.
122 func (l *TCPListener) Addr() net.Addr {
123 a, err := l.ep.GetLocalAddress()
124 if err != nil {
125 return nil
126 }
127 return fullToTCPAddr(a)
128 }
129
130 type deadlineTimer struct {
131 // mu protects the fields below.
132 mu sync.Mutex
133
134 readTimer *time.Timer
135 readCancelCh chan struct{}
136 writeTimer *time.Timer
137 writeCancelCh chan struct{}
138 }
139
140 func (d *deadlineTimer) init() {
141 d.readCancelCh = make(chan struct{})
142 d.writeCancelCh = make(chan struct{})
143 }
144
145 func (d *deadlineTimer) readCancel() <-chan struct{} {
146 d.mu.Lock()
147 c := d.readCancelCh
148 d.mu.Unlock()
149 return c
150 }
151 func (d *deadlineTimer) writeCancel() <-chan struct{} {
152 d.mu.Lock()
153 c := d.writeCancelCh
154 d.mu.Unlock()
155 return c
156 }
157
158 // setDeadline contains the shared logic for setting a deadline.
159 //
160 // cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
161 // deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
162 // deadlineTimer.writeTimer.
163 //
164 // setDeadline must only be called while holding d.mu.
165 func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
166 if *timer != nil && !(*timer).Stop() {
167 *cancelCh = make(chan struct{})
168 }
169
170 // Create a new channel if we already closed it due to setting an already
171 // expired time. We won't race with the timer because we already handled
172 // that above.
173 select {
174 case <-*cancelCh:
175 *cancelCh = make(chan struct{})
176 default:
177 }
178
179 // "A zero value for t means I/O operations will not time out."
180 // - net.Conn.SetDeadline
181 if t.IsZero() {
182 *timer = nil
183 return
184 }
185
186 timeout := t.Sub(time.Now())
187 if timeout <= 0 {
188 close(*cancelCh)
189 return
190 }
191
192 // Timer.Stop returns whether or not the AfterFunc has started, but
193 // does not indicate whether or not it has completed. Make a copy of
194 // the cancel channel to prevent this code from racing with the next
195 // call of setDeadline replacing *cancelCh.
196 ch := *cancelCh
197 *timer = time.AfterFunc(timeout, func() {
198 close(ch)
199 })
200 }
201
202 // SetReadDeadline implements net.Conn.SetReadDeadline and
203 // net.PacketConn.SetReadDeadline.
204 func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
205 d.mu.Lock()
206 d.setDeadline(&d.readCancelCh, &d.readTimer, t)
207 d.mu.Unlock()
208 return nil
209 }
210
211 // SetWriteDeadline implements net.Conn.SetWriteDeadline and
212 // net.PacketConn.SetWriteDeadline.
213 func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
214 d.mu.Lock()
215 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
216 d.mu.Unlock()
217 return nil
218 }
219
220 // SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
221 func (d *deadlineTimer) SetDeadline(t time.Time) error {
222 d.mu.Lock()
223 d.setDeadline(&d.readCancelCh, &d.readTimer, t)
224 d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
225 d.mu.Unlock()
226 return nil
227 }
228
229 // A TCPConn is a wrapper around a TCP tcpip.Endpoint that implements the net.Conn
230 // interface.
231 type TCPConn struct {
232 deadlineTimer
233
234 wq *waiter.Queue
235 ep tcpip.Endpoint
236
237 // readMu serializes reads and implicitly protects read.
238 //
239 // Lock ordering:
240 // If both readMu and deadlineTimer.mu are to be used in a single
241 // request, readMu must be acquired before deadlineTimer.mu.
242 readMu sync.Mutex
243
244 // read contains bytes that have been read from the endpoint,
245 // but haven't yet been returned.
246 read []byte
247 }
248
249 // NewTCPConn creates a new TCPConn.
250 func NewTCPConn(wq *waiter.Queue, ep tcpip.Endpoint) *TCPConn {
251 c := &TCPConn{
252 wq: wq,
253 ep: ep,
254 }
255 c.deadlineTimer.init()
256 return c
257 }
258
259 // Accept implements net.Conn.Accept.
260 func (l *TCPListener) Accept() (net.Conn, error) {
261 n, wq, err := l.ep.Accept(nil)
262
263 if _, ok := err.(*tcpip.ErrWouldBlock); ok {
264 // Create wait queue entry that notifies a channel.
265 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents)
266 l.wq.EventRegister(&waitEntry)
267 defer l.wq.EventUnregister(&waitEntry)
268
269 for {
270 n, wq, err = l.ep.Accept(nil)
271
272 if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
273 break
274 }
275
276 select {
277 case <-l.cancel:
278 return nil, errCanceled
279 case <-notifyCh:
280 }
281 }
282 }
283
284 if err != nil {
285 return nil, &net.OpError{
286 Op: "accept",
287 Net: "tcp",
288 Addr: l.Addr(),
289 Err: errors.New(err.String()),
290 }
291 }
292
293 return NewTCPConn(wq, n), nil
294 }
295
296 type opErrorer interface {
297 newOpError(op string, err error) *net.OpError
298 }
299
300 // commonRead implements the common logic between net.Conn.Read and
301 // net.PacketConn.ReadFrom.
302 func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) {
303 select {
304 case <-deadline:
305 return 0, errorer.newOpError("read", &timeoutError{})
306 default:
307 }
308
309 w := tcpip.SliceWriter(b)
310 opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
311 res, err := ep.Read(&w, opts)
312
313 if _, ok := err.(*tcpip.ErrWouldBlock); ok {
314 // Create wait queue entry that notifies a channel.
315 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.ReadableEvents)
316 wq.EventRegister(&waitEntry)
317 defer wq.EventUnregister(&waitEntry)
318 for {
319 res, err = ep.Read(&w, opts)
320 if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
321 break
322 }
323 select {
324 case <-deadline:
325 return 0, errorer.newOpError("read", &timeoutError{})
326 case <-notifyCh:
327 }
328 }
329 }
330
331 if _, ok := err.(*tcpip.ErrClosedForReceive); ok {
332 return 0, io.EOF
333 }
334
335 if err != nil {
336 return 0, errorer.newOpError("read", errors.New(err.String()))
337 }
338
339 if addr != nil {
340 *addr = res.RemoteAddr
341 }
342 return res.Count, nil
343 }
344
345 // Read implements net.Conn.Read.
346 func (c *TCPConn) Read(b []byte) (int, error) {
347 c.readMu.Lock()
348 defer c.readMu.Unlock()
349
350 deadline := c.readCancel()
351
352 n, err := commonRead(b, c.ep, c.wq, deadline, nil, c)
353 if n != 0 {
354 c.ep.ModerateRecvBuf(n)
355 }
356 return n, err
357 }
358
359 // Write implements net.Conn.Write.
360 func (c *TCPConn) Write(b []byte) (int, error) {
361 deadline := c.writeCancel()
362
363 // Check if deadlineTimer has already expired.
364 select {
365 case <-deadline:
366 return 0, c.newOpError("write", &timeoutError{})
367 default:
368 }
369
370 // We must handle two soft failure conditions simultaneously:
371 // 1. Write may write nothing and return *tcpip.ErrWouldBlock.
372 // If this happens, we need to register for notifications if we have
373 // not already and wait to try again.
374 // 2. Write may write fewer than the full number of bytes and return
375 // without error. In this case we need to try writing the remaining
376 // bytes again. I do not need to register for notifications.
377 //
378 // What is more, these two soft failure conditions can be interspersed.
379 // There is no guarantee that all of the condition #1s will occur before
380 // all of the condition #2s or visa-versa.
381 var (
382 r bytes.Reader
383 nbytes int
384 entry waiter.Entry
385 ch <-chan struct{}
386 )
387 for nbytes != len(b) {
388 r.Reset(b[nbytes:])
389 n, err := c.ep.Write(&r, tcpip.WriteOptions{})
390 nbytes += int(n)
391 switch err.(type) {
392 case nil:
393 case *tcpip.ErrWouldBlock:
394 if ch == nil {
395 entry, ch = waiter.NewChannelEntry(waiter.WritableEvents)
396 c.wq.EventRegister(&entry)
397 defer c.wq.EventUnregister(&entry)
398 } else {
399 // Don't wait immediately after registration in case more data
400 // became available between when we last checked and when we setup
401 // the notification.
402 select {
403 case <-deadline:
404 return nbytes, c.newOpError("write", &timeoutError{})
405 case <-ch:
406 continue
407 }
408 }
409 default:
410 return nbytes, c.newOpError("write", errors.New(err.String()))
411 }
412 }
413 return nbytes, nil
414 }
415
416 // Close implements net.Conn.Close.
417 func (c *TCPConn) Close() error {
418 c.ep.Close()
419 return nil
420 }
421
422 // CloseRead shuts down the reading side of the TCP connection. Most callers
423 // should just use Close.
424 //
425 // A TCP Half-Close is performed the same as CloseRead for *net.TCPConn.
426 func (c *TCPConn) CloseRead() error {
427 if terr := c.ep.Shutdown(tcpip.ShutdownRead); terr != nil {
428 return c.newOpError("close", errors.New(terr.String()))
429 }
430 return nil
431 }
432
433 // CloseWrite shuts down the writing side of the TCP connection. Most callers
434 // should just use Close.
435 //
436 // A TCP Half-Close is performed the same as CloseWrite for *net.TCPConn.
437 func (c *TCPConn) CloseWrite() error {
438 if terr := c.ep.Shutdown(tcpip.ShutdownWrite); terr != nil {
439 return c.newOpError("close", errors.New(terr.String()))
440 }
441 return nil
442 }
443
444 // LocalAddr implements net.Conn.LocalAddr.
445 func (c *TCPConn) LocalAddr() net.Addr {
446 a, err := c.ep.GetLocalAddress()
447 if err != nil {
448 return nil
449 }
450 return fullToTCPAddr(a)
451 }
452
453 // RemoteAddr implements net.Conn.RemoteAddr.
454 func (c *TCPConn) RemoteAddr() net.Addr {
455 a, err := c.ep.GetRemoteAddress()
456 if err != nil {
457 return nil
458 }
459 return fullToTCPAddr(a)
460 }
461
462 func (c *TCPConn) newOpError(op string, err error) *net.OpError {
463 return &net.OpError{
464 Op: op,
465 Net: "tcp",
466 Source: c.LocalAddr(),
467 Addr: c.RemoteAddr(),
468 Err: err,
469 }
470 }
471
472 func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr {
473 return &net.TCPAddr{IP: net.IP(addr.Addr.AsSlice()), Port: int(addr.Port)}
474 }
475
476 func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
477 return &net.UDPAddr{IP: net.IP(addr.Addr.AsSlice()), Port: int(addr.Port)}
478 }
479
480 // DialTCP creates a new TCPConn connected to the specified address.
481 func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
482 return DialContextTCP(context.Background(), s, addr, network)
483 }
484
485 // DialTCPWithBind creates a new TCPConn connected to the specified
486 // remoteAddress with its local address bound to localAddr.
487 func DialTCPWithBind(ctx context.Context, s *stack.Stack, localAddr, remoteAddr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
488 // Create TCP endpoint, then connect.
489 var wq waiter.Queue
490 ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
491 if err != nil {
492 return nil, errors.New(err.String())
493 }
494
495 // Create wait queue entry that notifies a channel.
496 //
497 // We do this unconditionally as Connect will always return an error.
498 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
499 wq.EventRegister(&waitEntry)
500 defer wq.EventUnregister(&waitEntry)
501
502 select {
503 case <-ctx.Done():
504 return nil, ctx.Err()
505 default:
506 }
507
508 // Bind before connect if requested.
509 if localAddr != (tcpip.FullAddress{}) {
510 if err = ep.Bind(localAddr); err != nil {
511 return nil, fmt.Errorf("ep.Bind(%+v) = %s", localAddr, err)
512 }
513 }
514
515 err = ep.Connect(remoteAddr)
516 if _, ok := err.(*tcpip.ErrConnectStarted); ok {
517 select {
518 case <-ctx.Done():
519 ep.Close()
520 return nil, ctx.Err()
521 case <-notifyCh:
522 }
523
524 err = ep.LastError()
525 }
526 if err != nil {
527 ep.Close()
528 return nil, &net.OpError{
529 Op: "connect",
530 Net: "tcp",
531 Addr: fullToTCPAddr(remoteAddr),
532 Err: errors.New(err.String()),
533 }
534 }
535
536 return NewTCPConn(&wq, ep), nil
537 }
538
539 // DialContextTCP creates a new TCPConn connected to the specified address
540 // with the option of adding cancellation and timeouts.
541 func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*TCPConn, error) {
542 return DialTCPWithBind(ctx, s, tcpip.FullAddress{} /* localAddr */, addr /* remoteAddr */, network)
543 }
544
545 // A UDPConn is a wrapper around a UDP tcpip.Endpoint that implements
546 // net.Conn and net.PacketConn.
547 type UDPConn struct {
548 deadlineTimer
549
550 ep tcpip.Endpoint
551 wq *waiter.Queue
552 }
553
554 // NewUDPConn creates a new UDPConn.
555 func NewUDPConn(wq *waiter.Queue, ep tcpip.Endpoint) *UDPConn {
556 c := &UDPConn{
557 ep: ep,
558 wq: wq,
559 }
560 c.deadlineTimer.init()
561 return c
562 }
563
564 // DialUDP creates a new UDPConn.
565 //
566 // If laddr is nil, a local address is automatically chosen.
567 //
568 // If raddr is nil, the UDPConn is left unconnected.
569 func DialUDP(s *stack.Stack, laddr, raddr *tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*UDPConn, error) {
570 var wq waiter.Queue
571 ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
572 if err != nil {
573 return nil, errors.New(err.String())
574 }
575
576 if laddr != nil {
577 if err := ep.Bind(*laddr); err != nil {
578 ep.Close()
579 return nil, &net.OpError{
580 Op: "bind",
581 Net: "udp",
582 Addr: fullToUDPAddr(*laddr),
583 Err: errors.New(err.String()),
584 }
585 }
586 }
587
588 c := NewUDPConn(&wq, ep)
589
590 if raddr != nil {
591 if err := c.ep.Connect(*raddr); err != nil {
592 c.ep.Close()
593 return nil, &net.OpError{
594 Op: "connect",
595 Net: "udp",
596 Addr: fullToUDPAddr(*raddr),
597 Err: errors.New(err.String()),
598 }
599 }
600 }
601
602 return c, nil
603 }
604
605 func (c *UDPConn) newOpError(op string, err error) *net.OpError {
606 return c.newRemoteOpError(op, nil, err)
607 }
608
609 func (c *UDPConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
610 return &net.OpError{
611 Op: op,
612 Net: "udp",
613 Source: c.LocalAddr(),
614 Addr: remote,
615 Err: err,
616 }
617 }
618
619 // RemoteAddr implements net.Conn.RemoteAddr.
620 func (c *UDPConn) RemoteAddr() net.Addr {
621 a, err := c.ep.GetRemoteAddress()
622 if err != nil {
623 return nil
624 }
625 return fullToUDPAddr(a)
626 }
627
628 // Read implements net.Conn.Read
629 func (c *UDPConn) Read(b []byte) (int, error) {
630 bytesRead, _, err := c.ReadFrom(b)
631 return bytesRead, err
632 }
633
634 // ReadFrom implements net.PacketConn.ReadFrom.
635 func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
636 deadline := c.readCancel()
637
638 var addr tcpip.FullAddress
639 n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c)
640 if err != nil {
641 return 0, nil, err
642 }
643 return n, fullToUDPAddr(addr), nil
644 }
645
646 func (c *UDPConn) Write(b []byte) (int, error) {
647 return c.WriteTo(b, nil)
648 }
649
650 // WriteTo implements net.PacketConn.WriteTo.
651 func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) {
652 deadline := c.writeCancel()
653
654 // Check if deadline has already expired.
655 select {
656 case <-deadline:
657 return 0, c.newRemoteOpError("write", addr, &timeoutError{})
658 default:
659 }
660
661 // If we're being called by Write, there is no addr
662 writeOptions := tcpip.WriteOptions{}
663 if addr != nil {
664 ua := addr.(*net.UDPAddr)
665 writeOptions.To = &tcpip.FullAddress{
666 Addr: tcpip.AddrFromSlice(ua.IP),
667 Port: uint16(ua.Port),
668 }
669 }
670
671 var r bytes.Reader
672 r.Reset(b)
673 n, err := c.ep.Write(&r, writeOptions)
674 if _, ok := err.(*tcpip.ErrWouldBlock); ok {
675 // Create wait queue entry that notifies a channel.
676 waitEntry, notifyCh := waiter.NewChannelEntry(waiter.WritableEvents)
677 c.wq.EventRegister(&waitEntry)
678 defer c.wq.EventUnregister(&waitEntry)
679 for {
680 select {
681 case <-deadline:
682 return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
683 case <-notifyCh:
684 }
685
686 n, err = c.ep.Write(&r, writeOptions)
687 if _, ok := err.(*tcpip.ErrWouldBlock); !ok {
688 break
689 }
690 }
691 }
692
693 if err == nil {
694 return int(n), nil
695 }
696
697 return int(n), c.newRemoteOpError("write", addr, errors.New(err.String()))
698 }
699
700 // Close implements net.PacketConn.Close.
701 func (c *UDPConn) Close() error {
702 c.ep.Close()
703 return nil
704 }
705
706 // LocalAddr implements net.PacketConn.LocalAddr.
707 func (c *UDPConn) LocalAddr() net.Addr {
708 a, err := c.ep.GetLocalAddress()
709 if err != nil {
710 return nil
711 }
712 return fullToUDPAddr(a)
713 }
714