ws.mx raw

   1  // Package ws provides a minimal WebSocket client for Nostr relays.
   2  // Implements RFC 6455 over raw TCP/TLS — no net/http dependency.
   3  package ws
   4  
   5  import (
   6  	"bufio"
   7  	"bytes"
   8  	"crypto/rand"
   9  	"crypto/sha1"
  10  	"crypto/tls"
  11  	"encoding/base64"
  12  	"encoding/binary"
  13  	"fmt"
  14  	"io"
  15  	"net"
  16  	"net/url"
  17  	"time"
  18  )
  19  
  20  const (
  21  	OpText   byte = 0x1
  22  	OpBinary byte = 0x2
  23  	OpClose  byte = 0x8
  24  	OpPing   byte = 0x9
  25  	OpPong   byte = 0xA
  26  
  27  	wsMagic    = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  28  	maxPayload = 33 << 20 // 33 MB
  29  )
  30  
  31  // Conn is a WebSocket connection (client or server mode).
  32  type Conn struct {
  33  	raw    net.Conn
  34  	br     *bufio.Reader
  35  	server bool // true = server mode (unmasked writes)
  36  }
  37  
  38  // NewServerConn wraps an already-upgraded connection in server mode.
  39  func NewServerConn(conn net.Conn, br *bufio.Reader) *Conn {
  40  	return &Conn{raw: conn, br: br, server: true}
  41  }
  42  
  43  // ComputeAccept calculates Sec-WebSocket-Accept from a client key.
  44  func ComputeAccept(key string) string { return computeAccept(key) }
  45  
  46  // Dial opens a WebSocket connection to the given URL (ws:// or wss://).
  47  func Dial(rawURL string) (*Conn, error) {
  48  	u, err := url.Parse(rawURL)
  49  	if err != nil {
  50  		return nil, err
  51  	}
  52  	host := u.Hostname()
  53  	port := u.Port()
  54  	useTLS := u.Scheme == "wss"
  55  	if port == "" {
  56  		if useTLS {
  57  			port = "443"
  58  		} else {
  59  			port = "80"
  60  		}
  61  	}
  62  	// Resolve hostname via DNS cache (24h TTL).
  63  	ip := host
  64  	if net.ParseIP(host) == nil {
  65  		ip, err = resolveHost(host)
  66  		if err != nil {
  67  			return nil, fmt.Errorf("ws: resolve %s: %w", host, err)
  68  		}
  69  	}
  70  	addr := net.JoinHostPort(ip, port)
  71  
  72  	var conn net.Conn
  73  	conn, err = net.Dial("tcp", addr)
  74  	if err != nil {
  75  		return nil, fmt.Errorf("ws: dial %s: %w", addr, err)
  76  	}
  77  	if useTLS {
  78  		tlsConn := tls.Client(conn, &tls.Config{ServerName: []byte(host)})
  79  		if err = tlsConn.Handshake(); err != nil {
  80  			conn.Close()
  81  			return nil, fmt.Errorf("ws: tls %s: %w", host, err)
  82  		}
  83  		conn = tlsConn
  84  	}
  85  
  86  	path := u.RequestURI()
  87  	if path == "" {
  88  		path = "/"
  89  	}
  90  
  91  	// Generate Sec-WebSocket-Key.
  92  	var keyRaw [16]byte
  93  	io.ReadFull(rand.Reader, keyRaw[:])
  94  	wsKey := base64.StdEncoding.EncodeToString(keyRaw[:])
  95  
  96  	// Send HTTP upgrade.
  97  	req := "GET " + path + " HTTP/1.1\r\n" +
  98  		"Host: " + host + "\r\n" +
  99  		"Upgrade: websocket\r\n" +
 100  		"Connection: Upgrade\r\n" +
 101  		"Sec-WebSocket-Key: " + wsKey + "\r\n" +
 102  		"Sec-WebSocket-Version: 13\r\n" +
 103  		"\r\n"
 104  	if _, err = conn.Write([]byte(req)); err != nil {
 105  		conn.Close()
 106  		return nil, fmt.Errorf("ws: write upgrade: %w", err)
 107  	}
 108  
 109  	br := bufio.NewReaderSize(conn, 32768)
 110  
 111  	// Read status line.
 112  	status, err := br.ReadString('\n')
 113  	if err != nil {
 114  		conn.Close()
 115  		return nil, fmt.Errorf("ws: read status: %w", err)
 116  	}
 117  	if !bytes.Contains(status, "101") {
 118  		conn.Close()
 119  		return nil, fmt.Errorf("ws: upgrade rejected: %s", bytes.TrimSpace(status))
 120  	}
 121  
 122  	// Consume headers, validate accept.
 123  	expectedAccept := computeAccept(wsKey)
 124  	var accepted bool
 125  	for {
 126  		line, err := br.ReadString('\n')
 127  		if err != nil {
 128  			conn.Close()
 129  			return nil, fmt.Errorf("ws: read header: %w", err)
 130  		}
 131  		trimmed := bytes.TrimSpace(line)
 132  		if trimmed == "" {
 133  			break
 134  		}
 135  		lower := bytes.ToLower(trimmed)
 136  		if bytes.HasPrefix(lower, "sec-websocket-accept:") {
 137  			val := bytes.TrimSpace(trimmed[len("sec-websocket-accept:"):])
 138  			if val == expectedAccept {
 139  				accepted = true
 140  			}
 141  		}
 142  	}
 143  	if !accepted {
 144  		conn.Close()
 145  		return nil, fmt.Errorf("ws: bad accept key")
 146  	}
 147  	return &Conn{raw: conn, br: br}, nil
 148  }
 149  
 150  func computeAccept(key string) string {
 151  	h := sha1.New()
 152  	h.Write([]byte(key))
 153  	h.Write([]byte(wsMagic))
 154  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 155  }
 156  
 157  // WriteText sends a text frame.
 158  func (c *Conn) WriteText(msg []byte) error { return c.writeFrame(OpText, msg) }
 159  
 160  // WritePong sends a pong control frame.
 161  func (c *Conn) WritePong(data []byte) error { return c.writeFrame(OpPong, data) }
 162  
 163  func (c *Conn) writeFrame(op byte, payload []byte) error {
 164  	if c.server {
 165  		return c.writeServerFrame(op, payload)
 166  	}
 167  	return c.writeClientFrame(op, payload)
 168  }
 169  
 170  // writeServerFrame sends an unmasked frame (RFC 6455: servers MUST NOT mask).
 171  func (c *Conn) writeServerFrame(op byte, payload []byte) error {
 172  	plen := len(payload)
 173  	var hdr []byte
 174  	if plen < 126 {
 175  		hdr = []byte{:2}
 176  		hdr[1] = byte(plen)
 177  	} else if plen < 65536 {
 178  		hdr = []byte{:4}
 179  		hdr[1] = 126
 180  		binary.BigEndian.PutUint16(hdr[2:], uint16(plen))
 181  	} else {
 182  		hdr = []byte{:10}
 183  		hdr[1] = 127
 184  		binary.BigEndian.PutUint64(hdr[2:], uint64(plen))
 185  	}
 186  	hdr[0] = 0x80 | op
 187  	if _, err := c.raw.Write(hdr); err != nil {
 188  		return err
 189  	}
 190  	_, err := c.raw.Write(payload)
 191  	return err
 192  }
 193  
 194  // writeClientFrame sends a masked frame (RFC 6455: clients MUST mask).
 195  func (c *Conn) writeClientFrame(op byte, payload []byte) error {
 196  	plen := len(payload)
 197  	var hdr []byte
 198  	if plen < 126 {
 199  		hdr = []byte{:6} // 2 + 4 mask
 200  		hdr[1] = 0x80 | byte(plen)
 201  	} else if plen < 65536 {
 202  		hdr = []byte{:8} // 4 + 4 mask
 203  		hdr[1] = 0x80 | 126
 204  		binary.BigEndian.PutUint16(hdr[2:], uint16(plen))
 205  	} else {
 206  		hdr = []byte{:14} // 10 + 4 mask
 207  		hdr[1] = 0x80 | 127
 208  		binary.BigEndian.PutUint64(hdr[2:], uint64(plen))
 209  	}
 210  	hdr[0] = 0x80 | op
 211  
 212  	maskOff := len(hdr) - 4
 213  	io.ReadFull(rand.Reader, hdr[maskOff:])
 214  	mask := [4]byte{hdr[maskOff], hdr[maskOff+1], hdr[maskOff+2], hdr[maskOff+3]}
 215  
 216  	masked := []byte{:plen}
 217  	for i, b := range payload {
 218  		masked[i] = b ^ mask[i%4]
 219  	}
 220  	if _, err := c.raw.Write(hdr); err != nil {
 221  		return err
 222  	}
 223  	_, err := c.raw.Write(masked)
 224  	return err
 225  }
 226  
 227  // ReadMessage reads the next data frame, automatically handling ping/pong.
 228  func (c *Conn) ReadMessage() (op byte, payload []byte, err error) {
 229  	for {
 230  		op, payload, err = c.readFrame()
 231  		if err != nil {
 232  			return
 233  		}
 234  		switch op {
 235  		case OpPing:
 236  			c.WritePong(payload)
 237  		case OpPong:
 238  			// ignore
 239  		case OpClose:
 240  			c.writeFrame(OpClose, payload)
 241  			return
 242  		default:
 243  			return
 244  		}
 245  	}
 246  }
 247  
 248  func (c *Conn) readFrame() (op byte, payload []byte, err error) {
 249  	var hdr [2]byte
 250  	if _, err = io.ReadFull(c.br, hdr[:]); err != nil {
 251  		return
 252  	}
 253  	op = hdr[0] & 0x0F
 254  	masked := hdr[1]&0x80 != 0
 255  	plen := uint64(hdr[1] & 0x7F)
 256  
 257  	if plen == 126 {
 258  		var ext [2]byte
 259  		if _, err = io.ReadFull(c.br, ext[:]); err != nil {
 260  			return
 261  		}
 262  		plen = uint64(binary.BigEndian.Uint16(ext[:]))
 263  	} else if plen == 127 {
 264  		var ext [8]byte
 265  		if _, err = io.ReadFull(c.br, ext[:]); err != nil {
 266  			return
 267  		}
 268  		plen = binary.BigEndian.Uint64(ext[:])
 269  	}
 270  	if plen > uint64(maxPayload) {
 271  		err = fmt.Errorf("ws: payload %d exceeds limit %d", plen, maxPayload)
 272  		return
 273  	}
 274  
 275  	var mask [4]byte
 276  	if masked {
 277  		if _, err = io.ReadFull(c.br, mask[:]); err != nil {
 278  			return
 279  		}
 280  	}
 281  
 282  	payload = []byte{:plen}
 283  	if _, err = io.ReadFull(c.br, payload); err != nil {
 284  		return
 285  	}
 286  	if masked {
 287  		for i := range payload {
 288  			payload[i] ^= mask[i%4]
 289  		}
 290  	}
 291  	return
 292  }
 293  
 294  // SetReadDeadline sets the read deadline on the underlying connection.
 295  func (c *Conn) SetReadDeadline(t time.Time) error {
 296  	return c.raw.SetReadDeadline(t)
 297  }
 298  
 299  // Close sends a close frame and closes the TCP connection.
 300  func (c *Conn) Close() error {
 301  	data := []byte{0x03, 0xE8} // code 1000 normal closure
 302  	c.writeFrame(OpClose, data)
 303  	return c.raw.Close()
 304  }
 305