// Package server provides a Nostr relay server. // Uses an epoll event loop with non-blocking I/O to handle multiple // WebSocket connections in a single-threaded moxie domain. package server import ( "bytes" "crypto/sha1" "encoding/base64" "fmt" "syscall" "smesh.lol/pkg/nostr/envelope" "smesh.lol/pkg/nostr/event" "smesh.lol/pkg/nostr/filter" "smesh.lol/pkg/relay/worker" ) const wsMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" const maxBuf = 1 << 20 // 1MB max per-connection buffer const ( phaseHTTP = 0 phaseWS = 1 ) const ( opText byte = 0x1 opBin byte = 0x2 opClose byte = 0x8 opPing byte = 0x9 opPong byte = 0xA ) // Server is a Nostr relay. type Server struct { Fallback func(path string, headers map[string]string) (int, map[string]string, []byte) OnReady func() // called once after listener binds store *worker.Store epfd int lnFD int conns map[int]*conn } type conn struct { fd int phase int buf []byte wpos int subs map[string]*sub srv *Server } type sub struct { id string filters filter.S } type httpReq struct { method string path string headers map[string]string } // New creates a relay server backed by the given store. func New(store *worker.Store) *Server { return &Server{ store: store, conns: map[int]*conn{}, } } // ListenAndServe runs the epoll event loop. func (s *Server) ListenAndServe(addr string) error { ip, port := parseAddr(addr) fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0) if err != nil { return fmt.Errorf("relay: socket: %w", err) } syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1) if err := syscall.SetNonblock(fd, true); err != nil { syscall.Close(fd) return fmt.Errorf("relay: nonblock: %w", err) } sa := &syscall.SockaddrInet4{Port: port, Addr: ip} if err := syscall.Bind(fd, sa); err != nil { syscall.Close(fd) return fmt.Errorf("relay: bind %s: %w", addr, err) } if err := syscall.Listen(fd, 128); err != nil { syscall.Close(fd) return fmt.Errorf("relay: listen: %w", err) } s.lnFD = fd s.epfd, err = syscall.EpollCreate1(0) if err != nil { syscall.Close(fd) return fmt.Errorf("relay: epoll: %w", err) } if err := epollAdd(s.epfd, fd); err != nil { syscall.Close(s.epfd) syscall.Close(fd) return fmt.Errorf("relay: epoll add: %w", err) } if s.OnReady != nil { s.OnReady() } events := []syscall.EpollEvent{:64} for { n, err := syscall.EpollWait(s.epfd, events, -1) if err != nil { if err == syscall.EINTR { continue } return fmt.Errorf("relay: epoll wait: %w", err) } for i := 0; i < n; i++ { evFD := int(events[i].Fd) if evFD == s.lnFD { s.acceptAll() } else if c := s.conns[evFD]; c != nil { if events[i].Events&(syscall.EPOLLERR|syscall.EPOLLHUP) != 0 { s.closeConn(c) } else { s.readConn(c) } } } } } func (s *Server) acceptAll() { for { nfd, _, err := syscall.Accept4(s.lnFD, syscall.SOCK_NONBLOCK) if err != nil { return // EAGAIN — no more pending connections } if err := epollAdd(s.epfd, nfd); err != nil { syscall.Close(nfd) continue } s.conns[nfd] = &conn{ fd: nfd, phase: phaseHTTP, buf: []byte{:4096}, subs: map[string]*sub{}, srv: s, } } } func (s *Server) readConn(c *conn) { avail := len(c.buf) - c.wpos if avail < 512 { if len(c.buf) >= maxBuf { s.closeConn(c) return } nb := []byte{:len(c.buf) * 2} copy(nb, c.buf[:c.wpos]) c.buf = nb } n, err := syscall.Read(c.fd, c.buf[c.wpos:]) if n <= 0 { if err == nil || (err != syscall.EAGAIN && err != syscall.EINTR) { s.closeConn(c) } return } c.wpos += n switch c.phase { case phaseHTTP: s.processHTTP(c) case phaseWS: s.processWS(c) } } func (s *Server) closeConn(c *conn) { epollDel(s.epfd, c.fd) syscall.Close(c.fd) delete(s.conns, c.fd) } // --- HTTP phase --- func (s *Server) processHTTP(c *conn) { data := c.buf[:c.wpos] end := bytes.Index(data, []byte("\r\n\r\n")) if end < 0 { return // incomplete headers } consumed := end + 4 req := parseHTTPHeaders(data[:end]) if req == nil { s.closeConn(c) return } // Consume parsed bytes. copy(c.buf, c.buf[consumed:c.wpos]) c.wpos -= consumed upgrade := req.headers["upgrade"] if bytes.EqualFold(upgrade, "websocket") { s.upgradeWS(c, req) return } // Serve HTTP and close. status, headers, body := s.routeHTTP(req) writeHTTPResponse(c.fd, status, headers, body) s.closeConn(c) } func parseHTTPHeaders(data []byte) *httpReq { lineEnd := bytes.IndexByte(data, '\n') if lineEnd < 0 { return nil } line := data[:lineEnd] if len(line) > 0 && line[len(line)-1] == '\r' { line = line[:len(line)-1] } sp1 := bytes.IndexByte(line, ' ') if sp1 < 0 { return nil } method := string(makeCopy(line[:sp1])) rest := line[sp1+1:] sp2 := bytes.IndexByte(rest, ' ') var path string if sp2 >= 0 { path = string(makeCopy(rest[:sp2])) } else { path = string(makeCopy(rest)) } headers := map[string]string{} pos := lineEnd + 1 for pos < len(data) { nlPos := bytes.IndexByte(data[pos:], '\n') if nlPos < 0 { break } nlPos += pos hline := data[pos:nlPos] if len(hline) > 0 && hline[len(hline)-1] == '\r' { hline = hline[:len(hline)-1] } colon := bytes.IndexByte(hline, ':') if colon >= 0 { key := string(toLower(makeCopy(hline[:colon]))) val := string(makeCopy(trimSpace(hline[colon+1:]))) headers[key] = val } pos = nlPos + 1 } return &httpReq{method: method, path: path, headers: headers} } // --- WebSocket upgrade --- func (s *Server) upgradeWS(c *conn, req *httpReq) { key := req.headers["sec-websocket-key"] if key == "" { writeHTTPResponse(c.fd, 400, nil, []byte("missing Sec-WebSocket-Key")) s.closeConn(c) return } accept := computeAccept(key) resp := "HTTP/1.1 101 Switching Protocols\r\n" | "Upgrade: websocket\r\n" | "Connection: Upgrade\r\n" | "Sec-WebSocket-Accept: " | accept | "\r\n" | "\r\n" writeAll(c.fd, []byte(resp)) c.phase = phaseWS if c.wpos > 0 { s.processWS(c) } } // --- WebSocket frame handling --- func (s *Server) processWS(c *conn) { for { op, payload, consumed := parseWSFrame(c.buf[:c.wpos]) if consumed == 0 { return // incomplete frame } copy(c.buf, c.buf[consumed:c.wpos]) c.wpos -= consumed switch op { case opClose: writeWSClose(c.fd) s.closeConn(c) return case opPing: writeWSFrame(c.fd, opPong, payload) case opText, opBin: c.dispatch(payload) } } } // parseWSFrame extracts one WS frame from data. // Returns (op, payload, consumed). consumed=0 means incomplete. func parseWSFrame(data []byte) (byte, []byte, int) { if len(data) < 2 { return 0, nil, 0 } op := data[0] & 0x0f masked := data[1]&0x80 != 0 length := int(data[1] & 0x7f) pos := 2 if length == 126 { if len(data) < 4 { return 0, nil, 0 } length = int(data[2])<<8 | int(data[3]) pos = 4 } else if length == 127 { if len(data) < 10 { return 0, nil, 0 } length = int(data[6])<<24 | int(data[7])<<16 | int(data[8])<<8 | int(data[9]) pos = 10 } var mask [4]byte if masked { if len(data) < pos+4 { return 0, nil, 0 } copy(mask[:], data[pos:pos+4]) pos += 4 } if len(data) < pos+length { return 0, nil, 0 } payload := makeCopy(data[pos : pos+length]) if masked { for i := range payload { payload[i] ^= mask[i%4] } } return op, payload, pos + length } func writeWSFrame(fd int, op byte, payload []byte) { plen := len(payload) var hdr [10]byte hdr[0] = 0x80 | op n := 2 if plen < 126 { hdr[1] = byte(plen) } else if plen < 65536 { hdr[1] = 126 hdr[2] = byte(plen >> 8) hdr[3] = byte(plen) n = 4 } else { hdr[1] = 127 hdr[6] = byte(plen >> 24) hdr[7] = byte(plen >> 16) hdr[8] = byte(plen >> 8) hdr[9] = byte(plen) n = 10 } buf := []byte{:0:n + plen} buf = append(buf, hdr[:n]...) buf = append(buf, payload...) writeAll(fd, buf) } func writeWSClose(fd int) { writeWSFrame(fd, opClose, []byte{0x03, 0xe8}) } func writeAll(fd int, data []byte) error { for len(data) > 0 { n, err := syscall.Write(fd, data) if err == syscall.EAGAIN { continue } if err != nil { return err } if n > 0 { data = data[n:] } } return nil } // --- Nostr message dispatch --- func (c *conn) dispatch(msg []byte) { label, _, _ := envelope.Identify(msg) switch label { case envelope.EventLabel: c.handleEvent(msg) case envelope.ReqLabel: c.handleReq(msg) case envelope.CloseLabel: c.handleClose(msg) case envelope.CountLabel: c.handleCount(msg) } } func (c *conn) handleEvent(msg []byte) { result := c.srv.store.Ingest(msg) ok := &envelope.OK{ EventID: result.EventID, OK: result.OK, Reason: result.Reason, } writeWSFrame(c.fd, opText, ok.Marshal(nil)) if result.OK && result.Event != nil { c.srv.broadcastEvent(result.Event, c) } } func (c *conn) handleReq(msg []byte) { _, rem, _ := envelope.Identify(msg) var req envelope.Req if _, err := req.Unmarshal(rem); err != nil { return } id := string(req.Subscription) filters := filter.S(*req.Filters) c.subs[id] = &sub{id: id, filters: filters} responses := c.srv.store.Query(msg) for _, resp := range responses { writeWSFrame(c.fd, opText, resp) } } func (c *conn) handleClose(msg []byte) { _, rem, _ := envelope.Identify(msg) var cl envelope.Close if _, err := cl.Unmarshal(rem); err != nil { return } delete(c.subs, string(cl.ID)) } func (c *conn) handleCount(msg []byte) { resp := c.srv.store.Count(msg) if resp != nil { writeWSFrame(c.fd, opText, resp) } } func (s *Server) broadcastEvent(ev *event.E, sender *conn) { for _, other := range s.conns { if other == sender || other.phase != phaseWS { continue } for _, sub := range other.subs { if sub.filters.Match(ev) { er := &envelope.EventResult{ Subscription: []byte(sub.id), Event: ev, } writeWSFrame(other.fd, opText, er.Marshal(nil)) } } } } // --- HTTP helpers --- func (s *Server) routeHTTP(req *httpReq) (int, map[string]string, []byte) { if req.headers["accept"] == "application/nostr+json" { return 200, map[string]string{ "Content-Type": "application/nostr+json", "Access-Control-Allow-Origin": "*", }, []byte(`{"name":"smesh","software":"smesh","version":"0.0.1","supported_nips":[1,9,11,40,45]}`) } if s.Fallback != nil { return s.Fallback(req.path, req.headers) } return 404, map[string]string{"Content-Type": "text/plain"}, []byte("404 page not found\n") } func writeHTTPResponse(fd int, status int, headers map[string]string, body []byte) { bodyCopy := []byte{:len(body)} copy(bodyCopy, body) statusText := "OK" switch status { case 200: statusText = "OK" case 301: statusText = "Moved Permanently" case 400: statusText = "Bad Request" case 404: statusText = "Not Found" case 405: statusText = "Method Not Allowed" } var buf []byte buf = append(buf, "HTTP/1.1 "...) buf = appendInt(buf, status) buf = append(buf, ' ') buf = append(buf, statusText...) buf = append(buf, "\r\n"...) for k, v := range headers { buf = append(buf, k...) buf = append(buf, ": "...) buf = append(buf, v...) buf = append(buf, "\r\n"...) } buf = append(buf, "Content-Length: "...) buf = appendInt(buf, len(bodyCopy)) buf = append(buf, "\r\n"...) buf = append(buf, "Connection: close\r\n"...) buf = append(buf, "\r\n"...) buf = append(buf, bodyCopy...) writeAll(fd, buf) } func appendInt(buf []byte, n int) []byte { if n == 0 { return append(buf, '0') } if n < 0 { buf = append(buf, '-') n = -n } var digits [20]byte i := len(digits) for n > 0 { i-- digits[i] = byte('0' + n%10) n /= 10 } return append(buf, digits[i:]...) } func computeAccept(key string) string { h := sha1.New() h.Write([]byte(key)) h.Write([]byte(wsMagic)) return base64.StdEncoding.EncodeToString(h.Sum(nil)) } // --- Utility --- func makeCopy(b []byte) []byte { c := []byte{:len(b)} copy(c, b) return c } func trimSpace(b []byte) []byte { for len(b) > 0 && (b[0] == ' ' || b[0] == '\t') { b = b[1:] } for len(b) > 0 && (b[len(b)-1] == ' ' || b[len(b)-1] == '\t') { b = b[:len(b)-1] } return b } func toLower(b []byte) []byte { for i := range b { if b[i] >= 'A' && b[i] <= 'Z' { b[i] = b[i] + 32 } } return b } func parseAddr(addr string) ([4]byte, int) { ab := []byte(addr) var ip [4]byte colon := -1 for i := len(ab) - 1; i >= 0; i-- { if ab[i] == ':' { colon = i break } } if colon < 0 { return ip, 0 } port := 0 for i := colon + 1; i < len(ab); i++ { port = port*10 + int(ab[i]-'0') } host := ab[:colon] if len(host) > 0 { octet := 0 idx := 0 for i := 0; i < len(host); i++ { if host[i] == '.' { if idx < 4 { ip[idx] = byte(octet) } idx++ octet = 0 } else { octet = octet*10 + int(host[i]-'0') } } if idx < 4 { ip[idx] = byte(octet) } } return ip, port } // --- Epoll helpers --- func epollAdd(epfd, fd int) error { ev := syscall.EpollEvent{Events: syscall.EPOLLIN, Fd: int32(fd)} return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, fd, &ev) } func epollDel(epfd, fd int) error { return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_DEL, fd, nil) }