ws_js.go raw
1 package websocket // import "github.com/coder/websocket"
2
3 import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "net"
10 "net/http"
11 "reflect"
12 "runtime"
13 "strings"
14 "sync"
15 "syscall/js"
16
17 "github.com/coder/websocket/internal/bpool"
18 "github.com/coder/websocket/internal/wsjs"
19 "github.com/coder/websocket/internal/xsync"
20 )
21
22 // opcode represents a WebSocket opcode.
23 type opcode int
24
25 // https://tools.ietf.org/html/rfc6455#section-11.8.
26 const (
27 opContinuation opcode = iota
28 opText
29 opBinary
30 // 3 - 7 are reserved for further non-control frames.
31 _
32 _
33 _
34 _
35 _
36 opClose
37 opPing
38 opPong
39 // 11-16 are reserved for further control frames.
40 )
41
42 // Conn provides a wrapper around the browser WebSocket API.
43 type Conn struct {
44 noCopy noCopy
45 ws wsjs.WebSocket
46
47 // read limit for a message in bytes.
48 msgReadLimit xsync.Int64
49
50 closeReadMu sync.Mutex
51 closeReadCtx context.Context
52
53 closingMu sync.Mutex
54 closeOnce sync.Once
55 closed chan struct{}
56 closeErrOnce sync.Once
57 closeErr error
58 closeWasClean bool
59
60 releaseOnClose func()
61 releaseOnError func()
62 releaseOnMessage func()
63
64 readSignal chan struct{}
65 readBufMu sync.Mutex
66 readBuf []wsjs.MessageEvent
67 }
68
69 func (c *Conn) close(err error, wasClean bool) {
70 c.closeOnce.Do(func() {
71 runtime.SetFinalizer(c, nil)
72
73 if !wasClean {
74 err = fmt.Errorf("unclean connection close: %w", err)
75 }
76 c.setCloseErr(err)
77 c.closeWasClean = wasClean
78 close(c.closed)
79 })
80 }
81
82 func (c *Conn) init() {
83 c.closed = make(chan struct{})
84 c.readSignal = make(chan struct{}, 1)
85
86 c.msgReadLimit.Store(32768)
87
88 c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) {
89 err := CloseError{
90 Code: StatusCode(e.Code),
91 Reason: e.Reason,
92 }
93 // We do not know if we sent or received this close as
94 // its possible the browser triggered it without us
95 // explicitly sending it.
96 c.close(err, e.WasClean)
97
98 c.releaseOnClose()
99 c.releaseOnError()
100 c.releaseOnMessage()
101 })
102
103 c.releaseOnError = c.ws.OnError(func(v js.Value) {
104 c.setCloseErr(errors.New(v.Get("message").String()))
105 c.closeWithInternal()
106 })
107
108 c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) {
109 c.readBufMu.Lock()
110 defer c.readBufMu.Unlock()
111
112 c.readBuf = append(c.readBuf, e)
113
114 // Lets the read goroutine know there is definitely something in readBuf.
115 select {
116 case c.readSignal <- struct{}{}:
117 default:
118 }
119 })
120
121 runtime.SetFinalizer(c, func(c *Conn) {
122 c.setCloseErr(errors.New("connection garbage collected"))
123 c.closeWithInternal()
124 })
125 }
126
127 func (c *Conn) closeWithInternal() {
128 c.Close(StatusInternalError, "something went wrong")
129 }
130
131 // Read attempts to read a message from the connection.
132 // The maximum time spent waiting is bounded by the context.
133 func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) {
134 c.closeReadMu.Lock()
135 closedRead := c.closeReadCtx != nil
136 c.closeReadMu.Unlock()
137 if closedRead {
138 return 0, nil, errors.New("WebSocket connection read closed")
139 }
140
141 typ, p, err := c.read(ctx)
142 if err != nil {
143 return 0, nil, fmt.Errorf("failed to read: %w", err)
144 }
145 readLimit := c.msgReadLimit.Load()
146 if readLimit >= 0 && int64(len(p)) > readLimit {
147 err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load())
148 c.Close(StatusMessageTooBig, err.Error())
149 return 0, nil, err
150 }
151 return typ, p, nil
152 }
153
154 func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) {
155 select {
156 case <-ctx.Done():
157 c.Close(StatusPolicyViolation, "read timed out")
158 return 0, nil, ctx.Err()
159 case <-c.readSignal:
160 case <-c.closed:
161 return 0, nil, net.ErrClosed
162 }
163
164 c.readBufMu.Lock()
165 defer c.readBufMu.Unlock()
166
167 me := c.readBuf[0]
168 // We copy the messages forward and decrease the size
169 // of the slice to avoid reallocating.
170 copy(c.readBuf, c.readBuf[1:])
171 c.readBuf = c.readBuf[:len(c.readBuf)-1]
172
173 if len(c.readBuf) > 0 {
174 // Next time we read, we'll grab the message.
175 select {
176 case c.readSignal <- struct{}{}:
177 default:
178 }
179 }
180
181 switch p := me.Data.(type) {
182 case string:
183 return MessageText, []byte(p), nil
184 case []byte:
185 return MessageBinary, p, nil
186 default:
187 panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String())
188 }
189 }
190
191 // Ping is mocked out for Wasm.
192 func (c *Conn) Ping(ctx context.Context) error {
193 return nil
194 }
195
196 // Write writes a message of the given type to the connection.
197 // Always non blocking.
198 func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error {
199 err := c.write(ctx, typ, p)
200 if err != nil {
201 // Have to ensure the WebSocket is closed after a write error
202 // to match the Go API. It can only error if the message type
203 // is unexpected or the passed bytes contain invalid UTF-8 for
204 // MessageText.
205 err := fmt.Errorf("failed to write: %w", err)
206 c.setCloseErr(err)
207 c.closeWithInternal()
208 return err
209 }
210 return nil
211 }
212
213 func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error {
214 if c.isClosed() {
215 return net.ErrClosed
216 }
217 switch typ {
218 case MessageBinary:
219 return c.ws.SendBytes(p)
220 case MessageText:
221 return c.ws.SendText(string(p))
222 default:
223 return fmt.Errorf("unexpected message type: %v", typ)
224 }
225 }
226
227 // Close closes the WebSocket with the given code and reason.
228 // It will wait until the peer responds with a close frame
229 // or the connection is closed.
230 // It thus performs the full WebSocket close handshake.
231 func (c *Conn) Close(code StatusCode, reason string) error {
232 err := c.exportedClose(code, reason)
233 if err != nil {
234 return fmt.Errorf("failed to close WebSocket: %w", err)
235 }
236 return nil
237 }
238
239 // CloseNow closes the WebSocket connection without attempting a close handshake.
240 // Use when you do not want the overhead of the close handshake.
241 //
242 // note: No different from Close(StatusGoingAway, "") in WASM as there is no way to close
243 // a WebSocket without the close handshake.
244 func (c *Conn) CloseNow() error {
245 return c.Close(StatusGoingAway, "")
246 }
247
248 func (c *Conn) exportedClose(code StatusCode, reason string) error {
249 c.closingMu.Lock()
250 defer c.closingMu.Unlock()
251
252 if c.isClosed() {
253 return net.ErrClosed
254 }
255
256 ce := fmt.Errorf("sent close: %w", CloseError{
257 Code: code,
258 Reason: reason,
259 })
260
261 c.setCloseErr(ce)
262 err := c.ws.Close(int(code), reason)
263 if err != nil {
264 return err
265 }
266
267 <-c.closed
268 if !c.closeWasClean {
269 return c.closeErr
270 }
271 return nil
272 }
273
274 // Subprotocol returns the negotiated subprotocol.
275 // An empty string means the default protocol.
276 func (c *Conn) Subprotocol() string {
277 return c.ws.Subprotocol()
278 }
279
280 // DialOptions represents the options available to pass to Dial.
281 type DialOptions struct {
282 // Subprotocols lists the subprotocols to negotiate with the server.
283 Subprotocols []string
284 }
285
286 // Dial creates a new WebSocket connection to the given url with the given options.
287 // The passed context bounds the maximum time spent waiting for the connection to open.
288 // The returned *http.Response is always nil or a mock. It's only in the signature
289 // to match the core API.
290 func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
291 c, resp, err := dial(ctx, url, opts)
292 if err != nil {
293 return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err)
294 }
295 return c, resp, nil
296 }
297
298 func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) {
299 if opts == nil {
300 opts = &DialOptions{}
301 }
302
303 url = strings.Replace(url, "http://", "ws://", 1)
304 url = strings.Replace(url, "https://", "wss://", 1)
305
306 ws, err := wsjs.New(url, opts.Subprotocols)
307 if err != nil {
308 return nil, nil, err
309 }
310
311 c := &Conn{
312 ws: ws,
313 }
314 c.init()
315
316 opench := make(chan struct{})
317 releaseOpen := ws.OnOpen(func(e js.Value) {
318 close(opench)
319 })
320 defer releaseOpen()
321
322 select {
323 case <-ctx.Done():
324 c.Close(StatusPolicyViolation, "dial timed out")
325 return nil, nil, ctx.Err()
326 case <-opench:
327 return c, &http.Response{
328 StatusCode: http.StatusSwitchingProtocols,
329 }, nil
330 case <-c.closed:
331 return nil, nil, net.ErrClosed
332 }
333 }
334
335 // Reader attempts to read a message from the connection.
336 // The maximum time spent waiting is bounded by the context.
337 func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
338 typ, p, err := c.Read(ctx)
339 if err != nil {
340 return 0, nil, err
341 }
342 return typ, bytes.NewReader(p), nil
343 }
344
345 // Writer returns a writer to write a WebSocket data message to the connection.
346 // It buffers the entire message in memory and then sends it when the writer
347 // is closed.
348 func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) {
349 return &writer{
350 c: c,
351 ctx: ctx,
352 typ: typ,
353 b: bpool.Get(),
354 }, nil
355 }
356
357 type writer struct {
358 closed bool
359
360 c *Conn
361 ctx context.Context
362 typ MessageType
363
364 b *bytes.Buffer
365 }
366
367 func (w *writer) Write(p []byte) (int, error) {
368 if w.closed {
369 return 0, errors.New("cannot write to closed writer")
370 }
371 n, err := w.b.Write(p)
372 if err != nil {
373 return n, fmt.Errorf("failed to write message: %w", err)
374 }
375 return n, nil
376 }
377
378 func (w *writer) Close() error {
379 if w.closed {
380 return errors.New("cannot close closed writer")
381 }
382 w.closed = true
383 defer bpool.Put(w.b)
384
385 err := w.c.Write(w.ctx, w.typ, w.b.Bytes())
386 if err != nil {
387 return fmt.Errorf("failed to close writer: %w", err)
388 }
389 return nil
390 }
391
392 // CloseRead implements *Conn.CloseRead for wasm.
393 func (c *Conn) CloseRead(ctx context.Context) context.Context {
394 c.closeReadMu.Lock()
395 ctx2 := c.closeReadCtx
396 if ctx2 != nil {
397 c.closeReadMu.Unlock()
398 return ctx2
399 }
400 ctx, cancel := context.WithCancel(ctx)
401 c.closeReadCtx = ctx
402 c.closeReadMu.Unlock()
403
404 go func() {
405 defer cancel()
406 defer c.CloseNow()
407 _, _, err := c.read(ctx)
408 if err != nil {
409 c.Close(StatusPolicyViolation, "unexpected data message")
410 }
411 }()
412 return ctx
413 }
414
415 // SetReadLimit implements *Conn.SetReadLimit for wasm.
416 func (c *Conn) SetReadLimit(n int64) {
417 c.msgReadLimit.Store(n)
418 }
419
420 func (c *Conn) setCloseErr(err error) {
421 c.closeErrOnce.Do(func() {
422 c.closeErr = fmt.Errorf("WebSocket closed: %w", err)
423 })
424 }
425
426 func (c *Conn) isClosed() bool {
427 select {
428 case <-c.closed:
429 return true
430 default:
431 return false
432 }
433 }
434
435 // AcceptOptions represents Accept's options.
436 type AcceptOptions struct {
437 Subprotocols []string
438 InsecureSkipVerify bool
439 OriginPatterns []string
440 CompressionMode CompressionMode
441 CompressionThreshold int
442 }
443
444 // Accept is stubbed out for Wasm.
445 func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
446 return nil, errors.New("unimplemented")
447 }
448
449 // StatusCode represents a WebSocket status code.
450 // https://tools.ietf.org/html/rfc6455#section-7.4
451 type StatusCode int
452
453 // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
454 //
455 // These are only the status codes defined by the protocol.
456 //
457 // You can define custom codes in the 3000-4999 range.
458 // The 3000-3999 range is reserved for use by libraries, frameworks and applications.
459 // The 4000-4999 range is reserved for private use.
460 const (
461 StatusNormalClosure StatusCode = 1000
462 StatusGoingAway StatusCode = 1001
463 StatusProtocolError StatusCode = 1002
464 StatusUnsupportedData StatusCode = 1003
465
466 // 1004 is reserved and so unexported.
467 statusReserved StatusCode = 1004
468
469 // StatusNoStatusRcvd cannot be sent in a close message.
470 // It is reserved for when a close message is received without
471 // a status code.
472 StatusNoStatusRcvd StatusCode = 1005
473
474 // StatusAbnormalClosure is exported for use only with Wasm.
475 // In non Wasm Go, the returned error will indicate whether the
476 // connection was closed abnormally.
477 StatusAbnormalClosure StatusCode = 1006
478
479 StatusInvalidFramePayloadData StatusCode = 1007
480 StatusPolicyViolation StatusCode = 1008
481 StatusMessageTooBig StatusCode = 1009
482 StatusMandatoryExtension StatusCode = 1010
483 StatusInternalError StatusCode = 1011
484 StatusServiceRestart StatusCode = 1012
485 StatusTryAgainLater StatusCode = 1013
486 StatusBadGateway StatusCode = 1014
487
488 // StatusTLSHandshake is only exported for use with Wasm.
489 // In non Wasm Go, the returned error will indicate whether there was
490 // a TLS handshake failure.
491 StatusTLSHandshake StatusCode = 1015
492 )
493
494 // CloseError is returned when the connection is closed with a status and reason.
495 //
496 // Use Go 1.13's errors.As to check for this error.
497 // Also see the CloseStatus helper.
498 type CloseError struct {
499 Code StatusCode
500 Reason string
501 }
502
503 func (ce CloseError) Error() string {
504 return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
505 }
506
507 // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
508 // the status code from a CloseError.
509 //
510 // -1 will be returned if the passed error is nil or not a CloseError.
511 func CloseStatus(err error) StatusCode {
512 var ce CloseError
513 if errors.As(err, &ce) {
514 return ce.Code
515 }
516 return -1
517 }
518
519 // CompressionMode represents the modes available to the deflate extension.
520 // See https://tools.ietf.org/html/rfc7692
521 // Works in all browsers except Safari which does not implement the deflate extension.
522 type CompressionMode int
523
524 const (
525 // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed
526 // for every message. This applies to both server and client side.
527 //
528 // This means less efficient compression as the sliding window from previous messages
529 // will not be used but the memory overhead will be lower if the connections
530 // are long lived and seldom used.
531 //
532 // The message will only be compressed if greater than 512 bytes.
533 CompressionNoContextTakeover CompressionMode = iota
534
535 // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection.
536 // This enables reusing the sliding window from previous messages.
537 // As most WebSocket protocols are repetitive, this can be very efficient.
538 // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover.
539 //
540 // If the peer negotiates NoContextTakeover on the client or server side, it will be
541 // used instead as this is required by the RFC.
542 CompressionContextTakeover
543
544 // CompressionDisabled disables the deflate extension.
545 //
546 // Use this if you are using a predominantly binary protocol with very
547 // little duplication in between messages or CPU and memory are more
548 // important than bandwidth.
549 CompressionDisabled
550 )
551
552 // MessageType represents the type of a WebSocket message.
553 // See https://tools.ietf.org/html/rfc6455#section-5.6
554 type MessageType int
555
556 // MessageType constants.
557 const (
558 // MessageText is for UTF-8 encoded text messages like JSON.
559 MessageText MessageType = iota + 1
560 // MessageBinary is for binary messages like protobufs.
561 MessageBinary
562 )
563
564 type mu struct {
565 c *Conn
566 ch chan struct{}
567 }
568
569 func newMu(c *Conn) *mu {
570 return &mu{
571 c: c,
572 ch: make(chan struct{}, 1),
573 }
574 }
575
576 func (m *mu) forceLock() {
577 m.ch <- struct{}{}
578 }
579
580 func (m *mu) tryLock() bool {
581 select {
582 case m.ch <- struct{}{}:
583 return true
584 default:
585 return false
586 }
587 }
588
589 func (m *mu) unlock() {
590 select {
591 case <-m.ch:
592 default:
593 }
594 }
595
596 type noCopy struct{}
597
598 func (*noCopy) Lock() {}
599