util.go raw

   1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package websocket
   6  
   7  import (
   8  	"crypto/rand"
   9  	"crypto/sha1"
  10  	"encoding/base64"
  11  	"io"
  12  	"net/http"
  13  	"strings"
  14  	"unicode/utf8"
  15  )
  16  
  17  var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
  18  
  19  func computeAcceptKey(challengeKey string) string {
  20  	h := sha1.New()
  21  	h.Write([]byte(challengeKey))
  22  	h.Write(keyGUID)
  23  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
  24  }
  25  
  26  func generateChallengeKey() (string, error) {
  27  	p := make([]byte, 16)
  28  	if _, err := io.ReadFull(rand.Reader, p); err != nil {
  29  		return "", err
  30  	}
  31  	return base64.StdEncoding.EncodeToString(p), nil
  32  }
  33  
  34  // Token octets per RFC 2616.
  35  var isTokenOctet = [256]bool{
  36  	'!':  true,
  37  	'#':  true,
  38  	'$':  true,
  39  	'%':  true,
  40  	'&':  true,
  41  	'\'': true,
  42  	'*':  true,
  43  	'+':  true,
  44  	'-':  true,
  45  	'.':  true,
  46  	'0':  true,
  47  	'1':  true,
  48  	'2':  true,
  49  	'3':  true,
  50  	'4':  true,
  51  	'5':  true,
  52  	'6':  true,
  53  	'7':  true,
  54  	'8':  true,
  55  	'9':  true,
  56  	'A':  true,
  57  	'B':  true,
  58  	'C':  true,
  59  	'D':  true,
  60  	'E':  true,
  61  	'F':  true,
  62  	'G':  true,
  63  	'H':  true,
  64  	'I':  true,
  65  	'J':  true,
  66  	'K':  true,
  67  	'L':  true,
  68  	'M':  true,
  69  	'N':  true,
  70  	'O':  true,
  71  	'P':  true,
  72  	'Q':  true,
  73  	'R':  true,
  74  	'S':  true,
  75  	'T':  true,
  76  	'U':  true,
  77  	'W':  true,
  78  	'V':  true,
  79  	'X':  true,
  80  	'Y':  true,
  81  	'Z':  true,
  82  	'^':  true,
  83  	'_':  true,
  84  	'`':  true,
  85  	'a':  true,
  86  	'b':  true,
  87  	'c':  true,
  88  	'd':  true,
  89  	'e':  true,
  90  	'f':  true,
  91  	'g':  true,
  92  	'h':  true,
  93  	'i':  true,
  94  	'j':  true,
  95  	'k':  true,
  96  	'l':  true,
  97  	'm':  true,
  98  	'n':  true,
  99  	'o':  true,
 100  	'p':  true,
 101  	'q':  true,
 102  	'r':  true,
 103  	's':  true,
 104  	't':  true,
 105  	'u':  true,
 106  	'v':  true,
 107  	'w':  true,
 108  	'x':  true,
 109  	'y':  true,
 110  	'z':  true,
 111  	'|':  true,
 112  	'~':  true,
 113  }
 114  
 115  // skipSpace returns a slice of the string s with all leading RFC 2616 linear
 116  // whitespace removed.
 117  func skipSpace(s string) (rest string) {
 118  	i := 0
 119  	for ; i < len(s); i++ {
 120  		if b := s[i]; b != ' ' && b != '\t' {
 121  			break
 122  		}
 123  	}
 124  	return s[i:]
 125  }
 126  
 127  // nextToken returns the leading RFC 2616 token of s and the string following
 128  // the token.
 129  func nextToken(s string) (token, rest string) {
 130  	i := 0
 131  	for ; i < len(s); i++ {
 132  		if !isTokenOctet[s[i]] {
 133  			break
 134  		}
 135  	}
 136  	return s[:i], s[i:]
 137  }
 138  
 139  // nextTokenOrQuoted returns the leading token or quoted string per RFC 2616
 140  // and the string following the token or quoted string.
 141  func nextTokenOrQuoted(s string) (value string, rest string) {
 142  	if !strings.HasPrefix(s, "\"") {
 143  		return nextToken(s)
 144  	}
 145  	s = s[1:]
 146  	for i := 0; i < len(s); i++ {
 147  		switch s[i] {
 148  		case '"':
 149  			return s[:i], s[i+1:]
 150  		case '\\':
 151  			p := make([]byte, len(s)-1)
 152  			j := copy(p, s[:i])
 153  			escape := true
 154  			for i = i + 1; i < len(s); i++ {
 155  				b := s[i]
 156  				switch {
 157  				case escape:
 158  					escape = false
 159  					p[j] = b
 160  					j++
 161  				case b == '\\':
 162  					escape = true
 163  				case b == '"':
 164  					return string(p[:j]), s[i+1:]
 165  				default:
 166  					p[j] = b
 167  					j++
 168  				}
 169  			}
 170  			return "", ""
 171  		}
 172  	}
 173  	return "", ""
 174  }
 175  
 176  // equalASCIIFold returns true if s is equal to t with ASCII case folding as
 177  // defined in RFC 4790.
 178  func equalASCIIFold(s, t string) bool {
 179  	for s != "" && t != "" {
 180  		sr, size := utf8.DecodeRuneInString(s)
 181  		s = s[size:]
 182  		tr, size := utf8.DecodeRuneInString(t)
 183  		t = t[size:]
 184  		if sr == tr {
 185  			continue
 186  		}
 187  		if 'A' <= sr && sr <= 'Z' {
 188  			sr = sr + 'a' - 'A'
 189  		}
 190  		if 'A' <= tr && tr <= 'Z' {
 191  			tr = tr + 'a' - 'A'
 192  		}
 193  		if sr != tr {
 194  			return false
 195  		}
 196  	}
 197  	return s == t
 198  }
 199  
 200  // tokenListContainsValue returns true if the 1#token header with the given
 201  // name contains a token equal to value with ASCII case folding.
 202  func tokenListContainsValue(header http.Header, name string, value string) bool {
 203  headers:
 204  	for _, s := range header[name] {
 205  		for {
 206  			var t string
 207  			t, s = nextToken(skipSpace(s))
 208  			if t == "" {
 209  				continue headers
 210  			}
 211  			s = skipSpace(s)
 212  			if s != "" && s[0] != ',' {
 213  				continue headers
 214  			}
 215  			if equalASCIIFold(t, value) {
 216  				return true
 217  			}
 218  			if s == "" {
 219  				continue headers
 220  			}
 221  			s = s[1:]
 222  		}
 223  	}
 224  	return false
 225  }
 226  
 227  // parseExtensions parses WebSocket extensions from a header.
 228  func parseExtensions(header http.Header) []map[string]string {
 229  	// From RFC 6455:
 230  	//
 231  	//  Sec-WebSocket-Extensions = extension-list
 232  	//  extension-list = 1#extension
 233  	//  extension = extension-token *( ";" extension-param )
 234  	//  extension-token = registered-token
 235  	//  registered-token = token
 236  	//  extension-param = token [ "=" (token | quoted-string) ]
 237  	//     ;When using the quoted-string syntax variant, the value
 238  	//     ;after quoted-string unescaping MUST conform to the
 239  	//     ;'token' ABNF.
 240  
 241  	var result []map[string]string
 242  headers:
 243  	for _, s := range header["Sec-Websocket-Extensions"] {
 244  		for {
 245  			var t string
 246  			t, s = nextToken(skipSpace(s))
 247  			if t == "" {
 248  				continue headers
 249  			}
 250  			ext := map[string]string{"": t}
 251  			for {
 252  				s = skipSpace(s)
 253  				if !strings.HasPrefix(s, ";") {
 254  					break
 255  				}
 256  				var k string
 257  				k, s = nextToken(skipSpace(s[1:]))
 258  				if k == "" {
 259  					continue headers
 260  				}
 261  				s = skipSpace(s)
 262  				var v string
 263  				if strings.HasPrefix(s, "=") {
 264  					v, s = nextTokenOrQuoted(skipSpace(s[1:]))
 265  					s = skipSpace(s)
 266  				}
 267  				if s != "" && s[0] != ',' && s[0] != ';' {
 268  					continue headers
 269  				}
 270  				ext[k] = v
 271  			}
 272  			if s != "" && s[0] != ',' {
 273  				continue headers
 274  			}
 275  			result = append(result, ext)
 276  			if s == "" {
 277  				continue headers
 278  			}
 279  			s = s[1:]
 280  		}
 281  	}
 282  	return result
 283  }
 284  
 285  // isValidChallengeKey checks if the argument meets RFC6455 specification.
 286  func isValidChallengeKey(s string) bool {
 287  	// From RFC6455:
 288  	//
 289  	// A |Sec-WebSocket-Key| header field with a base64-encoded (see
 290  	// Section 4 of [RFC4648]) value that, when decoded, is 16 bytes in
 291  	// length.
 292  
 293  	if s == "" {
 294  		return false
 295  	}
 296  	decoded, err := base64.StdEncoding.DecodeString(s)
 297  	return err == nil && len(decoded) == 16
 298  }
 299