hybi.go raw
1 // Copyright 2011 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4
5 package websocket
6
7 // This file implements a protocol of hybi draft.
8 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17
9
10 import (
11 "bufio"
12 "bytes"
13 "crypto/rand"
14 "crypto/sha1"
15 "encoding/base64"
16 "encoding/binary"
17 "fmt"
18 "io"
19 "net/http"
20 "net/url"
21 "strings"
22 )
23
24 const (
25 websocketGUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
26
27 closeStatusNormal = 1000
28 closeStatusGoingAway = 1001
29 closeStatusProtocolError = 1002
30 closeStatusUnsupportedData = 1003
31 closeStatusFrameTooLarge = 1004
32 closeStatusNoStatusRcvd = 1005
33 closeStatusAbnormalClosure = 1006
34 closeStatusBadMessageData = 1007
35 closeStatusPolicyViolation = 1008
36 closeStatusTooBigData = 1009
37 closeStatusExtensionMismatch = 1010
38
39 maxControlFramePayloadLength = 125
40 )
41
42 var (
43 ErrBadMaskingKey = &ProtocolError{"bad masking key"}
44 ErrBadPongMessage = &ProtocolError{"bad pong message"}
45 ErrBadClosingStatus = &ProtocolError{"bad closing status"}
46 ErrUnsupportedExtensions = &ProtocolError{"unsupported extensions"}
47 ErrNotImplemented = &ProtocolError{"not implemented"}
48
49 handshakeHeader = map[string]bool{
50 "Host": true,
51 "Upgrade": true,
52 "Connection": true,
53 "Sec-Websocket-Key": true,
54 "Sec-Websocket-Origin": true,
55 "Sec-Websocket-Version": true,
56 "Sec-Websocket-Protocol": true,
57 "Sec-Websocket-Accept": true,
58 }
59 )
60
61 // A hybiFrameHeader is a frame header as defined in hybi draft.
62 type hybiFrameHeader struct {
63 Fin bool
64 Rsv [3]bool
65 OpCode byte
66 Length int64
67 MaskingKey []byte
68
69 data *bytes.Buffer
70 }
71
72 // A hybiFrameReader is a reader for hybi frame.
73 type hybiFrameReader struct {
74 reader io.Reader
75
76 header hybiFrameHeader
77 pos int64
78 length int
79 }
80
81 func (frame *hybiFrameReader) Read(msg []byte) (n int, err error) {
82 n, err = frame.reader.Read(msg)
83 if frame.header.MaskingKey != nil {
84 for i := 0; i < n; i++ {
85 msg[i] = msg[i] ^ frame.header.MaskingKey[frame.pos%4]
86 frame.pos++
87 }
88 }
89 return n, err
90 }
91
92 func (frame *hybiFrameReader) PayloadType() byte { return frame.header.OpCode }
93
94 func (frame *hybiFrameReader) HeaderReader() io.Reader {
95 if frame.header.data == nil {
96 return nil
97 }
98 if frame.header.data.Len() == 0 {
99 return nil
100 }
101 return frame.header.data
102 }
103
104 func (frame *hybiFrameReader) TrailerReader() io.Reader { return nil }
105
106 func (frame *hybiFrameReader) Len() (n int) { return frame.length }
107
108 // A hybiFrameReaderFactory creates new frame reader based on its frame type.
109 type hybiFrameReaderFactory struct {
110 *bufio.Reader
111 }
112
113 // NewFrameReader reads a frame header from the connection, and creates new reader for the frame.
114 // See Section 5.2 Base Framing protocol for detail.
115 // http://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-17#section-5.2
116 func (buf hybiFrameReaderFactory) NewFrameReader() (frame frameReader, err error) {
117 hybiFrame := new(hybiFrameReader)
118 frame = hybiFrame
119 var header []byte
120 var b byte
121 // First byte. FIN/RSV1/RSV2/RSV3/OpCode(4bits)
122 b, err = buf.ReadByte()
123 if err != nil {
124 return
125 }
126 header = append(header, b)
127 hybiFrame.header.Fin = ((header[0] >> 7) & 1) != 0
128 for i := 0; i < 3; i++ {
129 j := uint(6 - i)
130 hybiFrame.header.Rsv[i] = ((header[0] >> j) & 1) != 0
131 }
132 hybiFrame.header.OpCode = header[0] & 0x0f
133
134 // Second byte. Mask/Payload len(7bits)
135 b, err = buf.ReadByte()
136 if err != nil {
137 return
138 }
139 header = append(header, b)
140 mask := (b & 0x80) != 0
141 b &= 0x7f
142 lengthFields := 0
143 switch {
144 case b <= 125: // Payload length 7bits.
145 hybiFrame.header.Length = int64(b)
146 case b == 126: // Payload length 7+16bits
147 lengthFields = 2
148 case b == 127: // Payload length 7+64bits
149 lengthFields = 8
150 }
151 for i := 0; i < lengthFields; i++ {
152 b, err = buf.ReadByte()
153 if err != nil {
154 return
155 }
156 if lengthFields == 8 && i == 0 { // MSB must be zero when 7+64 bits
157 b &= 0x7f
158 }
159 header = append(header, b)
160 hybiFrame.header.Length = hybiFrame.header.Length*256 + int64(b)
161 }
162 if mask {
163 // Masking key. 4 bytes.
164 for i := 0; i < 4; i++ {
165 b, err = buf.ReadByte()
166 if err != nil {
167 return
168 }
169 header = append(header, b)
170 hybiFrame.header.MaskingKey = append(hybiFrame.header.MaskingKey, b)
171 }
172 }
173 hybiFrame.reader = io.LimitReader(buf.Reader, hybiFrame.header.Length)
174 hybiFrame.header.data = bytes.NewBuffer(header)
175 hybiFrame.length = len(header) + int(hybiFrame.header.Length)
176 return
177 }
178
179 // A HybiFrameWriter is a writer for hybi frame.
180 type hybiFrameWriter struct {
181 writer *bufio.Writer
182
183 header *hybiFrameHeader
184 }
185
186 func (frame *hybiFrameWriter) Write(msg []byte) (n int, err error) {
187 var header []byte
188 var b byte
189 if frame.header.Fin {
190 b |= 0x80
191 }
192 for i := 0; i < 3; i++ {
193 if frame.header.Rsv[i] {
194 j := uint(6 - i)
195 b |= 1 << j
196 }
197 }
198 b |= frame.header.OpCode
199 header = append(header, b)
200 if frame.header.MaskingKey != nil {
201 b = 0x80
202 } else {
203 b = 0
204 }
205 lengthFields := 0
206 length := len(msg)
207 switch {
208 case length <= 125:
209 b |= byte(length)
210 case length < 65536:
211 b |= 126
212 lengthFields = 2
213 default:
214 b |= 127
215 lengthFields = 8
216 }
217 header = append(header, b)
218 for i := 0; i < lengthFields; i++ {
219 j := uint((lengthFields - i - 1) * 8)
220 b = byte((length >> j) & 0xff)
221 header = append(header, b)
222 }
223 if frame.header.MaskingKey != nil {
224 if len(frame.header.MaskingKey) != 4 {
225 return 0, ErrBadMaskingKey
226 }
227 header = append(header, frame.header.MaskingKey...)
228 frame.writer.Write(header)
229 data := make([]byte, length)
230 for i := range data {
231 data[i] = msg[i] ^ frame.header.MaskingKey[i%4]
232 }
233 frame.writer.Write(data)
234 err = frame.writer.Flush()
235 return length, err
236 }
237 frame.writer.Write(header)
238 frame.writer.Write(msg)
239 err = frame.writer.Flush()
240 return length, err
241 }
242
243 func (frame *hybiFrameWriter) Close() error { return nil }
244
245 type hybiFrameWriterFactory struct {
246 *bufio.Writer
247 needMaskingKey bool
248 }
249
250 func (buf hybiFrameWriterFactory) NewFrameWriter(payloadType byte) (frame frameWriter, err error) {
251 frameHeader := &hybiFrameHeader{Fin: true, OpCode: payloadType}
252 if buf.needMaskingKey {
253 frameHeader.MaskingKey, err = generateMaskingKey()
254 if err != nil {
255 return nil, err
256 }
257 }
258 return &hybiFrameWriter{writer: buf.Writer, header: frameHeader}, nil
259 }
260
261 type hybiFrameHandler struct {
262 conn *Conn
263 payloadType byte
264 }
265
266 func (handler *hybiFrameHandler) HandleFrame(frame frameReader) (frameReader, error) {
267 if handler.conn.IsServerConn() {
268 // The client MUST mask all frames sent to the server.
269 if frame.(*hybiFrameReader).header.MaskingKey == nil {
270 handler.WriteClose(closeStatusProtocolError)
271 return nil, io.EOF
272 }
273 } else {
274 // The server MUST NOT mask all frames.
275 if frame.(*hybiFrameReader).header.MaskingKey != nil {
276 handler.WriteClose(closeStatusProtocolError)
277 return nil, io.EOF
278 }
279 }
280 if header := frame.HeaderReader(); header != nil {
281 io.Copy(io.Discard, header)
282 }
283 switch frame.PayloadType() {
284 case ContinuationFrame:
285 frame.(*hybiFrameReader).header.OpCode = handler.payloadType
286 case TextFrame, BinaryFrame:
287 handler.payloadType = frame.PayloadType()
288 case CloseFrame:
289 return nil, io.EOF
290 case PingFrame, PongFrame:
291 b := make([]byte, maxControlFramePayloadLength)
292 n, err := io.ReadFull(frame, b)
293 if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF {
294 return nil, err
295 }
296 io.Copy(io.Discard, frame)
297 if frame.PayloadType() == PingFrame {
298 if _, err := handler.WritePong(b[:n]); err != nil {
299 return nil, err
300 }
301 }
302 return nil, nil
303 }
304 return frame, nil
305 }
306
307 func (handler *hybiFrameHandler) WriteClose(status int) (err error) {
308 handler.conn.wio.Lock()
309 defer handler.conn.wio.Unlock()
310 w, err := handler.conn.frameWriterFactory.NewFrameWriter(CloseFrame)
311 if err != nil {
312 return err
313 }
314 msg := make([]byte, 2)
315 binary.BigEndian.PutUint16(msg, uint16(status))
316 _, err = w.Write(msg)
317 w.Close()
318 return err
319 }
320
321 func (handler *hybiFrameHandler) WritePong(msg []byte) (n int, err error) {
322 handler.conn.wio.Lock()
323 defer handler.conn.wio.Unlock()
324 w, err := handler.conn.frameWriterFactory.NewFrameWriter(PongFrame)
325 if err != nil {
326 return 0, err
327 }
328 n, err = w.Write(msg)
329 w.Close()
330 return n, err
331 }
332
333 // newHybiConn creates a new WebSocket connection speaking hybi draft protocol.
334 func newHybiConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
335 if buf == nil {
336 br := bufio.NewReader(rwc)
337 bw := bufio.NewWriter(rwc)
338 buf = bufio.NewReadWriter(br, bw)
339 }
340 ws := &Conn{config: config, request: request, buf: buf, rwc: rwc,
341 frameReaderFactory: hybiFrameReaderFactory{buf.Reader},
342 frameWriterFactory: hybiFrameWriterFactory{
343 buf.Writer, request == nil},
344 PayloadType: TextFrame,
345 defaultCloseStatus: closeStatusNormal}
346 ws.frameHandler = &hybiFrameHandler{conn: ws}
347 return ws
348 }
349
350 // generateMaskingKey generates a masking key for a frame.
351 func generateMaskingKey() (maskingKey []byte, err error) {
352 maskingKey = make([]byte, 4)
353 if _, err = io.ReadFull(rand.Reader, maskingKey); err != nil {
354 return
355 }
356 return
357 }
358
359 // generateNonce generates a nonce consisting of a randomly selected 16-byte
360 // value that has been base64-encoded.
361 func generateNonce() (nonce []byte) {
362 key := make([]byte, 16)
363 if _, err := io.ReadFull(rand.Reader, key); err != nil {
364 panic(err)
365 }
366 nonce = make([]byte, 24)
367 base64.StdEncoding.Encode(nonce, key)
368 return
369 }
370
371 // removeZone removes IPv6 zone identifier from host.
372 // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080"
373 func removeZone(host string) string {
374 if !strings.HasPrefix(host, "[") {
375 return host
376 }
377 i := strings.LastIndex(host, "]")
378 if i < 0 {
379 return host
380 }
381 j := strings.LastIndex(host[:i], "%")
382 if j < 0 {
383 return host
384 }
385 return host[:j] + host[i:]
386 }
387
388 // getNonceAccept computes the base64-encoded SHA-1 of the concatenation of
389 // the nonce ("Sec-WebSocket-Key" value) with the websocket GUID string.
390 func getNonceAccept(nonce []byte) (expected []byte, err error) {
391 h := sha1.New()
392 if _, err = h.Write(nonce); err != nil {
393 return
394 }
395 if _, err = h.Write([]byte(websocketGUID)); err != nil {
396 return
397 }
398 expected = make([]byte, 28)
399 base64.StdEncoding.Encode(expected, h.Sum(nil))
400 return
401 }
402
403 // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
404 func hybiClientHandshake(config *Config, br *bufio.Reader, bw *bufio.Writer) (err error) {
405 bw.WriteString("GET " + config.Location.RequestURI() + " HTTP/1.1\r\n")
406
407 // According to RFC 6874, an HTTP client, proxy, or other
408 // intermediary must remove any IPv6 zone identifier attached
409 // to an outgoing URI.
410 bw.WriteString("Host: " + removeZone(config.Location.Host) + "\r\n")
411 bw.WriteString("Upgrade: websocket\r\n")
412 bw.WriteString("Connection: Upgrade\r\n")
413 nonce := generateNonce()
414 if config.handshakeData != nil {
415 nonce = []byte(config.handshakeData["key"])
416 }
417 bw.WriteString("Sec-WebSocket-Key: " + string(nonce) + "\r\n")
418 bw.WriteString("Origin: " + strings.ToLower(config.Origin.String()) + "\r\n")
419
420 if config.Version != ProtocolVersionHybi13 {
421 return ErrBadProtocolVersion
422 }
423
424 bw.WriteString("Sec-WebSocket-Version: " + fmt.Sprintf("%d", config.Version) + "\r\n")
425 if len(config.Protocol) > 0 {
426 bw.WriteString("Sec-WebSocket-Protocol: " + strings.Join(config.Protocol, ", ") + "\r\n")
427 }
428 // TODO(ukai): send Sec-WebSocket-Extensions.
429 err = config.Header.WriteSubset(bw, handshakeHeader)
430 if err != nil {
431 return err
432 }
433
434 bw.WriteString("\r\n")
435 if err = bw.Flush(); err != nil {
436 return err
437 }
438
439 resp, err := http.ReadResponse(br, &http.Request{Method: "GET"})
440 if err != nil {
441 return err
442 }
443 defer resp.Body.Close()
444 if resp.StatusCode != 101 {
445 return ErrBadStatus
446 }
447 if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" ||
448 strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
449 return ErrBadUpgrade
450 }
451 expectedAccept, err := getNonceAccept(nonce)
452 if err != nil {
453 return err
454 }
455 if resp.Header.Get("Sec-WebSocket-Accept") != string(expectedAccept) {
456 return ErrChallengeResponse
457 }
458 if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
459 return ErrUnsupportedExtensions
460 }
461 offeredProtocol := resp.Header.Get("Sec-WebSocket-Protocol")
462 if offeredProtocol != "" {
463 protocolMatched := false
464 for i := 0; i < len(config.Protocol); i++ {
465 if config.Protocol[i] == offeredProtocol {
466 protocolMatched = true
467 break
468 }
469 }
470 if !protocolMatched {
471 return ErrBadWebSocketProtocol
472 }
473 config.Protocol = []string{offeredProtocol}
474 }
475
476 return nil
477 }
478
479 // newHybiClientConn creates a client WebSocket connection after handshake.
480 func newHybiClientConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser) *Conn {
481 return newHybiConn(config, buf, rwc, nil)
482 }
483
484 // A HybiServerHandshaker performs a server handshake using hybi draft protocol.
485 type hybiServerHandshaker struct {
486 *Config
487 accept []byte
488 }
489
490 func (c *hybiServerHandshaker) ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error) {
491 c.Version = ProtocolVersionHybi13
492 if req.Method != "GET" {
493 return http.StatusMethodNotAllowed, ErrBadRequestMethod
494 }
495 // HTTP version can be safely ignored.
496
497 if strings.ToLower(req.Header.Get("Upgrade")) != "websocket" ||
498 !strings.Contains(strings.ToLower(req.Header.Get("Connection")), "upgrade") {
499 return http.StatusBadRequest, ErrNotWebSocket
500 }
501
502 key := req.Header.Get("Sec-Websocket-Key")
503 if key == "" {
504 return http.StatusBadRequest, ErrChallengeResponse
505 }
506 version := req.Header.Get("Sec-Websocket-Version")
507 switch version {
508 case "13":
509 c.Version = ProtocolVersionHybi13
510 default:
511 return http.StatusBadRequest, ErrBadWebSocketVersion
512 }
513 var scheme string
514 if req.TLS != nil {
515 scheme = "wss"
516 } else {
517 scheme = "ws"
518 }
519 c.Location, err = url.ParseRequestURI(scheme + "://" + req.Host + req.URL.RequestURI())
520 if err != nil {
521 return http.StatusBadRequest, err
522 }
523 protocol := strings.TrimSpace(req.Header.Get("Sec-Websocket-Protocol"))
524 if protocol != "" {
525 protocols := strings.Split(protocol, ",")
526 for i := 0; i < len(protocols); i++ {
527 c.Protocol = append(c.Protocol, strings.TrimSpace(protocols[i]))
528 }
529 }
530 c.accept, err = getNonceAccept([]byte(key))
531 if err != nil {
532 return http.StatusInternalServerError, err
533 }
534 return http.StatusSwitchingProtocols, nil
535 }
536
537 // Origin parses the Origin header in req.
538 // If the Origin header is not set, it returns nil and nil.
539 func Origin(config *Config, req *http.Request) (*url.URL, error) {
540 var origin string
541 switch config.Version {
542 case ProtocolVersionHybi13:
543 origin = req.Header.Get("Origin")
544 }
545 if origin == "" {
546 return nil, nil
547 }
548 return url.ParseRequestURI(origin)
549 }
550
551 func (c *hybiServerHandshaker) AcceptHandshake(buf *bufio.Writer) (err error) {
552 if len(c.Protocol) > 0 {
553 if len(c.Protocol) != 1 {
554 // You need choose a Protocol in Handshake func in Server.
555 return ErrBadWebSocketProtocol
556 }
557 }
558 buf.WriteString("HTTP/1.1 101 Switching Protocols\r\n")
559 buf.WriteString("Upgrade: websocket\r\n")
560 buf.WriteString("Connection: Upgrade\r\n")
561 buf.WriteString("Sec-WebSocket-Accept: " + string(c.accept) + "\r\n")
562 if len(c.Protocol) > 0 {
563 buf.WriteString("Sec-WebSocket-Protocol: " + c.Protocol[0] + "\r\n")
564 }
565 // TODO(ukai): send Sec-WebSocket-Extensions.
566 if c.Header != nil {
567 err := c.Header.WriteSubset(buf, handshakeHeader)
568 if err != nil {
569 return err
570 }
571 }
572 buf.WriteString("\r\n")
573 return buf.Flush()
574 }
575
576 func (c *hybiServerHandshaker) NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
577 return newHybiServerConn(c.Config, buf, rwc, request)
578 }
579
580 // newHybiServerConn returns a new WebSocket connection speaking hybi draft protocol.
581 func newHybiServerConn(config *Config, buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) *Conn {
582 return newHybiConn(config, buf, rwc, request)
583 }
584