netconn.go raw

   1  package websocket
   2  
   3  import (
   4  	"context"
   5  	"fmt"
   6  	"io"
   7  	"math"
   8  	"net"
   9  	"sync/atomic"
  10  	"time"
  11  )
  12  
  13  // NetConn converts a *websocket.Conn into a net.Conn.
  14  //
  15  // It's for tunneling arbitrary protocols over WebSockets.
  16  // Few users of the library will need this but it's tricky to implement
  17  // correctly and so provided in the library.
  18  // See https://github.com/nhooyr/websocket/issues/100.
  19  //
  20  // Every Write to the net.Conn will correspond to a message write of
  21  // the given type on *websocket.Conn.
  22  //
  23  // The passed ctx bounds the lifetime of the net.Conn. If cancelled,
  24  // all reads and writes on the net.Conn will be cancelled.
  25  //
  26  // If a message is read that is not of the correct type, the connection
  27  // will be closed with StatusUnsupportedData and an error will be returned.
  28  //
  29  // Close will close the *websocket.Conn with StatusNormalClosure.
  30  //
  31  // When a deadline is hit and there is an active read or write goroutine, the
  32  // connection will be closed. This is different from most net.Conn implementations
  33  // where only the reading/writing goroutines are interrupted but the connection
  34  // is kept alive.
  35  //
  36  // The Addr methods will return the real addresses for connections obtained
  37  // from websocket.Accept. But for connections obtained from websocket.Dial, a mock net.Addr
  38  // will be returned that gives "websocket" for Network() and "websocket/unknown-addr" for
  39  // String(). This is because websocket.Dial only exposes a io.ReadWriteCloser instead of the
  40  // full net.Conn to us.
  41  //
  42  // When running as WASM, the Addr methods will always return the mock address described above.
  43  //
  44  // A received StatusNormalClosure or StatusGoingAway close frame will be translated to
  45  // io.EOF when reading.
  46  //
  47  // Furthermore, the ReadLimit is set to -1 to disable it.
  48  func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn {
  49  	c.SetReadLimit(-1)
  50  
  51  	nc := &netConn{
  52  		c:       c,
  53  		msgType: msgType,
  54  		readMu:  newMu(c),
  55  		writeMu: newMu(c),
  56  	}
  57  
  58  	nc.writeCtx, nc.writeCancel = context.WithCancel(ctx)
  59  	nc.readCtx, nc.readCancel = context.WithCancel(ctx)
  60  
  61  	nc.writeTimer = time.AfterFunc(math.MaxInt64, func() {
  62  		if !nc.writeMu.tryLock() {
  63  			// If the lock cannot be acquired, then there is an
  64  			// active write goroutine and so we should cancel the context.
  65  			nc.writeCancel()
  66  			return
  67  		}
  68  		defer nc.writeMu.unlock()
  69  
  70  		// Prevents future writes from writing until the deadline is reset.
  71  		atomic.StoreInt64(&nc.writeExpired, 1)
  72  	})
  73  	if !nc.writeTimer.Stop() {
  74  		<-nc.writeTimer.C
  75  	}
  76  
  77  	nc.readTimer = time.AfterFunc(math.MaxInt64, func() {
  78  		if !nc.readMu.tryLock() {
  79  			// If the lock cannot be acquired, then there is an
  80  			// active read goroutine and so we should cancel the context.
  81  			nc.readCancel()
  82  			return
  83  		}
  84  		defer nc.readMu.unlock()
  85  
  86  		// Prevents future reads from reading until the deadline is reset.
  87  		atomic.StoreInt64(&nc.readExpired, 1)
  88  	})
  89  	if !nc.readTimer.Stop() {
  90  		<-nc.readTimer.C
  91  	}
  92  
  93  	return nc
  94  }
  95  
  96  type netConn struct {
  97  	// These must be first to be aligned on 32 bit platforms.
  98  	// https://github.com/nhooyr/websocket/pull/438
  99  	readExpired  int64
 100  	writeExpired int64
 101  
 102  	c       *Conn
 103  	msgType MessageType
 104  
 105  	writeTimer  *time.Timer
 106  	writeMu     *mu
 107  	writeCtx    context.Context
 108  	writeCancel context.CancelFunc
 109  
 110  	readTimer  *time.Timer
 111  	readMu     *mu
 112  	readCtx    context.Context
 113  	readCancel context.CancelFunc
 114  	readEOFed  bool
 115  	reader     io.Reader
 116  }
 117  
 118  var _ net.Conn = &netConn{}
 119  
 120  func (nc *netConn) Close() error {
 121  	nc.writeTimer.Stop()
 122  	nc.writeCancel()
 123  	nc.readTimer.Stop()
 124  	nc.readCancel()
 125  	return nc.c.Close(StatusNormalClosure, "")
 126  }
 127  
 128  func (nc *netConn) Write(p []byte) (int, error) {
 129  	nc.writeMu.forceLock()
 130  	defer nc.writeMu.unlock()
 131  
 132  	if atomic.LoadInt64(&nc.writeExpired) == 1 {
 133  		return 0, fmt.Errorf("failed to write: %w", context.DeadlineExceeded)
 134  	}
 135  
 136  	err := nc.c.Write(nc.writeCtx, nc.msgType, p)
 137  	if err != nil {
 138  		return 0, err
 139  	}
 140  	return len(p), nil
 141  }
 142  
 143  func (nc *netConn) Read(p []byte) (int, error) {
 144  	nc.readMu.forceLock()
 145  	defer nc.readMu.unlock()
 146  
 147  	for {
 148  		n, err := nc.read(p)
 149  		if err != nil {
 150  			return n, err
 151  		}
 152  		if n == 0 {
 153  			continue
 154  		}
 155  		return n, nil
 156  	}
 157  }
 158  
 159  func (nc *netConn) read(p []byte) (int, error) {
 160  	if atomic.LoadInt64(&nc.readExpired) == 1 {
 161  		return 0, fmt.Errorf("failed to read: %w", context.DeadlineExceeded)
 162  	}
 163  
 164  	if nc.readEOFed {
 165  		return 0, io.EOF
 166  	}
 167  
 168  	if nc.reader == nil {
 169  		typ, r, err := nc.c.Reader(nc.readCtx)
 170  		if err != nil {
 171  			switch CloseStatus(err) {
 172  			case StatusNormalClosure, StatusGoingAway:
 173  				nc.readEOFed = true
 174  				return 0, io.EOF
 175  			}
 176  			return 0, err
 177  		}
 178  		if typ != nc.msgType {
 179  			err := fmt.Errorf("unexpected frame type read (expected %v): %v", nc.msgType, typ)
 180  			nc.c.Close(StatusUnsupportedData, err.Error())
 181  			return 0, err
 182  		}
 183  		nc.reader = r
 184  	}
 185  
 186  	n, err := nc.reader.Read(p)
 187  	if err == io.EOF {
 188  		nc.reader = nil
 189  		err = nil
 190  	}
 191  	return n, err
 192  }
 193  
 194  type websocketAddr struct {
 195  }
 196  
 197  func (a websocketAddr) Network() string {
 198  	return "websocket"
 199  }
 200  
 201  func (a websocketAddr) String() string {
 202  	return "websocket/unknown-addr"
 203  }
 204  
 205  func (nc *netConn) SetDeadline(t time.Time) error {
 206  	nc.SetWriteDeadline(t)
 207  	nc.SetReadDeadline(t)
 208  	return nil
 209  }
 210  
 211  func (nc *netConn) SetWriteDeadline(t time.Time) error {
 212  	atomic.StoreInt64(&nc.writeExpired, 0)
 213  	if t.IsZero() {
 214  		nc.writeTimer.Stop()
 215  	} else {
 216  		dur := time.Until(t)
 217  		if dur <= 0 {
 218  			dur = 1
 219  		}
 220  		nc.writeTimer.Reset(dur)
 221  	}
 222  	return nil
 223  }
 224  
 225  func (nc *netConn) SetReadDeadline(t time.Time) error {
 226  	atomic.StoreInt64(&nc.readExpired, 0)
 227  	if t.IsZero() {
 228  		nc.readTimer.Stop()
 229  	} else {
 230  		dur := time.Until(t)
 231  		if dur <= 0 {
 232  			dur = 1
 233  		}
 234  		nc.readTimer.Reset(dur)
 235  	}
 236  	return nil
 237  }
 238