close.go raw

   1  //go:build !js
   2  // +build !js
   3  
   4  package websocket
   5  
   6  import (
   7  	"context"
   8  	"encoding/binary"
   9  	"errors"
  10  	"fmt"
  11  	"net"
  12  	"time"
  13  
  14  	"github.com/coder/websocket/internal/errd"
  15  )
  16  
  17  // StatusCode represents a WebSocket status code.
  18  // https://tools.ietf.org/html/rfc6455#section-7.4
  19  type StatusCode int
  20  
  21  // https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
  22  //
  23  // These are only the status codes defined by the protocol.
  24  //
  25  // You can define custom codes in the 3000-4999 range.
  26  // The 3000-3999 range is reserved for use by libraries, frameworks and applications.
  27  // The 4000-4999 range is reserved for private use.
  28  const (
  29  	StatusNormalClosure   StatusCode = 1000
  30  	StatusGoingAway       StatusCode = 1001
  31  	StatusProtocolError   StatusCode = 1002
  32  	StatusUnsupportedData StatusCode = 1003
  33  
  34  	// 1004 is reserved and so unexported.
  35  	statusReserved StatusCode = 1004
  36  
  37  	// StatusNoStatusRcvd cannot be sent in a close message.
  38  	// It is reserved for when a close message is received without
  39  	// a status code.
  40  	StatusNoStatusRcvd StatusCode = 1005
  41  
  42  	// StatusAbnormalClosure is exported for use only with Wasm.
  43  	// In non Wasm Go, the returned error will indicate whether the
  44  	// connection was closed abnormally.
  45  	StatusAbnormalClosure StatusCode = 1006
  46  
  47  	StatusInvalidFramePayloadData StatusCode = 1007
  48  	StatusPolicyViolation         StatusCode = 1008
  49  	StatusMessageTooBig           StatusCode = 1009
  50  	StatusMandatoryExtension      StatusCode = 1010
  51  	StatusInternalError           StatusCode = 1011
  52  	StatusServiceRestart          StatusCode = 1012
  53  	StatusTryAgainLater           StatusCode = 1013
  54  	StatusBadGateway              StatusCode = 1014
  55  
  56  	// StatusTLSHandshake is only exported for use with Wasm.
  57  	// In non Wasm Go, the returned error will indicate whether there was
  58  	// a TLS handshake failure.
  59  	StatusTLSHandshake StatusCode = 1015
  60  )
  61  
  62  // CloseError is returned when the connection is closed with a status and reason.
  63  //
  64  // Use Go 1.13's errors.As to check for this error.
  65  // Also see the CloseStatus helper.
  66  type CloseError struct {
  67  	Code   StatusCode
  68  	Reason string
  69  }
  70  
  71  func (ce CloseError) Error() string {
  72  	return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason)
  73  }
  74  
  75  // CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab
  76  // the status code from a CloseError.
  77  //
  78  // -1 will be returned if the passed error is nil or not a CloseError.
  79  func CloseStatus(err error) StatusCode {
  80  	var ce CloseError
  81  	if errors.As(err, &ce) {
  82  		return ce.Code
  83  	}
  84  	return -1
  85  }
  86  
  87  // Close performs the WebSocket close handshake with the given status code and reason.
  88  //
  89  // It will write a WebSocket close frame with a timeout of 5s and then wait 5s for
  90  // the peer to send a close frame.
  91  // All data messages received from the peer during the close handshake will be discarded.
  92  //
  93  // The connection can only be closed once. Additional calls to Close
  94  // are no-ops.
  95  //
  96  // The maximum length of reason must be 125 bytes. Avoid sending a dynamic reason.
  97  //
  98  // Close will unblock all goroutines interacting with the connection once
  99  // complete.
 100  func (c *Conn) Close(code StatusCode, reason string) (err error) {
 101  	defer errd.Wrap(&err, "failed to close WebSocket")
 102  
 103  	if !c.casClosing() {
 104  		err = c.waitGoroutines()
 105  		if err != nil {
 106  			return err
 107  		}
 108  		return net.ErrClosed
 109  	}
 110  	defer func() {
 111  		if errors.Is(err, net.ErrClosed) {
 112  			err = nil
 113  		}
 114  	}()
 115  
 116  	err = c.closeHandshake(code, reason)
 117  
 118  	err2 := c.close()
 119  	if err == nil && err2 != nil {
 120  		err = err2
 121  	}
 122  
 123  	err2 = c.waitGoroutines()
 124  	if err == nil && err2 != nil {
 125  		err = err2
 126  	}
 127  
 128  	return err
 129  }
 130  
 131  // CloseNow closes the WebSocket connection without attempting a close handshake.
 132  // Use when you do not want the overhead of the close handshake.
 133  func (c *Conn) CloseNow() (err error) {
 134  	defer errd.Wrap(&err, "failed to immediately close WebSocket")
 135  
 136  	if !c.casClosing() {
 137  		err = c.waitGoroutines()
 138  		if err != nil {
 139  			return err
 140  		}
 141  		return net.ErrClosed
 142  	}
 143  	defer func() {
 144  		if errors.Is(err, net.ErrClosed) {
 145  			err = nil
 146  		}
 147  	}()
 148  
 149  	err = c.close()
 150  
 151  	err2 := c.waitGoroutines()
 152  	if err == nil && err2 != nil {
 153  		err = err2
 154  	}
 155  	return err
 156  }
 157  
 158  func (c *Conn) closeHandshake(code StatusCode, reason string) error {
 159  	err := c.writeClose(code, reason)
 160  	if err != nil {
 161  		return err
 162  	}
 163  
 164  	err = c.waitCloseHandshake()
 165  	if CloseStatus(err) != code {
 166  		return err
 167  	}
 168  	return nil
 169  }
 170  
 171  func (c *Conn) writeClose(code StatusCode, reason string) error {
 172  	ce := CloseError{
 173  		Code:   code,
 174  		Reason: reason,
 175  	}
 176  
 177  	var p []byte
 178  	var err error
 179  	if ce.Code != StatusNoStatusRcvd {
 180  		p, err = ce.bytes()
 181  		if err != nil {
 182  			return err
 183  		}
 184  	}
 185  
 186  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
 187  	defer cancel()
 188  
 189  	err = c.writeControl(ctx, opClose, p)
 190  	// If the connection closed as we're writing we ignore the error as we might
 191  	// have written the close frame, the peer responded and then someone else read it
 192  	// and closed the connection.
 193  	if err != nil && !errors.Is(err, net.ErrClosed) {
 194  		return err
 195  	}
 196  	return nil
 197  }
 198  
 199  func (c *Conn) waitCloseHandshake() error {
 200  	ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
 201  	defer cancel()
 202  
 203  	err := c.readMu.lock(ctx)
 204  	if err != nil {
 205  		return err
 206  	}
 207  	defer c.readMu.unlock()
 208  
 209  	for i := int64(0); i < c.msgReader.payloadLength; i++ {
 210  		_, err := c.br.ReadByte()
 211  		if err != nil {
 212  			return err
 213  		}
 214  	}
 215  
 216  	for {
 217  		h, err := c.readLoop(ctx)
 218  		if err != nil {
 219  			return err
 220  		}
 221  
 222  		for i := int64(0); i < h.payloadLength; i++ {
 223  			_, err := c.br.ReadByte()
 224  			if err != nil {
 225  				return err
 226  			}
 227  		}
 228  	}
 229  }
 230  
 231  func (c *Conn) waitGoroutines() error {
 232  	t := time.NewTimer(time.Second * 15)
 233  	defer t.Stop()
 234  
 235  	select {
 236  	case <-c.timeoutLoopDone:
 237  	case <-t.C:
 238  		return errors.New("failed to wait for timeoutLoop goroutine to exit")
 239  	}
 240  
 241  	c.closeReadMu.Lock()
 242  	closeRead := c.closeReadCtx != nil
 243  	c.closeReadMu.Unlock()
 244  	if closeRead {
 245  		select {
 246  		case <-c.closeReadDone:
 247  		case <-t.C:
 248  			return errors.New("failed to wait for close read goroutine to exit")
 249  		}
 250  	}
 251  
 252  	select {
 253  	case <-c.closed:
 254  	case <-t.C:
 255  		return errors.New("failed to wait for connection to be closed")
 256  	}
 257  
 258  	return nil
 259  }
 260  
 261  func parseClosePayload(p []byte) (CloseError, error) {
 262  	if len(p) == 0 {
 263  		return CloseError{
 264  			Code: StatusNoStatusRcvd,
 265  		}, nil
 266  	}
 267  
 268  	if len(p) < 2 {
 269  		return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p)
 270  	}
 271  
 272  	ce := CloseError{
 273  		Code:   StatusCode(binary.BigEndian.Uint16(p)),
 274  		Reason: string(p[2:]),
 275  	}
 276  
 277  	if !validWireCloseCode(ce.Code) {
 278  		return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code)
 279  	}
 280  
 281  	return ce, nil
 282  }
 283  
 284  // See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number
 285  // and https://tools.ietf.org/html/rfc6455#section-7.4.1
 286  func validWireCloseCode(code StatusCode) bool {
 287  	switch code {
 288  	case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake:
 289  		return false
 290  	}
 291  
 292  	if code >= StatusNormalClosure && code <= StatusBadGateway {
 293  		return true
 294  	}
 295  	if code >= 3000 && code <= 4999 {
 296  		return true
 297  	}
 298  
 299  	return false
 300  }
 301  
 302  func (ce CloseError) bytes() ([]byte, error) {
 303  	p, err := ce.bytesErr()
 304  	if err != nil {
 305  		err = fmt.Errorf("failed to marshal close frame: %w", err)
 306  		ce = CloseError{
 307  			Code: StatusInternalError,
 308  		}
 309  		p, _ = ce.bytesErr()
 310  	}
 311  	return p, err
 312  }
 313  
 314  const maxCloseReason = maxControlPayload - 2
 315  
 316  func (ce CloseError) bytesErr() ([]byte, error) {
 317  	if len(ce.Reason) > maxCloseReason {
 318  		return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason))
 319  	}
 320  
 321  	if !validWireCloseCode(ce.Code) {
 322  		return nil, fmt.Errorf("status code %v cannot be set", ce.Code)
 323  	}
 324  
 325  	buf := make([]byte, 2+len(ce.Reason))
 326  	binary.BigEndian.PutUint16(buf, uint16(ce.Code))
 327  	copy(buf[2:], ce.Reason)
 328  	return buf, nil
 329  }
 330  
 331  func (c *Conn) casClosing() bool {
 332  	c.closeMu.Lock()
 333  	defer c.closeMu.Unlock()
 334  	if !c.closing {
 335  		c.closing = true
 336  		return true
 337  	}
 338  	return false
 339  }
 340  
 341  func (c *Conn) isClosed() bool {
 342  	select {
 343  	case <-c.closed:
 344  		return true
 345  	default:
 346  		return false
 347  	}
 348  }
 349