read.go raw

   1  //go:build !js
   2  // +build !js
   3  
   4  package websocket
   5  
   6  import (
   7  	"bufio"
   8  	"context"
   9  	"errors"
  10  	"fmt"
  11  	"io"
  12  	"net"
  13  	"strings"
  14  	"time"
  15  
  16  	"github.com/coder/websocket/internal/errd"
  17  	"github.com/coder/websocket/internal/util"
  18  	"github.com/coder/websocket/internal/xsync"
  19  )
  20  
  21  // Reader reads from the connection until there is a WebSocket
  22  // data message to be read. It will handle ping, pong and close frames as appropriate.
  23  //
  24  // It returns the type of the message and an io.Reader to read it.
  25  // The passed context will also bound the reader.
  26  // Ensure you read to EOF otherwise the connection will hang.
  27  //
  28  // Call CloseRead if you do not expect any data messages from the peer.
  29  //
  30  // Only one Reader may be open at a time.
  31  //
  32  // If you need a separate timeout on the Reader call and the Read itself,
  33  // use time.AfterFunc to cancel the context passed in.
  34  // See https://github.com/nhooyr/websocket/issues/87#issue-451703332
  35  // Most users should not need this.
  36  func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
  37  	return c.reader(ctx)
  38  }
  39  
  40  // Read is a convenience method around Reader to read a single message
  41  // from the connection.
  42  func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
  43  	typ, r, err := c.Reader(ctx)
  44  	if err != nil {
  45  		return 0, nil, err
  46  	}
  47  
  48  	b, err := io.ReadAll(r)
  49  	return typ, b, err
  50  }
  51  
  52  // CloseRead starts a goroutine to read from the connection until it is closed
  53  // or a data message is received.
  54  //
  55  // Once CloseRead is called you cannot read any messages from the connection.
  56  // The returned context will be cancelled when the connection is closed.
  57  //
  58  // If a data message is received, the connection will be closed with StatusPolicyViolation.
  59  //
  60  // Call CloseRead when you do not expect to read any more messages.
  61  // Since it actively reads from the connection, it will ensure that ping, pong and close
  62  // frames are responded to. This means c.Ping and c.Close will still work as expected.
  63  //
  64  // This function is idempotent.
  65  func (c *Conn) CloseRead(ctx context.Context) context.Context {
  66  	c.closeReadMu.Lock()
  67  	ctx2 := c.closeReadCtx
  68  	if ctx2 != nil {
  69  		c.closeReadMu.Unlock()
  70  		return ctx2
  71  	}
  72  	ctx, cancel := context.WithCancel(ctx)
  73  	c.closeReadCtx = ctx
  74  	c.closeReadDone = make(chan struct{})
  75  	c.closeReadMu.Unlock()
  76  
  77  	go func() {
  78  		defer close(c.closeReadDone)
  79  		defer cancel()
  80  		defer c.close()
  81  		_, _, err := c.Reader(ctx)
  82  		if err == nil {
  83  			c.Close(StatusPolicyViolation, "unexpected data message")
  84  		}
  85  	}()
  86  	return ctx
  87  }
  88  
  89  // SetReadLimit sets the max number of bytes to read for a single message.
  90  // It applies to the Reader and Read methods.
  91  //
  92  // By default, the connection has a message read limit of 32768 bytes.
  93  //
  94  // When the limit is hit, the connection will be closed with StatusMessageTooBig.
  95  //
  96  // Set to -1 to disable.
  97  func (c *Conn) SetReadLimit(n int64) {
  98  	if n >= 0 {
  99  		// We read one more byte than the limit in case
 100  		// there is a fin frame that needs to be read.
 101  		n++
 102  	}
 103  
 104  	c.msgReader.limitReader.limit.Store(n)
 105  }
 106  
 107  const defaultReadLimit = 32768
 108  
 109  func newMsgReader(c *Conn) *msgReader {
 110  	mr := &msgReader{
 111  		c:   c,
 112  		fin: true,
 113  	}
 114  	mr.readFunc = mr.read
 115  
 116  	mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1)
 117  	return mr
 118  }
 119  
 120  func (mr *msgReader) resetFlate() {
 121  	if mr.flateContextTakeover() {
 122  		if mr.dict == nil {
 123  			mr.dict = &slidingWindow{}
 124  		}
 125  		mr.dict.init(32768)
 126  	}
 127  	if mr.flateBufio == nil {
 128  		mr.flateBufio = getBufioReader(mr.readFunc)
 129  	}
 130  
 131  	if mr.flateContextTakeover() {
 132  		mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf)
 133  	} else {
 134  		mr.flateReader = getFlateReader(mr.flateBufio, nil)
 135  	}
 136  	mr.limitReader.r = mr.flateReader
 137  	mr.flateTail.Reset(deflateMessageTail)
 138  }
 139  
 140  func (mr *msgReader) putFlateReader() {
 141  	if mr.flateReader != nil {
 142  		putFlateReader(mr.flateReader)
 143  		mr.flateReader = nil
 144  	}
 145  }
 146  
 147  func (mr *msgReader) close() {
 148  	mr.c.readMu.forceLock()
 149  	mr.putFlateReader()
 150  	if mr.dict != nil {
 151  		mr.dict.close()
 152  		mr.dict = nil
 153  	}
 154  	if mr.flateBufio != nil {
 155  		putBufioReader(mr.flateBufio)
 156  	}
 157  
 158  	if mr.c.client {
 159  		putBufioReader(mr.c.br)
 160  		mr.c.br = nil
 161  	}
 162  }
 163  
 164  func (mr *msgReader) flateContextTakeover() bool {
 165  	if mr.c.client {
 166  		return !mr.c.copts.serverNoContextTakeover
 167  	}
 168  	return !mr.c.copts.clientNoContextTakeover
 169  }
 170  
 171  func (c *Conn) readRSV1Illegal(h header) bool {
 172  	// If compression is disabled, rsv1 is illegal.
 173  	if !c.flate() {
 174  		return true
 175  	}
 176  	// rsv1 is only allowed on data frames beginning messages.
 177  	if h.opcode != opText && h.opcode != opBinary {
 178  		return true
 179  	}
 180  	return false
 181  }
 182  
 183  func (c *Conn) readLoop(ctx context.Context) (header, error) {
 184  	for {
 185  		h, err := c.readFrameHeader(ctx)
 186  		if err != nil {
 187  			return header{}, err
 188  		}
 189  
 190  		if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 {
 191  			err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3)
 192  			c.writeError(StatusProtocolError, err)
 193  			return header{}, err
 194  		}
 195  
 196  		if !c.client && !h.masked {
 197  			return header{}, errors.New("received unmasked frame from client")
 198  		}
 199  
 200  		switch h.opcode {
 201  		case opClose, opPing, opPong:
 202  			err = c.handleControl(ctx, h)
 203  			if err != nil {
 204  				// Pass through CloseErrors when receiving a close frame.
 205  				if h.opcode == opClose && CloseStatus(err) != -1 {
 206  					return header{}, err
 207  				}
 208  				return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err)
 209  			}
 210  		case opContinuation, opText, opBinary:
 211  			return h, nil
 212  		default:
 213  			err := fmt.Errorf("received unknown opcode %v", h.opcode)
 214  			c.writeError(StatusProtocolError, err)
 215  			return header{}, err
 216  		}
 217  	}
 218  }
 219  
 220  func (c *Conn) readFrameHeader(ctx context.Context) (header, error) {
 221  	select {
 222  	case <-c.closed:
 223  		return header{}, net.ErrClosed
 224  	case c.readTimeout <- ctx:
 225  	}
 226  
 227  	h, err := readFrameHeader(c.br, c.readHeaderBuf[:])
 228  	if err != nil {
 229  		select {
 230  		case <-c.closed:
 231  			return header{}, net.ErrClosed
 232  		case <-ctx.Done():
 233  			return header{}, ctx.Err()
 234  		default:
 235  			return header{}, err
 236  		}
 237  	}
 238  
 239  	select {
 240  	case <-c.closed:
 241  		return header{}, net.ErrClosed
 242  	case c.readTimeout <- context.Background():
 243  	}
 244  
 245  	return h, nil
 246  }
 247  
 248  func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) {
 249  	select {
 250  	case <-c.closed:
 251  		return 0, net.ErrClosed
 252  	case c.readTimeout <- ctx:
 253  	}
 254  
 255  	n, err := io.ReadFull(c.br, p)
 256  	if err != nil {
 257  		select {
 258  		case <-c.closed:
 259  			return n, net.ErrClosed
 260  		case <-ctx.Done():
 261  			return n, ctx.Err()
 262  		default:
 263  			return n, fmt.Errorf("failed to read frame payload: %w", err)
 264  		}
 265  	}
 266  
 267  	select {
 268  	case <-c.closed:
 269  		return n, net.ErrClosed
 270  	case c.readTimeout <- context.Background():
 271  	}
 272  
 273  	return n, err
 274  }
 275  
 276  func (c *Conn) handleControl(ctx context.Context, h header) (err error) {
 277  	if h.payloadLength < 0 || h.payloadLength > maxControlPayload {
 278  		err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength)
 279  		c.writeError(StatusProtocolError, err)
 280  		return err
 281  	}
 282  
 283  	if !h.fin {
 284  		err := errors.New("received fragmented control frame")
 285  		c.writeError(StatusProtocolError, err)
 286  		return err
 287  	}
 288  
 289  	ctx, cancel := context.WithTimeout(ctx, time.Second*5)
 290  	defer cancel()
 291  
 292  	b := c.readControlBuf[:h.payloadLength]
 293  	_, err = c.readFramePayload(ctx, b)
 294  	if err != nil {
 295  		return err
 296  	}
 297  
 298  	if h.masked {
 299  		mask(b, h.maskKey)
 300  	}
 301  
 302  	switch h.opcode {
 303  	case opPing:
 304  		return c.writeControl(ctx, opPong, b)
 305  	case opPong:
 306  		c.activePingsMu.Lock()
 307  		pong, ok := c.activePings[string(b)]
 308  		c.activePingsMu.Unlock()
 309  		if ok {
 310  			select {
 311  			case pong <- struct{}{}:
 312  			default:
 313  			}
 314  		}
 315  		return nil
 316  	}
 317  
 318  	// opClose
 319  
 320  	ce, err := parseClosePayload(b)
 321  	if err != nil {
 322  		err = fmt.Errorf("received invalid close payload: %w", err)
 323  		c.writeError(StatusProtocolError, err)
 324  		return err
 325  	}
 326  
 327  	err = fmt.Errorf("received close frame: %w", ce)
 328  	c.writeClose(ce.Code, ce.Reason)
 329  	c.readMu.unlock()
 330  	c.close()
 331  	return err
 332  }
 333  
 334  func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) {
 335  	defer errd.Wrap(&err, "failed to get reader")
 336  
 337  	err = c.readMu.lock(ctx)
 338  	if err != nil {
 339  		return 0, nil, err
 340  	}
 341  	defer c.readMu.unlock()
 342  
 343  	if !c.msgReader.fin {
 344  		return 0, nil, errors.New("previous message not read to completion")
 345  	}
 346  
 347  	h, err := c.readLoop(ctx)
 348  	if err != nil {
 349  		return 0, nil, err
 350  	}
 351  
 352  	if h.opcode == opContinuation {
 353  		err := errors.New("received continuation frame without text or binary frame")
 354  		c.writeError(StatusProtocolError, err)
 355  		return 0, nil, err
 356  	}
 357  
 358  	c.msgReader.reset(ctx, h)
 359  
 360  	return MessageType(h.opcode), c.msgReader, nil
 361  }
 362  
 363  type msgReader struct {
 364  	c *Conn
 365  
 366  	ctx         context.Context
 367  	flate       bool
 368  	flateReader io.Reader
 369  	flateBufio  *bufio.Reader
 370  	flateTail   strings.Reader
 371  	limitReader *limitReader
 372  	dict        *slidingWindow
 373  
 374  	fin           bool
 375  	payloadLength int64
 376  	maskKey       uint32
 377  
 378  	// util.ReaderFunc(mr.Read) to avoid continuous allocations.
 379  	readFunc util.ReaderFunc
 380  }
 381  
 382  func (mr *msgReader) reset(ctx context.Context, h header) {
 383  	mr.ctx = ctx
 384  	mr.flate = h.rsv1
 385  	mr.limitReader.reset(mr.readFunc)
 386  
 387  	if mr.flate {
 388  		mr.resetFlate()
 389  	}
 390  
 391  	mr.setFrame(h)
 392  }
 393  
 394  func (mr *msgReader) setFrame(h header) {
 395  	mr.fin = h.fin
 396  	mr.payloadLength = h.payloadLength
 397  	mr.maskKey = h.maskKey
 398  }
 399  
 400  func (mr *msgReader) Read(p []byte) (n int, err error) {
 401  	err = mr.c.readMu.lock(mr.ctx)
 402  	if err != nil {
 403  		return 0, fmt.Errorf("failed to read: %w", err)
 404  	}
 405  	defer mr.c.readMu.unlock()
 406  
 407  	n, err = mr.limitReader.Read(p)
 408  	if mr.flate && mr.flateContextTakeover() {
 409  		p = p[:n]
 410  		mr.dict.write(p)
 411  	}
 412  	if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
 413  		mr.putFlateReader()
 414  		return n, io.EOF
 415  	}
 416  	if err != nil {
 417  		return n, fmt.Errorf("failed to read: %w", err)
 418  	}
 419  	return n, nil
 420  }
 421  
 422  func (mr *msgReader) read(p []byte) (int, error) {
 423  	for {
 424  		if mr.payloadLength == 0 {
 425  			if mr.fin {
 426  				if mr.flate {
 427  					return mr.flateTail.Read(p)
 428  				}
 429  				return 0, io.EOF
 430  			}
 431  
 432  			h, err := mr.c.readLoop(mr.ctx)
 433  			if err != nil {
 434  				return 0, err
 435  			}
 436  			if h.opcode != opContinuation {
 437  				err := errors.New("received new data message without finishing the previous message")
 438  				mr.c.writeError(StatusProtocolError, err)
 439  				return 0, err
 440  			}
 441  			mr.setFrame(h)
 442  
 443  			continue
 444  		}
 445  
 446  		if int64(len(p)) > mr.payloadLength {
 447  			p = p[:mr.payloadLength]
 448  		}
 449  
 450  		n, err := mr.c.readFramePayload(mr.ctx, p)
 451  		if err != nil {
 452  			return n, err
 453  		}
 454  
 455  		mr.payloadLength -= int64(n)
 456  
 457  		if !mr.c.client {
 458  			mr.maskKey = mask(p, mr.maskKey)
 459  		}
 460  
 461  		return n, nil
 462  	}
 463  }
 464  
 465  type limitReader struct {
 466  	c     *Conn
 467  	r     io.Reader
 468  	limit xsync.Int64
 469  	n     int64
 470  }
 471  
 472  func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader {
 473  	lr := &limitReader{
 474  		c: c,
 475  	}
 476  	lr.limit.Store(limit)
 477  	lr.reset(r)
 478  	return lr
 479  }
 480  
 481  func (lr *limitReader) reset(r io.Reader) {
 482  	lr.n = lr.limit.Load()
 483  	lr.r = r
 484  }
 485  
 486  func (lr *limitReader) Read(p []byte) (int, error) {
 487  	if lr.n < 0 {
 488  		return lr.r.Read(p)
 489  	}
 490  
 491  	if lr.n == 0 {
 492  		err := fmt.Errorf("read limited at %v bytes", lr.limit.Load())
 493  		lr.c.writeError(StatusMessageTooBig, err)
 494  		return 0, err
 495  	}
 496  
 497  	if int64(len(p)) > lr.n {
 498  		p = p[:lr.n]
 499  	}
 500  	n, err := lr.r.Read(p)
 501  	lr.n -= int64(n)
 502  	if lr.n < 0 {
 503  		lr.n = 0
 504  	}
 505  	return n, err
 506  }
 507