client.go raw

   1  package relaytester
   2  
   3  import (
   4  	"context"
   5  	"encoding/json"
   6  	"sync"
   7  	"time"
   8  
   9  	"github.com/gorilla/websocket"
  10  	"next.orly.dev/pkg/nostr/encoders/event"
  11  	"next.orly.dev/pkg/nostr/encoders/hex"
  12  	"next.orly.dev/pkg/lol/errorf"
  13  	"next.orly.dev/pkg/interfaces/neterr"
  14  )
  15  
  16  // Client wraps a WebSocket connection to a relay for testing.
  17  type Client struct {
  18  	conn     *websocket.Conn
  19  	url      string
  20  	mu       sync.Mutex
  21  	subs     map[string]chan []byte
  22  	complete map[string]bool // Track if subscription is complete (e.g., by ID)
  23  	okCh     chan []byte     // Channel for OK messages
  24  	countCh  chan []byte     // Channel for COUNT messages
  25  	ctx      context.Context
  26  	cancel   context.CancelFunc
  27  }
  28  
  29  // NewClient creates a new test client connected to the relay.
  30  func NewClient(url string) (c *Client, err error) {
  31  	ctx, cancel := context.WithCancel(context.Background())
  32  	var conn *websocket.Conn
  33  	dialer := websocket.Dialer{
  34  		HandshakeTimeout: 5 * time.Second,
  35  	}
  36  	if conn, _, err = dialer.Dial(url, nil); err != nil {
  37  		cancel()
  38  		return
  39  	}
  40  
  41  	c = &Client{
  42  		conn:     conn,
  43  		url:      url,
  44  		subs:     make(map[string]chan []byte),
  45  		complete: make(map[string]bool),
  46  		okCh:     make(chan []byte, 100),
  47  		countCh:  make(chan []byte, 100),
  48  		ctx:      ctx,
  49  		cancel:   cancel,
  50  	}
  51  
  52  	// Set up ping/pong handling to keep connection alive
  53  	pongWait := 60 * time.Second
  54  	conn.SetReadDeadline(time.Now().Add(pongWait))
  55  	conn.SetPongHandler(func(string) error {
  56  		conn.SetReadDeadline(time.Now().Add(pongWait))
  57  		return nil
  58  	})
  59  	conn.SetPingHandler(func(appData string) error {
  60  		conn.SetReadDeadline(time.Now().Add(pongWait))
  61  		deadline := time.Now().Add(10 * time.Second)
  62  		c.mu.Lock()
  63  		err := conn.WriteControl(websocket.PongMessage, []byte(appData), deadline)
  64  		c.mu.Unlock()
  65  		if err != nil {
  66  			return nil
  67  		}
  68  		return nil
  69  	})
  70  	// Also extend deadlines after each successful read in the loop below
  71  
  72  	go c.readLoop()
  73  	return
  74  }
  75  
  76  // Close closes the client connection.
  77  func (c *Client) Close() error {
  78  	c.cancel()
  79  	return c.conn.Close()
  80  }
  81  
  82  // URL returns the relay URL.
  83  func (c *Client) URL() string {
  84  	return c.url
  85  }
  86  
  87  // Send sends a JSON message to the relay.
  88  func (c *Client) Send(msg interface{}) (err error) {
  89  	c.mu.Lock()
  90  	defer c.mu.Unlock()
  91  	var data []byte
  92  	if data, err = json.Marshal(msg); err != nil {
  93  		return errorf.E("failed to marshal message: %w", err)
  94  	}
  95  	if err = c.conn.WriteMessage(websocket.TextMessage, data); err != nil {
  96  		return errorf.E("failed to write message: %w", err)
  97  	}
  98  	return
  99  }
 100  
 101  // readLoop reads messages from the relay and routes them to subscriptions.
 102  func (c *Client) readLoop() {
 103  	defer c.conn.Close()
 104  	pongWait := 60 * time.Second
 105  	for {
 106  		select {
 107  		case <-c.ctx.Done():
 108  			return
 109  		default:
 110  		}
 111  		// Don't set deadline here - let pong handler manage it
 112  		// SetReadDeadline is called initially in NewClient and extended by pong handler
 113  		_, msg, err := c.conn.ReadMessage()
 114  		if err != nil {
 115  			// Check if context is done
 116  			select {
 117  			case <-c.ctx.Done():
 118  				return
 119  			default:
 120  			}
 121  			// Check if it's a timeout - connection might still be alive
 122  			if netErr, ok := err.(neterr.TimeoutError); ok && netErr.Timeout() {
 123  				// Pong handler should have extended deadline, but if we timeout,
 124  				// reset it and continue - connection might still be alive
 125  				// This can happen during idle periods when no messages are received
 126  				c.conn.SetReadDeadline(time.Now().Add(pongWait))
 127  				// Continue reading - connection should still be alive if pings/pongs are working
 128  				continue
 129  			}
 130  			// For other errors, check if it's a close error
 131  			if websocket.IsCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
 132  				return
 133  			}
 134  			// For other errors, return (connection is likely dead)
 135  			return
 136  		}
 137  		// Extend read deadline on successful read
 138  		c.conn.SetReadDeadline(time.Now().Add(pongWait))
 139  		var raw []interface{}
 140  		if err = json.Unmarshal(msg, &raw); err != nil {
 141  			continue
 142  		}
 143  		if len(raw) < 2 {
 144  			continue
 145  		}
 146  		typ, ok := raw[0].(string)
 147  		if !ok {
 148  			continue
 149  		}
 150  		c.mu.Lock()
 151  		switch typ {
 152  		case "EVENT":
 153  			if len(raw) >= 2 {
 154  				if subID, ok := raw[1].(string); ok {
 155  					if ch, exists := c.subs[subID]; exists {
 156  						select {
 157  						case ch <- msg:
 158  						default:
 159  						}
 160  					}
 161  				}
 162  			}
 163  		case "EOSE":
 164  			if len(raw) >= 2 {
 165  				if subID, ok := raw[1].(string); ok {
 166  					if ch, exists := c.subs[subID]; exists {
 167  						// Send EOSE message to channel
 168  						select {
 169  						case ch <- msg:
 170  						default:
 171  						}
 172  						// For complete subscriptions (by ID), close the channel after EOSE
 173  						if c.complete[subID] {
 174  							close(ch)
 175  							delete(c.subs, subID)
 176  							delete(c.complete, subID)
 177  						}
 178  					}
 179  				}
 180  			}
 181  		case "OK":
 182  			// Route OK messages to okCh for WaitForOK
 183  			select {
 184  			case c.okCh <- msg:
 185  			default:
 186  			}
 187  		case "COUNT":
 188  			// Route COUNT messages to countCh for Count
 189  			select {
 190  			case c.countCh <- msg:
 191  			default:
 192  			}
 193  		case "NOTICE":
 194  			// Notice messages are logged
 195  		case "CLOSED":
 196  			// Closed messages indicate subscription ended
 197  		case "AUTH":
 198  			// Auth challenge messages
 199  		}
 200  		c.mu.Unlock()
 201  	}
 202  }
 203  
 204  // Subscribe creates a subscription and returns a channel for events.
 205  func (c *Client) Subscribe(subID string, filters []interface{}) (ch chan []byte, err error) {
 206  	req := []interface{}{"REQ", subID}
 207  	req = append(req, filters...)
 208  	if err = c.Send(req); err != nil {
 209  		return
 210  	}
 211  	c.mu.Lock()
 212  	ch = make(chan []byte, 100)
 213  	c.subs[subID] = ch
 214  	// Check if subscription is complete (has 'ids' filter)
 215  	isComplete := false
 216  	for _, f := range filters {
 217  		if fMap, ok := f.(map[string]interface{}); ok {
 218  			if ids, exists := fMap["ids"]; exists {
 219  				if idList, ok := ids.([]string); ok && len(idList) > 0 {
 220  					isComplete = true
 221  					break
 222  				}
 223  			}
 224  		}
 225  	}
 226  	c.complete[subID] = isComplete
 227  	c.mu.Unlock()
 228  	return
 229  }
 230  
 231  // Unsubscribe closes a subscription.
 232  func (c *Client) Unsubscribe(subID string) error {
 233  	c.mu.Lock()
 234  	if ch, exists := c.subs[subID]; exists {
 235  		// Channel might already be closed by EOSE, so use recover to handle gracefully
 236  		func() {
 237  			defer func() {
 238  				if recover() != nil {
 239  					// Channel was already closed, ignore
 240  				}
 241  			}()
 242  			close(ch)
 243  		}()
 244  		delete(c.subs, subID)
 245  		delete(c.complete, subID)
 246  	}
 247  	c.mu.Unlock()
 248  	return c.Send([]interface{}{"CLOSE", subID})
 249  }
 250  
 251  // Publish sends an EVENT message to the relay.
 252  func (c *Client) Publish(ev *event.E) (err error) {
 253  	evJSON := ev.Serialize()
 254  	var evMap map[string]interface{}
 255  	if err = json.Unmarshal(evJSON, &evMap); err != nil {
 256  		return errorf.E("failed to unmarshal event: %w", err)
 257  	}
 258  	return c.Send([]interface{}{"EVENT", evMap})
 259  }
 260  
 261  // WaitForOK waits for an OK response for the given event ID.
 262  func (c *Client) WaitForOK(eventID []byte, timeout time.Duration) (accepted bool, reason string, err error) {
 263  	ctx, cancel := context.WithTimeout(c.ctx, timeout)
 264  	defer cancel()
 265  	idStr := hex.Enc(eventID)
 266  	for {
 267  		select {
 268  		case <-ctx.Done():
 269  			return false, "", errorf.E("timeout waiting for OK response")
 270  		case msg := <-c.okCh:
 271  			var raw []interface{}
 272  			if err = json.Unmarshal(msg, &raw); err != nil {
 273  				continue
 274  			}
 275  			if len(raw) < 3 {
 276  				continue
 277  			}
 278  			if id, ok := raw[1].(string); ok && id == idStr {
 279  				accepted, _ = raw[2].(bool)
 280  				if len(raw) > 3 {
 281  					reason, _ = raw[3].(string)
 282  				}
 283  				return
 284  			}
 285  		}
 286  	}
 287  }
 288  
 289  // Count sends a COUNT request and returns the count.
 290  func (c *Client) Count(filters []interface{}) (count int64, err error) {
 291  	req := []interface{}{"COUNT", "count-sub"}
 292  	req = append(req, filters...)
 293  	if err = c.Send(req); err != nil {
 294  		return
 295  	}
 296  	ctx, cancel := context.WithTimeout(c.ctx, 5*time.Second)
 297  	defer cancel()
 298  	for {
 299  		select {
 300  		case <-ctx.Done():
 301  			return 0, errorf.E("timeout waiting for COUNT response")
 302  		case msg := <-c.countCh:
 303  			var raw []interface{}
 304  			if err = json.Unmarshal(msg, &raw); err != nil {
 305  				continue
 306  			}
 307  			if len(raw) >= 3 {
 308  				if subID, ok := raw[1].(string); ok && subID == "count-sub" {
 309  					// COUNT response format: ["COUNT", "subscription-id", count, approximate?]
 310  					if cnt, ok := raw[2].(float64); ok {
 311  						return int64(cnt), nil
 312  					}
 313  				}
 314  			}
 315  		}
 316  	}
 317  }
 318  
 319  // Auth sends an AUTH message with the signed event.
 320  func (c *Client) Auth(ev *event.E) error {
 321  	evJSON := ev.Serialize()
 322  	var evMap map[string]interface{}
 323  	if err := json.Unmarshal(evJSON, &evMap); err != nil {
 324  		return errorf.E("failed to unmarshal event: %w", err)
 325  	}
 326  	return c.Send([]interface{}{"AUTH", evMap})
 327  }
 328  
 329  // GetEvents collects all events from a subscription until EOSE.
 330  func (c *Client) GetEvents(subID string, filters []interface{}, timeout time.Duration) (events []*event.E, err error) {
 331  	ch, err := c.Subscribe(subID, filters)
 332  	if err != nil {
 333  		return
 334  	}
 335  	defer c.Unsubscribe(subID)
 336  	ctx, cancel := context.WithTimeout(c.ctx, timeout)
 337  	defer cancel()
 338  	for {
 339  		select {
 340  		case <-ctx.Done():
 341  			return events, nil
 342  		case msg, ok := <-ch:
 343  			if !ok {
 344  				return events, nil
 345  			}
 346  			var raw []interface{}
 347  			if err = json.Unmarshal(msg, &raw); err != nil {
 348  				continue
 349  			}
 350  			if len(raw) < 2 {
 351  				continue
 352  			}
 353  			typ, ok := raw[0].(string)
 354  			if !ok {
 355  				continue
 356  			}
 357  			switch typ {
 358  			case "EVENT":
 359  				if len(raw) >= 3 {
 360  					if evData, ok := raw[2].(map[string]interface{}); ok {
 361  						evJSON, _ := json.Marshal(evData)
 362  						ev := event.New()
 363  						if _, err = ev.Unmarshal(evJSON); err == nil {
 364  							events = append(events, ev)
 365  						}
 366  					}
 367  				}
 368  			case "EOSE":
 369  				// End of stored events - return what we have
 370  				return events, nil
 371  			}
 372  		}
 373  	}
 374  }
 375