dial.go raw

   1  //go:build !js
   2  // +build !js
   3  
   4  package websocket
   5  
   6  import (
   7  	"bufio"
   8  	"bytes"
   9  	"context"
  10  	"crypto/rand"
  11  	"encoding/base64"
  12  	"fmt"
  13  	"io"
  14  	"net/http"
  15  	"net/url"
  16  	"strings"
  17  	"sync"
  18  	"time"
  19  
  20  	"github.com/coder/websocket/internal/errd"
  21  )
  22  
  23  // DialOptions represents Dial's options.
  24  type DialOptions struct {
  25  	// HTTPClient is used for the connection.
  26  	// Its Transport must return writable bodies for WebSocket handshakes.
  27  	// http.Transport does beginning with Go 1.12.
  28  	HTTPClient *http.Client
  29  
  30  	// HTTPHeader specifies the HTTP headers included in the handshake request.
  31  	HTTPHeader http.Header
  32  
  33  	// Host optionally overrides the Host HTTP header to send. If empty, the value
  34  	// of URL.Host will be used.
  35  	Host string
  36  
  37  	// Subprotocols lists the WebSocket subprotocols to negotiate with the server.
  38  	Subprotocols []string
  39  
  40  	// CompressionMode controls the compression mode.
  41  	// Defaults to CompressionDisabled.
  42  	//
  43  	// See docs on CompressionMode for details.
  44  	CompressionMode CompressionMode
  45  
  46  	// CompressionThreshold controls the minimum size of a message before compression is applied.
  47  	//
  48  	// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
  49  	// for CompressionContextTakeover.
  50  	CompressionThreshold int
  51  }
  52  
  53  func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
  54  	var cancel context.CancelFunc
  55  
  56  	var o DialOptions
  57  	if opts != nil {
  58  		o = *opts
  59  	}
  60  	if o.HTTPClient == nil {
  61  		o.HTTPClient = http.DefaultClient
  62  	}
  63  	if o.HTTPClient.Timeout > 0 {
  64  		ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)
  65  
  66  		newClient := *o.HTTPClient
  67  		newClient.Timeout = 0
  68  		o.HTTPClient = &newClient
  69  	}
  70  	if o.HTTPHeader == nil {
  71  		o.HTTPHeader = http.Header{}
  72  	}
  73  	newClient := *o.HTTPClient
  74  	oldCheckRedirect := o.HTTPClient.CheckRedirect
  75  	newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
  76  		switch req.URL.Scheme {
  77  		case "ws":
  78  			req.URL.Scheme = "http"
  79  		case "wss":
  80  			req.URL.Scheme = "https"
  81  		}
  82  		if oldCheckRedirect != nil {
  83  			return oldCheckRedirect(req, via)
  84  		}
  85  		return nil
  86  	}
  87  	o.HTTPClient = &newClient
  88  
  89  	return ctx, cancel, &o
  90  }
  91  
  92  // Dial performs a WebSocket handshake on url.
  93  //
  94  // The response is the WebSocket handshake response from the server.
  95  // You never need to close resp.Body yourself.
  96  //
  97  // If an error occurs, the returned response may be non nil.
  98  // However, you can only read the first 1024 bytes of the body.
  99  //
 100  // This function requires at least Go 1.12 as it uses a new feature
 101  // in net/http to perform WebSocket handshakes.
 102  // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
 103  //
 104  // URLs with http/https schemes will work and are interpreted as ws/wss.
 105  func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
 106  	return dial(ctx, u, opts, nil)
 107  }
 108  
 109  func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
 110  	defer errd.Wrap(&err, "failed to WebSocket dial")
 111  
 112  	var cancel context.CancelFunc
 113  	ctx, cancel, opts = opts.cloneWithDefaults(ctx)
 114  	if cancel != nil {
 115  		defer cancel()
 116  	}
 117  
 118  	secWebSocketKey, err := secWebSocketKey(rand)
 119  	if err != nil {
 120  		return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
 121  	}
 122  
 123  	var copts *compressionOptions
 124  	if opts.CompressionMode != CompressionDisabled {
 125  		copts = opts.CompressionMode.opts()
 126  	}
 127  
 128  	resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
 129  	if err != nil {
 130  		return nil, resp, err
 131  	}
 132  	respBody := resp.Body
 133  	resp.Body = nil
 134  	defer func() {
 135  		if err != nil {
 136  			// We read a bit of the body for easier debugging.
 137  			r := io.LimitReader(respBody, 1024)
 138  
 139  			timer := time.AfterFunc(time.Second*3, func() {
 140  				respBody.Close()
 141  			})
 142  			defer timer.Stop()
 143  
 144  			b, _ := io.ReadAll(r)
 145  			respBody.Close()
 146  			resp.Body = io.NopCloser(bytes.NewReader(b))
 147  		}
 148  	}()
 149  
 150  	copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
 151  	if err != nil {
 152  		return nil, resp, err
 153  	}
 154  
 155  	rwc, ok := respBody.(io.ReadWriteCloser)
 156  	if !ok {
 157  		return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
 158  	}
 159  
 160  	return newConn(connConfig{
 161  		subprotocol:    resp.Header.Get("Sec-WebSocket-Protocol"),
 162  		rwc:            rwc,
 163  		client:         true,
 164  		copts:          copts,
 165  		flateThreshold: opts.CompressionThreshold,
 166  		br:             getBufioReader(rwc),
 167  		bw:             getBufioWriter(rwc),
 168  	}), resp, nil
 169  }
 170  
 171  func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
 172  	u, err := url.Parse(urls)
 173  	if err != nil {
 174  		return nil, fmt.Errorf("failed to parse url: %w", err)
 175  	}
 176  
 177  	switch u.Scheme {
 178  	case "ws":
 179  		u.Scheme = "http"
 180  	case "wss":
 181  		u.Scheme = "https"
 182  	case "http", "https":
 183  	default:
 184  		return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
 185  	}
 186  
 187  	req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
 188  	if err != nil {
 189  		return nil, fmt.Errorf("failed to create new http request: %w", err)
 190  	}
 191  	if len(opts.Host) > 0 {
 192  		req.Host = opts.Host
 193  	}
 194  	req.Header = opts.HTTPHeader.Clone()
 195  	req.Header.Set("Connection", "Upgrade")
 196  	req.Header.Set("Upgrade", "websocket")
 197  	req.Header.Set("Sec-WebSocket-Version", "13")
 198  	req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
 199  	if len(opts.Subprotocols) > 0 {
 200  		req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
 201  	}
 202  	if copts != nil {
 203  		req.Header.Set("Sec-WebSocket-Extensions", copts.String())
 204  	}
 205  
 206  	resp, err := opts.HTTPClient.Do(req)
 207  	if err != nil {
 208  		return nil, fmt.Errorf("failed to send handshake request: %w", err)
 209  	}
 210  	return resp, nil
 211  }
 212  
 213  func secWebSocketKey(rr io.Reader) (string, error) {
 214  	if rr == nil {
 215  		rr = rand.Reader
 216  	}
 217  	b := make([]byte, 16)
 218  	_, err := io.ReadFull(rr, b)
 219  	if err != nil {
 220  		return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
 221  	}
 222  	return base64.StdEncoding.EncodeToString(b), nil
 223  }
 224  
 225  func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
 226  	if resp.StatusCode != http.StatusSwitchingProtocols {
 227  		return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
 228  	}
 229  
 230  	if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
 231  		return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
 232  	}
 233  
 234  	if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
 235  		return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
 236  	}
 237  
 238  	if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
 239  		return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
 240  			resp.Header.Get("Sec-WebSocket-Accept"),
 241  			secWebSocketKey,
 242  		)
 243  	}
 244  
 245  	err := verifySubprotocol(opts.Subprotocols, resp)
 246  	if err != nil {
 247  		return nil, err
 248  	}
 249  
 250  	return verifyServerExtensions(copts, resp.Header)
 251  }
 252  
 253  func verifySubprotocol(subprotos []string, resp *http.Response) error {
 254  	proto := resp.Header.Get("Sec-WebSocket-Protocol")
 255  	if proto == "" {
 256  		return nil
 257  	}
 258  
 259  	for _, sp2 := range subprotos {
 260  		if strings.EqualFold(sp2, proto) {
 261  			return nil
 262  		}
 263  	}
 264  
 265  	return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
 266  }
 267  
 268  func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
 269  	exts := websocketExtensions(h)
 270  	if len(exts) == 0 {
 271  		return nil, nil
 272  	}
 273  
 274  	ext := exts[0]
 275  	if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
 276  		return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
 277  	}
 278  
 279  	_copts := *copts
 280  	copts = &_copts
 281  
 282  	for _, p := range ext.params {
 283  		switch p {
 284  		case "client_no_context_takeover":
 285  			copts.clientNoContextTakeover = true
 286  			continue
 287  		case "server_no_context_takeover":
 288  			copts.serverNoContextTakeover = true
 289  			continue
 290  		}
 291  		if strings.HasPrefix(p, "server_max_window_bits=") {
 292  			// We can't adjust the deflate window, but decoding with a larger window is acceptable.
 293  			continue
 294  		}
 295  
 296  		return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
 297  	}
 298  
 299  	return copts, nil
 300  }
 301  
 302  var bufioReaderPool sync.Pool
 303  
 304  func getBufioReader(r io.Reader) *bufio.Reader {
 305  	br, ok := bufioReaderPool.Get().(*bufio.Reader)
 306  	if !ok {
 307  		return bufio.NewReader(r)
 308  	}
 309  	br.Reset(r)
 310  	return br
 311  }
 312  
 313  func putBufioReader(br *bufio.Reader) {
 314  	bufioReaderPool.Put(br)
 315  }
 316  
 317  var bufioWriterPool sync.Pool
 318  
 319  func getBufioWriter(w io.Writer) *bufio.Writer {
 320  	bw, ok := bufioWriterPool.Get().(*bufio.Writer)
 321  	if !ok {
 322  		return bufio.NewWriter(w)
 323  	}
 324  	bw.Reset(w)
 325  	return bw
 326  }
 327  
 328  func putBufioWriter(bw *bufio.Writer) {
 329  	bufioWriterPool.Put(bw)
 330  }
 331