dial.go raw
1 //go:build !js
2 // +build !js
3
4 package websocket
5
6 import (
7 "bufio"
8 "bytes"
9 "context"
10 "crypto/rand"
11 "encoding/base64"
12 "fmt"
13 "io"
14 "net/http"
15 "net/url"
16 "strings"
17 "sync"
18 "time"
19
20 "github.com/coder/websocket/internal/errd"
21 )
22
23 // DialOptions represents Dial's options.
24 type DialOptions struct {
25 // HTTPClient is used for the connection.
26 // Its Transport must return writable bodies for WebSocket handshakes.
27 // http.Transport does beginning with Go 1.12.
28 HTTPClient *http.Client
29
30 // HTTPHeader specifies the HTTP headers included in the handshake request.
31 HTTPHeader http.Header
32
33 // Host optionally overrides the Host HTTP header to send. If empty, the value
34 // of URL.Host will be used.
35 Host string
36
37 // Subprotocols lists the WebSocket subprotocols to negotiate with the server.
38 Subprotocols []string
39
40 // CompressionMode controls the compression mode.
41 // Defaults to CompressionDisabled.
42 //
43 // See docs on CompressionMode for details.
44 CompressionMode CompressionMode
45
46 // CompressionThreshold controls the minimum size of a message before compression is applied.
47 //
48 // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
49 // for CompressionContextTakeover.
50 CompressionThreshold int
51 }
52
53 func (opts *DialOptions) cloneWithDefaults(ctx context.Context) (context.Context, context.CancelFunc, *DialOptions) {
54 var cancel context.CancelFunc
55
56 var o DialOptions
57 if opts != nil {
58 o = *opts
59 }
60 if o.HTTPClient == nil {
61 o.HTTPClient = http.DefaultClient
62 }
63 if o.HTTPClient.Timeout > 0 {
64 ctx, cancel = context.WithTimeout(ctx, o.HTTPClient.Timeout)
65
66 newClient := *o.HTTPClient
67 newClient.Timeout = 0
68 o.HTTPClient = &newClient
69 }
70 if o.HTTPHeader == nil {
71 o.HTTPHeader = http.Header{}
72 }
73 newClient := *o.HTTPClient
74 oldCheckRedirect := o.HTTPClient.CheckRedirect
75 newClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
76 switch req.URL.Scheme {
77 case "ws":
78 req.URL.Scheme = "http"
79 case "wss":
80 req.URL.Scheme = "https"
81 }
82 if oldCheckRedirect != nil {
83 return oldCheckRedirect(req, via)
84 }
85 return nil
86 }
87 o.HTTPClient = &newClient
88
89 return ctx, cancel, &o
90 }
91
92 // Dial performs a WebSocket handshake on url.
93 //
94 // The response is the WebSocket handshake response from the server.
95 // You never need to close resp.Body yourself.
96 //
97 // If an error occurs, the returned response may be non nil.
98 // However, you can only read the first 1024 bytes of the body.
99 //
100 // This function requires at least Go 1.12 as it uses a new feature
101 // in net/http to perform WebSocket handshakes.
102 // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861
103 //
104 // URLs with http/https schemes will work and are interpreted as ws/wss.
105 func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) {
106 return dial(ctx, u, opts, nil)
107 }
108
109 func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) {
110 defer errd.Wrap(&err, "failed to WebSocket dial")
111
112 var cancel context.CancelFunc
113 ctx, cancel, opts = opts.cloneWithDefaults(ctx)
114 if cancel != nil {
115 defer cancel()
116 }
117
118 secWebSocketKey, err := secWebSocketKey(rand)
119 if err != nil {
120 return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err)
121 }
122
123 var copts *compressionOptions
124 if opts.CompressionMode != CompressionDisabled {
125 copts = opts.CompressionMode.opts()
126 }
127
128 resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey)
129 if err != nil {
130 return nil, resp, err
131 }
132 respBody := resp.Body
133 resp.Body = nil
134 defer func() {
135 if err != nil {
136 // We read a bit of the body for easier debugging.
137 r := io.LimitReader(respBody, 1024)
138
139 timer := time.AfterFunc(time.Second*3, func() {
140 respBody.Close()
141 })
142 defer timer.Stop()
143
144 b, _ := io.ReadAll(r)
145 respBody.Close()
146 resp.Body = io.NopCloser(bytes.NewReader(b))
147 }
148 }()
149
150 copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp)
151 if err != nil {
152 return nil, resp, err
153 }
154
155 rwc, ok := respBody.(io.ReadWriteCloser)
156 if !ok {
157 return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody)
158 }
159
160 return newConn(connConfig{
161 subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"),
162 rwc: rwc,
163 client: true,
164 copts: copts,
165 flateThreshold: opts.CompressionThreshold,
166 br: getBufioReader(rwc),
167 bw: getBufioWriter(rwc),
168 }), resp, nil
169 }
170
171 func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) {
172 u, err := url.Parse(urls)
173 if err != nil {
174 return nil, fmt.Errorf("failed to parse url: %w", err)
175 }
176
177 switch u.Scheme {
178 case "ws":
179 u.Scheme = "http"
180 case "wss":
181 u.Scheme = "https"
182 case "http", "https":
183 default:
184 return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme)
185 }
186
187 req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil)
188 if err != nil {
189 return nil, fmt.Errorf("failed to create new http request: %w", err)
190 }
191 if len(opts.Host) > 0 {
192 req.Host = opts.Host
193 }
194 req.Header = opts.HTTPHeader.Clone()
195 req.Header.Set("Connection", "Upgrade")
196 req.Header.Set("Upgrade", "websocket")
197 req.Header.Set("Sec-WebSocket-Version", "13")
198 req.Header.Set("Sec-WebSocket-Key", secWebSocketKey)
199 if len(opts.Subprotocols) > 0 {
200 req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ","))
201 }
202 if copts != nil {
203 req.Header.Set("Sec-WebSocket-Extensions", copts.String())
204 }
205
206 resp, err := opts.HTTPClient.Do(req)
207 if err != nil {
208 return nil, fmt.Errorf("failed to send handshake request: %w", err)
209 }
210 return resp, nil
211 }
212
213 func secWebSocketKey(rr io.Reader) (string, error) {
214 if rr == nil {
215 rr = rand.Reader
216 }
217 b := make([]byte, 16)
218 _, err := io.ReadFull(rr, b)
219 if err != nil {
220 return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err)
221 }
222 return base64.StdEncoding.EncodeToString(b), nil
223 }
224
225 func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) {
226 if resp.StatusCode != http.StatusSwitchingProtocols {
227 return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode)
228 }
229
230 if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") {
231 return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection"))
232 }
233
234 if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") {
235 return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade"))
236 }
237
238 if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) {
239 return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q",
240 resp.Header.Get("Sec-WebSocket-Accept"),
241 secWebSocketKey,
242 )
243 }
244
245 err := verifySubprotocol(opts.Subprotocols, resp)
246 if err != nil {
247 return nil, err
248 }
249
250 return verifyServerExtensions(copts, resp.Header)
251 }
252
253 func verifySubprotocol(subprotos []string, resp *http.Response) error {
254 proto := resp.Header.Get("Sec-WebSocket-Protocol")
255 if proto == "" {
256 return nil
257 }
258
259 for _, sp2 := range subprotos {
260 if strings.EqualFold(sp2, proto) {
261 return nil
262 }
263 }
264
265 return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
266 }
267
268 func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) {
269 exts := websocketExtensions(h)
270 if len(exts) == 0 {
271 return nil, nil
272 }
273
274 ext := exts[0]
275 if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil {
276 return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:])
277 }
278
279 _copts := *copts
280 copts = &_copts
281
282 for _, p := range ext.params {
283 switch p {
284 case "client_no_context_takeover":
285 copts.clientNoContextTakeover = true
286 continue
287 case "server_no_context_takeover":
288 copts.serverNoContextTakeover = true
289 continue
290 }
291 if strings.HasPrefix(p, "server_max_window_bits=") {
292 // We can't adjust the deflate window, but decoding with a larger window is acceptable.
293 continue
294 }
295
296 return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
297 }
298
299 return copts, nil
300 }
301
302 var bufioReaderPool sync.Pool
303
304 func getBufioReader(r io.Reader) *bufio.Reader {
305 br, ok := bufioReaderPool.Get().(*bufio.Reader)
306 if !ok {
307 return bufio.NewReader(r)
308 }
309 br.Reset(r)
310 return br
311 }
312
313 func putBufioReader(br *bufio.Reader) {
314 bufioReaderPool.Put(br)
315 }
316
317 var bufioWriterPool sync.Pool
318
319 func getBufioWriter(w io.Writer) *bufio.Writer {
320 bw, ok := bufioWriterPool.Get().(*bufio.Writer)
321 if !ok {
322 return bufio.NewWriter(w)
323 }
324 bw.Reset(w)
325 return bw
326 }
327
328 func putBufioWriter(bw *bufio.Writer) {
329 bufioWriterPool.Put(bw)
330 }
331