accept.go raw

   1  //go:build !js
   2  // +build !js
   3  
   4  package websocket
   5  
   6  import (
   7  	"bytes"
   8  	"crypto/sha1"
   9  	"encoding/base64"
  10  	"errors"
  11  	"fmt"
  12  	"io"
  13  	"log"
  14  	"net/http"
  15  	"net/textproto"
  16  	"net/url"
  17  	"path/filepath"
  18  	"strings"
  19  
  20  	"github.com/coder/websocket/internal/errd"
  21  )
  22  
  23  // AcceptOptions represents Accept's options.
  24  type AcceptOptions struct {
  25  	// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
  26  	// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
  27  	// reject it, close the connection when c.Subprotocol() == "".
  28  	Subprotocols []string
  29  
  30  	// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
  31  	//
  32  	// You probably want to use OriginPatterns instead.
  33  	InsecureSkipVerify bool
  34  
  35  	// OriginPatterns lists the host patterns for authorized origins.
  36  	// The request host is always authorized.
  37  	// Use this to enable cross origin WebSockets.
  38  	//
  39  	// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
  40  	// In such a case, example.com is the origin and chat.example.com is the request host.
  41  	// One would set this field to []string{"example.com"} to authorize example.com to connect.
  42  	//
  43  	// Each pattern is matched case insensitively against the request origin host
  44  	// with filepath.Match.
  45  	// See https://golang.org/pkg/path/filepath/#Match
  46  	//
  47  	// Please ensure you understand the ramifications of enabling this.
  48  	// If used incorrectly your WebSocket server will be open to CSRF attacks.
  49  	//
  50  	// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
  51  	// to bring attention to the danger of such a setting.
  52  	OriginPatterns []string
  53  
  54  	// CompressionMode controls the compression mode.
  55  	// Defaults to CompressionDisabled.
  56  	//
  57  	// See docs on CompressionMode for details.
  58  	CompressionMode CompressionMode
  59  
  60  	// CompressionThreshold controls the minimum size of a message before compression is applied.
  61  	//
  62  	// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
  63  	// for CompressionContextTakeover.
  64  	CompressionThreshold int
  65  }
  66  
  67  func (opts *AcceptOptions) cloneWithDefaults() *AcceptOptions {
  68  	var o AcceptOptions
  69  	if opts != nil {
  70  		o = *opts
  71  	}
  72  	return &o
  73  }
  74  
  75  // Accept accepts a WebSocket handshake from a client and upgrades the
  76  // the connection to a WebSocket.
  77  //
  78  // Accept will not allow cross origin requests by default.
  79  // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
  80  //
  81  // Accept will write a response to w on all errors.
  82  func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
  83  	return accept(w, r, opts)
  84  }
  85  
  86  func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
  87  	defer errd.Wrap(&err, "failed to accept WebSocket connection")
  88  
  89  	errCode, err := verifyClientRequest(w, r)
  90  	if err != nil {
  91  		http.Error(w, err.Error(), errCode)
  92  		return nil, err
  93  	}
  94  
  95  	opts = opts.cloneWithDefaults()
  96  	if !opts.InsecureSkipVerify {
  97  		err = authenticateOrigin(r, opts.OriginPatterns)
  98  		if err != nil {
  99  			if errors.Is(err, filepath.ErrBadPattern) {
 100  				log.Printf("websocket: %v", err)
 101  				err = errors.New(http.StatusText(http.StatusForbidden))
 102  			}
 103  			http.Error(w, err.Error(), http.StatusForbidden)
 104  			return nil, err
 105  		}
 106  	}
 107  
 108  	hj, ok := w.(http.Hijacker)
 109  	if !ok {
 110  		err = errors.New("http.ResponseWriter does not implement http.Hijacker")
 111  		http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
 112  		return nil, err
 113  	}
 114  
 115  	w.Header().Set("Upgrade", "websocket")
 116  	w.Header().Set("Connection", "Upgrade")
 117  
 118  	key := r.Header.Get("Sec-WebSocket-Key")
 119  	w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
 120  
 121  	subproto := selectSubprotocol(r, opts.Subprotocols)
 122  	if subproto != "" {
 123  		w.Header().Set("Sec-WebSocket-Protocol", subproto)
 124  	}
 125  
 126  	copts, ok := selectDeflate(websocketExtensions(r.Header), opts.CompressionMode)
 127  	if ok {
 128  		w.Header().Set("Sec-WebSocket-Extensions", copts.String())
 129  	}
 130  
 131  	w.WriteHeader(http.StatusSwitchingProtocols)
 132  	// See https://github.com/nhooyr/websocket/issues/166
 133  	if ginWriter, ok := w.(interface {
 134  		WriteHeaderNow()
 135  	}); ok {
 136  		ginWriter.WriteHeaderNow()
 137  	}
 138  
 139  	netConn, brw, err := hj.Hijack()
 140  	if err != nil {
 141  		err = fmt.Errorf("failed to hijack connection: %w", err)
 142  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
 143  		return nil, err
 144  	}
 145  
 146  	// https://github.com/golang/go/issues/32314
 147  	b, _ := brw.Reader.Peek(brw.Reader.Buffered())
 148  	brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
 149  
 150  	return newConn(connConfig{
 151  		subprotocol:    w.Header().Get("Sec-WebSocket-Protocol"),
 152  		rwc:            netConn,
 153  		client:         false,
 154  		copts:          copts,
 155  		flateThreshold: opts.CompressionThreshold,
 156  
 157  		br: brw.Reader,
 158  		bw: brw.Writer,
 159  	}), nil
 160  }
 161  
 162  func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
 163  	if !r.ProtoAtLeast(1, 1) {
 164  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
 165  	}
 166  
 167  	if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
 168  		w.Header().Set("Connection", "Upgrade")
 169  		w.Header().Set("Upgrade", "websocket")
 170  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
 171  	}
 172  
 173  	if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
 174  		w.Header().Set("Connection", "Upgrade")
 175  		w.Header().Set("Upgrade", "websocket")
 176  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
 177  	}
 178  
 179  	if r.Method != "GET" {
 180  		return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
 181  	}
 182  
 183  	if r.Header.Get("Sec-WebSocket-Version") != "13" {
 184  		w.Header().Set("Sec-WebSocket-Version", "13")
 185  		return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
 186  	}
 187  
 188  	websocketSecKeys := r.Header.Values("Sec-WebSocket-Key")
 189  	if len(websocketSecKeys) == 0 {
 190  		return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
 191  	}
 192  
 193  	if len(websocketSecKeys) > 1 {
 194  		return http.StatusBadRequest, errors.New("WebSocket protocol violation: multiple Sec-WebSocket-Key headers")
 195  	}
 196  
 197  	// The RFC states to remove any leading or trailing whitespace.
 198  	websocketSecKey := strings.TrimSpace(websocketSecKeys[0])
 199  	if v, err := base64.StdEncoding.DecodeString(websocketSecKey); err != nil || len(v) != 16 {
 200  		return http.StatusBadRequest, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Key %q, must be a 16 byte base64 encoded string", websocketSecKey)
 201  	}
 202  
 203  	return 0, nil
 204  }
 205  
 206  func authenticateOrigin(r *http.Request, originHosts []string) error {
 207  	origin := r.Header.Get("Origin")
 208  	if origin == "" {
 209  		return nil
 210  	}
 211  
 212  	u, err := url.Parse(origin)
 213  	if err != nil {
 214  		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
 215  	}
 216  
 217  	if strings.EqualFold(r.Host, u.Host) {
 218  		return nil
 219  	}
 220  
 221  	for _, hostPattern := range originHosts {
 222  		matched, err := match(hostPattern, u.Host)
 223  		if err != nil {
 224  			return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
 225  		}
 226  		if matched {
 227  			return nil
 228  		}
 229  	}
 230  	if u.Host == "" {
 231  		return fmt.Errorf("request Origin %q is not a valid URL with a host", origin)
 232  	}
 233  	return fmt.Errorf("request Origin %q is not authorized for Host %q", u.Host, r.Host)
 234  }
 235  
 236  func match(pattern, s string) (bool, error) {
 237  	return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
 238  }
 239  
 240  func selectSubprotocol(r *http.Request, subprotocols []string) string {
 241  	cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
 242  	for _, sp := range subprotocols {
 243  		for _, cp := range cps {
 244  			if strings.EqualFold(sp, cp) {
 245  				return cp
 246  			}
 247  		}
 248  	}
 249  	return ""
 250  }
 251  
 252  func selectDeflate(extensions []websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
 253  	if mode == CompressionDisabled {
 254  		return nil, false
 255  	}
 256  	for _, ext := range extensions {
 257  		switch ext.name {
 258  		// We used to implement x-webkit-deflate-frame too for Safari but Safari has bugs...
 259  		// See https://github.com/nhooyr/websocket/issues/218
 260  		case "permessage-deflate":
 261  			copts, ok := acceptDeflate(ext, mode)
 262  			if ok {
 263  				return copts, true
 264  			}
 265  		}
 266  	}
 267  	return nil, false
 268  }
 269  
 270  func acceptDeflate(ext websocketExtension, mode CompressionMode) (*compressionOptions, bool) {
 271  	copts := mode.opts()
 272  	for _, p := range ext.params {
 273  		switch p {
 274  		case "client_no_context_takeover":
 275  			copts.clientNoContextTakeover = true
 276  			continue
 277  		case "server_no_context_takeover":
 278  			copts.serverNoContextTakeover = true
 279  			continue
 280  		case "client_max_window_bits",
 281  			"server_max_window_bits=15":
 282  			continue
 283  		}
 284  
 285  		if strings.HasPrefix(p, "client_max_window_bits=") {
 286  			// We can't adjust the deflate window, but decoding with a larger window is acceptable.
 287  			continue
 288  		}
 289  		return nil, false
 290  	}
 291  	return copts, true
 292  }
 293  
 294  func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
 295  	for _, t := range headerTokens(h, key) {
 296  		if strings.EqualFold(t, token) {
 297  			return true
 298  		}
 299  	}
 300  	return false
 301  }
 302  
 303  type websocketExtension struct {
 304  	name   string
 305  	params []string
 306  }
 307  
 308  func websocketExtensions(h http.Header) []websocketExtension {
 309  	var exts []websocketExtension
 310  	extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
 311  	for _, extStr := range extStrs {
 312  		if extStr == "" {
 313  			continue
 314  		}
 315  
 316  		vals := strings.Split(extStr, ";")
 317  		for i := range vals {
 318  			vals[i] = strings.TrimSpace(vals[i])
 319  		}
 320  
 321  		e := websocketExtension{
 322  			name:   vals[0],
 323  			params: vals[1:],
 324  		}
 325  
 326  		exts = append(exts, e)
 327  	}
 328  	return exts
 329  }
 330  
 331  func headerTokens(h http.Header, key string) []string {
 332  	key = textproto.CanonicalMIMEHeaderKey(key)
 333  	var tokens []string
 334  	for _, v := range h[key] {
 335  		v = strings.TrimSpace(v)
 336  		for _, t := range strings.Split(v, ",") {
 337  			t = strings.TrimSpace(t)
 338  			tokens = append(tokens, t)
 339  		}
 340  	}
 341  	return tokens
 342  }
 343  
 344  var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
 345  
 346  func secWebSocketAccept(secWebSocketKey string) string {
 347  	h := sha1.New()
 348  	h.Write([]byte(secWebSocketKey))
 349  	h.Write(keyGUID)
 350  
 351  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 352  }
 353