conn.go raw

   1  //go:build !js
   2  // +build !js
   3  
   4  package websocket
   5  
   6  import (
   7  	"bufio"
   8  	"context"
   9  	"fmt"
  10  	"io"
  11  	"net"
  12  	"runtime"
  13  	"strconv"
  14  	"sync"
  15  	"sync/atomic"
  16  )
  17  
  18  // MessageType represents the type of a WebSocket message.
  19  // See https://tools.ietf.org/html/rfc6455#section-5.6
  20  type MessageType int
  21  
  22  // MessageType constants.
  23  const (
  24  	// MessageText is for UTF-8 encoded text messages like JSON.
  25  	MessageText MessageType = iota + 1
  26  	// MessageBinary is for binary messages like protobufs.
  27  	MessageBinary
  28  )
  29  
  30  // Conn represents a WebSocket connection.
  31  // All methods may be called concurrently except for Reader and Read.
  32  //
  33  // You must always read from the connection. Otherwise control
  34  // frames will not be handled. See Reader and CloseRead.
  35  //
  36  // Be sure to call Close on the connection when you
  37  // are finished with it to release associated resources.
  38  //
  39  // On any error from any method, the connection is closed
  40  // with an appropriate reason.
  41  //
  42  // This applies to context expirations as well unfortunately.
  43  // See https://github.com/nhooyr/websocket/issues/242#issuecomment-633182220
  44  type Conn struct {
  45  	noCopy noCopy
  46  
  47  	subprotocol    string
  48  	rwc            io.ReadWriteCloser
  49  	client         bool
  50  	copts          *compressionOptions
  51  	flateThreshold int
  52  	br             *bufio.Reader
  53  	bw             *bufio.Writer
  54  
  55  	readTimeout     chan context.Context
  56  	writeTimeout    chan context.Context
  57  	timeoutLoopDone chan struct{}
  58  
  59  	// Read state.
  60  	readMu         *mu
  61  	readHeaderBuf  [8]byte
  62  	readControlBuf [maxControlPayload]byte
  63  	msgReader      *msgReader
  64  
  65  	// Write state.
  66  	msgWriter      *msgWriter
  67  	writeFrameMu   *mu
  68  	writeBuf       []byte
  69  	writeHeaderBuf [8]byte
  70  	writeHeader    header
  71  
  72  	closeReadMu   sync.Mutex
  73  	closeReadCtx  context.Context
  74  	closeReadDone chan struct{}
  75  
  76  	closed  chan struct{}
  77  	closeMu sync.Mutex
  78  	closing bool
  79  
  80  	pingCounter   int32
  81  	activePingsMu sync.Mutex
  82  	activePings   map[string]chan<- struct{}
  83  }
  84  
  85  type connConfig struct {
  86  	subprotocol    string
  87  	rwc            io.ReadWriteCloser
  88  	client         bool
  89  	copts          *compressionOptions
  90  	flateThreshold int
  91  
  92  	br *bufio.Reader
  93  	bw *bufio.Writer
  94  }
  95  
  96  func newConn(cfg connConfig) *Conn {
  97  	c := &Conn{
  98  		subprotocol:    cfg.subprotocol,
  99  		rwc:            cfg.rwc,
 100  		client:         cfg.client,
 101  		copts:          cfg.copts,
 102  		flateThreshold: cfg.flateThreshold,
 103  
 104  		br: cfg.br,
 105  		bw: cfg.bw,
 106  
 107  		readTimeout:     make(chan context.Context),
 108  		writeTimeout:    make(chan context.Context),
 109  		timeoutLoopDone: make(chan struct{}),
 110  
 111  		closed:      make(chan struct{}),
 112  		activePings: make(map[string]chan<- struct{}),
 113  	}
 114  
 115  	c.readMu = newMu(c)
 116  	c.writeFrameMu = newMu(c)
 117  
 118  	c.msgReader = newMsgReader(c)
 119  
 120  	c.msgWriter = newMsgWriter(c)
 121  	if c.client {
 122  		c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc)
 123  	}
 124  
 125  	if c.flate() && c.flateThreshold == 0 {
 126  		c.flateThreshold = 128
 127  		if !c.msgWriter.flateContextTakeover() {
 128  			c.flateThreshold = 512
 129  		}
 130  	}
 131  
 132  	runtime.SetFinalizer(c, func(c *Conn) {
 133  		c.close()
 134  	})
 135  
 136  	go c.timeoutLoop()
 137  
 138  	return c
 139  }
 140  
 141  // Subprotocol returns the negotiated subprotocol.
 142  // An empty string means the default protocol.
 143  func (c *Conn) Subprotocol() string {
 144  	return c.subprotocol
 145  }
 146  
 147  func (c *Conn) close() error {
 148  	c.closeMu.Lock()
 149  	defer c.closeMu.Unlock()
 150  
 151  	if c.isClosed() {
 152  		return net.ErrClosed
 153  	}
 154  	runtime.SetFinalizer(c, nil)
 155  	close(c.closed)
 156  
 157  	// Have to close after c.closed is closed to ensure any goroutine that wakes up
 158  	// from the connection being closed also sees that c.closed is closed and returns
 159  	// closeErr.
 160  	err := c.rwc.Close()
 161  	// With the close of rwc, these become safe to close.
 162  	c.msgWriter.close()
 163  	c.msgReader.close()
 164  	return err
 165  }
 166  
 167  func (c *Conn) timeoutLoop() {
 168  	defer close(c.timeoutLoopDone)
 169  
 170  	readCtx := context.Background()
 171  	writeCtx := context.Background()
 172  
 173  	for {
 174  		select {
 175  		case <-c.closed:
 176  			return
 177  
 178  		case writeCtx = <-c.writeTimeout:
 179  		case readCtx = <-c.readTimeout:
 180  
 181  		case <-readCtx.Done():
 182  			c.close()
 183  			return
 184  		case <-writeCtx.Done():
 185  			c.close()
 186  			return
 187  		}
 188  	}
 189  }
 190  
 191  func (c *Conn) flate() bool {
 192  	return c.copts != nil
 193  }
 194  
 195  // Ping sends a ping to the peer and waits for a pong.
 196  // Use this to measure latency or ensure the peer is responsive.
 197  // Ping must be called concurrently with Reader as it does
 198  // not read from the connection but instead waits for a Reader call
 199  // to read the pong.
 200  //
 201  // TCP Keepalives should suffice for most use cases.
 202  func (c *Conn) Ping(ctx context.Context) error {
 203  	p := atomic.AddInt32(&c.pingCounter, 1)
 204  
 205  	err := c.ping(ctx, strconv.Itoa(int(p)))
 206  	if err != nil {
 207  		return fmt.Errorf("failed to ping: %w", err)
 208  	}
 209  	return nil
 210  }
 211  
 212  func (c *Conn) ping(ctx context.Context, p string) error {
 213  	pong := make(chan struct{}, 1)
 214  
 215  	c.activePingsMu.Lock()
 216  	c.activePings[p] = pong
 217  	c.activePingsMu.Unlock()
 218  
 219  	defer func() {
 220  		c.activePingsMu.Lock()
 221  		delete(c.activePings, p)
 222  		c.activePingsMu.Unlock()
 223  	}()
 224  
 225  	err := c.writeControl(ctx, opPing, []byte(p))
 226  	if err != nil {
 227  		return err
 228  	}
 229  
 230  	select {
 231  	case <-c.closed:
 232  		return net.ErrClosed
 233  	case <-ctx.Done():
 234  		return fmt.Errorf("failed to wait for pong: %w", ctx.Err())
 235  	case <-pong:
 236  		return nil
 237  	}
 238  }
 239  
 240  type mu struct {
 241  	c  *Conn
 242  	ch chan struct{}
 243  }
 244  
 245  func newMu(c *Conn) *mu {
 246  	return &mu{
 247  		c:  c,
 248  		ch: make(chan struct{}, 1),
 249  	}
 250  }
 251  
 252  func (m *mu) forceLock() {
 253  	m.ch <- struct{}{}
 254  }
 255  
 256  func (m *mu) tryLock() bool {
 257  	select {
 258  	case m.ch <- struct{}{}:
 259  		return true
 260  	default:
 261  		return false
 262  	}
 263  }
 264  
 265  func (m *mu) lock(ctx context.Context) error {
 266  	select {
 267  	case <-m.c.closed:
 268  		return net.ErrClosed
 269  	case <-ctx.Done():
 270  		return fmt.Errorf("failed to acquire lock: %w", ctx.Err())
 271  	case m.ch <- struct{}{}:
 272  		// To make sure the connection is certainly alive.
 273  		// As it's possible the send on m.ch was selected
 274  		// over the receive on closed.
 275  		select {
 276  		case <-m.c.closed:
 277  			// Make sure to release.
 278  			m.unlock()
 279  			return net.ErrClosed
 280  		default:
 281  		}
 282  		return nil
 283  	}
 284  }
 285  
 286  func (m *mu) unlock() {
 287  	select {
 288  	case <-m.ch:
 289  	default:
 290  	}
 291  }
 292  
 293  type noCopy struct{}
 294  
 295  func (*noCopy) Lock() {}
 296