relay.go raw

   1  package nostr
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"crypto/tls"
   7  	"errors"
   8  	"fmt"
   9  	"log"
  10  	"net/http"
  11  	"strconv"
  12  	"strings"
  13  	"sync"
  14  	"sync/atomic"
  15  	"time"
  16  
  17  	"github.com/puzpuzpuz/xsync/v3"
  18  )
  19  
  20  var subscriptionIDCounter atomic.Int64
  21  
  22  // Relay represents a connection to a Nostr relay.
  23  type Relay struct {
  24  	closeMutex sync.Mutex
  25  
  26  	URL           string
  27  	requestHeader http.Header // e.g. for origin header
  28  
  29  	Connection    *Connection
  30  	Subscriptions *xsync.MapOf[int64, *Subscription]
  31  
  32  	ConnectionError         error
  33  	connectionContext       context.Context // will be canceled when the connection closes
  34  	connectionContextCancel context.CancelCauseFunc
  35  
  36  	challenge                     string       // NIP-42 challenge, we only keep the last
  37  	noticeHandler                 func(string) // NIP-01 NOTICEs
  38  	customHandler                 func(string) // nonstandard unparseable messages
  39  	okCallbacks                   *xsync.MapOf[string, func(bool, string)]
  40  	writeQueue                    chan writeRequest
  41  	subscriptionChannelCloseQueue chan *Subscription
  42  
  43  	// custom things that aren't often used
  44  	//
  45  	AssumeValid bool // this will skip verifying signatures for events received from this relay
  46  }
  47  
  48  type writeRequest struct {
  49  	msg    []byte
  50  	answer chan error
  51  }
  52  
  53  // NewRelay returns a new relay. It takes a context that, when canceled, will close the relay connection.
  54  func NewRelay(ctx context.Context, url string, opts ...RelayOption) *Relay {
  55  	ctx, cancel := context.WithCancelCause(ctx)
  56  	r := &Relay{
  57  		URL:                           NormalizeURL(url),
  58  		connectionContext:             ctx,
  59  		connectionContextCancel:       cancel,
  60  		Subscriptions:                 xsync.NewMapOf[int64, *Subscription](),
  61  		okCallbacks:                   xsync.NewMapOf[string, func(bool, string)](),
  62  		writeQueue:                    make(chan writeRequest),
  63  		subscriptionChannelCloseQueue: make(chan *Subscription),
  64  		requestHeader:                 nil,
  65  	}
  66  
  67  	for _, opt := range opts {
  68  		opt.ApplyRelayOption(r)
  69  	}
  70  
  71  	return r
  72  }
  73  
  74  // RelayConnect returns a relay object connected to url.
  75  //
  76  // The given subscription is only used during the connection phase. Once successfully connected, cancelling ctx has no effect.
  77  //
  78  // The ongoing relay connection uses a background context. To close the connection, call r.Close().
  79  // If you need fine grained long-term connection contexts, use NewRelay() instead.
  80  func RelayConnect(ctx context.Context, url string, opts ...RelayOption) (*Relay, error) {
  81  	r := NewRelay(context.Background(), url, opts...)
  82  	err := r.Connect(ctx)
  83  	return r, err
  84  }
  85  
  86  // RelayOption is the type of the argument passed when instantiating relay connections.
  87  type RelayOption interface {
  88  	ApplyRelayOption(*Relay)
  89  }
  90  
  91  var (
  92  	_ RelayOption = (WithNoticeHandler)(nil)
  93  	_ RelayOption = (WithCustomHandler)(nil)
  94  	_ RelayOption = (WithRequestHeader)(nil)
  95  )
  96  
  97  // WithNoticeHandler just takes notices and is expected to do something with them.
  98  // when not given, defaults to logging the notices.
  99  type WithNoticeHandler func(notice string)
 100  
 101  func (nh WithNoticeHandler) ApplyRelayOption(r *Relay) {
 102  	r.noticeHandler = nh
 103  }
 104  
 105  // WithCustomHandler must be a function that handles any relay message that couldn't be
 106  // parsed as a standard envelope.
 107  type WithCustomHandler func(data string)
 108  
 109  func (ch WithCustomHandler) ApplyRelayOption(r *Relay) {
 110  	r.customHandler = ch
 111  }
 112  
 113  // WithRequestHeader sets the HTTP request header of the websocket preflight request.
 114  type WithRequestHeader http.Header
 115  
 116  func (ch WithRequestHeader) ApplyRelayOption(r *Relay) {
 117  	r.requestHeader = http.Header(ch)
 118  }
 119  
 120  // String just returns the relay URL.
 121  func (r *Relay) String() string {
 122  	return r.URL
 123  }
 124  
 125  // Context retrieves the context that is associated with this relay connection.
 126  // It will be closed when the relay is disconnected.
 127  func (r *Relay) Context() context.Context { return r.connectionContext }
 128  
 129  // IsConnected returns true if the connection to this relay seems to be active.
 130  func (r *Relay) IsConnected() bool { return r.connectionContext.Err() == nil }
 131  
 132  // Connect tries to establish a websocket connection to r.URL.
 133  // If the context expires before the connection is complete, an error is returned.
 134  // Once successfully connected, context expiration has no effect: call r.Close
 135  // to close the connection.
 136  //
 137  // The given context here is only used during the connection phase. The long-living
 138  // relay connection will be based on the context given to NewRelay().
 139  func (r *Relay) Connect(ctx context.Context) error {
 140  	return r.ConnectWithTLS(ctx, nil)
 141  }
 142  
 143  // ConnectWithTLS is like Connect(), but takes a special tls.Config if you need that.
 144  func (r *Relay) ConnectWithTLS(ctx context.Context, tlsConfig *tls.Config) error {
 145  	if r.connectionContext == nil || r.Subscriptions == nil {
 146  		return fmt.Errorf("relay must be initialized with a call to NewRelay()")
 147  	}
 148  
 149  	if r.URL == "" {
 150  		return fmt.Errorf("invalid relay URL '%s'", r.URL)
 151  	}
 152  
 153  	if _, ok := ctx.Deadline(); !ok {
 154  		// if no timeout is set, force it to 7 seconds
 155  		var cancel context.CancelFunc
 156  		ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("connection took too long"))
 157  		defer cancel()
 158  	}
 159  
 160  	conn, err := NewConnection(ctx, r.URL, r.requestHeader, tlsConfig)
 161  	if err != nil {
 162  		return fmt.Errorf("error opening websocket to '%s': %w", r.URL, err)
 163  	}
 164  	r.Connection = conn
 165  
 166  	// ping every 29 seconds
 167  	ticker := time.NewTicker(29 * time.Second)
 168  
 169  	// queue all write operations here so we don't do mutex spaghetti
 170  	go func() {
 171  		for {
 172  			select {
 173  			case <-r.connectionContext.Done():
 174  				ticker.Stop()
 175  				r.Connection = nil
 176  
 177  				for _, sub := range r.Subscriptions.Range {
 178  					sub.unsub(fmt.Errorf("relay connection closed: %w / %w", context.Cause(r.connectionContext), r.ConnectionError))
 179  				}
 180  				return
 181  
 182  			case <-ticker.C:
 183  				err := r.Connection.Ping(r.connectionContext)
 184  				if err != nil && !strings.Contains(err.Error(), "failed to wait for pong") {
 185  					InfoLogger.Printf("{%s} error writing ping: %v; closing websocket", r.URL, err)
 186  					r.Close() // this should trigger a context cancelation
 187  					return
 188  				}
 189  
 190  			case writeRequest := <-r.writeQueue:
 191  				// all write requests will go through this to prevent races
 192  				debugLogf("{%s} sending %v\n", r.URL, string(writeRequest.msg))
 193  				if err := r.Connection.WriteMessage(r.connectionContext, writeRequest.msg); err != nil {
 194  					writeRequest.answer <- err
 195  				}
 196  				close(writeRequest.answer)
 197  			}
 198  		}
 199  	}()
 200  
 201  	// general message reader loop
 202  	go func() {
 203  		buf := new(bytes.Buffer)
 204  		mp := NewMessageParser()
 205  
 206  		for {
 207  			buf.Reset()
 208  
 209  			if err := conn.ReadMessage(r.connectionContext, buf); err != nil {
 210  				r.ConnectionError = err
 211  				r.close(err)
 212  				break
 213  			}
 214  
 215  			message := string(buf.Bytes())
 216  			debugLogf("{%s} received %v\n", r.URL, message)
 217  
 218  			// if this is an "EVENT" we will have this preparser logic that should speed things up a little
 219  			// as we skip handling duplicate events
 220  			subid := extractSubID(message)
 221  			sub, ok := r.Subscriptions.Load(subIdToSerial(subid))
 222  			if ok {
 223  				if sub.checkDuplicate != nil {
 224  					if sub.checkDuplicate(extractEventID(message[10+len(subid):]), r.URL) {
 225  						continue
 226  					}
 227  				} else if sub.checkDuplicateReplaceable != nil {
 228  					if sub.checkDuplicateReplaceable(
 229  						ReplaceableKey{extractEventPubKey(message), extractDTag(message)},
 230  						extractTimestamp(message),
 231  					) {
 232  						continue
 233  					}
 234  				}
 235  			}
 236  
 237  			envelope, err := mp.ParseMessage(message)
 238  			if envelope == nil {
 239  				if r.customHandler != nil && err == UnknownLabel {
 240  					r.customHandler(message)
 241  				}
 242  				continue
 243  			}
 244  
 245  			switch env := envelope.(type) {
 246  			case *NoticeEnvelope:
 247  				// see WithNoticeHandler
 248  				if r.noticeHandler != nil {
 249  					r.noticeHandler(string(*env))
 250  				} else {
 251  					log.Printf("NOTICE from %s: '%s'\n", r.URL, string(*env))
 252  				}
 253  			case *AuthEnvelope:
 254  				if env.Challenge == nil {
 255  					continue
 256  				}
 257  				r.challenge = *env.Challenge
 258  			case *EventEnvelope:
 259  				// we already have the subscription from the pre-check above, so we can just reuse it
 260  				if sub == nil {
 261  					// InfoLogger.Printf("{%s} no subscription with id '%s'\n", r.URL, *env.SubscriptionID)
 262  					continue
 263  				} else {
 264  					// check if the event matches the desired filter, ignore otherwise
 265  					if !sub.match(&env.Event) {
 266  						InfoLogger.Printf("{%s} filter does not match: %v ~ %v\n", r.URL, sub.Filters, env.Event)
 267  						continue
 268  					}
 269  
 270  					// check signature, ignore invalid, except from trusted (AssumeValid) relays
 271  					if !r.AssumeValid {
 272  						if ok, _ := env.Event.CheckSignature(); !ok {
 273  							InfoLogger.Printf("{%s} bad signature on %s\n", r.URL, env.Event.ID)
 274  							continue
 275  						}
 276  					}
 277  
 278  					// dispatch this to the internal .events channel of the subscription
 279  					sub.dispatchEvent(&env.Event)
 280  				}
 281  			case *EOSEEnvelope:
 282  				if subscription, ok := r.Subscriptions.Load(subIdToSerial(string(*env))); ok {
 283  					subscription.dispatchEose()
 284  				}
 285  			case *ClosedEnvelope:
 286  				if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok {
 287  					subscription.handleClosed(env.Reason)
 288  				}
 289  			case *CountEnvelope:
 290  				if subscription, ok := r.Subscriptions.Load(subIdToSerial(env.SubscriptionID)); ok && env.Count != nil && subscription.countResult != nil {
 291  					subscription.countResult <- *env
 292  				}
 293  			case *OKEnvelope:
 294  				if okCallback, exist := r.okCallbacks.Load(env.EventID); exist {
 295  					okCallback(env.OK, env.Reason)
 296  				} else {
 297  					InfoLogger.Printf("{%s} got an unexpected OK message for event %s", r.URL, env.EventID)
 298  				}
 299  			}
 300  		}
 301  	}()
 302  
 303  	return nil
 304  }
 305  
 306  // Write queues an arbitrary message to be sent to the relay.
 307  func (r *Relay) Write(msg []byte) <-chan error {
 308  	ch := make(chan error)
 309  	select {
 310  	case r.writeQueue <- writeRequest{msg: msg, answer: ch}:
 311  	case <-r.connectionContext.Done():
 312  		go func() { ch <- fmt.Errorf("connection closed") }()
 313  	}
 314  	return ch
 315  }
 316  
 317  // Publish sends an "EVENT" command to the relay r as in NIP-01 and waits for an OK response.
 318  func (r *Relay) Publish(ctx context.Context, event Event) error {
 319  	return r.publish(ctx, event.ID, &EventEnvelope{Event: event})
 320  }
 321  
 322  // Auth sends an "AUTH" command client->relay as in NIP-42 and waits for an OK response.
 323  //
 324  // You don't have to build the AUTH event yourself, this function takes a function to which the
 325  // event that must be signed will be passed, so it's only necessary to sign that.
 326  func (r *Relay) Auth(ctx context.Context, sign func(event *Event) error) error {
 327  	authEvent := Event{
 328  		CreatedAt: Now(),
 329  		Kind:      KindClientAuthentication,
 330  		Tags: Tags{
 331  			Tag{"relay", r.URL},
 332  			Tag{"challenge", r.challenge},
 333  		},
 334  		Content: "",
 335  	}
 336  	if err := sign(&authEvent); err != nil {
 337  		return fmt.Errorf("error signing auth event: %w", err)
 338  	}
 339  
 340  	return r.publish(ctx, authEvent.ID, &AuthEnvelope{Event: authEvent})
 341  }
 342  
 343  func (r *Relay) publish(ctx context.Context, id string, env Envelope) error {
 344  	var err error
 345  	var cancel context.CancelFunc
 346  
 347  	if _, ok := ctx.Deadline(); !ok {
 348  		// if no timeout is set, force it to 7 seconds
 349  		ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, fmt.Errorf("given up waiting for an OK"))
 350  		defer cancel()
 351  	} else {
 352  		// otherwise make the context cancellable so we can stop everything upon receiving an "OK"
 353  		ctx, cancel = context.WithCancel(ctx)
 354  		defer cancel()
 355  	}
 356  
 357  	// listen for an OK callback
 358  	gotOk := false
 359  	r.okCallbacks.Store(id, func(ok bool, reason string) {
 360  		gotOk = true
 361  		if !ok {
 362  			err = fmt.Errorf("msg: %s", reason)
 363  		}
 364  		cancel()
 365  	})
 366  	defer r.okCallbacks.Delete(id)
 367  
 368  	// publish event
 369  	envb, _ := env.MarshalJSON()
 370  	if err := <-r.Write(envb); err != nil {
 371  		return err
 372  	}
 373  
 374  	for {
 375  		select {
 376  		case <-ctx.Done():
 377  			// this will be called when we get an OK or when the context has been canceled
 378  			if gotOk {
 379  				return err
 380  			}
 381  			return ctx.Err()
 382  		case <-r.connectionContext.Done():
 383  			// this is caused when we lose connectivity
 384  			return err
 385  		}
 386  	}
 387  }
 388  
 389  // Subscribe sends a "REQ" command to the relay r as in NIP-01.
 390  // Events are returned through the channel sub.Events.
 391  // The subscription is closed when context ctx is cancelled ("CLOSE" in NIP-01).
 392  //
 393  // Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
 394  // Failure to do that will result in a huge number of halted goroutines being created.
 395  func (r *Relay) Subscribe(ctx context.Context, filters Filters, opts ...SubscriptionOption) (*Subscription, error) {
 396  	sub := r.PrepareSubscription(ctx, filters, opts...)
 397  
 398  	if r.Connection == nil {
 399  		return nil, fmt.Errorf("not connected to %s", r.URL)
 400  	}
 401  
 402  	if err := sub.Fire(); err != nil {
 403  		return nil, fmt.Errorf("couldn't subscribe to %v at %s: %w", filters, r.URL, err)
 404  	}
 405  
 406  	return sub, nil
 407  }
 408  
 409  // PrepareSubscription creates a subscription, but doesn't fire it.
 410  //
 411  // Remember to cancel subscriptions, either by calling `.Unsub()` on them or ensuring their `context.Context` will be canceled at some point.
 412  // Failure to do that will result in a huge number of halted goroutines being created.
 413  func (r *Relay) PrepareSubscription(ctx context.Context, filters Filters, opts ...SubscriptionOption) *Subscription {
 414  	current := subscriptionIDCounter.Add(1)
 415  	ctx, cancel := context.WithCancelCause(ctx)
 416  
 417  	sub := &Subscription{
 418  		Relay:             r,
 419  		Context:           ctx,
 420  		cancel:            cancel,
 421  		counter:           current,
 422  		Events:            make(chan *Event),
 423  		EndOfStoredEvents: make(chan struct{}, 1),
 424  		ClosedReason:      make(chan string, 1),
 425  		Filters:           filters,
 426  		match:             filters.Match,
 427  	}
 428  
 429  	label := ""
 430  	for _, opt := range opts {
 431  		switch o := opt.(type) {
 432  		case WithLabel:
 433  			label = string(o)
 434  		case WithCheckDuplicate:
 435  			sub.checkDuplicate = o
 436  		case WithCheckDuplicateReplaceable:
 437  			sub.checkDuplicateReplaceable = o
 438  		}
 439  	}
 440  
 441  	// subscription id computation
 442  	buf := subIdPool.Get().([]byte)[:0]
 443  	buf = strconv.AppendInt(buf, sub.counter, 10)
 444  	buf = append(buf, ':')
 445  	buf = append(buf, label...)
 446  	defer subIdPool.Put(buf)
 447  	sub.id = string(buf)
 448  
 449  	// we track subscriptions only by their counter, no need for the full id
 450  	r.Subscriptions.Store(int64(sub.counter), sub)
 451  
 452  	// start handling events, eose, unsub etc:
 453  	go sub.start()
 454  
 455  	return sub
 456  }
 457  
 458  // QueryEvents subscribes to events matching the given filter and returns a channel of events.
 459  //
 460  // In most cases it's better to use SimplePool instead of this method.
 461  func (r *Relay) QueryEvents(ctx context.Context, filter Filter) (chan *Event, error) {
 462  	sub, err := r.Subscribe(ctx, Filters{filter})
 463  	if err != nil {
 464  		return nil, err
 465  	}
 466  
 467  	go func() {
 468  		for {
 469  			select {
 470  			case <-sub.ClosedReason:
 471  			case <-sub.EndOfStoredEvents:
 472  			case <-ctx.Done():
 473  			case <-r.Context().Done():
 474  			}
 475  			sub.unsub(errors.New("QueryEvents() ended"))
 476  			return
 477  		}
 478  	}()
 479  
 480  	return sub.Events, nil
 481  }
 482  
 483  // QuerySync subscribes to events matching the given filter and returns a slice of events.
 484  // This method blocks until all events are received or the context is canceled.
 485  //
 486  // In most cases it's better to use SimplePool instead of this method.
 487  func (r *Relay) QuerySync(ctx context.Context, filter Filter) ([]*Event, error) {
 488  	if _, ok := ctx.Deadline(); !ok {
 489  		// if no timeout is set, force it to 7 seconds
 490  		var cancel context.CancelFunc
 491  		ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("QuerySync() took too long"))
 492  		defer cancel()
 493  	}
 494  
 495  	events := make([]*Event, 0, max(filter.Limit, 250))
 496  	ch, err := r.QueryEvents(ctx, filter)
 497  	if err != nil {
 498  		return nil, err
 499  	}
 500  
 501  	for evt := range ch {
 502  		events = append(events, evt)
 503  	}
 504  
 505  	return events, nil
 506  }
 507  
 508  // Count sends a "COUNT" command to the relay and returns the count of events matching the filters.
 509  func (r *Relay) Count(
 510  	ctx context.Context,
 511  	filters Filters,
 512  	opts ...SubscriptionOption,
 513  ) (int64, []byte, error) {
 514  	v, err := r.countInternal(ctx, filters, opts...)
 515  	if err != nil {
 516  		return 0, nil, err
 517  	}
 518  
 519  	return *v.Count, v.HyperLogLog, nil
 520  }
 521  
 522  func (r *Relay) countInternal(ctx context.Context, filters Filters, opts ...SubscriptionOption) (CountEnvelope, error) {
 523  	sub := r.PrepareSubscription(ctx, filters, opts...)
 524  	sub.countResult = make(chan CountEnvelope)
 525  
 526  	if err := sub.Fire(); err != nil {
 527  		return CountEnvelope{}, err
 528  	}
 529  
 530  	defer sub.unsub(errors.New("countInternal() ended"))
 531  
 532  	if _, ok := ctx.Deadline(); !ok {
 533  		// if no timeout is set, force it to 7 seconds
 534  		var cancel context.CancelFunc
 535  		ctx, cancel = context.WithTimeoutCause(ctx, 7*time.Second, errors.New("countInternal took too long"))
 536  		defer cancel()
 537  	}
 538  
 539  	for {
 540  		select {
 541  		case count := <-sub.countResult:
 542  			return count, nil
 543  		case <-ctx.Done():
 544  			return CountEnvelope{}, ctx.Err()
 545  		}
 546  	}
 547  }
 548  
 549  // Close closes the relay connection.
 550  func (r *Relay) Close() error {
 551  	return r.close(errors.New("Close() called"))
 552  }
 553  
 554  func (r *Relay) close(reason error) error {
 555  	r.closeMutex.Lock()
 556  	defer r.closeMutex.Unlock()
 557  
 558  	if r.connectionContextCancel == nil {
 559  		return fmt.Errorf("relay already closed")
 560  	}
 561  	r.connectionContextCancel(reason)
 562  	r.connectionContextCancel = nil
 563  
 564  	if r.Connection == nil {
 565  		return fmt.Errorf("relay not connected")
 566  	}
 567  
 568  	err := r.Connection.Close()
 569  	if err != nil {
 570  		return err
 571  	}
 572  
 573  	return nil
 574  }
 575  
 576  var subIdPool = sync.Pool{
 577  	New: func() any { return make([]byte, 0, 15) },
 578  }
 579