transport.mx raw

   1  // Package transport provides the network layer: epoll event loop, TCP accept,
   2  // HTTP/1.1 request parsing, and WebSocket framing. It has no knowledge of
   3  // Nostr protocol or any application domain logic. It calls back to a Handler
   4  // interface for all domain events.
   5  //
   6  // Dependency direction: server imports transport, never the reverse.
   7  // No smesh package imports are allowed here - only stdlib.
   8  package transport
   9  
  10  import (
  11  	"bytes"
  12  	"fmt"
  13  	"syscall"
  14  )
  15  
  16  
  17  const maxBuf = 20 << 20 // 20MB max per-connection buffer
  18  
  19  const (
  20  	phaseHTTP         = 0
  21  	phaseWS           = 1
  22  	phaseHTTPBody     = 2
  23  	phaseHTTPDeferred = 3 // waiting for async worker response
  24  )
  25  
  26  const (
  27  	opText  byte = 0x1
  28  	opBin   byte = 0x2
  29  	opClose byte = 0x8
  30  	opPing  byte = 0x9
  31  	opPong  byte = 0xA
  32  )
  33  
  34  // HTTPDeferred is returned from Handler.OnHTTP to indicate async processing.
  35  // Transport sets the connection to deferred phase and expects CompleteHTTP later.
  36  const HTTPDeferred = -1
  37  
  38  // Handler is implemented by the server layer. Transport calls these methods
  39  // for all connection and message events.
  40  type Handler interface {
  41  	// OnAccept: new TCP connection. Return false to close immediately.
  42  	OnAccept(fd int, ip string) bool
  43  	// OnWSUpgrade: WS handshake requested. currentIPWSCount = existing WS conns from ip.
  44  	// Return (whitelisted, allow). allow=false → 429.
  45  	OnWSUpgrade(fd int, ip string, currentIPWSCount int) (whitelisted bool, allow bool)
  46  	// OnWSConnected: 101 response sent. Handler sets up conn state, sends auth challenge.
  47  	OnWSConnected(fd int)
  48  	// OnWSMessage: decoded WS payload.
  49  	OnWSMessage(fd int, payload []byte)
  50  	// OnWSClose: WS connection closed. Handler cleans up conn state.
  51  	OnWSClose(fd int)
  52  	// OnHTTP: complete HTTP request. Return (HTTPDeferred, nil, nil, false) for async.
  53  	OnHTTP(fd int, method, path string, headers map[string]string, body []byte) (status int, respHeaders map[string]string, respBody []byte, connClose bool)
  54  	// OnFD: registered worker FD became readable.
  55  	OnFD(fd int32)
  56  	// OnTick: called periodically when epoll_wait times out (every ~5s).
  57  	OnTick()
  58  }
  59  
  60  // Server runs the epoll event loop.
  61  type Server struct {
  62  	BotBlock     bool           // block known bot User-Agents
  63  	OnReady      func()         // called after bind+listen, before epoll loop
  64  	handler      Handler
  65  	epfd         int
  66  	lnFD         int
  67  	sigFD        int
  68  	conns        map[int]*tconn
  69  	ipConns      map[string]int // WS connection counts per IP
  70  	maxConnPerIP int
  71  	extraFDs     map[int32]bool // registered worker FDs
  72  }
  73  
  74  type tconn struct {
  75  	fd          int
  76  	phase       int
  77  	buf         []byte
  78  	wpos        int
  79  	remoteIP    string
  80  	whitelisted bool
  81  	pendingReq  *httpReq
  82  	bodyNeeded  int
  83  	wbuf        []byte // pending write data (EAGAIN buffered)
  84  	wbufClose   bool   // close connection after wbuf drains
  85  }
  86  
  87  type httpReq struct {
  88  	method  string
  89  	path    string
  90  	headers map[string]string
  91  	body    []byte
  92  }
  93  
  94  //export moxie_signal_enable
  95  func moxie_signal_enable(s uint32)
  96  
  97  //export moxie_signal_pipe_init
  98  func moxie_signal_pipe_init() int32
  99  
 100  //export moxie_signal_pipe_read
 101  func moxie_signal_pipe_read() int32
 102  
 103  var globalSigFD int32 = -1
 104  
 105  // InitSignals sets up SIGTERM/SIGINT handling. Must be called before any store
 106  // operations so that shutdown signals during slow startup are handled cleanly.
 107  func InitSignals() {
 108  	globalSigFD = moxie_signal_pipe_init()
 109  	moxie_signal_enable(15) // SIGTERM
 110  	moxie_signal_enable(2)  // SIGINT
 111  }
 112  
 113  // New creates a transport Server. maxConnPerIP=0 means unlimited.
 114  func New(handler Handler, maxConnPerIP int) *Server {
 115  	return &Server{
 116  		handler:      handler,
 117  		conns:        map[int]*tconn{},
 118  		ipConns:      map[string]int{},
 119  		maxConnPerIP: maxConnPerIP,
 120  		extraFDs:     map[int32]bool{},
 121  	}
 122  }
 123  
 124  // RegisterFD adds a file descriptor to the epoll set.
 125  // Readable events on this FD call handler.OnFD(fd).
 126  // Sets the fd non-blocking so parent-side reads/writes never stall the
 127  // epoll event loop (prevents bilateral deadlock on spawn socketpairs).
 128  func (s *Server) RegisterFD(fd int32) {
 129  	s.extraFDs[fd] = true
 130  	syscall.SetNonblock(int(fd), true)
 131  	if s.epfd != 0 {
 132  		epollAdd(s.epfd, int(fd))
 133  	}
 134  }
 135  
 136  // RemoveFD removes a file descriptor from the epoll set.
 137  func (s *Server) RemoveFD(fd int32) {
 138  	if s.extraFDs[fd] {
 139  		delete(s.extraFDs, fd)
 140  		epollDel(s.epfd, int(fd))
 141  	}
 142  }
 143  
 144  // ConnCount returns the current number of tracked TCP connections.
 145  func (s *Server) ConnCount() int { return len(s.conns) }
 146  
 147  // ConnIP returns the effective remote IP for a connection (XFF-substituted).
 148  func (s *Server) ConnIP(fd int) string {
 149  	if c := s.conns[fd]; c != nil {
 150  		return c.remoteIP
 151  	}
 152  	return ""
 153  }
 154  
 155  // ConnIsWhitelisted reports whether fd's IP is whitelisted.
 156  func (s *Server) ConnIsWhitelisted(fd int) bool {
 157  	if c := s.conns[fd]; c != nil {
 158  		return c.whitelisted
 159  	}
 160  	return false
 161  }
 162  
 163  // ConnIsWS reports whether fd is in WS phase.
 164  func (s *Server) ConnIsWS(fd int) bool {
 165  	if c := s.conns[fd]; c != nil {
 166  		return c.phase == phaseWS
 167  	}
 168  	return false
 169  }
 170  
 171  // IPConnCount returns the current WS connection count for ip.
 172  func (s *Server) IPConnCount(ip string) int {
 173  	return s.ipConns[ip]
 174  }
 175  
 176  // SendWS writes a WS text frame. Buffers on EAGAIN; closes on error.
 177  func (s *Server) SendWS(fd int, payload []byte) {
 178  	c := s.conns[fd]
 179  	if c == nil {
 180  		return
 181  	}
 182  	s.connWrite(c, buildWSFrame(opText, payload), false)
 183  }
 184  
 185  // SendWSErr writes a WS text frame. Returns error on EAGAIN (closes conn).
 186  func (s *Server) SendWSErr(fd int, payload []byte) error {
 187  	c := s.conns[fd]
 188  	if c == nil {
 189  		return fmt.Errorf("no conn")
 190  	}
 191  	data := buildWSFrame(opText, payload)
 192  	for len(data) > 0 {
 193  		n, err := syscall.Write(c.fd, data)
 194  		if n > 0 {
 195  			data = data[n:]
 196  		}
 197  		if err == syscall.EAGAIN {
 198  			return syscall.EAGAIN
 199  		}
 200  		if err != nil {
 201  			return err
 202  		}
 203  	}
 204  	return nil
 205  }
 206  
 207  // connWrite writes data to a connection, buffering on EAGAIN.
 208  func (s *Server) connWrite(c *tconn, data []byte, closeAfter bool) {
 209  	if c.wbuf != nil {
 210  		c.wbuf = append(c.wbuf, data...)
 211  		if closeAfter {
 212  			c.wbufClose = true
 213  		}
 214  		return
 215  	}
 216  	for len(data) > 0 {
 217  		n, err := syscall.Write(c.fd, data)
 218  		if n > 0 {
 219  			data = data[n:]
 220  		}
 221  		if err == syscall.EAGAIN {
 222  			c.wbuf = []byte{:len(data)}
 223  			copy(c.wbuf, data)
 224  			c.wbufClose = closeAfter
 225  			epollModWrite(s.epfd, c.fd)
 226  			return
 227  		}
 228  		if err != nil {
 229  			s.closeConn(c)
 230  			return
 231  		}
 232  	}
 233  	if closeAfter {
 234  		s.closeConn(c)
 235  	}
 236  }
 237  
 238  // SendHTTP writes an HTTP response. Uses buffered writes; closes on error.
 239  func (s *Server) SendHTTP(fd int, status int, headers map[string]string, body []byte) {
 240  	c := s.conns[fd]
 241  	if c == nil {
 242  		return
 243  	}
 244  	s.sendHTTPBuffered(c, status, headers, body, false)
 245  }
 246  
 247  // CloseConn closes a connection.
 248  func (s *Server) CloseConn(fd int) {
 249  	if c := s.conns[fd]; c != nil {
 250  		s.closeConn(c)
 251  	}
 252  }
 253  
 254  // CompleteHTTP sends an HTTP response for a previously deferred connection.
 255  func (s *Server) CompleteHTTP(fd int, status int, headers map[string]string, body []byte, connClose bool) {
 256  	c := s.conns[fd]
 257  	if c == nil {
 258  		return
 259  	}
 260  	c.phase = phaseHTTP
 261  	s.sendHTTPBuffered(c, status, headers, body, connClose)
 262  }
 263  
 264  // ListenAndServe runs the epoll event loop. Returns on signal or error.
 265  func (s *Server) ListenAndServe(addr string) error {
 266  	ip, port := ParseAddr(addr)
 267  
 268  	fd, err := syscall.Socket(syscall.AF_INET, syscall.SOCK_STREAM, 0)
 269  	if err != nil {
 270  		return fmt.Errorf("transport: socket: %w", err)
 271  	}
 272  	syscall.SetsockoptInt(fd, syscall.SOL_SOCKET, syscall.SO_REUSEADDR, 1)
 273  	if err := syscall.SetNonblock(fd, true); err != nil {
 274  		syscall.Close(fd)
 275  		return fmt.Errorf("transport: nonblock: %w", err)
 276  	}
 277  	sa := &syscall.SockaddrInet4{Port: port, Addr: ip}
 278  	if err := syscall.Bind(fd, sa); err != nil {
 279  		syscall.Close(fd)
 280  		return fmt.Errorf("transport: bind %s: %w", addr, err)
 281  	}
 282  	if err := syscall.Listen(fd, 4096); err != nil {
 283  		syscall.Close(fd)
 284  		return fmt.Errorf("transport: listen: %w", err)
 285  	}
 286  	s.lnFD = fd
 287  
 288  	s.epfd, err = syscall.EpollCreate1(0)
 289  	if err != nil {
 290  		syscall.Close(fd)
 291  		return fmt.Errorf("transport: epoll: %w", err)
 292  	}
 293  	if err := epollAdd(s.epfd, fd); err != nil {
 294  		syscall.Close(s.epfd)
 295  		syscall.Close(fd)
 296  		return fmt.Errorf("transport: epoll add: %w", err)
 297  	}
 298  
 299  	if globalSigFD >= 0 {
 300  		s.sigFD = int(globalSigFD)
 301  		if err := epollAdd(s.epfd, s.sigFD); err != nil {
 302  			return fmt.Errorf("transport: epoll add signal pipe: %w", err)
 303  		}
 304  	}
 305  
 306  	// Register pre-registered FDs (worker pipes added before ListenAndServe).
 307  	for fd := range s.extraFDs {
 308  		if err := epollAdd(s.epfd, int(fd)); err != nil {
 309  			return fmt.Errorf("transport: epoll add fd %d: %w", fd, err)
 310  		}
 311  	}
 312  
 313  	if s.OnReady != nil {
 314  		s.OnReady()
 315  	}
 316  
 317  	events := []syscall.EpollEvent{:64}
 318  	for {
 319  		n, err := syscall.EpollWait(s.epfd, events, 5000)
 320  		if err != nil {
 321  			if err == syscall.EINTR {
 322  				return nil
 323  			}
 324  			return fmt.Errorf("transport: epoll wait: %w", err)
 325  		}
 326  		if n == 0 {
 327  			s.handler.OnTick()
 328  			continue
 329  		}
 330  		for i := 0; i < n; i++ {
 331  			evFD := int(events[i].Fd)
 332  			if evFD == s.sigFD {
 333  				moxie_signal_pipe_read()
 334  				return nil
 335  			} else if evFD == s.lnFD {
 336  				s.acceptAll()
 337  			} else if s.extraFDs[int32(evFD)] {
 338  				s.handler.OnFD(int32(evFD))
 339  			} else if c := s.conns[evFD]; c != nil {
 340  				if events[i].Events&(syscall.EPOLLERR|syscall.EPOLLHUP) != 0 {
 341  					s.closeConn(c)
 342  				} else if events[i].Events&syscall.EPOLLOUT != 0 {
 343  					s.drainWrite(c)
 344  				} else {
 345  					s.readConn(c)
 346  				}
 347  			}
 348  		}
 349  	}
 350  }
 351  
 352  func (s *Server) acceptAll() {
 353  	for {
 354  		nfd, sa, err := syscall.Accept4(s.lnFD, syscall.SOCK_NONBLOCK)
 355  		if err != nil {
 356  			return
 357  		}
 358  		ip := peerAddr(sa)
 359  		if !s.handler.OnAccept(nfd, ip) {
 360  			syscall.Close(nfd)
 361  			continue
 362  		}
 363  		if err := epollAdd(s.epfd, nfd); err != nil {
 364  			syscall.Close(nfd)
 365  			continue
 366  		}
 367  		s.conns[nfd] = &tconn{
 368  			fd:       nfd,
 369  			phase:    phaseHTTP,
 370  			buf:      []byte{:4096},
 371  			remoteIP: ip,
 372  		}
 373  	}
 374  }
 375  
 376  func peerAddr(sa syscall.Sockaddr) string {
 377  	if sa4, ok := sa.(*syscall.SockaddrInet4); ok {
 378  		b := []byte{:0:20}
 379  		b = appendInt(b, int(sa4.Addr[0]))
 380  		b = append(b, '.')
 381  		b = appendInt(b, int(sa4.Addr[1]))
 382  		b = append(b, '.')
 383  		b = appendInt(b, int(sa4.Addr[2]))
 384  		b = append(b, '.')
 385  		b = appendInt(b, int(sa4.Addr[3]))
 386  		return string(makeCopy(b))
 387  	}
 388  	return "unknown"
 389  }
 390  
 391  func (s *Server) readConn(c *tconn) {
 392  	if c.wbuf != nil {
 393  		return
 394  	}
 395  	avail := len(c.buf) - c.wpos
 396  	if avail < 512 {
 397  		if len(c.buf) >= maxBuf {
 398  			s.closeConn(c)
 399  			return
 400  		}
 401  		nb := []byte{:len(c.buf) * 2}
 402  		copy(nb, c.buf[:c.wpos])
 403  		c.buf = nb
 404  	}
 405  	n, err := syscall.Read(c.fd, c.buf[c.wpos:])
 406  	if n <= 0 {
 407  		if err == nil || (err != syscall.EAGAIN && err != syscall.EINTR) {
 408  			s.closeConn(c)
 409  		}
 410  		return
 411  	}
 412  	c.wpos += n
 413  
 414  	switch c.phase {
 415  	case phaseHTTP:
 416  		s.processHTTP(c)
 417  	case phaseHTTPBody:
 418  		s.processHTTPBody(c)
 419  	case phaseWS:
 420  		s.processWS(c)
 421  	case phaseHTTPDeferred:
 422  		c.wpos = 0 // drop bytes while awaiting async response
 423  	}
 424  }
 425  
 426  func (s *Server) closeConn(c *tconn) {
 427  	epollDel(s.epfd, c.fd)
 428  	syscall.Close(c.fd)
 429  	delete(s.conns, c.fd)
 430  	if c.phase == phaseWS {
 431  		s.handler.OnWSClose(c.fd)
 432  		if n := s.ipConns[c.remoteIP] - 1; n <= 0 {
 433  			delete(s.ipConns, c.remoteIP)
 434  		} else {
 435  			s.ipConns[c.remoteIP] = n
 436  		}
 437  	}
 438  }
 439  
 440  func (s *Server) keepAlive(c *tconn) {
 441  	c.phase = phaseHTTP
 442  	c.pendingReq = nil
 443  	c.bodyNeeded = 0
 444  	if c.wpos > 0 {
 445  		s.processHTTP(c)
 446  	}
 447  }
 448  
 449  func (s *Server) processHTTP(c *tconn) {
 450  	data := c.buf[:c.wpos]
 451  	end := bytes.Index(data, []byte("\r\n\r\n"))
 452  	if end < 0 {
 453  		return
 454  	}
 455  	consumed := end + 4
 456  
 457  	req := parseHTTPHeaders(data[:end])
 458  	if req == nil {
 459  		s.closeConn(c)
 460  		return
 461  	}
 462  
 463  	copy(c.buf, c.buf[consumed:c.wpos])
 464  	c.wpos -= consumed
 465  
 466  	// Reverse proxy: substitute real IP from X-Forwarded-For before any checks.
 467  	if c.remoteIP == "127.0.0.1" {
 468  		if xff := req.headers["x-forwarded-for"]; xff != "" {
 469  			realIP := FirstXFF(xff)
 470  			if len(realIP) > 0 {
 471  				c.remoteIP = realIP
 472  			}
 473  		}
 474  	}
 475  
 476  	if bytes.EqualFold([]byte(req.headers["upgrade"]), []byte("websocket")) {
 477  		s.upgradeWS(c, req)
 478  		return
 479  	}
 480  
 481  	if s.BotBlock && isBot(req.headers["user-agent"]) {
 482  		writeHTTPResponse(c.fd, 403, nil, []byte("forbidden"))
 483  		s.closeConn(c)
 484  		return
 485  	}
 486  
 487  	cl := parseContentLength(req.headers["content-length"])
 488  	if cl > 0 && cl <= maxBuf {
 489  		c.pendingReq = req
 490  		c.bodyNeeded = cl
 491  		c.phase = phaseHTTPBody
 492  		s.processHTTPBody(c)
 493  		return
 494  	}
 495  
 496  	status, headers, body, connClose := s.handler.OnHTTP(c.fd, req.method, req.path, req.headers, nil)
 497  	if status == HTTPDeferred {
 498  		c.phase = phaseHTTPDeferred
 499  		return
 500  	}
 501  	s.sendHTTPBuffered(c, status, headers, body, connClose || req.headers["connection"] == "close")
 502  }
 503  
 504  func (s *Server) processHTTPBody(c *tconn) {
 505  	if c.wpos < c.bodyNeeded {
 506  		return
 507  	}
 508  	c.pendingReq.body = makeCopy(c.buf[:c.bodyNeeded])
 509  	copy(c.buf, c.buf[c.bodyNeeded:c.wpos])
 510  	c.wpos -= c.bodyNeeded
 511  	req := c.pendingReq
 512  	c.pendingReq = nil
 513  
 514  	status, headers, body, connClose := s.handler.OnHTTP(c.fd, req.method, req.path, req.headers, req.body)
 515  	if status == HTTPDeferred {
 516  		c.phase = phaseHTTPDeferred
 517  		return
 518  	}
 519  	s.sendHTTPBuffered(c, status, headers, body, connClose || req.headers["connection"] == "close")
 520  }
 521  
 522  // sendHTTPBuffered writes an HTTP response, buffering on EAGAIN.
 523  func (s *Server) sendHTTPBuffered(c *tconn, status int, headers map[string]string, body []byte, connClose bool) {
 524  	s.connWrite(c, buildHTTPResponse(status, headers, body), connClose)
 525  	if c.wbuf == nil && s.conns[c.fd] != nil && !connClose {
 526  		s.keepAlive(c)
 527  	}
 528  }
 529  
 530  // drainWrite flushes pending write buffer when socket becomes writable.
 531  func (s *Server) drainWrite(c *tconn) {
 532  	for len(c.wbuf) > 0 {
 533  		n, err := syscall.Write(c.fd, c.wbuf)
 534  		if n > 0 {
 535  			c.wbuf = c.wbuf[n:]
 536  		}
 537  		if err == syscall.EAGAIN {
 538  			return
 539  		}
 540  		if err != nil {
 541  			s.closeConn(c)
 542  			return
 543  		}
 544  	}
 545  	c.wbuf = nil
 546  	epollModRead(s.epfd, c.fd)
 547  	if c.wbufClose {
 548  		s.closeConn(c)
 549  		return
 550  	}
 551  	if c.phase == phaseHTTP || c.phase == phaseHTTPBody {
 552  		s.keepAlive(c)
 553  	}
 554  }
 555  
 556  // Epoll helpers.
 557  
 558  func epollAdd(epfd, fd int) error {
 559  	ev := syscall.EpollEvent{Events: syscall.EPOLLIN, Fd: int32(fd)}
 560  	return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_ADD, fd, &ev)
 561  }
 562  
 563  func epollModWrite(epfd, fd int) {
 564  	ev := syscall.EpollEvent{Events: syscall.EPOLLIN | syscall.EPOLLOUT, Fd: int32(fd)}
 565  	syscall.EpollCtl(epfd, syscall.EPOLL_CTL_MOD, fd, &ev)
 566  }
 567  
 568  func epollModRead(epfd, fd int) {
 569  	ev := syscall.EpollEvent{Events: syscall.EPOLLIN, Fd: int32(fd)}
 570  	syscall.EpollCtl(epfd, syscall.EPOLL_CTL_MOD, fd, &ev)
 571  }
 572  
 573  func epollDel(epfd, fd int) error {
 574  	return syscall.EpollCtl(epfd, syscall.EPOLL_CTL_DEL, fd, nil)
 575  }
 576  
 577