ws_js.go raw

   1  package websocket // import "github.com/coder/websocket"
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"errors"
   7  	"fmt"
   8  	"io"
   9  	"net"
  10  	"net/http"
  11  	"reflect"
  12  	"runtime"
  13  	"strings"
  14  	"sync"
  15  	"syscall/js"
  16  
  17  	"github.com/coder/websocket/internal/bpool"
  18  	"github.com/coder/websocket/internal/wsjs"
  19  	"github.com/coder/websocket/internal/xsync"
  20  )
  21  
  22  // opcode represents a WebSocket opcode.
  23  type opcode int
  24  
  25  // https://tools.ietf.org/html/rfc6455#section-11.8.
  26  const (
  27  	opContinuation opcode = iota
  28  	opText
  29  	opBinary
  30  	// 3 - 7 are reserved for further non-control frames.
  31  	_
  32  	_
  33  	_
  34  	_
  35  	_
  36  	opClose
  37  	opPing
  38  	opPong
  39  	// 11-16 are reserved for further control frames.
  40  )
  41  
  42  // Conn provides a wrapper around the browser WebSocket API.
  43  type Conn struct {
  44  	noCopy noCopy
  45  	ws     wsjs.WebSocket
  46  
  47  	// read limit for a message in bytes.
  48  	msgReadLimit xsync.Int64
  49  
  50  	closeReadMu  sync.Mutex
  51  	closeReadCtx context.Context
  52  
  53  	closingMu     sync.Mutex
  54  	closeOnce     sync.Once
  55  	closed        chan struct{}
  56  	closeErrOnce  sync.Once
  57  	closeErr      error
  58  	closeWasClean bool
  59  
  60  	releaseOnClose   func()
  61  	releaseOnError   func()
  62  	releaseOnMessage func()
  63  
  64  	readSignal chan struct{}
  65  	readBufMu  sync.Mutex
  66  	readBuf    []wsjs.MessageEvent
  67  }
  68  
  69  func (c *Conn) close(err error, wasClean bool) {
  70  	c.closeOnce.Do(func() {
  71  		runtime.SetFinalizer(c, nil)
  72  
  73  		if !wasClean {
  74  			err = fmt.Errorf("unclean connection close: %w", err)
  75  		}
  76  		c.setCloseErr(err)
  77  		c.closeWasClean = wasClean
  78  		close(c.closed)
  79  	})
  80  }
  81  
  82  func (c *Conn) init() {
  83  	c.closed = make(chan struct{})
  84  	c.readSignal = make(chan struct{}, 1)
  85  
  86  	c.msgReadLimit.Store(32768)
  87  
  88  	c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
  89  		err := CloseError{
  90  			Code:   StatusCode(e.Code),
  91  			Reason: e.Reason,
  92  		}
  93  		// We do not know if we sent or received this close as
  94  		// its possible the browser triggered it without us
  95  		// explicitly sending it.
  96  		c.close(err, e.WasClean)
  97  
  98  		c.releaseOnClose()
  99  		c.releaseOnError()
 100  		c.releaseOnMessage()
 101  	})
 102  
 103  	c.releaseOnError = c.ws.OnError(func(v js.Value) {
 104  		c.setCloseErr(errors.New(v.Get("message").String()))
 105  		c.closeWithInternal()
 106  	})
 107  
 108  	c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
 109  		c.readBufMu.Lock()
 110  		defer c.readBufMu.Unlock()
 111  
 112  		c.readBuf = append(c.readBuf, e)
 113  
 114  		// Lets the read goroutine know there is definitely something in readBuf.
 115  		select {
 116  		case c.readSignal <- struct{}{}:
 117  		default:
 118  		}
 119  	})
 120  
 121  	runtime.SetFinalizer(c, func(c *Conn) {
 122  		c.setCloseErr(errors.New("connection garbage collected"))
 123  		c.closeWithInternal()
 124  	})
 125  }
 126  
 127  func (c *Conn) closeWithInternal() {
 128  	c.Close(StatusInternalError, "something went wrong")
 129  }
 130  
 131  // Read attempts to read a message from the connection.
 132  // The maximum time spent waiting is bounded by the context.
 133  func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
 134  	c.closeReadMu.Lock()
 135  	closedRead := c.closeReadCtx != nil
 136  	c.closeReadMu.Unlock()
 137  	if closedRead {
 138  		return 0, nil, errors.New("WebSocket connection read closed")
 139  	}
 140  
 141  	typ, p, err := c.read(ctx)
 142  	if err != nil {
 143  		return 0, nil, fmt.Errorf("failed to read: %w", err)
 144  	}
 145  	readLimit := c.msgReadLimit.Load()
 146  	if readLimit >= 0 && int64(len(p)) > readLimit {
 147  		err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
 148  		c.Close(StatusMessageTooBig, err.Error())
 149  		return 0, nil, err
 150  	}
 151  	return typ, p, nil
 152  }
 153  
 154  func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
 155  	select {
 156  	case <-ctx.Done():
 157  		c.Close(StatusPolicyViolation, "read timed out")
 158  		return 0, nil, ctx.Err()
 159  	case <-c.readSignal:
 160  	case <-c.closed:
 161  		return 0, nil, net.ErrClosed
 162  	}
 163  
 164  	c.readBufMu.Lock()
 165  	defer c.readBufMu.Unlock()
 166  
 167  	me := c.readBuf[0]
 168  	// We copy the messages forward and decrease the size
 169  	// of the slice to avoid reallocating.
 170  	copy(c.readBuf, c.readBuf[1:])
 171  	c.readBuf = c.readBuf[:len(c.readBuf)-1]
 172  
 173  	if len(c.readBuf) > 0 {
 174  		// Next time we read, we'll grab the message.
 175  		select {
 176  		case c.readSignal <- struct{}{}:
 177  		default:
 178  		}
 179  	}
 180  
 181  	switch p := me.Data.(type) {
 182  	case string:
 183  		return MessageText, []byte(p), nil
 184  	case []byte:
 185  		return MessageBinary, p, nil
 186  	default:
 187  		panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
 188  	}
 189  }
 190  
 191  // Ping is mocked out for Wasm.
 192  func (c *Conn) Ping(ctx context.Context) error {
 193  	return nil
 194  }
 195  
 196  // Write writes a message of the given type to the connection.
 197  // Always non blocking.
 198  func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
 199  	err := c.write(ctx, typ, p)
 200  	if err != nil {
 201  		// Have to ensure the WebSocket is closed after a write error
 202  		// to match the Go API. It can only error if the message type
 203  		// is unexpected or the passed bytes contain invalid UTF-8 for
 204  		// MessageText.
 205  		err := fmt.Errorf("failed to write: %w", err)
 206  		c.setCloseErr(err)
 207  		c.closeWithInternal()
 208  		return err
 209  	}
 210  	return nil
 211  }
 212  
 213  func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
 214  	if c.isClosed() {
 215  		return net.ErrClosed
 216  	}
 217  	switch typ {
 218  	case MessageBinary:
 219  		return c.ws.SendBytes(p)
 220  	case MessageText:
 221  		return c.ws.SendText(string(p))
 222  	default:
 223  		return fmt.Errorf("unexpected message type: %v", typ)
 224  	}
 225  }
 226  
 227  // Close closes the WebSocket with the given code and reason.
 228  // It will wait until the peer responds with a close frame
 229  // or the connection is closed.
 230  // It thus performs the full WebSocket close handshake.
 231  func (c *Conn) Close(code StatusCode, reason string) error {
 232  	err := c.exportedClose(code, reason)
 233  	if err != nil {
 234  		return fmt.Errorf("failed to close WebSocket: %w", err)
 235  	}
 236  	return nil
 237  }
 238  
 239  // CloseNow closes the WebSocket connection without attempting a close handshake.
 240  // Use when you do not want the overhead of the close handshake.
 241  //
 242  // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
 243  // a WebSocket without the close handshake.
 244  func (c *Conn) CloseNow() error {
 245  	return c.Close(StatusGoingAway, "")
 246  }
 247  
 248  func (c *Conn) exportedClose(code StatusCode, reason string) error {
 249  	c.closingMu.Lock()
 250  	defer c.closingMu.Unlock()
 251  
 252  	if c.isClosed() {
 253  		return net.ErrClosed
 254  	}
 255  
 256  	ce := fmt.Errorf("sent close: %w", CloseError{
 257  		Code:   code,
 258  		Reason: reason,
 259  	})
 260  
 261  	c.setCloseErr(ce)
 262  	err := c.ws.Close(int(code), reason)
 263  	if err != nil {
 264  		return err
 265  	}
 266  
 267  	<-c.closed
 268  	if !c.closeWasClean {
 269  		return c.closeErr
 270  	}
 271  	return nil
 272  }
 273  
 274  // Subprotocol returns the negotiated subprotocol.
 275  // An empty string means the default protocol.
 276  func (c *Conn) Subprotocol() string {
 277  	return c.ws.Subprotocol()
 278  }
 279  
 280  // DialOptions represents the options available to pass to Dial.
 281  type DialOptions struct {
 282  	// Subprotocols lists the subprotocols to negotiate with the server.
 283  	Subprotocols []string
 284  }
 285  
 286  // Dial creates a new WebSocket connection to the given url with the given options.
 287  // The passed context bounds the maximum time spent waiting for the connection to open.
 288  // The returned *http.Response is always nil or a mock. It's only in the signature
 289  // to match the core API.
 290  func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
 291  	c, resp, err := dial(ctx, url, opts)
 292  	if err != nil {
 293  		return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
 294  	}
 295  	return c, resp, nil
 296  }
 297  
 298  func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
 299  	if opts == nil {
 300  		opts = &DialOptions{}
 301  	}
 302  
 303  	url = strings.Replace(url, "http://", "ws://", 1)
 304  	url = strings.Replace(url, "https://", "wss://", 1)
 305  
 306  	ws, err := wsjs.New(url, opts.Subprotocols)
 307  	if err != nil {
 308  		return nil, nil, err
 309  	}
 310  
 311  	c := &Conn{
 312  		ws: ws,
 313  	}
 314  	c.init()
 315  
 316  	opench := make(chan struct{})
 317  	releaseOpen := ws.OnOpen(func(e js.Value) {
 318  		close(opench)
 319  	})
 320  	defer releaseOpen()
 321  
 322  	select {
 323  	case <-ctx.Done():
 324  		c.Close(StatusPolicyViolation, "dial timed out")
 325  		return nil, nil, ctx.Err()
 326  	case <-opench:
 327  		return c, &http.Response{
 328  			StatusCode: http.StatusSwitchingProtocols,
 329  		}, nil
 330  	case <-c.closed:
 331  		return nil, nil, net.ErrClosed
 332  	}
 333  }
 334  
 335  // Reader attempts to read a message from the connection.
 336  // The maximum time spent waiting is bounded by the context.
 337  func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
 338  	typ, p, err := c.Read(ctx)
 339  	if err != nil {
 340  		return 0, nil, err
 341  	}
 342  	return typ, bytes.NewReader(p), nil
 343  }
 344  
 345  // Writer returns a writer to write a WebSocket data message to the connection.
 346  // It buffers the entire message in memory and then sends it when the writer
 347  // is closed.
 348  func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
 349  	return &writer{
 350  		c:   c,
 351  		ctx: ctx,
 352  		typ: typ,
 353  		b:   bpool.Get(),
 354  	}, nil
 355  }
 356  
 357  type writer struct {
 358  	closed bool
 359  
 360  	c   *Conn
 361  	ctx context.Context
 362  	typ MessageType
 363  
 364  	b *bytes.Buffer
 365  }
 366  
 367  func (w *writer) Write(p []byte) (int, error) {
 368  	if w.closed {
 369  		return 0, errors.New("cannot write to closed writer")
 370  	}
 371  	n, err := w.b.Write(p)
 372  	if err != nil {
 373  		return n, fmt.Errorf("failed to write message: %w", err)
 374  	}
 375  	return n, nil
 376  }
 377  
 378  func (w *writer) Close() error {
 379  	if w.closed {
 380  		return errors.New("cannot close closed writer")
 381  	}
 382  	w.closed = true
 383  	defer bpool.Put(w.b)
 384  
 385  	err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
 386  	if err != nil {
 387  		return fmt.Errorf("failed to close writer: %w", err)
 388  	}
 389  	return nil
 390  }
 391  
 392  // CloseRead implements *Conn.CloseRead for wasm.
 393  func (c *Conn) CloseRead(ctx context.Context) context.Context {
 394  	c.closeReadMu.Lock()
 395  	ctx2 := c.closeReadCtx
 396  	if ctx2 != nil {
 397  		c.closeReadMu.Unlock()
 398  		return ctx2
 399  	}
 400  	ctx, cancel := context.WithCancel(ctx)
 401  	c.closeReadCtx = ctx
 402  	c.closeReadMu.Unlock()
 403  
 404  	go func() {
 405  		defer cancel()
 406  		defer c.CloseNow()
 407  		_, _, err := c.read(ctx)
 408  		if err != nil {
 409  			c.Close(StatusPolicyViolation, "unexpected data message")
 410  		}
 411  	}()
 412  	return ctx
 413  }
 414  
 415  // SetReadLimit implements *Conn.SetReadLimit for wasm.
 416  func (c *Conn) SetReadLimit(n int64) {
 417  	c.msgReadLimit.Store(n)
 418  }
 419  
 420  func (c *Conn) setCloseErr(err error) {
 421  	c.closeErrOnce.Do(func() {
 422  		c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
 423  	})
 424  }
 425  
 426  func (c *Conn) isClosed() bool {
 427  	select {
 428  	case <-c.closed:
 429  		return true
 430  	default:
 431  		return false
 432  	}
 433  }
 434  
 435  // AcceptOptions represents Accept's options.
 436  type AcceptOptions struct {
 437  	Subprotocols         []string
 438  	InsecureSkipVerify   bool
 439  	OriginPatterns       []string
 440  	CompressionMode      CompressionMode
 441  	CompressionThreshold int
 442  }
 443  
 444  // Accept is stubbed out for Wasm.
 445  func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
 446  	return nil, errors.New("unimplemented")
 447  }
 448  
 449  // StatusCode represents a WebSocket status code.
 450  // https://tools.ietf.org/html/rfc6455#section-7.4
 451  type StatusCode int
 452  
 453  // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
 454  //
 455  // These are only the status codes defined by the protocol.
 456  //
 457  // You can define custom codes in the 3000-4999 range.
 458  // The 3000-3999 range is reserved for use by libraries, frameworks and applications.
 459  // The 4000-4999 range is reserved for private use.
 460  const (
 461  	StatusNormalClosure   StatusCode = 1000
 462  	StatusGoingAway       StatusCode = 1001
 463  	StatusProtocolError   StatusCode = 1002
 464  	StatusUnsupportedData StatusCode = 1003
 465  
 466  	// 1004 is reserved and so unexported.
 467  	statusReserved StatusCode = 1004
 468  
 469  	// StatusNoStatusRcvd cannot be sent in a close message.
 470  	// It is reserved for when a close message is received without
 471  	// a status code.
 472  	StatusNoStatusRcvd StatusCode = 1005
 473  
 474  	// StatusAbnormalClosure is exported for use only with Wasm.
 475  	// In non Wasm Go, the returned error will indicate whether the
 476  	// connection was closed abnormally.
 477  	StatusAbnormalClosure StatusCode = 1006
 478  
 479  	StatusInvalidFramePayloadData StatusCode = 1007
 480  	StatusPolicyViolation         StatusCode = 1008
 481  	StatusMessageTooBig           StatusCode = 1009
 482  	StatusMandatoryExtension      StatusCode = 1010
 483  	StatusInternalError           StatusCode = 1011
 484  	StatusServiceRestart          StatusCode = 1012
 485  	StatusTryAgainLater           StatusCode = 1013
 486  	StatusBadGateway              StatusCode = 1014
 487  
 488  	// StatusTLSHandshake is only exported for use with Wasm.
 489  	// In non Wasm Go, the returned error will indicate whether there was
 490  	// a TLS handshake failure.
 491  	StatusTLSHandshake StatusCode = 1015
 492  )
 493  
 494  // CloseError is returned when the connection is closed with a status and reason.
 495  //
 496  // Use Go 1.13's errors.As to check for this error.
 497  // Also see the CloseStatus helper.
 498  type CloseError struct {
 499  	Code   StatusCode
 500  	Reason string
 501  }
 502  
 503  func (ce CloseError) Error() string {
 504  	return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
 505  }
 506  
 507  // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
 508  // the status code from a CloseError.
 509  //
 510  // -1 will be returned if the passed error is nil or not a CloseError.
 511  func CloseStatus(err error) StatusCode {
 512  	var ce CloseError
 513  	if errors.As(err, &ce) {
 514  		return ce.Code
 515  	}
 516  	return -1
 517  }
 518  
 519  // CompressionMode represents the modes available to the deflate extension.
 520  // See https://tools.ietf.org/html/rfc7692
 521  // Works in all browsers except Safari which does not implement the deflate extension.
 522  type CompressionMode int
 523  
 524  const (
 525  	// CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
 526  	// for every message. This applies to both server and client side.
 527  	//
 528  	// This means less efficient compression as the sliding window from previous messages
 529  	// will not be used but the memory overhead will be lower if the connections
 530  	// are long lived and seldom used.
 531  	//
 532  	// The message will only be compressed if greater than 512 bytes.
 533  	CompressionNoContextTakeover CompressionMode = iota
 534  
 535  	// CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
 536  	// This enables reusing the sliding window from previous messages.
 537  	// As most WebSocket protocols are repetitive, this can be very efficient.
 538  	// It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover.
 539  	//
 540  	// If the peer negotiates NoContextTakeover on the client or server side, it will be
 541  	// used instead as this is required by the RFC.
 542  	CompressionContextTakeover
 543  
 544  	// CompressionDisabled disables the deflate extension.
 545  	//
 546  	// Use this if you are using a predominantly binary protocol with very
 547  	// little duplication in between messages or CPU and memory are more
 548  	// important than bandwidth.
 549  	CompressionDisabled
 550  )
 551  
 552  // MessageType represents the type of a WebSocket message.
 553  // See https://tools.ietf.org/html/rfc6455#section-5.6
 554  type MessageType int
 555  
 556  // MessageType constants.
 557  const (
 558  	// MessageText is for UTF-8 encoded text messages like JSON.
 559  	MessageText MessageType = iota + 1
 560  	// MessageBinary is for binary messages like protobufs.
 561  	MessageBinary
 562  )
 563  
 564  type mu struct {
 565  	c  *Conn
 566  	ch chan struct{}
 567  }
 568  
 569  func newMu(c *Conn) *mu {
 570  	return &mu{
 571  		c:  c,
 572  		ch: make(chan struct{}, 1),
 573  	}
 574  }
 575  
 576  func (m *mu) forceLock() {
 577  	m.ch <- struct{}{}
 578  }
 579  
 580  func (m *mu) tryLock() bool {
 581  	select {
 582  	case m.ch <- struct{}{}:
 583  		return true
 584  	default:
 585  		return false
 586  	}
 587  }
 588  
 589  func (m *mu) unlock() {
 590  	select {
 591  	case <-m.ch:
 592  	default:
 593  	}
 594  }
 595  
 596  type noCopy struct{}
 597  
 598  func (*noCopy) Lock() {}
 599