connection.go raw

   1  package ws
   2  
   3  import (
   4  	"context"
   5  	"crypto/tls"
   6  	"fmt"
   7  	"io"
   8  	"net/http"
   9  	"time"
  10  
  11  	"next.orly.dev/pkg/nostr/utils/units"
  12  	"github.com/gorilla/websocket"
  13  	"next.orly.dev/pkg/lol/errorf"
  14  )
  15  
  16  // Connection represents a websocket connection to a Nostr relay.
  17  type Connection struct {
  18  	conn *websocket.Conn
  19  }
  20  
  21  // NewConnection creates a new websocket connection to a Nostr relay.
  22  func NewConnection(
  23  	ctx context.Context, url string, reqHeader http.Header,
  24  	tlsConfig *tls.Config,
  25  ) (c *Connection, err error) {
  26  	var conn *websocket.Conn
  27  	var resp *http.Response
  28  	dialer := getConnectionOptions(reqHeader, tlsConfig)
  29  
  30  	// Prepare headers with default User-Agent if not present
  31  	headers := reqHeader
  32  	if headers == nil {
  33  		headers = make(http.Header)
  34  	}
  35  	if headers.Get("User-Agent") == "" {
  36  		headers.Set("User-Agent", "github.com/nbd-wtf/go-nostr")
  37  	}
  38  
  39  	if conn, resp, err = dialer.DialContext(ctx, url, headers); err != nil {
  40  		if resp != nil {
  41  			resp.Body.Close()
  42  		}
  43  		return
  44  	}
  45  	conn.SetReadLimit(33 * units.Mb)
  46  	// Set a pong handler to extend the read deadline when pong is received.
  47  	// Without this, the 60-second read deadline in ReadMessage expires even
  48  	// though pong frames are being received, because NextReader processes
  49  	// pong frames without resetting the deadline.
  50  	conn.SetPongHandler(func(string) error {
  51  		return conn.SetReadDeadline(time.Now().Add(60 * time.Second))
  52  	})
  53  	return &Connection{
  54  		conn: conn,
  55  	}, nil
  56  }
  57  
  58  // WriteMessage writes arbitrary bytes to the websocket connection.
  59  func (c *Connection) WriteMessage(
  60  	ctx context.Context, data []byte,
  61  ) (err error) {
  62  	deadline := time.Now().Add(10 * time.Second)
  63  	if ctx != nil {
  64  		if d, ok := ctx.Deadline(); ok {
  65  			deadline = d
  66  		}
  67  	}
  68  	c.conn.SetWriteDeadline(deadline)
  69  	if err = c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
  70  		err = errorf.E("failed to write message: %w", err)
  71  		return
  72  	}
  73  	return nil
  74  }
  75  
  76  // ReadMessage reads arbitrary bytes from the websocket connection into the provided buffer.
  77  func (c *Connection) ReadMessage(
  78  	ctx context.Context, buf io.Writer,
  79  ) (err error) {
  80  	deadline := time.Now().Add(60 * time.Second)
  81  	if ctx != nil {
  82  		if d, ok := ctx.Deadline(); ok {
  83  			deadline = d
  84  		}
  85  	}
  86  	c.conn.SetReadDeadline(deadline)
  87  	messageType, reader, err := c.conn.NextReader()
  88  	if err != nil {
  89  		err = fmt.Errorf("failed to get reader: %w", err)
  90  		return
  91  	}
  92  	if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
  93  		err = fmt.Errorf("unexpected message type: %d", messageType)
  94  		return
  95  	}
  96  	if _, err = io.Copy(buf, reader); err != nil {
  97  		err = fmt.Errorf("failed to read message: %w", err)
  98  		return
  99  	}
 100  	return
 101  }
 102  
 103  // Close closes the websocket connection.
 104  func (c *Connection) Close() error {
 105  	c.conn.WriteControl(
 106  		websocket.CloseMessage,
 107  		websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
 108  		time.Now().Add(time.Second),
 109  	)
 110  	return c.conn.Close()
 111  }
 112  
 113  // Ping sends a ping message to the websocket connection.
 114  func (c *Connection) Ping(ctx context.Context) error {
 115  	deadline := time.Now().Add(800 * time.Millisecond)
 116  	if ctx != nil {
 117  		if d, ok := ctx.Deadline(); ok {
 118  			deadline = d
 119  		}
 120  	}
 121  	c.conn.SetWriteDeadline(deadline)
 122  	return c.conn.WriteControl(websocket.PingMessage, []byte{}, deadline)
 123  }
 124