server.mx raw
1 // Package server provides a Nostr relay server.
2 // Uses an epoll event loop with non-blocking I/O to handle multiple
3 // WebSocket connections in a single-threaded moxie domain.
4 package server
5
6 import (
7 "bytes"
8 "crypto/sha1"
9 "encoding/base64"
10 "fmt"
11 "syscall"
12
13 "smesh.lol/pkg/nostr/envelope"
14 "smesh.lol/pkg/nostr/event"
15 "smesh.lol/pkg/nostr/filter"
16 "smesh.lol/pkg/relay/worker"
17 )
18
19 const wsMagic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
20 const maxBuf = 1 << 20 // 1MB max per-connection buffer
21
22 const (
23 phaseHTTP = 0
24 phaseWS = 1
25 )
26
27 const (
28 opText byte = 0x1
29 opBin byte = 0x2
30 opClose byte = 0x8
31 opPing byte = 0x9
32 opPong byte = 0xA
33 )
34
35 // Server is a Nostr relay.
36 type Server struct {
37 Fallback func(path string, headers map[string]string) (int, map[string]string, []byte)
38 OnReady func() // called once after listener binds
39 store *worker.Store
40 epfd int
41 lnFD int
42 conns map[int]*conn
43 }
44
45 type conn struct {
46 fd int
47 phase int
48 buf []byte
49 wpos int
50 subs map[string]*sub
51 srv *Server
52 }
53
54 type sub struct {
55 id string
56 filters filter.S
57 }
58
59 type httpReq struct {
60 method string
61 path string
62 headers map[string]string
63 }
64
65 // New creates a relay server backed by the given store.
66 func New(store *worker.Store) *Server {
67 return &Server{
68 store: store,
69 conns: map[int]*conn{},
70 }
71 }
72
73 // ListenAndServe runs the epoll event loop.
74 func (s *Server) ListenAndServe(addr string) error {
75 ip, port := parseAddr(addr)
76
77 fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
78 if err != nil {
79 return fmt.Errorf("relay: socket: %w", err)
80 }
81 syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
82 if err := syscall.SetNonblock(fd, true); err != nil {
83 syscall.Close(fd)
84 return fmt.Errorf("relay: nonblock: %w", err)
85 }
86 sa := &syscall.SockaddrInet4{Port: port, Addr: ip}
87 if err := syscall.Bind(fd, sa); err != nil {
88 syscall.Close(fd)
89 return fmt.Errorf("relay: bind %s: %w", addr, err)
90 }
91 if err := syscall.Listen(fd, 128); err != nil {
92 syscall.Close(fd)
93 return fmt.Errorf("relay: listen: %w", err)
94 }
95 s.lnFD = fd
96
97 s.epfd, err = syscall.EpollCreate1(0)
98 if err != nil {
99 syscall.Close(fd)
100 return fmt.Errorf("relay: epoll: %w", err)
101 }
102 if err := epollAdd(s.epfd, fd); err != nil {
103 syscall.Close(s.epfd)
104 syscall.Close(fd)
105 return fmt.Errorf("relay: epoll add: %w", err)
106 }
107
108 if s.OnReady != nil {
109 s.OnReady()
110 }
111
112 events := []syscall.EpollEvent{:64}
113 for {
114 n, err := syscall.EpollWait(s.epfd, events, -1)
115 if err != nil {
116 if err == syscall.EINTR {
117 continue
118 }
119 return fmt.Errorf("relay: epoll wait: %w", err)
120 }
121 for i := 0; i < n; i++ {
122 evFD := int(events[i].Fd)
123 if evFD == s.lnFD {
124 s.acceptAll()
125 } else if c := s.conns[evFD]; c != nil {
126 if events[i].Events&(syscall.EPOLLERR|syscall.EPOLLHUP) != 0 {
127 s.closeConn(c)
128 } else {
129 s.readConn(c)
130 }
131 }
132 }
133 }
134 }
135
136 func (s *Server) acceptAll() {
137 for {
138 nfd, _, err := syscall.Accept4(s.lnFD, syscall.SOCK_NONBLOCK)
139 if err != nil {
140 return // EAGAIN — no more pending connections
141 }
142 if err := epollAdd(s.epfd, nfd); err != nil {
143 syscall.Close(nfd)
144 continue
145 }
146 s.conns[nfd] = &conn{
147 fd: nfd,
148 phase: phaseHTTP,
149 buf: []byte{:4096},
150 subs: map[string]*sub{},
151 srv: s,
152 }
153 }
154 }
155
156 func (s *Server) readConn(c *conn) {
157 avail := len(c.buf) - c.wpos
158 if avail < 512 {
159 if len(c.buf) >= maxBuf {
160 s.closeConn(c)
161 return
162 }
163 nb := []byte{:len(c.buf) * 2}
164 copy(nb, c.buf[:c.wpos])
165 c.buf = nb
166 }
167 n, err := syscall.Read(c.fd, c.buf[c.wpos:])
168 if n <= 0 {
169 if err == nil || (err != syscall.EAGAIN && err != syscall.EINTR) {
170 s.closeConn(c)
171 }
172 return
173 }
174 c.wpos += n
175
176 switch c.phase {
177 case phaseHTTP:
178 s.processHTTP(c)
179 case phaseWS:
180 s.processWS(c)
181 }
182 }
183
184 func (s *Server) closeConn(c *conn) {
185 epollDel(s.epfd, c.fd)
186 syscall.Close(c.fd)
187 delete(s.conns, c.fd)
188 }
189
190 // --- HTTP phase ---
191
192 func (s *Server) processHTTP(c *conn) {
193 data := c.buf[:c.wpos]
194 end := bytes.Index(data, []byte("\r\n\r\n"))
195 if end < 0 {
196 return // incomplete headers
197 }
198 consumed := end + 4
199
200 req := parseHTTPHeaders(data[:end])
201 if req == nil {
202 s.closeConn(c)
203 return
204 }
205
206 // Consume parsed bytes.
207 copy(c.buf, c.buf[consumed:c.wpos])
208 c.wpos -= consumed
209
210 upgrade := req.headers["upgrade"]
211 if bytes.EqualFold(upgrade, "websocket") {
212 s.upgradeWS(c, req)
213 return
214 }
215
216 // Serve HTTP and close.
217 status, headers, body := s.routeHTTP(req)
218 writeHTTPResponse(c.fd, status, headers, body)
219 s.closeConn(c)
220 }
221
222 func parseHTTPHeaders(data []byte) *httpReq {
223 lineEnd := bytes.IndexByte(data, '\n')
224 if lineEnd < 0 {
225 return nil
226 }
227 line := data[:lineEnd]
228 if len(line) > 0 && line[len(line)-1] == '\r' {
229 line = line[:len(line)-1]
230 }
231
232 sp1 := bytes.IndexByte(line, ' ')
233 if sp1 < 0 {
234 return nil
235 }
236 method := string(makeCopy(line[:sp1]))
237 rest := line[sp1+1:]
238 sp2 := bytes.IndexByte(rest, ' ')
239 var path string
240 if sp2 >= 0 {
241 path = string(makeCopy(rest[:sp2]))
242 } else {
243 path = string(makeCopy(rest))
244 }
245
246 headers := map[string]string{}
247 pos := lineEnd + 1
248 for pos < len(data) {
249 nlPos := bytes.IndexByte(data[pos:], '\n')
250 if nlPos < 0 {
251 break
252 }
253 nlPos += pos
254 hline := data[pos:nlPos]
255 if len(hline) > 0 && hline[len(hline)-1] == '\r' {
256 hline = hline[:len(hline)-1]
257 }
258 colon := bytes.IndexByte(hline, ':')
259 if colon >= 0 {
260 key := string(toLower(makeCopy(hline[:colon])))
261 val := string(makeCopy(trimSpace(hline[colon+1:])))
262 headers[key] = val
263 }
264 pos = nlPos + 1
265 }
266
267 return &httpReq{method: method, path: path, headers: headers}
268 }
269
270 // --- WebSocket upgrade ---
271
272 func (s *Server) upgradeWS(c *conn, req *httpReq) {
273 key := req.headers["sec-websocket-key"]
274 if key == "" {
275 writeHTTPResponse(c.fd, 400, nil, []byte("missing Sec-WebSocket-Key"))
276 s.closeConn(c)
277 return
278 }
279 accept := computeAccept(key)
280 resp := "HTTP/1.1 101 Switching Protocols\r\n" |
281 "Upgrade: websocket\r\n" |
282 "Connection: Upgrade\r\n" |
283 "Sec-WebSocket-Accept: " | accept | "\r\n" |
284 "\r\n"
285 writeAll(c.fd, []byte(resp))
286 c.phase = phaseWS
287 if c.wpos > 0 {
288 s.processWS(c)
289 }
290 }
291
292 // --- WebSocket frame handling ---
293
294 func (s *Server) processWS(c *conn) {
295 for {
296 op, payload, consumed := parseWSFrame(c.buf[:c.wpos])
297 if consumed == 0 {
298 return // incomplete frame
299 }
300 copy(c.buf, c.buf[consumed:c.wpos])
301 c.wpos -= consumed
302
303 switch op {
304 case opClose:
305 writeWSClose(c.fd)
306 s.closeConn(c)
307 return
308 case opPing:
309 writeWSFrame(c.fd, opPong, payload)
310 case opText, opBin:
311 c.dispatch(payload)
312 }
313 }
314 }
315
316 // parseWSFrame extracts one WS frame from data.
317 // Returns (op, payload, consumed). consumed=0 means incomplete.
318 func parseWSFrame(data []byte) (byte, []byte, int) {
319 if len(data) < 2 {
320 return 0, nil, 0
321 }
322 op := data[0] & 0x0f
323 masked := data[1]&0x80 != 0
324 length := int(data[1] & 0x7f)
325 pos := 2
326
327 if length == 126 {
328 if len(data) < 4 {
329 return 0, nil, 0
330 }
331 length = int(data[2])<<8 | int(data[3])
332 pos = 4
333 } else if length == 127 {
334 if len(data) < 10 {
335 return 0, nil, 0
336 }
337 length = int(data[6])<<24 | int(data[7])<<16 | int(data[8])<<8 | int(data[9])
338 pos = 10
339 }
340
341 var mask [4]byte
342 if masked {
343 if len(data) < pos+4 {
344 return 0, nil, 0
345 }
346 copy(mask[:], data[pos:pos+4])
347 pos += 4
348 }
349 if len(data) < pos+length {
350 return 0, nil, 0
351 }
352
353 payload := makeCopy(data[pos : pos+length])
354 if masked {
355 for i := range payload {
356 payload[i] ^= mask[i%4]
357 }
358 }
359 return op, payload, pos + length
360 }
361
362 func writeWSFrame(fd int, op byte, payload []byte) {
363 plen := len(payload)
364 var hdr [10]byte
365 hdr[0] = 0x80 | op
366 n := 2
367 if plen < 126 {
368 hdr[1] = byte(plen)
369 } else if plen < 65536 {
370 hdr[1] = 126
371 hdr[2] = byte(plen >> 8)
372 hdr[3] = byte(plen)
373 n = 4
374 } else {
375 hdr[1] = 127
376 hdr[6] = byte(plen >> 24)
377 hdr[7] = byte(plen >> 16)
378 hdr[8] = byte(plen >> 8)
379 hdr[9] = byte(plen)
380 n = 10
381 }
382 buf := []byte{:0:n + plen}
383 buf = append(buf, hdr[:n]...)
384 buf = append(buf, payload...)
385 writeAll(fd, buf)
386 }
387
388 func writeWSClose(fd int) {
389 writeWSFrame(fd, opClose, []byte{0x03, 0xe8})
390 }
391
392 func writeAll(fd int, data []byte) error {
393 for len(data) > 0 {
394 n, err := syscall.Write(fd, data)
395 if err == syscall.EAGAIN {
396 continue
397 }
398 if err != nil {
399 return err
400 }
401 if n > 0 {
402 data = data[n:]
403 }
404 }
405 return nil
406 }
407
408 // --- Nostr message dispatch ---
409
410 func (c *conn) dispatch(msg []byte) {
411 label, _, _ := envelope.Identify(msg)
412 switch label {
413 case envelope.EventLabel:
414 c.handleEvent(msg)
415 case envelope.ReqLabel:
416 c.handleReq(msg)
417 case envelope.CloseLabel:
418 c.handleClose(msg)
419 case envelope.CountLabel:
420 c.handleCount(msg)
421 }
422 }
423
424 func (c *conn) handleEvent(msg []byte) {
425 result := c.srv.store.Ingest(msg)
426 ok := &envelope.OK{
427 EventID: result.EventID,
428 OK: result.OK,
429 Reason: result.Reason,
430 }
431 writeWSFrame(c.fd, opText, ok.Marshal(nil))
432 if result.OK && result.Event != nil {
433 c.srv.broadcastEvent(result.Event, c)
434 }
435 }
436
437 func (c *conn) handleReq(msg []byte) {
438 _, rem, _ := envelope.Identify(msg)
439 var req envelope.Req
440 if _, err := req.Unmarshal(rem); err != nil {
441 return
442 }
443 id := string(req.Subscription)
444 filters := filter.S(*req.Filters)
445 c.subs[id] = &sub{id: id, filters: filters}
446
447 responses := c.srv.store.Query(msg)
448 for _, resp := range responses {
449 writeWSFrame(c.fd, opText, resp)
450 }
451 }
452
453 func (c *conn) handleClose(msg []byte) {
454 _, rem, _ := envelope.Identify(msg)
455 var cl envelope.Close
456 if _, err := cl.Unmarshal(rem); err != nil {
457 return
458 }
459 delete(c.subs, string(cl.ID))
460 }
461
462 func (c *conn) handleCount(msg []byte) {
463 resp := c.srv.store.Count(msg)
464 if resp != nil {
465 writeWSFrame(c.fd, opText, resp)
466 }
467 }
468
469 func (s *Server) broadcastEvent(ev *event.E, sender *conn) {
470 for _, other := range s.conns {
471 if other == sender || other.phase != phaseWS {
472 continue
473 }
474 for _, sub := range other.subs {
475 if sub.filters.Match(ev) {
476 er := &envelope.EventResult{
477 Subscription: []byte(sub.id),
478 Event: ev,
479 }
480 writeWSFrame(other.fd, opText, er.Marshal(nil))
481 }
482 }
483 }
484 }
485
486 // --- HTTP helpers ---
487
488 func (s *Server) routeHTTP(req *httpReq) (int, map[string]string, []byte) {
489 if req.headers["accept"] == "application/nostr+json" {
490 return 200, map[string]string{
491 "Content-Type": "application/nostr+json",
492 "Access-Control-Allow-Origin": "*",
493 }, []byte(`{"name":"smesh","software":"smesh","version":"0.0.1","supported_nips":[1,9,11,40,45]}`)
494 }
495 if s.Fallback != nil {
496 return s.Fallback(req.path, req.headers)
497 }
498 return 404, map[string]string{"Content-Type": "text/plain"}, []byte("404 page not found\n")
499 }
500
501 func writeHTTPResponse(fd int, status int, headers map[string]string, body []byte) {
502 bodyCopy := []byte{:len(body)}
503 copy(bodyCopy, body)
504
505 statusText := "OK"
506 switch status {
507 case 200:
508 statusText = "OK"
509 case 301:
510 statusText = "Moved Permanently"
511 case 400:
512 statusText = "Bad Request"
513 case 404:
514 statusText = "Not Found"
515 case 405:
516 statusText = "Method Not Allowed"
517 }
518
519 var buf []byte
520 buf = append(buf, "HTTP/1.1 "...)
521 buf = appendInt(buf, status)
522 buf = append(buf, ' ')
523 buf = append(buf, statusText...)
524 buf = append(buf, "\r\n"...)
525 for k, v := range headers {
526 buf = append(buf, k...)
527 buf = append(buf, ": "...)
528 buf = append(buf, v...)
529 buf = append(buf, "\r\n"...)
530 }
531 buf = append(buf, "Content-Length: "...)
532 buf = appendInt(buf, len(bodyCopy))
533 buf = append(buf, "\r\n"...)
534 buf = append(buf, "Connection: close\r\n"...)
535 buf = append(buf, "\r\n"...)
536 buf = append(buf, bodyCopy...)
537 writeAll(fd, buf)
538 }
539
540 func appendInt(buf []byte, n int) []byte {
541 if n == 0 {
542 return append(buf, '0')
543 }
544 if n < 0 {
545 buf = append(buf, '-')
546 n = -n
547 }
548 var digits [20]byte
549 i := len(digits)
550 for n > 0 {
551 i--
552 digits[i] = byte('0' + n%10)
553 n /= 10
554 }
555 return append(buf, digits[i:]...)
556 }
557
558 func computeAccept(key string) string {
559 h := sha1.New()
560 h.Write([]byte(key))
561 h.Write([]byte(wsMagic))
562 return base64.StdEncoding.EncodeToString(h.Sum(nil))
563 }
564
565 // --- Utility ---
566
567 func makeCopy(b []byte) []byte {
568 c := []byte{:len(b)}
569 copy(c, b)
570 return c
571 }
572
573 func trimSpace(b []byte) []byte {
574 for len(b) > 0 && (b[0] == ' ' || b[0] == '\t') {
575 b = b[1:]
576 }
577 for len(b) > 0 && (b[len(b)-1] == ' ' || b[len(b)-1] == '\t') {
578 b = b[:len(b)-1]
579 }
580 return b
581 }
582
583 func toLower(b []byte) []byte {
584 for i := range b {
585 if b[i] >= 'A' && b[i] <= 'Z' {
586 b[i] = b[i] + 32
587 }
588 }
589 return b
590 }
591
592 func parseAddr(addr string) ([4]byte, int) {
593 ab := []byte(addr)
594 var ip [4]byte
595 colon := -1
596 for i := len(ab) - 1; i >= 0; i-- {
597 if ab[i] == ':' {
598 colon = i
599 break
600 }
601 }
602 if colon < 0 {
603 return ip, 0
604 }
605 port := 0
606 for i := colon + 1; i < len(ab); i++ {
607 port = port*10 + int(ab[i]-'0')
608 }
609 host := ab[:colon]
610 if len(host) > 0 {
611 octet := 0
612 idx := 0
613 for i := 0; i < len(host); i++ {
614 if host[i] == '.' {
615 if idx < 4 {
616 ip[idx] = byte(octet)
617 }
618 idx++
619 octet = 0
620 } else {
621 octet = octet*10 + int(host[i]-'0')
622 }
623 }
624 if idx < 4 {
625 ip[idx] = byte(octet)
626 }
627 }
628 return ip, port
629 }
630
631 // --- Epoll helpers ---
632
633 func epollAdd(epfd, fd int) error {
634 ev := syscall.EpollEvent{Events: syscall.EPOLLIN, Fd: int32(fd)}
635 return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, fd, &ev)
636 }
637
638 func epollDel(epfd, fd int) error {
639 return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_DEL, fd, nil)
640 }
641