server.go raw

   1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package websocket
   6  
   7  import (
   8  	"bufio"
   9  	"errors"
  10  	"io"
  11  	"net/http"
  12  	"net/url"
  13  	"strings"
  14  	"time"
  15  )
  16  
  17  // HandshakeError describes an error with the handshake from the peer.
  18  type HandshakeError struct {
  19  	message string
  20  }
  21  
  22  func (e HandshakeError) Error() string { return e.message }
  23  
  24  // Upgrader specifies parameters for upgrading an HTTP connection to a
  25  // WebSocket connection.
  26  //
  27  // It is safe to call Upgrader's methods concurrently.
  28  type Upgrader struct {
  29  	// HandshakeTimeout specifies the duration for the handshake to complete.
  30  	HandshakeTimeout time.Duration
  31  
  32  	// ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer
  33  	// size is zero, then buffers allocated by the HTTP server are used. The
  34  	// I/O buffer sizes do not limit the size of the messages that can be sent
  35  	// or received.
  36  	ReadBufferSize, WriteBufferSize int
  37  
  38  	// WriteBufferPool is a pool of buffers for write operations. If the value
  39  	// is not set, then write buffers are allocated to the connection for the
  40  	// lifetime of the connection.
  41  	//
  42  	// A pool is most useful when the application has a modest volume of writes
  43  	// across a large number of connections.
  44  	//
  45  	// Applications should use a single pool for each unique value of
  46  	// WriteBufferSize.
  47  	WriteBufferPool BufferPool
  48  
  49  	// Subprotocols specifies the server's supported protocols in order of
  50  	// preference. If this field is not nil, then the Upgrade method negotiates a
  51  	// subprotocol by selecting the first match in this list with a protocol
  52  	// requested by the client. If there's no match, then no protocol is
  53  	// negotiated (the Sec-Websocket-Protocol header is not included in the
  54  	// handshake response).
  55  	Subprotocols []string
  56  
  57  	// Error specifies the function for generating HTTP error responses. If Error
  58  	// is nil, then http.Error is used to generate the HTTP response.
  59  	Error func(w http.ResponseWriter, r *http.Request, status int, reason error)
  60  
  61  	// CheckOrigin returns true if the request Origin header is acceptable. If
  62  	// CheckOrigin is nil, then a safe default is used: return false if the
  63  	// Origin request header is present and the origin host is not equal to
  64  	// request Host header.
  65  	//
  66  	// A CheckOrigin function should carefully validate the request origin to
  67  	// prevent cross-site request forgery.
  68  	CheckOrigin func(r *http.Request) bool
  69  
  70  	// EnableCompression specify if the server should attempt to negotiate per
  71  	// message compression (RFC 7692). Setting this value to true does not
  72  	// guarantee that compression will be supported. Currently only "no context
  73  	// takeover" modes are supported.
  74  	EnableCompression bool
  75  }
  76  
  77  func (u *Upgrader) returnError(w http.ResponseWriter, r *http.Request, status int, reason string) (*Conn, error) {
  78  	err := HandshakeError{reason}
  79  	if u.Error != nil {
  80  		u.Error(w, r, status, err)
  81  	} else {
  82  		w.Header().Set("Sec-Websocket-Version", "13")
  83  		http.Error(w, http.StatusText(status), status)
  84  	}
  85  	return nil, err
  86  }
  87  
  88  // checkSameOrigin returns true if the origin is not set or is equal to the request host.
  89  func checkSameOrigin(r *http.Request) bool {
  90  	origin := r.Header["Origin"]
  91  	if len(origin) == 0 {
  92  		return true
  93  	}
  94  	u, err := url.Parse(origin[0])
  95  	if err != nil {
  96  		return false
  97  	}
  98  	return equalASCIIFold(u.Host, r.Host)
  99  }
 100  
 101  func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header) string {
 102  	if u.Subprotocols != nil {
 103  		clientProtocols := Subprotocols(r)
 104  		for _, serverProtocol := range u.Subprotocols {
 105  			for _, clientProtocol := range clientProtocols {
 106  				if clientProtocol == serverProtocol {
 107  					return clientProtocol
 108  				}
 109  			}
 110  		}
 111  	} else if responseHeader != nil {
 112  		return responseHeader.Get("Sec-Websocket-Protocol")
 113  	}
 114  	return ""
 115  }
 116  
 117  // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
 118  //
 119  // The responseHeader is included in the response to the client's upgrade
 120  // request. Use the responseHeader to specify cookies (Set-Cookie). To specify
 121  // subprotocols supported by the server, set Upgrader.Subprotocols directly.
 122  //
 123  // If the upgrade fails, then Upgrade replies to the client with an HTTP error
 124  // response.
 125  func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) {
 126  	const badHandshake = "websocket: the client is not using the websocket protocol: "
 127  
 128  	if !tokenListContainsValue(r.Header, "Connection", "upgrade") {
 129  		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'upgrade' token not found in 'Connection' header")
 130  	}
 131  
 132  	if !tokenListContainsValue(r.Header, "Upgrade", "websocket") {
 133  		return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
 134  	}
 135  
 136  	if r.Method != http.MethodGet {
 137  		return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
 138  	}
 139  
 140  	if !tokenListContainsValue(r.Header, "Sec-Websocket-Version", "13") {
 141  		return u.returnError(w, r, http.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header")
 142  	}
 143  
 144  	if _, ok := responseHeader["Sec-Websocket-Extensions"]; ok {
 145  		return u.returnError(w, r, http.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported")
 146  	}
 147  
 148  	checkOrigin := u.CheckOrigin
 149  	if checkOrigin == nil {
 150  		checkOrigin = checkSameOrigin
 151  	}
 152  	if !checkOrigin(r) {
 153  		return u.returnError(w, r, http.StatusForbidden, "websocket: request origin not allowed by Upgrader.CheckOrigin")
 154  	}
 155  
 156  	challengeKey := r.Header.Get("Sec-Websocket-Key")
 157  	if !isValidChallengeKey(challengeKey) {
 158  		return u.returnError(w, r, http.StatusBadRequest, "websocket: not a websocket handshake: 'Sec-WebSocket-Key' header must be Base64 encoded value of 16-byte in length")
 159  	}
 160  
 161  	subprotocol := u.selectSubprotocol(r, responseHeader)
 162  
 163  	// Negotiate PMCE
 164  	var compress bool
 165  	if u.EnableCompression {
 166  		for _, ext := range parseExtensions(r.Header) {
 167  			if ext[""] != "permessage-deflate" {
 168  				continue
 169  			}
 170  			compress = true
 171  			break
 172  		}
 173  	}
 174  
 175  	h, ok := w.(http.Hijacker)
 176  	if !ok {
 177  		return u.returnError(w, r, http.StatusInternalServerError, "websocket: response does not implement http.Hijacker")
 178  	}
 179  	var brw *bufio.ReadWriter
 180  	netConn, brw, err := h.Hijack()
 181  	if err != nil {
 182  		return u.returnError(w, r, http.StatusInternalServerError, err.Error())
 183  	}
 184  
 185  	if brw.Reader.Buffered() > 0 {
 186  		netConn.Close()
 187  		return nil, errors.New("websocket: client sent data before handshake is complete")
 188  	}
 189  
 190  	var br *bufio.Reader
 191  	if u.ReadBufferSize == 0 && bufioReaderSize(netConn, brw.Reader) > 256 {
 192  		// Reuse hijacked buffered reader as connection reader.
 193  		br = brw.Reader
 194  	}
 195  
 196  	buf := bufioWriterBuffer(netConn, brw.Writer)
 197  
 198  	var writeBuf []byte
 199  	if u.WriteBufferPool == nil && u.WriteBufferSize == 0 && len(buf) >= maxFrameHeaderSize+256 {
 200  		// Reuse hijacked write buffer as connection buffer.
 201  		writeBuf = buf
 202  	}
 203  
 204  	c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, br, writeBuf)
 205  	c.subprotocol = subprotocol
 206  
 207  	if compress {
 208  		c.newCompressionWriter = compressNoContextTakeover
 209  		c.newDecompressionReader = decompressNoContextTakeover
 210  	}
 211  
 212  	// Use larger of hijacked buffer and connection write buffer for header.
 213  	p := buf
 214  	if len(c.writeBuf) > len(p) {
 215  		p = c.writeBuf
 216  	}
 217  	p = p[:0]
 218  
 219  	p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
 220  	p = append(p, computeAcceptKey(challengeKey)...)
 221  	p = append(p, "\r\n"...)
 222  	if c.subprotocol != "" {
 223  		p = append(p, "Sec-WebSocket-Protocol: "...)
 224  		p = append(p, c.subprotocol...)
 225  		p = append(p, "\r\n"...)
 226  	}
 227  	if compress {
 228  		p = append(p, "Sec-WebSocket-Extensions: permessage-deflate; server_no_context_takeover; client_no_context_takeover\r\n"...)
 229  	}
 230  	for k, vs := range responseHeader {
 231  		if k == "Sec-Websocket-Protocol" {
 232  			continue
 233  		}
 234  		for _, v := range vs {
 235  			p = append(p, k...)
 236  			p = append(p, ": "...)
 237  			for i := 0; i < len(v); i++ {
 238  				b := v[i]
 239  				if b <= 31 {
 240  					// prevent response splitting.
 241  					b = ' '
 242  				}
 243  				p = append(p, b)
 244  			}
 245  			p = append(p, "\r\n"...)
 246  		}
 247  	}
 248  	p = append(p, "\r\n"...)
 249  
 250  	// Clear deadlines set by HTTP server.
 251  	netConn.SetDeadline(time.Time{})
 252  
 253  	if u.HandshakeTimeout > 0 {
 254  		netConn.SetWriteDeadline(time.Now().Add(u.HandshakeTimeout))
 255  	}
 256  	if _, err = netConn.Write(p); err != nil {
 257  		netConn.Close()
 258  		return nil, err
 259  	}
 260  	if u.HandshakeTimeout > 0 {
 261  		netConn.SetWriteDeadline(time.Time{})
 262  	}
 263  
 264  	return c, nil
 265  }
 266  
 267  // Upgrade upgrades the HTTP server connection to the WebSocket protocol.
 268  //
 269  // Deprecated: Use websocket.Upgrader instead.
 270  //
 271  // Upgrade does not perform origin checking. The application is responsible for
 272  // checking the Origin header before calling Upgrade. An example implementation
 273  // of the same origin policy check is:
 274  //
 275  //	if req.Header.Get("Origin") != "http://"+req.Host {
 276  //		http.Error(w, "Origin not allowed", http.StatusForbidden)
 277  //		return
 278  //	}
 279  //
 280  // If the endpoint supports subprotocols, then the application is responsible
 281  // for negotiating the protocol used on the connection. Use the Subprotocols()
 282  // function to get the subprotocols requested by the client. Use the
 283  // Sec-Websocket-Protocol response header to specify the subprotocol selected
 284  // by the application.
 285  //
 286  // The responseHeader is included in the response to the client's upgrade
 287  // request. Use the responseHeader to specify cookies (Set-Cookie) and the
 288  // negotiated subprotocol (Sec-Websocket-Protocol).
 289  //
 290  // The connection buffers IO to the underlying network connection. The
 291  // readBufSize and writeBufSize parameters specify the size of the buffers to
 292  // use. Messages can be larger than the buffers.
 293  //
 294  // If the request is not a valid WebSocket handshake, then Upgrade returns an
 295  // error of type HandshakeError. Applications should handle this error by
 296  // replying to the client with an HTTP error response.
 297  func Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header, readBufSize, writeBufSize int) (*Conn, error) {
 298  	u := Upgrader{ReadBufferSize: readBufSize, WriteBufferSize: writeBufSize}
 299  	u.Error = func(w http.ResponseWriter, r *http.Request, status int, reason error) {
 300  		// don't return errors to maintain backwards compatibility
 301  	}
 302  	u.CheckOrigin = func(r *http.Request) bool {
 303  		// allow all connections by default
 304  		return true
 305  	}
 306  	return u.Upgrade(w, r, responseHeader)
 307  }
 308  
 309  // Subprotocols returns the subprotocols requested by the client in the
 310  // Sec-Websocket-Protocol header.
 311  func Subprotocols(r *http.Request) []string {
 312  	h := strings.TrimSpace(r.Header.Get("Sec-Websocket-Protocol"))
 313  	if h == "" {
 314  		return nil
 315  	}
 316  	protocols := strings.Split(h, ",")
 317  	for i := range protocols {
 318  		protocols[i] = strings.TrimSpace(protocols[i])
 319  	}
 320  	return protocols
 321  }
 322  
 323  // IsWebSocketUpgrade returns true if the client requested upgrade to the
 324  // WebSocket protocol.
 325  func IsWebSocketUpgrade(r *http.Request) bool {
 326  	return tokenListContainsValue(r.Header, "Connection", "upgrade") &&
 327  		tokenListContainsValue(r.Header, "Upgrade", "websocket")
 328  }
 329  
 330  // bufioReaderSize size returns the size of a bufio.Reader.
 331  func bufioReaderSize(originalReader io.Reader, br *bufio.Reader) int {
 332  	// This code assumes that peek on a reset reader returns
 333  	// bufio.Reader.buf[:0].
 334  	// TODO: Use bufio.Reader.Size() after Go 1.10
 335  	br.Reset(originalReader)
 336  	if p, err := br.Peek(0); err == nil {
 337  		return cap(p)
 338  	}
 339  	return 0
 340  }
 341  
 342  // writeHook is an io.Writer that records the last slice passed to it vio
 343  // io.Writer.Write.
 344  type writeHook struct {
 345  	p []byte
 346  }
 347  
 348  func (wh *writeHook) Write(p []byte) (int, error) {
 349  	wh.p = p
 350  	return len(p), nil
 351  }
 352  
 353  // bufioWriterBuffer grabs the buffer from a bufio.Writer.
 354  func bufioWriterBuffer(originalWriter io.Writer, bw *bufio.Writer) []byte {
 355  	// This code assumes that bufio.Writer.buf[:1] is passed to the
 356  	// bufio.Writer's underlying writer.
 357  	var wh writeHook
 358  	bw.Reset(&wh)
 359  	bw.WriteByte(0)
 360  	bw.Flush()
 361  
 362  	bw.Reset(originalWriter)
 363  
 364  	return wh.p[:cap(wh.p)]
 365  }
 366