1 package transport
2
3 import (
4 "crypto/rand"
5 "crypto/sha1"
6 "encoding/base64"
7 "encoding/hex"
8 "syscall"
9 )
10
11 const wsMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
12
13 // parseWSFrame extracts one WS frame from data.
14 // Returns (op, payload, consumed). consumed=0 means incomplete.
15 func parseWSFrame(data []byte) (byte, []byte, int) {
16 if len(data) < 2 {
17 return 0, nil, 0
18 }
19 op := data[0] & 0x0f
20 masked := data[1]&0x80 != 0
21 length := int(data[1] & 0x7f)
22 pos := 2
23
24 if length == 126 {
25 if len(data) < 4 {
26 return 0, nil, 0
27 }
28 length = int(data[2])<<8 | int(data[3])
29 pos = 4
30 } else if length == 127 {
31 if len(data) < 10 {
32 return 0, nil, 0
33 }
34 length = int(data[6])<<24 | int(data[7])<<16 | int(data[8])<<8 | int(data[9])
35 pos = 10
36 }
37
38 var mask [4]byte
39 if masked {
40 if len(data) < pos+4 {
41 return 0, nil, 0
42 }
43 copy(mask[:], data[pos:pos+4])
44 pos += 4
45 }
46 if len(data) < pos+length {
47 return 0, nil, 0
48 }
49
50 payload := makeCopy(data[pos : pos+length])
51 if masked {
52 for i := range payload {
53 payload[i] ^= mask[i%4]
54 }
55 }
56 return op, payload, pos + length
57 }
58
59 func buildWSFrame(op byte, payload []byte) []byte {
60 plen := len(payload)
61 var hdr [10]byte
62 hdr[0] = 0x80 | op
63 n := 2
64 if plen < 126 {
65 hdr[1] = byte(plen)
66 } else if plen < 65536 {
67 hdr[1] = 126
68 hdr[2] = byte(plen >> 8)
69 hdr[3] = byte(plen)
70 n = 4
71 } else {
72 hdr[1] = 127
73 hdr[6] = byte(plen >> 24)
74 hdr[7] = byte(plen >> 16)
75 hdr[8] = byte(plen >> 8)
76 hdr[9] = byte(plen)
77 n = 10
78 }
79 buf := []byte{:0:n + plen}
80 buf = append(buf, hdr[:n]...)
81 buf = append(buf, payload...)
82 return buf
83 }
84
85 func writeWSFrame(fd int, op byte, payload []byte) {
86 writeAll(fd, buildWSFrame(op, payload))
87 }
88
89 func writeWSFrameErr(fd int, op byte, payload []byte) error {
90 return writeAll(fd, buildWSFrame(op, payload))
91 }
92
93 func writeWSClose(fd int) {
94 writeWSFrame(fd, opClose, []byte{0x03, 0xe8})
95 }
96
97 func writeAll(fd int, data []byte) error {
98 for len(data) > 0 {
99 n, err := syscall.Write(fd, data)
100 if n > 0 {
101 data = data[n:]
102 }
103 if err == syscall.EAGAIN {
104 return syscall.EAGAIN
105 }
106 if err != nil {
107 return err
108 }
109 }
110 return nil
111 }
112
113 func computeAccept(key string) string {
114 h := sha1.New()
115 h.Write([]byte(key))
116 h.Write([]byte(wsMagic))
117 return base64.StdEncoding.EncodeToString(h.Sum(nil))
118 }
119
120 func (s *Server) upgradeWS(c *tconn, req *httpReq) {
121 key := req.headers["sec-websocket-key"]
122 if key == "" {
123 writeHTTPResponse(c.fd, 400, nil, []byte("missing Sec-WebSocket-Key"))
124 s.closeConn(c)
125 return
126 }
127 wl, allow := s.handler.OnWSUpgrade(c.fd, c.remoteIP, s.ipConns[c.remoteIP])
128 if !allow {
129 writeHTTPResponse(c.fd, 429, nil, []byte("too many websockets"))
130 s.closeConn(c)
131 return
132 }
133 s.ipConns[c.remoteIP]++
134 c.whitelisted = wl
135 accept := computeAccept(key)
136 resp := "HTTP/1.1 101 Switching Protocols\r\n" |
137 "Upgrade: websocket\r\n" |
138 "Connection: Upgrade\r\n" |
139 "Sec-WebSocket-Accept: " | accept | "\r\n" |
140 "\r\n"
141 writeAll(c.fd, []byte(resp))
142 c.phase = phaseWS
143 s.handler.OnWSConnected(c.fd)
144 if c.wpos > 0 {
145 s.processWS(c)
146 }
147 }
148
149 func (s *Server) processWS(c *tconn) {
150 op, payload, consumed := parseWSFrame(c.buf[:c.wpos])
151 if consumed == 0 {
152 return
153 }
154 copy(c.buf, c.buf[consumed:c.wpos])
155 c.wpos -= consumed
156
157 switch op {
158 case opClose:
159 writeWSClose(c.fd)
160 s.closeConn(c)
161 return
162 case opPing:
163 writeWSFrame(c.fd, opPong, payload)
164 case opText, opBin:
165 s.handler.OnWSMessage(c.fd, payload)
166 }
167 }
168
169 // Challenge generates a 32-byte random hex challenge string for NIP-42 auth.
170 func Challenge() []byte {
171 var cb [32]byte
172 rand.Read(cb[:])
173 out := []byte{:64}
174 hex.Encode(out, cb[:])
175 return out
176 }
177