websocket.go raw

   1  // Copyright 2009 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Package websocket implements a client and server for the WebSocket protocol
   6  // as specified in RFC 6455.
   7  //
   8  // This package currently lacks some features found in an alternative
   9  // and more actively maintained WebSocket packages:
  10  //
  11  //   - [github.com/gorilla/websocket]
  12  //   - [github.com/coder/websocket]
  13  package websocket // import "golang.org/x/net/websocket"
  14  
  15  import (
  16  	"bufio"
  17  	"crypto/tls"
  18  	"encoding/json"
  19  	"errors"
  20  	"io"
  21  	"net"
  22  	"net/http"
  23  	"net/url"
  24  	"sync"
  25  	"time"
  26  )
  27  
  28  const (
  29  	ProtocolVersionHybi13    = 13
  30  	ProtocolVersionHybi      = ProtocolVersionHybi13
  31  	SupportedProtocolVersion = "13"
  32  
  33  	ContinuationFrame = 0
  34  	TextFrame         = 1
  35  	BinaryFrame       = 2
  36  	CloseFrame        = 8
  37  	PingFrame         = 9
  38  	PongFrame         = 10
  39  	UnknownFrame      = 255
  40  
  41  	DefaultMaxPayloadBytes = 32 << 20 // 32MB
  42  )
  43  
  44  // ProtocolError represents WebSocket protocol errors.
  45  type ProtocolError struct {
  46  	ErrorString string
  47  }
  48  
  49  func (err *ProtocolError) Error() string { return err.ErrorString }
  50  
  51  var (
  52  	ErrBadProtocolVersion   = &ProtocolError{"bad protocol version"}
  53  	ErrBadScheme            = &ProtocolError{"bad scheme"}
  54  	ErrBadStatus            = &ProtocolError{"bad status"}
  55  	ErrBadUpgrade           = &ProtocolError{"missing or bad upgrade"}
  56  	ErrBadWebSocketOrigin   = &ProtocolError{"missing or bad WebSocket-Origin"}
  57  	ErrBadWebSocketLocation = &ProtocolError{"missing or bad WebSocket-Location"}
  58  	ErrBadWebSocketProtocol = &ProtocolError{"missing or bad WebSocket-Protocol"}
  59  	ErrBadWebSocketVersion  = &ProtocolError{"missing or bad WebSocket Version"}
  60  	ErrChallengeResponse    = &ProtocolError{"mismatch challenge/response"}
  61  	ErrBadFrame             = &ProtocolError{"bad frame"}
  62  	ErrBadFrameBoundary     = &ProtocolError{"not on frame boundary"}
  63  	ErrNotWebSocket         = &ProtocolError{"not websocket protocol"}
  64  	ErrBadRequestMethod     = &ProtocolError{"bad method"}
  65  	ErrNotSupported         = &ProtocolError{"not supported"}
  66  )
  67  
  68  // ErrFrameTooLarge is returned by Codec's Receive method if payload size
  69  // exceeds limit set by Conn.MaxPayloadBytes
  70  var ErrFrameTooLarge = errors.New("websocket: frame payload size exceeds limit")
  71  
  72  // Addr is an implementation of net.Addr for WebSocket.
  73  type Addr struct {
  74  	*url.URL
  75  }
  76  
  77  // Network returns the network type for a WebSocket, "websocket".
  78  func (addr *Addr) Network() string { return "websocket" }
  79  
  80  // Config is a WebSocket configuration
  81  type Config struct {
  82  	// A WebSocket server address.
  83  	Location *url.URL
  84  
  85  	// A Websocket client origin.
  86  	Origin *url.URL
  87  
  88  	// WebSocket subprotocols.
  89  	Protocol []string
  90  
  91  	// WebSocket protocol version.
  92  	Version int
  93  
  94  	// TLS config for secure WebSocket (wss).
  95  	TlsConfig *tls.Config
  96  
  97  	// Additional header fields to be sent in WebSocket opening handshake.
  98  	Header http.Header
  99  
 100  	// Dialer used when opening websocket connections.
 101  	Dialer *net.Dialer
 102  
 103  	handshakeData map[string]string
 104  }
 105  
 106  // serverHandshaker is an interface to handle WebSocket server side handshake.
 107  type serverHandshaker interface {
 108  	// ReadHandshake reads handshake request message from client.
 109  	// Returns http response code and error if any.
 110  	ReadHandshake(buf *bufio.Reader, req *http.Request) (code int, err error)
 111  
 112  	// AcceptHandshake accepts the client handshake request and sends
 113  	// handshake response back to client.
 114  	AcceptHandshake(buf *bufio.Writer) (err error)
 115  
 116  	// NewServerConn creates a new WebSocket connection.
 117  	NewServerConn(buf *bufio.ReadWriter, rwc io.ReadWriteCloser, request *http.Request) (conn *Conn)
 118  }
 119  
 120  // frameReader is an interface to read a WebSocket frame.
 121  type frameReader interface {
 122  	// Reader is to read payload of the frame.
 123  	io.Reader
 124  
 125  	// PayloadType returns payload type.
 126  	PayloadType() byte
 127  
 128  	// HeaderReader returns a reader to read header of the frame.
 129  	HeaderReader() io.Reader
 130  
 131  	// TrailerReader returns a reader to read trailer of the frame.
 132  	// If it returns nil, there is no trailer in the frame.
 133  	TrailerReader() io.Reader
 134  
 135  	// Len returns total length of the frame, including header and trailer.
 136  	Len() int
 137  }
 138  
 139  // frameReaderFactory is an interface to creates new frame reader.
 140  type frameReaderFactory interface {
 141  	NewFrameReader() (r frameReader, err error)
 142  }
 143  
 144  // frameWriter is an interface to write a WebSocket frame.
 145  type frameWriter interface {
 146  	// Writer is to write payload of the frame.
 147  	io.WriteCloser
 148  }
 149  
 150  // frameWriterFactory is an interface to create new frame writer.
 151  type frameWriterFactory interface {
 152  	NewFrameWriter(payloadType byte) (w frameWriter, err error)
 153  }
 154  
 155  type frameHandler interface {
 156  	HandleFrame(frame frameReader) (r frameReader, err error)
 157  	WriteClose(status int) (err error)
 158  }
 159  
 160  // Conn represents a WebSocket connection.
 161  //
 162  // Multiple goroutines may invoke methods on a Conn simultaneously.
 163  type Conn struct {
 164  	config  *Config
 165  	request *http.Request
 166  
 167  	buf *bufio.ReadWriter
 168  	rwc io.ReadWriteCloser
 169  
 170  	rio sync.Mutex
 171  	frameReaderFactory
 172  	frameReader
 173  
 174  	wio sync.Mutex
 175  	frameWriterFactory
 176  
 177  	frameHandler
 178  	PayloadType        byte
 179  	defaultCloseStatus int
 180  
 181  	// MaxPayloadBytes limits the size of frame payload received over Conn
 182  	// by Codec's Receive method. If zero, DefaultMaxPayloadBytes is used.
 183  	MaxPayloadBytes int
 184  }
 185  
 186  // Read implements the io.Reader interface:
 187  // it reads data of a frame from the WebSocket connection.
 188  // if msg is not large enough for the frame data, it fills the msg and next Read
 189  // will read the rest of the frame data.
 190  // it reads Text frame or Binary frame.
 191  func (ws *Conn) Read(msg []byte) (n int, err error) {
 192  	ws.rio.Lock()
 193  	defer ws.rio.Unlock()
 194  again:
 195  	if ws.frameReader == nil {
 196  		frame, err := ws.frameReaderFactory.NewFrameReader()
 197  		if err != nil {
 198  			return 0, err
 199  		}
 200  		ws.frameReader, err = ws.frameHandler.HandleFrame(frame)
 201  		if err != nil {
 202  			return 0, err
 203  		}
 204  		if ws.frameReader == nil {
 205  			goto again
 206  		}
 207  	}
 208  	n, err = ws.frameReader.Read(msg)
 209  	if err == io.EOF {
 210  		if trailer := ws.frameReader.TrailerReader(); trailer != nil {
 211  			io.Copy(io.Discard, trailer)
 212  		}
 213  		ws.frameReader = nil
 214  		goto again
 215  	}
 216  	return n, err
 217  }
 218  
 219  // Write implements the io.Writer interface:
 220  // it writes data as a frame to the WebSocket connection.
 221  func (ws *Conn) Write(msg []byte) (n int, err error) {
 222  	ws.wio.Lock()
 223  	defer ws.wio.Unlock()
 224  	w, err := ws.frameWriterFactory.NewFrameWriter(ws.PayloadType)
 225  	if err != nil {
 226  		return 0, err
 227  	}
 228  	n, err = w.Write(msg)
 229  	w.Close()
 230  	return n, err
 231  }
 232  
 233  // Close implements the io.Closer interface.
 234  func (ws *Conn) Close() error {
 235  	err := ws.frameHandler.WriteClose(ws.defaultCloseStatus)
 236  	err1 := ws.rwc.Close()
 237  	if err != nil {
 238  		return err
 239  	}
 240  	return err1
 241  }
 242  
 243  // IsClientConn reports whether ws is a client-side connection.
 244  func (ws *Conn) IsClientConn() bool { return ws.request == nil }
 245  
 246  // IsServerConn reports whether ws is a server-side connection.
 247  func (ws *Conn) IsServerConn() bool { return ws.request != nil }
 248  
 249  // LocalAddr returns the WebSocket Origin for the connection for client, or
 250  // the WebSocket location for server.
 251  func (ws *Conn) LocalAddr() net.Addr {
 252  	if ws.IsClientConn() {
 253  		return &Addr{ws.config.Origin}
 254  	}
 255  	return &Addr{ws.config.Location}
 256  }
 257  
 258  // RemoteAddr returns the WebSocket location for the connection for client, or
 259  // the Websocket Origin for server.
 260  func (ws *Conn) RemoteAddr() net.Addr {
 261  	if ws.IsClientConn() {
 262  		return &Addr{ws.config.Location}
 263  	}
 264  	return &Addr{ws.config.Origin}
 265  }
 266  
 267  var errSetDeadline = errors.New("websocket: cannot set deadline: not using a net.Conn")
 268  
 269  // SetDeadline sets the connection's network read & write deadlines.
 270  func (ws *Conn) SetDeadline(t time.Time) error {
 271  	if conn, ok := ws.rwc.(net.Conn); ok {
 272  		return conn.SetDeadline(t)
 273  	}
 274  	return errSetDeadline
 275  }
 276  
 277  // SetReadDeadline sets the connection's network read deadline.
 278  func (ws *Conn) SetReadDeadline(t time.Time) error {
 279  	if conn, ok := ws.rwc.(net.Conn); ok {
 280  		return conn.SetReadDeadline(t)
 281  	}
 282  	return errSetDeadline
 283  }
 284  
 285  // SetWriteDeadline sets the connection's network write deadline.
 286  func (ws *Conn) SetWriteDeadline(t time.Time) error {
 287  	if conn, ok := ws.rwc.(net.Conn); ok {
 288  		return conn.SetWriteDeadline(t)
 289  	}
 290  	return errSetDeadline
 291  }
 292  
 293  // Config returns the WebSocket config.
 294  func (ws *Conn) Config() *Config { return ws.config }
 295  
 296  // Request returns the http request upgraded to the WebSocket.
 297  // It is nil for client side.
 298  func (ws *Conn) Request() *http.Request { return ws.request }
 299  
 300  // Codec represents a symmetric pair of functions that implement a codec.
 301  type Codec struct {
 302  	Marshal   func(v interface{}) (data []byte, payloadType byte, err error)
 303  	Unmarshal func(data []byte, payloadType byte, v interface{}) (err error)
 304  }
 305  
 306  // Send sends v marshaled by cd.Marshal as single frame to ws.
 307  func (cd Codec) Send(ws *Conn, v interface{}) (err error) {
 308  	data, payloadType, err := cd.Marshal(v)
 309  	if err != nil {
 310  		return err
 311  	}
 312  	ws.wio.Lock()
 313  	defer ws.wio.Unlock()
 314  	w, err := ws.frameWriterFactory.NewFrameWriter(payloadType)
 315  	if err != nil {
 316  		return err
 317  	}
 318  	_, err = w.Write(data)
 319  	w.Close()
 320  	return err
 321  }
 322  
 323  // Receive receives single frame from ws, unmarshaled by cd.Unmarshal and stores
 324  // in v. The whole frame payload is read to an in-memory buffer; max size of
 325  // payload is defined by ws.MaxPayloadBytes. If frame payload size exceeds
 326  // limit, ErrFrameTooLarge is returned; in this case frame is not read off wire
 327  // completely. The next call to Receive would read and discard leftover data of
 328  // previous oversized frame before processing next frame.
 329  func (cd Codec) Receive(ws *Conn, v interface{}) (err error) {
 330  	ws.rio.Lock()
 331  	defer ws.rio.Unlock()
 332  	if ws.frameReader != nil {
 333  		_, err = io.Copy(io.Discard, ws.frameReader)
 334  		if err != nil {
 335  			return err
 336  		}
 337  		ws.frameReader = nil
 338  	}
 339  again:
 340  	frame, err := ws.frameReaderFactory.NewFrameReader()
 341  	if err != nil {
 342  		return err
 343  	}
 344  	frame, err = ws.frameHandler.HandleFrame(frame)
 345  	if err != nil {
 346  		return err
 347  	}
 348  	if frame == nil {
 349  		goto again
 350  	}
 351  	maxPayloadBytes := ws.MaxPayloadBytes
 352  	if maxPayloadBytes == 0 {
 353  		maxPayloadBytes = DefaultMaxPayloadBytes
 354  	}
 355  	if hf, ok := frame.(*hybiFrameReader); ok && hf.header.Length > int64(maxPayloadBytes) {
 356  		// payload size exceeds limit, no need to call Unmarshal
 357  		//
 358  		// set frameReader to current oversized frame so that
 359  		// the next call to this function can drain leftover
 360  		// data before processing the next frame
 361  		ws.frameReader = frame
 362  		return ErrFrameTooLarge
 363  	}
 364  	payloadType := frame.PayloadType()
 365  	data, err := io.ReadAll(frame)
 366  	if err != nil {
 367  		return err
 368  	}
 369  	return cd.Unmarshal(data, payloadType, v)
 370  }
 371  
 372  func marshal(v interface{}) (msg []byte, payloadType byte, err error) {
 373  	switch data := v.(type) {
 374  	case string:
 375  		return []byte(data), TextFrame, nil
 376  	case []byte:
 377  		return data, BinaryFrame, nil
 378  	}
 379  	return nil, UnknownFrame, ErrNotSupported
 380  }
 381  
 382  func unmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
 383  	switch data := v.(type) {
 384  	case *string:
 385  		*data = string(msg)
 386  		return nil
 387  	case *[]byte:
 388  		*data = msg
 389  		return nil
 390  	}
 391  	return ErrNotSupported
 392  }
 393  
 394  /*
 395  Message is a codec to send/receive text/binary data in a frame on WebSocket connection.
 396  To send/receive text frame, use string type.
 397  To send/receive binary frame, use []byte type.
 398  
 399  Trivial usage:
 400  
 401  	import "websocket"
 402  
 403  	// receive text frame
 404  	var message string
 405  	websocket.Message.Receive(ws, &message)
 406  
 407  	// send text frame
 408  	message = "hello"
 409  	websocket.Message.Send(ws, message)
 410  
 411  	// receive binary frame
 412  	var data []byte
 413  	websocket.Message.Receive(ws, &data)
 414  
 415  	// send binary frame
 416  	data = []byte{0, 1, 2}
 417  	websocket.Message.Send(ws, data)
 418  */
 419  var Message = Codec{marshal, unmarshal}
 420  
 421  func jsonMarshal(v interface{}) (msg []byte, payloadType byte, err error) {
 422  	msg, err = json.Marshal(v)
 423  	return msg, TextFrame, err
 424  }
 425  
 426  func jsonUnmarshal(msg []byte, payloadType byte, v interface{}) (err error) {
 427  	return json.Unmarshal(msg, v)
 428  }
 429  
 430  /*
 431  JSON is a codec to send/receive JSON data in a frame from a WebSocket connection.
 432  
 433  Trivial usage:
 434  
 435  	import "websocket"
 436  
 437  	type T struct {
 438  		Msg string
 439  		Count int
 440  	}
 441  
 442  	// receive JSON type T
 443  	var data T
 444  	websocket.JSON.Receive(ws, &data)
 445  
 446  	// send JSON type T
 447  	websocket.JSON.Send(ws, data)
 448  */
 449  var JSON = Codec{jsonMarshal, jsonUnmarshal}
 450