1 // Package ws provides a minimal WebSocket client for Nostr relays.
2 // Implements RFC 6455 over raw TCP/TLS — no net/http dependency.
3 package ws
4
5 import (
6 "bufio"
7 "bytes"
8 "crypto/rand"
9 "crypto/sha1"
10 "crypto/tls"
11 "encoding/base64"
12 "encoding/binary"
13 "fmt"
14 "io"
15 "net"
16 "net/url"
17 "time"
18 )
19
20 const (
21 OpText byte = 0x1
22 OpBinary byte = 0x2
23 OpClose byte = 0x8
24 OpPing byte = 0x9
25 OpPong byte = 0xA
26
27 wsMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
28 maxPayload = 33 << 20 // 33 MB
29 )
30
31 // Conn is a WebSocket connection (client or server mode).
32 type Conn struct {
33 raw net.Conn
34 br *bufio.Reader
35 server bool // true = server mode (unmasked writes)
36 }
37
38 // NewServerConn wraps an already-upgraded connection in server mode.
39 func NewServerConn(conn net.Conn, br *bufio.Reader) *Conn {
40 return &Conn{raw: conn, br: br, server: true}
41 }
42
43 // ComputeAccept calculates Sec-WebSocket-Accept from a client key.
44 func ComputeAccept(key string) string { return computeAccept(key) }
45
46 // Dial opens a WebSocket connection to the given URL (ws:// or wss://).
47 func Dial(rawURL string) (*Conn, error) {
48 u, err := url.Parse(rawURL)
49 if err != nil {
50 return nil, err
51 }
52 host := u.Hostname()
53 port := u.Port()
54 useTLS := u.Scheme == "wss"
55 if port == "" {
56 if useTLS {
57 port = "443"
58 } else {
59 port = "80"
60 }
61 }
62 // Resolve hostname via DNS cache (24h TTL).
63 ip := host
64 if net.ParseIP(host) == nil {
65 ip, err = resolveHost(host)
66 if err != nil {
67 return nil, fmt.Errorf("ws: resolve %s: %w", host, err)
68 }
69 }
70 addr := net.JoinHostPort(ip, port)
71
72 var conn net.Conn
73 conn, err = net.Dial("tcp", addr)
74 if err != nil {
75 return nil, fmt.Errorf("ws: dial %s: %w", addr, err)
76 }
77 if useTLS {
78 tlsConn := tls.Client(conn, &tls.Config{ServerName: []byte(host)})
79 if err = tlsConn.Handshake(); err != nil {
80 conn.Close()
81 return nil, fmt.Errorf("ws: tls %s: %w", host, err)
82 }
83 conn = tlsConn
84 }
85
86 path := u.RequestURI()
87 if path == "" {
88 path = "/"
89 }
90
91 // Generate Sec-WebSocket-Key.
92 var keyRaw [16]byte
93 io.ReadFull(rand.Reader, keyRaw[:])
94 wsKey := base64.StdEncoding.EncodeToString(keyRaw[:])
95
96 // Send HTTP upgrade.
97 req := "GET " + path + " HTTP/1.1\r\n" +
98 "Host: " + host + "\r\n" +
99 "Upgrade: websocket\r\n" +
100 "Connection: Upgrade\r\n" +
101 "Sec-WebSocket-Key: " + wsKey + "\r\n" +
102 "Sec-WebSocket-Version: 13\r\n" +
103 "\r\n"
104 if _, err = conn.Write([]byte(req)); err != nil {
105 conn.Close()
106 return nil, fmt.Errorf("ws: write upgrade: %w", err)
107 }
108
109 br := bufio.NewReaderSize(conn, 32768)
110
111 // Read status line.
112 status, err := br.ReadString('\n')
113 if err != nil {
114 conn.Close()
115 return nil, fmt.Errorf("ws: read status: %w", err)
116 }
117 if !bytes.Contains(status, "101") {
118 conn.Close()
119 return nil, fmt.Errorf("ws: upgrade rejected: %s", bytes.TrimSpace(status))
120 }
121
122 // Consume headers, validate accept.
123 expectedAccept := computeAccept(wsKey)
124 var accepted bool
125 for {
126 line, err := br.ReadString('\n')
127 if err != nil {
128 conn.Close()
129 return nil, fmt.Errorf("ws: read header: %w", err)
130 }
131 trimmed := bytes.TrimSpace(line)
132 if trimmed == "" {
133 break
134 }
135 lower := bytes.ToLower(trimmed)
136 if bytes.HasPrefix(lower, "sec-websocket-accept:") {
137 val := bytes.TrimSpace(trimmed[len("sec-websocket-accept:"):])
138 if val == expectedAccept {
139 accepted = true
140 }
141 }
142 }
143 if !accepted {
144 conn.Close()
145 return nil, fmt.Errorf("ws: bad accept key")
146 }
147 return &Conn{raw: conn, br: br}, nil
148 }
149
150 func computeAccept(key string) string {
151 h := sha1.New()
152 h.Write([]byte(key))
153 h.Write([]byte(wsMagic))
154 return base64.StdEncoding.EncodeToString(h.Sum(nil))
155 }
156
157 // WriteText sends a text frame.
158 func (c *Conn) WriteText(msg []byte) error { return c.writeFrame(OpText, msg) }
159
160 // WritePong sends a pong control frame.
161 func (c *Conn) WritePong(data []byte) error { return c.writeFrame(OpPong, data) }
162
163 func (c *Conn) writeFrame(op byte, payload []byte) error {
164 if c.server {
165 return c.writeServerFrame(op, payload)
166 }
167 return c.writeClientFrame(op, payload)
168 }
169
170 // writeServerFrame sends an unmasked frame (RFC 6455: servers MUST NOT mask).
171 func (c *Conn) writeServerFrame(op byte, payload []byte) error {
172 plen := len(payload)
173 var hdr []byte
174 if plen < 126 {
175 hdr = []byte{:2}
176 hdr[1] = byte(plen)
177 } else if plen < 65536 {
178 hdr = []byte{:4}
179 hdr[1] = 126
180 binary.BigEndian.PutUint16(hdr[2:], uint16(plen))
181 } else {
182 hdr = []byte{:10}
183 hdr[1] = 127
184 binary.BigEndian.PutUint64(hdr[2:], uint64(plen))
185 }
186 hdr[0] = 0x80 | op
187 if _, err := c.raw.Write(hdr); err != nil {
188 return err
189 }
190 _, err := c.raw.Write(payload)
191 return err
192 }
193
194 // writeClientFrame sends a masked frame (RFC 6455: clients MUST mask).
195 func (c *Conn) writeClientFrame(op byte, payload []byte) error {
196 plen := len(payload)
197 var hdr []byte
198 if plen < 126 {
199 hdr = []byte{:6} // 2 + 4 mask
200 hdr[1] = 0x80 | byte(plen)
201 } else if plen < 65536 {
202 hdr = []byte{:8} // 4 + 4 mask
203 hdr[1] = 0x80 | 126
204 binary.BigEndian.PutUint16(hdr[2:], uint16(plen))
205 } else {
206 hdr = []byte{:14} // 10 + 4 mask
207 hdr[1] = 0x80 | 127
208 binary.BigEndian.PutUint64(hdr[2:], uint64(plen))
209 }
210 hdr[0] = 0x80 | op
211
212 maskOff := len(hdr) - 4
213 io.ReadFull(rand.Reader, hdr[maskOff:])
214 mask := [4]byte{hdr[maskOff], hdr[maskOff+1], hdr[maskOff+2], hdr[maskOff+3]}
215
216 masked := []byte{:plen}
217 for i, b := range payload {
218 masked[i] = b ^ mask[i%4]
219 }
220 if _, err := c.raw.Write(hdr); err != nil {
221 return err
222 }
223 _, err := c.raw.Write(masked)
224 return err
225 }
226
227 // ReadMessage reads the next data frame, automatically handling ping/pong.
228 func (c *Conn) ReadMessage() (op byte, payload []byte, err error) {
229 for {
230 op, payload, err = c.readFrame()
231 if err != nil {
232 return
233 }
234 switch op {
235 case OpPing:
236 c.WritePong(payload)
237 case OpPong:
238 // ignore
239 case OpClose:
240 c.writeFrame(OpClose, payload)
241 return
242 default:
243 return
244 }
245 }
246 }
247
248 func (c *Conn) readFrame() (op byte, payload []byte, err error) {
249 var hdr [2]byte
250 if _, err = io.ReadFull(c.br, hdr[:]); err != nil {
251 return
252 }
253 op = hdr[0] & 0x0F
254 masked := hdr[1]&0x80 != 0
255 plen := uint64(hdr[1] & 0x7F)
256
257 if plen == 126 {
258 var ext [2]byte
259 if _, err = io.ReadFull(c.br, ext[:]); err != nil {
260 return
261 }
262 plen = uint64(binary.BigEndian.Uint16(ext[:]))
263 } else if plen == 127 {
264 var ext [8]byte
265 if _, err = io.ReadFull(c.br, ext[:]); err != nil {
266 return
267 }
268 plen = binary.BigEndian.Uint64(ext[:])
269 }
270 if plen > uint64(maxPayload) {
271 err = fmt.Errorf("ws: payload %d exceeds limit %d", plen, maxPayload)
272 return
273 }
274
275 var mask [4]byte
276 if masked {
277 if _, err = io.ReadFull(c.br, mask[:]); err != nil {
278 return
279 }
280 }
281
282 payload = []byte{:plen}
283 if _, err = io.ReadFull(c.br, payload); err != nil {
284 return
285 }
286 if masked {
287 for i := range payload {
288 payload[i] ^= mask[i%4]
289 }
290 }
291 return
292 }
293
294 // SetReadDeadline sets the read deadline on the underlying connection.
295 func (c *Conn) SetReadDeadline(t time.Time) error {
296 return c.raw.SetReadDeadline(t)
297 }
298
299 // Close sends a close frame and closes the TCP connection.
300 func (c *Conn) Close() error {
301 data := []byte{0x03, 0xE8} // code 1000 normal closure
302 c.writeFrame(OpClose, data)
303 return c.raw.Close()
304 }
305