ws.mx raw

   1  package transport
   2  
   3  import (
   4  	"crypto/rand"
   5  	"crypto/sha1"
   6  	"encoding/base64"
   7  	"encoding/hex"
   8  	"syscall"
   9  )
  10  
  11  const wsMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  12  
  13  // parseWSFrame extracts one WS frame from data.
  14  // Returns (op, payload, consumed). consumed=0 means incomplete.
  15  func parseWSFrame(data []byte) (byte, []byte, int) {
  16  	if len(data) < 2 {
  17  		return 0, nil, 0
  18  	}
  19  	op := data[0] & 0x0f
  20  	masked := data[1]&0x80 != 0
  21  	length := int(data[1] & 0x7f)
  22  	pos := 2
  23  
  24  	if length == 126 {
  25  		if len(data) < 4 {
  26  			return 0, nil, 0
  27  		}
  28  		length = int(data[2])<<8 | int(data[3])
  29  		pos = 4
  30  	} else if length == 127 {
  31  		if len(data) < 10 {
  32  			return 0, nil, 0
  33  		}
  34  		length = int(data[6])<<24 | int(data[7])<<16 | int(data[8])<<8 | int(data[9])
  35  		pos = 10
  36  	}
  37  
  38  	var mask [4]byte
  39  	if masked {
  40  		if len(data) < pos+4 {
  41  			return 0, nil, 0
  42  		}
  43  		copy(mask[:], data[pos:pos+4])
  44  		pos += 4
  45  	}
  46  	if len(data) < pos+length {
  47  		return 0, nil, 0
  48  	}
  49  
  50  	payload := makeCopy(data[pos : pos+length])
  51  	if masked {
  52  		for i := range payload {
  53  			payload[i] ^= mask[i%4]
  54  		}
  55  	}
  56  	return op, payload, pos + length
  57  }
  58  
  59  func buildWSFrame(op byte, payload []byte) []byte {
  60  	plen := len(payload)
  61  	var hdr [10]byte
  62  	hdr[0] = 0x80 | op
  63  	n := 2
  64  	if plen < 126 {
  65  		hdr[1] = byte(plen)
  66  	} else if plen < 65536 {
  67  		hdr[1] = 126
  68  		hdr[2] = byte(plen >> 8)
  69  		hdr[3] = byte(plen)
  70  		n = 4
  71  	} else {
  72  		hdr[1] = 127
  73  		hdr[6] = byte(plen >> 24)
  74  		hdr[7] = byte(plen >> 16)
  75  		hdr[8] = byte(plen >> 8)
  76  		hdr[9] = byte(plen)
  77  		n = 10
  78  	}
  79  	buf := []byte{:0:n + plen}
  80  	buf = append(buf, hdr[:n]...)
  81  	buf = append(buf, payload...)
  82  	return buf
  83  }
  84  
  85  func writeWSFrame(fd int, op byte, payload []byte) {
  86  	writeAll(fd, buildWSFrame(op, payload))
  87  }
  88  
  89  func writeWSFrameErr(fd int, op byte, payload []byte) error {
  90  	return writeAll(fd, buildWSFrame(op, payload))
  91  }
  92  
  93  func writeWSClose(fd int) {
  94  	writeWSFrame(fd, opClose, []byte{0x03, 0xe8})
  95  }
  96  
  97  func writeAll(fd int, data []byte) error {
  98  	for len(data) > 0 {
  99  		n, err := syscall.Write(fd, data)
 100  		if n > 0 {
 101  			data = data[n:]
 102  		}
 103  		if err == syscall.EAGAIN {
 104  			return syscall.EAGAIN
 105  		}
 106  		if err != nil {
 107  			return err
 108  		}
 109  	}
 110  	return nil
 111  }
 112  
 113  func computeAccept(key string) string {
 114  	h := sha1.New()
 115  	h.Write([]byte(key))
 116  	h.Write([]byte(wsMagic))
 117  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 118  }
 119  
 120  func (s *Server) upgradeWS(c *tconn, req *httpReq) {
 121  	key := req.headers["sec-websocket-key"]
 122  	if key == "" {
 123  		writeHTTPResponse(c.fd, 400, nil, []byte("missing Sec-WebSocket-Key"))
 124  		s.closeConn(c)
 125  		return
 126  	}
 127  	wl, allow := s.handler.OnWSUpgrade(c.fd, c.remoteIP, s.ipConns[c.remoteIP])
 128  	if !allow {
 129  		writeHTTPResponse(c.fd, 429, nil, []byte("too many websockets"))
 130  		s.closeConn(c)
 131  		return
 132  	}
 133  	s.ipConns[c.remoteIP]++
 134  	c.whitelisted = wl
 135  	accept := computeAccept(key)
 136  	resp := "HTTP/1.1 101 Switching Protocols\r\n" |
 137  		"Upgrade: websocket\r\n" |
 138  		"Connection: Upgrade\r\n" |
 139  		"Sec-WebSocket-Accept: " | accept | "\r\n" |
 140  		"\r\n"
 141  	writeAll(c.fd, []byte(resp))
 142  	c.phase = phaseWS
 143  	s.handler.OnWSConnected(c.fd)
 144  	if c.wpos > 0 {
 145  		s.processWS(c)
 146  	}
 147  }
 148  
 149  func (s *Server) processWS(c *tconn) {
 150  	op, payload, consumed := parseWSFrame(c.buf[:c.wpos])
 151  	if consumed == 0 {
 152  		return
 153  	}
 154  	copy(c.buf, c.buf[consumed:c.wpos])
 155  	c.wpos -= consumed
 156  
 157  	switch op {
 158  	case opClose:
 159  		writeWSClose(c.fd)
 160  		s.closeConn(c)
 161  		return
 162  	case opPing:
 163  		writeWSFrame(c.fd, opPong, payload)
 164  	case opText, opBin:
 165  		s.handler.OnWSMessage(c.fd, payload)
 166  	}
 167  }
 168  
 169  // Challenge generates a 32-byte random hex challenge string for NIP-42 auth.
 170  func Challenge() []byte {
 171  	var cb [32]byte
 172  	rand.Read(cb[:])
 173  	out := []byte{:64}
 174  	hex.Encode(out, cb[:])
 175  	return out
 176  }
 177