// Package ws provides a minimal WebSocket client for Nostr relays. // Implements RFC 6455 over raw TCP/TLS — no net/http dependency. package ws import ( "bufio" "bytes" "crypto/rand" "crypto/sha1" "crypto/tls" "encoding/base64" "encoding/binary" "fmt" "io" "net" "net/url" "time" ) const ( OpText byte = 0x1 OpBinary byte = 0x2 OpClose byte = 0x8 OpPing byte = 0x9 OpPong byte = 0xA wsMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" maxPayload = 33 << 20 // 33 MB ) // Conn is a WebSocket connection (client or server mode). type Conn struct { raw net.Conn br *bufio.Reader server bool // true = server mode (unmasked writes) } // NewServerConn wraps an already-upgraded connection in server mode. func NewServerConn(conn net.Conn, br *bufio.Reader) *Conn { return &Conn{raw: conn, br: br, server: true} } // ComputeAccept calculates Sec-WebSocket-Accept from a client key. func ComputeAccept(key string) string { return computeAccept(key) } // Dial opens a WebSocket connection to the given URL (ws:// or wss://). func Dial(rawURL string) (*Conn, error) { u, err := url.Parse(rawURL) if err != nil { return nil, err } host := u.Hostname() port := u.Port() useTLS := u.Scheme == "wss" if port == "" { if useTLS { port = "443" } else { port = "80" } } // Resolve hostname via DNS cache (24h TTL). ip := host if net.ParseIP(host) == nil { ip, err = resolveHost(host) if err != nil { return nil, fmt.Errorf("ws: resolve %s: %w", host, err) } } addr := net.JoinHostPort(ip, port) var conn net.Conn conn, err = net.Dial("tcp", addr) if err != nil { return nil, fmt.Errorf("ws: dial %s: %w", addr, err) } if useTLS { tlsConn := tls.Client(conn, &tls.Config{ServerName: []byte(host)}) if err = tlsConn.Handshake(); err != nil { conn.Close() return nil, fmt.Errorf("ws: tls %s: %w", host, err) } conn = tlsConn } path := u.RequestURI() if path == "" { path = "/" } // Generate Sec-WebSocket-Key. var keyRaw [16]byte io.ReadFull(rand.Reader, keyRaw[:]) wsKey := base64.StdEncoding.EncodeToString(keyRaw[:]) // Send HTTP upgrade. req := "GET " + path + " HTTP/1.1\r\n" + "Host: " + host + "\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: " + wsKey + "\r\n" + "Sec-WebSocket-Version: 13\r\n" + "\r\n" if _, err = conn.Write([]byte(req)); err != nil { conn.Close() return nil, fmt.Errorf("ws: write upgrade: %w", err) } br := bufio.NewReaderSize(conn, 32768) // Read status line. status, err := br.ReadString('\n') if err != nil { conn.Close() return nil, fmt.Errorf("ws: read status: %w", err) } if !bytes.Contains(status, "101") { conn.Close() return nil, fmt.Errorf("ws: upgrade rejected: %s", bytes.TrimSpace(status)) } // Consume headers, validate accept. expectedAccept := computeAccept(wsKey) var accepted bool for { line, err := br.ReadString('\n') if err != nil { conn.Close() return nil, fmt.Errorf("ws: read header: %w", err) } trimmed := bytes.TrimSpace(line) if trimmed == "" { break } lower := bytes.ToLower(trimmed) if bytes.HasPrefix(lower, "sec-websocket-accept:") { val := bytes.TrimSpace(trimmed[len("sec-websocket-accept:"):]) if val == expectedAccept { accepted = true } } } if !accepted { conn.Close() return nil, fmt.Errorf("ws: bad accept key") } return &Conn{raw: conn, br: br}, nil } func computeAccept(key string) string { h := sha1.New() h.Write([]byte(key)) h.Write([]byte(wsMagic)) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // WriteText sends a text frame. func (c *Conn) WriteText(msg []byte) error { return c.writeFrame(OpText, msg) } // WritePong sends a pong control frame. func (c *Conn) WritePong(data []byte) error { return c.writeFrame(OpPong, data) } func (c *Conn) writeFrame(op byte, payload []byte) error { if c.server { return c.writeServerFrame(op, payload) } return c.writeClientFrame(op, payload) } // writeServerFrame sends an unmasked frame (RFC 6455: servers MUST NOT mask). func (c *Conn) writeServerFrame(op byte, payload []byte) error { plen := len(payload) var hdr []byte if plen < 126 { hdr = []byte{:2} hdr[1] = byte(plen) } else if plen < 65536 { hdr = []byte{:4} hdr[1] = 126 binary.BigEndian.PutUint16(hdr[2:], uint16(plen)) } else { hdr = []byte{:10} hdr[1] = 127 binary.BigEndian.PutUint64(hdr[2:], uint64(plen)) } hdr[0] = 0x80 | op if _, err := c.raw.Write(hdr); err != nil { return err } _, err := c.raw.Write(payload) return err } // writeClientFrame sends a masked frame (RFC 6455: clients MUST mask). func (c *Conn) writeClientFrame(op byte, payload []byte) error { plen := len(payload) var hdr []byte if plen < 126 { hdr = []byte{:6} // 2 + 4 mask hdr[1] = 0x80 | byte(plen) } else if plen < 65536 { hdr = []byte{:8} // 4 + 4 mask hdr[1] = 0x80 | 126 binary.BigEndian.PutUint16(hdr[2:], uint16(plen)) } else { hdr = []byte{:14} // 10 + 4 mask hdr[1] = 0x80 | 127 binary.BigEndian.PutUint64(hdr[2:], uint64(plen)) } hdr[0] = 0x80 | op maskOff := len(hdr) - 4 io.ReadFull(rand.Reader, hdr[maskOff:]) mask := [4]byte{hdr[maskOff], hdr[maskOff+1], hdr[maskOff+2], hdr[maskOff+3]} masked := []byte{:plen} for i, b := range payload { masked[i] = b ^ mask[i%4] } if _, err := c.raw.Write(hdr); err != nil { return err } _, err := c.raw.Write(masked) return err } // ReadMessage reads the next data frame, automatically handling ping/pong. func (c *Conn) ReadMessage() (op byte, payload []byte, err error) { for { op, payload, err = c.readFrame() if err != nil { return } switch op { case OpPing: c.WritePong(payload) case OpPong: // ignore case OpClose: c.writeFrame(OpClose, payload) return default: return } } } func (c *Conn) readFrame() (op byte, payload []byte, err error) { var hdr [2]byte if _, err = io.ReadFull(c.br, hdr[:]); err != nil { return } op = hdr[0] & 0x0F masked := hdr[1]&0x80 != 0 plen := uint64(hdr[1] & 0x7F) if plen == 126 { var ext [2]byte if _, err = io.ReadFull(c.br, ext[:]); err != nil { return } plen = uint64(binary.BigEndian.Uint16(ext[:])) } else if plen == 127 { var ext [8]byte if _, err = io.ReadFull(c.br, ext[:]); err != nil { return } plen = binary.BigEndian.Uint64(ext[:]) } if plen > uint64(maxPayload) { err = fmt.Errorf("ws: payload %d exceeds limit %d", plen, maxPayload) return } var mask [4]byte if masked { if _, err = io.ReadFull(c.br, mask[:]); err != nil { return } } payload = []byte{:plen} if _, err = io.ReadFull(c.br, payload); err != nil { return } if masked { for i := range payload { payload[i] ^= mask[i%4] } } return } // SetReadDeadline sets the read deadline on the underlying connection. func (c *Conn) SetReadDeadline(t time.Time) error { return c.raw.SetReadDeadline(t) } // Close sends a close frame and closes the TCP connection. func (c *Conn) Close() error { data := []byte{0x03, 0xE8} // code 1000 normal closure c.writeFrame(OpClose, data) return c.raw.Close() }