server.mx raw

   1  // Package server provides a Nostr relay server.
   2  // Uses an epoll event loop with non-blocking I/O to handle multiple
   3  // WebSocket connections in a single-threaded moxie domain.
   4  package server
   5  
   6  import (
   7  	"bytes"
   8  	"crypto/sha1"
   9  	"encoding/base64"
  10  	"fmt"
  11  	"syscall"
  12  
  13  	"smesh.lol/pkg/nostr/envelope"
  14  	"smesh.lol/pkg/nostr/event"
  15  	"smesh.lol/pkg/nostr/filter"
  16  	"smesh.lol/pkg/relay/worker"
  17  )
  18  
  19  const wsMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
  20  const maxBuf = 1 << 20 // 1MB max per-connection buffer
  21  
  22  const (
  23  	phaseHTTP = 0
  24  	phaseWS   = 1
  25  )
  26  
  27  const (
  28  	opText  byte = 0x1
  29  	opBin   byte = 0x2
  30  	opClose byte = 0x8
  31  	opPing  byte = 0x9
  32  	opPong  byte = 0xA
  33  )
  34  
  35  // Server is a Nostr relay.
  36  type Server struct {
  37  	Fallback func(path string, headers map[string]string) (int, map[string]string, []byte)
  38  	OnReady  func() // called once after listener binds
  39  	store    *worker.Store
  40  	epfd     int
  41  	lnFD     int
  42  	conns    map[int]*conn
  43  }
  44  
  45  type conn struct {
  46  	fd    int
  47  	phase int
  48  	buf   []byte
  49  	wpos  int
  50  	subs  map[string]*sub
  51  	srv   *Server
  52  }
  53  
  54  type sub struct {
  55  	id      string
  56  	filters filter.S
  57  }
  58  
  59  type httpReq struct {
  60  	method  string
  61  	path    string
  62  	headers map[string]string
  63  }
  64  
  65  // New creates a relay server backed by the given store.
  66  func New(store *worker.Store) *Server {
  67  	return &Server{
  68  		store: store,
  69  		conns: map[int]*conn{},
  70  	}
  71  }
  72  
  73  // ListenAndServe runs the epoll event loop.
  74  func (s *Server) ListenAndServe(addr string) error {
  75  	ip, port := parseAddr(addr)
  76  
  77  	fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
  78  	if err != nil {
  79  		return fmt.Errorf("relay: socket: %w", err)
  80  	}
  81  	syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
  82  	if err := syscall.SetNonblock(fd, true); err != nil {
  83  		syscall.Close(fd)
  84  		return fmt.Errorf("relay: nonblock: %w", err)
  85  	}
  86  	sa := &syscall.SockaddrInet4{Port: port, Addr: ip}
  87  	if err := syscall.Bind(fd, sa); err != nil {
  88  		syscall.Close(fd)
  89  		return fmt.Errorf("relay: bind %s: %w", addr, err)
  90  	}
  91  	if err := syscall.Listen(fd, 128); err != nil {
  92  		syscall.Close(fd)
  93  		return fmt.Errorf("relay: listen: %w", err)
  94  	}
  95  	s.lnFD = fd
  96  
  97  	s.epfd, err = syscall.EpollCreate1(0)
  98  	if err != nil {
  99  		syscall.Close(fd)
 100  		return fmt.Errorf("relay: epoll: %w", err)
 101  	}
 102  	if err := epollAdd(s.epfd, fd); err != nil {
 103  		syscall.Close(s.epfd)
 104  		syscall.Close(fd)
 105  		return fmt.Errorf("relay: epoll add: %w", err)
 106  	}
 107  
 108  	if s.OnReady != nil {
 109  		s.OnReady()
 110  	}
 111  
 112  	events := []syscall.EpollEvent{:64}
 113  	for {
 114  		n, err := syscall.EpollWait(s.epfd, events, -1)
 115  		if err != nil {
 116  			if err == syscall.EINTR {
 117  				continue
 118  			}
 119  			return fmt.Errorf("relay: epoll wait: %w", err)
 120  		}
 121  		for i := 0; i < n; i++ {
 122  			evFD := int(events[i].Fd)
 123  			if evFD == s.lnFD {
 124  				s.acceptAll()
 125  			} else if c := s.conns[evFD]; c != nil {
 126  				if events[i].Events&(syscall.EPOLLERR|syscall.EPOLLHUP) != 0 {
 127  					s.closeConn(c)
 128  				} else {
 129  					s.readConn(c)
 130  				}
 131  			}
 132  		}
 133  	}
 134  }
 135  
 136  func (s *Server) acceptAll() {
 137  	for {
 138  		nfd, _, err := syscall.Accept4(s.lnFD, syscall.SOCK_NONBLOCK)
 139  		if err != nil {
 140  			return // EAGAIN — no more pending connections
 141  		}
 142  		if err := epollAdd(s.epfd, nfd); err != nil {
 143  			syscall.Close(nfd)
 144  			continue
 145  		}
 146  		s.conns[nfd] = &conn{
 147  			fd:    nfd,
 148  			phase: phaseHTTP,
 149  			buf:   []byte{:4096},
 150  			subs:  map[string]*sub{},
 151  			srv:   s,
 152  		}
 153  	}
 154  }
 155  
 156  func (s *Server) readConn(c *conn) {
 157  	avail := len(c.buf) - c.wpos
 158  	if avail < 512 {
 159  		if len(c.buf) >= maxBuf {
 160  			s.closeConn(c)
 161  			return
 162  		}
 163  		nb := []byte{:len(c.buf) * 2}
 164  		copy(nb, c.buf[:c.wpos])
 165  		c.buf = nb
 166  	}
 167  	n, err := syscall.Read(c.fd, c.buf[c.wpos:])
 168  	if n <= 0 {
 169  		if err == nil || (err != syscall.EAGAIN && err != syscall.EINTR) {
 170  			s.closeConn(c)
 171  		}
 172  		return
 173  	}
 174  	c.wpos += n
 175  
 176  	switch c.phase {
 177  	case phaseHTTP:
 178  		s.processHTTP(c)
 179  	case phaseWS:
 180  		s.processWS(c)
 181  	}
 182  }
 183  
 184  func (s *Server) closeConn(c *conn) {
 185  	epollDel(s.epfd, c.fd)
 186  	syscall.Close(c.fd)
 187  	delete(s.conns, c.fd)
 188  }
 189  
 190  // --- HTTP phase ---
 191  
 192  func (s *Server) processHTTP(c *conn) {
 193  	data := c.buf[:c.wpos]
 194  	end := bytes.Index(data, []byte("\r\n\r\n"))
 195  	if end < 0 {
 196  		return // incomplete headers
 197  	}
 198  	consumed := end + 4
 199  
 200  	req := parseHTTPHeaders(data[:end])
 201  	if req == nil {
 202  		s.closeConn(c)
 203  		return
 204  	}
 205  
 206  	// Consume parsed bytes.
 207  	copy(c.buf, c.buf[consumed:c.wpos])
 208  	c.wpos -= consumed
 209  
 210  	upgrade := req.headers["upgrade"]
 211  	if bytes.EqualFold(upgrade, "websocket") {
 212  		s.upgradeWS(c, req)
 213  		return
 214  	}
 215  
 216  	// Serve HTTP and close.
 217  	status, headers, body := s.routeHTTP(req)
 218  	writeHTTPResponse(c.fd, status, headers, body)
 219  	s.closeConn(c)
 220  }
 221  
 222  func parseHTTPHeaders(data []byte) *httpReq {
 223  	lineEnd := bytes.IndexByte(data, '\n')
 224  	if lineEnd < 0 {
 225  		return nil
 226  	}
 227  	line := data[:lineEnd]
 228  	if len(line) > 0 && line[len(line)-1] == '\r' {
 229  		line = line[:len(line)-1]
 230  	}
 231  
 232  	sp1 := bytes.IndexByte(line, ' ')
 233  	if sp1 < 0 {
 234  		return nil
 235  	}
 236  	method := string(makeCopy(line[:sp1]))
 237  	rest := line[sp1+1:]
 238  	sp2 := bytes.IndexByte(rest, ' ')
 239  	var path string
 240  	if sp2 >= 0 {
 241  		path = string(makeCopy(rest[:sp2]))
 242  	} else {
 243  		path = string(makeCopy(rest))
 244  	}
 245  
 246  	headers := map[string]string{}
 247  	pos := lineEnd + 1
 248  	for pos < len(data) {
 249  		nlPos := bytes.IndexByte(data[pos:], '\n')
 250  		if nlPos < 0 {
 251  			break
 252  		}
 253  		nlPos += pos
 254  		hline := data[pos:nlPos]
 255  		if len(hline) > 0 && hline[len(hline)-1] == '\r' {
 256  			hline = hline[:len(hline)-1]
 257  		}
 258  		colon := bytes.IndexByte(hline, ':')
 259  		if colon >= 0 {
 260  			key := string(toLower(makeCopy(hline[:colon])))
 261  			val := string(makeCopy(trimSpace(hline[colon+1:])))
 262  			headers[key] = val
 263  		}
 264  		pos = nlPos + 1
 265  	}
 266  
 267  	return &httpReq{method: method, path: path, headers: headers}
 268  }
 269  
 270  // --- WebSocket upgrade ---
 271  
 272  func (s *Server) upgradeWS(c *conn, req *httpReq) {
 273  	key := req.headers["sec-websocket-key"]
 274  	if key == "" {
 275  		writeHTTPResponse(c.fd, 400, nil, []byte("missing Sec-WebSocket-Key"))
 276  		s.closeConn(c)
 277  		return
 278  	}
 279  	accept := computeAccept(key)
 280  	resp := "HTTP/1.1 101 Switching Protocols\r\n" |
 281  		"Upgrade: websocket\r\n" |
 282  		"Connection: Upgrade\r\n" |
 283  		"Sec-WebSocket-Accept: " | accept | "\r\n" |
 284  		"\r\n"
 285  	writeAll(c.fd, []byte(resp))
 286  	c.phase = phaseWS
 287  	if c.wpos > 0 {
 288  		s.processWS(c)
 289  	}
 290  }
 291  
 292  // --- WebSocket frame handling ---
 293  
 294  func (s *Server) processWS(c *conn) {
 295  	for {
 296  		op, payload, consumed := parseWSFrame(c.buf[:c.wpos])
 297  		if consumed == 0 {
 298  			return // incomplete frame
 299  		}
 300  		copy(c.buf, c.buf[consumed:c.wpos])
 301  		c.wpos -= consumed
 302  
 303  		switch op {
 304  		case opClose:
 305  			writeWSClose(c.fd)
 306  			s.closeConn(c)
 307  			return
 308  		case opPing:
 309  			writeWSFrame(c.fd, opPong, payload)
 310  		case opText, opBin:
 311  			c.dispatch(payload)
 312  		}
 313  	}
 314  }
 315  
 316  // parseWSFrame extracts one WS frame from data.
 317  // Returns (op, payload, consumed). consumed=0 means incomplete.
 318  func parseWSFrame(data []byte) (byte, []byte, int) {
 319  	if len(data) < 2 {
 320  		return 0, nil, 0
 321  	}
 322  	op := data[0] & 0x0f
 323  	masked := data[1]&0x80 != 0
 324  	length := int(data[1] & 0x7f)
 325  	pos := 2
 326  
 327  	if length == 126 {
 328  		if len(data) < 4 {
 329  			return 0, nil, 0
 330  		}
 331  		length = int(data[2])<<8 | int(data[3])
 332  		pos = 4
 333  	} else if length == 127 {
 334  		if len(data) < 10 {
 335  			return 0, nil, 0
 336  		}
 337  		length = int(data[6])<<24 | int(data[7])<<16 | int(data[8])<<8 | int(data[9])
 338  		pos = 10
 339  	}
 340  
 341  	var mask [4]byte
 342  	if masked {
 343  		if len(data) < pos+4 {
 344  			return 0, nil, 0
 345  		}
 346  		copy(mask[:], data[pos:pos+4])
 347  		pos += 4
 348  	}
 349  	if len(data) < pos+length {
 350  		return 0, nil, 0
 351  	}
 352  
 353  	payload := makeCopy(data[pos : pos+length])
 354  	if masked {
 355  		for i := range payload {
 356  			payload[i] ^= mask[i%4]
 357  		}
 358  	}
 359  	return op, payload, pos + length
 360  }
 361  
 362  func writeWSFrame(fd int, op byte, payload []byte) {
 363  	plen := len(payload)
 364  	var hdr [10]byte
 365  	hdr[0] = 0x80 | op
 366  	n := 2
 367  	if plen < 126 {
 368  		hdr[1] = byte(plen)
 369  	} else if plen < 65536 {
 370  		hdr[1] = 126
 371  		hdr[2] = byte(plen >> 8)
 372  		hdr[3] = byte(plen)
 373  		n = 4
 374  	} else {
 375  		hdr[1] = 127
 376  		hdr[6] = byte(plen >> 24)
 377  		hdr[7] = byte(plen >> 16)
 378  		hdr[8] = byte(plen >> 8)
 379  		hdr[9] = byte(plen)
 380  		n = 10
 381  	}
 382  	buf := []byte{:0:n + plen}
 383  	buf = append(buf, hdr[:n]...)
 384  	buf = append(buf, payload...)
 385  	writeAll(fd, buf)
 386  }
 387  
 388  func writeWSClose(fd int) {
 389  	writeWSFrame(fd, opClose, []byte{0x03, 0xe8})
 390  }
 391  
 392  func writeAll(fd int, data []byte) error {
 393  	for len(data) > 0 {
 394  		n, err := syscall.Write(fd, data)
 395  		if err == syscall.EAGAIN {
 396  			continue
 397  		}
 398  		if err != nil {
 399  			return err
 400  		}
 401  		if n > 0 {
 402  			data = data[n:]
 403  		}
 404  	}
 405  	return nil
 406  }
 407  
 408  // --- Nostr message dispatch ---
 409  
 410  func (c *conn) dispatch(msg []byte) {
 411  	label, _, _ := envelope.Identify(msg)
 412  	switch label {
 413  	case envelope.EventLabel:
 414  		c.handleEvent(msg)
 415  	case envelope.ReqLabel:
 416  		c.handleReq(msg)
 417  	case envelope.CloseLabel:
 418  		c.handleClose(msg)
 419  	case envelope.CountLabel:
 420  		c.handleCount(msg)
 421  	}
 422  }
 423  
 424  func (c *conn) handleEvent(msg []byte) {
 425  	result := c.srv.store.Ingest(msg)
 426  	ok := &envelope.OK{
 427  		EventID: result.EventID,
 428  		OK:      result.OK,
 429  		Reason:  result.Reason,
 430  	}
 431  	writeWSFrame(c.fd, opText, ok.Marshal(nil))
 432  	if result.OK && result.Event != nil {
 433  		c.srv.broadcastEvent(result.Event, c)
 434  	}
 435  }
 436  
 437  func (c *conn) handleReq(msg []byte) {
 438  	_, rem, _ := envelope.Identify(msg)
 439  	var req envelope.Req
 440  	if _, err := req.Unmarshal(rem); err != nil {
 441  		return
 442  	}
 443  	id := string(req.Subscription)
 444  	filters := filter.S(*req.Filters)
 445  	c.subs[id] = &sub{id: id, filters: filters}
 446  
 447  	responses := c.srv.store.Query(msg)
 448  	for _, resp := range responses {
 449  		writeWSFrame(c.fd, opText, resp)
 450  	}
 451  }
 452  
 453  func (c *conn) handleClose(msg []byte) {
 454  	_, rem, _ := envelope.Identify(msg)
 455  	var cl envelope.Close
 456  	if _, err := cl.Unmarshal(rem); err != nil {
 457  		return
 458  	}
 459  	delete(c.subs, string(cl.ID))
 460  }
 461  
 462  func (c *conn) handleCount(msg []byte) {
 463  	resp := c.srv.store.Count(msg)
 464  	if resp != nil {
 465  		writeWSFrame(c.fd, opText, resp)
 466  	}
 467  }
 468  
 469  func (s *Server) broadcastEvent(ev *event.E, sender *conn) {
 470  	for _, other := range s.conns {
 471  		if other == sender || other.phase != phaseWS {
 472  			continue
 473  		}
 474  		for _, sub := range other.subs {
 475  			if sub.filters.Match(ev) {
 476  				er := &envelope.EventResult{
 477  					Subscription: []byte(sub.id),
 478  					Event:        ev,
 479  				}
 480  				writeWSFrame(other.fd, opText, er.Marshal(nil))
 481  			}
 482  		}
 483  	}
 484  }
 485  
 486  // --- HTTP helpers ---
 487  
 488  func (s *Server) routeHTTP(req *httpReq) (int, map[string]string, []byte) {
 489  	if req.headers["accept"] == "application/nostr+json" {
 490  		return 200, map[string]string{
 491  			"Content-Type":                "application/nostr+json",
 492  			"Access-Control-Allow-Origin": "*",
 493  		}, []byte(`{"name":"smesh","software":"smesh","version":"0.0.1","supported_nips":[1,9,11,40,45]}`)
 494  	}
 495  	if s.Fallback != nil {
 496  		return s.Fallback(req.path, req.headers)
 497  	}
 498  	return 404, map[string]string{"Content-Type": "text/plain"}, []byte("404 page not found\n")
 499  }
 500  
 501  func writeHTTPResponse(fd int, status int, headers map[string]string, body []byte) {
 502  	bodyCopy := []byte{:len(body)}
 503  	copy(bodyCopy, body)
 504  
 505  	statusText := "OK"
 506  	switch status {
 507  	case 200:
 508  		statusText = "OK"
 509  	case 301:
 510  		statusText = "Moved Permanently"
 511  	case 400:
 512  		statusText = "Bad Request"
 513  	case 404:
 514  		statusText = "Not Found"
 515  	case 405:
 516  		statusText = "Method Not Allowed"
 517  	}
 518  
 519  	var buf []byte
 520  	buf = append(buf, "HTTP/1.1 "...)
 521  	buf = appendInt(buf, status)
 522  	buf = append(buf, ' ')
 523  	buf = append(buf, statusText...)
 524  	buf = append(buf, "\r\n"...)
 525  	for k, v := range headers {
 526  		buf = append(buf, k...)
 527  		buf = append(buf, ": "...)
 528  		buf = append(buf, v...)
 529  		buf = append(buf, "\r\n"...)
 530  	}
 531  	buf = append(buf, "Content-Length: "...)
 532  	buf = appendInt(buf, len(bodyCopy))
 533  	buf = append(buf, "\r\n"...)
 534  	buf = append(buf, "Connection: close\r\n"...)
 535  	buf = append(buf, "\r\n"...)
 536  	buf = append(buf, bodyCopy...)
 537  	writeAll(fd, buf)
 538  }
 539  
 540  func appendInt(buf []byte, n int) []byte {
 541  	if n == 0 {
 542  		return append(buf, '0')
 543  	}
 544  	if n < 0 {
 545  		buf = append(buf, '-')
 546  		n = -n
 547  	}
 548  	var digits [20]byte
 549  	i := len(digits)
 550  	for n > 0 {
 551  		i--
 552  		digits[i] = byte('0' + n%10)
 553  		n /= 10
 554  	}
 555  	return append(buf, digits[i:]...)
 556  }
 557  
 558  func computeAccept(key string) string {
 559  	h := sha1.New()
 560  	h.Write([]byte(key))
 561  	h.Write([]byte(wsMagic))
 562  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
 563  }
 564  
 565  // --- Utility ---
 566  
 567  func makeCopy(b []byte) []byte {
 568  	c := []byte{:len(b)}
 569  	copy(c, b)
 570  	return c
 571  }
 572  
 573  func trimSpace(b []byte) []byte {
 574  	for len(b) > 0 && (b[0] == ' ' || b[0] == '\t') {
 575  		b = b[1:]
 576  	}
 577  	for len(b) > 0 && (b[len(b)-1] == ' ' || b[len(b)-1] == '\t') {
 578  		b = b[:len(b)-1]
 579  	}
 580  	return b
 581  }
 582  
 583  func toLower(b []byte) []byte {
 584  	for i := range b {
 585  		if b[i] >= 'A' && b[i] <= 'Z' {
 586  			b[i] = b[i] + 32
 587  		}
 588  	}
 589  	return b
 590  }
 591  
 592  func parseAddr(addr string) ([4]byte, int) {
 593  	ab := []byte(addr)
 594  	var ip [4]byte
 595  	colon := -1
 596  	for i := len(ab) - 1; i >= 0; i-- {
 597  		if ab[i] == ':' {
 598  			colon = i
 599  			break
 600  		}
 601  	}
 602  	if colon < 0 {
 603  		return ip, 0
 604  	}
 605  	port := 0
 606  	for i := colon + 1; i < len(ab); i++ {
 607  		port = port*10 + int(ab[i]-'0')
 608  	}
 609  	host := ab[:colon]
 610  	if len(host) > 0 {
 611  		octet := 0
 612  		idx := 0
 613  		for i := 0; i < len(host); i++ {
 614  			if host[i] == '.' {
 615  				if idx < 4 {
 616  					ip[idx] = byte(octet)
 617  				}
 618  				idx++
 619  				octet = 0
 620  			} else {
 621  				octet = octet*10 + int(host[i]-'0')
 622  			}
 623  		}
 624  		if idx < 4 {
 625  			ip[idx] = byte(octet)
 626  		}
 627  	}
 628  	return ip, port
 629  }
 630  
 631  // --- Epoll helpers ---
 632  
 633  func epollAdd(epfd, fd int) error {
 634  	ev := syscall.EpollEvent{Events: syscall.EPOLLIN, Fd: int32(fd)}
 635  	return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, fd, &ev)
 636  }
 637  
 638  func epollDel(epfd, fd int) error {
 639  	return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_DEL, fd, nil)
 640  }
 641