server.go raw

   1  // Copyright 2009 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package websocket
   6  
   7  import (
   8  	"bufio"
   9  	"fmt"
  10  	"io"
  11  	"net/http"
  12  )
  13  
  14  func newServerConn(rwc io.ReadWriteCloser, buf *bufio.ReadWriter, req *http.Request, config *Config, handshake func(*Config, *http.Request) error) (conn *Conn, err error) {
  15  	var hs serverHandshaker = &hybiServerHandshaker{Config: config}
  16  	code, err := hs.ReadHandshake(buf.Reader, req)
  17  	if err == ErrBadWebSocketVersion {
  18  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
  19  		fmt.Fprintf(buf, "Sec-WebSocket-Version: %s\r\n", SupportedProtocolVersion)
  20  		buf.WriteString("\r\n")
  21  		buf.WriteString(err.Error())
  22  		buf.Flush()
  23  		return
  24  	}
  25  	if err != nil {
  26  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
  27  		buf.WriteString("\r\n")
  28  		buf.WriteString(err.Error())
  29  		buf.Flush()
  30  		return
  31  	}
  32  	if handshake != nil {
  33  		err = handshake(config, req)
  34  		if err != nil {
  35  			code = http.StatusForbidden
  36  			fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
  37  			buf.WriteString("\r\n")
  38  			buf.Flush()
  39  			return
  40  		}
  41  	}
  42  	err = hs.AcceptHandshake(buf.Writer)
  43  	if err != nil {
  44  		code = http.StatusBadRequest
  45  		fmt.Fprintf(buf, "HTTP/1.1 %03d %s\r\n", code, http.StatusText(code))
  46  		buf.WriteString("\r\n")
  47  		buf.Flush()
  48  		return
  49  	}
  50  	conn = hs.NewServerConn(buf, rwc, req)
  51  	return
  52  }
  53  
  54  // Server represents a server of a WebSocket.
  55  type Server struct {
  56  	// Config is a WebSocket configuration for new WebSocket connection.
  57  	Config
  58  
  59  	// Handshake is an optional function in WebSocket handshake.
  60  	// For example, you can check, or don't check Origin header.
  61  	// Another example, you can select config.Protocol.
  62  	Handshake func(*Config, *http.Request) error
  63  
  64  	// Handler handles a WebSocket connection.
  65  	Handler
  66  }
  67  
  68  // ServeHTTP implements the http.Handler interface for a WebSocket
  69  func (s Server) ServeHTTP(w http.ResponseWriter, req *http.Request) {
  70  	s.serveWebSocket(w, req)
  71  }
  72  
  73  func (s Server) serveWebSocket(w http.ResponseWriter, req *http.Request) {
  74  	rwc, buf, err := w.(http.Hijacker).Hijack()
  75  	if err != nil {
  76  		panic("Hijack failed: " + err.Error())
  77  	}
  78  	// The server should abort the WebSocket connection if it finds
  79  	// the client did not send a handshake that matches with protocol
  80  	// specification.
  81  	defer rwc.Close()
  82  	conn, err := newServerConn(rwc, buf, req, &s.Config, s.Handshake)
  83  	if err != nil {
  84  		return
  85  	}
  86  	if conn == nil {
  87  		panic("unexpected nil conn")
  88  	}
  89  	s.Handler(conn)
  90  }
  91  
  92  // Handler is a simple interface to a WebSocket browser client.
  93  // It checks if Origin header is valid URL by default.
  94  // You might want to verify websocket.Conn.Config().Origin in the func.
  95  // If you use Server instead of Handler, you could call websocket.Origin and
  96  // check the origin in your Handshake func. So, if you want to accept
  97  // non-browser clients, which do not send an Origin header, set a
  98  // Server.Handshake that does not check the origin.
  99  type Handler func(*Conn)
 100  
 101  func checkOrigin(config *Config, req *http.Request) (err error) {
 102  	config.Origin, err = Origin(config, req)
 103  	if err == nil && config.Origin == nil {
 104  		return fmt.Errorf("null origin")
 105  	}
 106  	return err
 107  }
 108  
 109  // ServeHTTP implements the http.Handler interface for a WebSocket
 110  func (h Handler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 111  	s := Server{Handler: h, Handshake: checkOrigin}
 112  	s.serveWebSocket(w, req)
 113  }
 114