marmot.go raw

   1  package marmot
   2  
   3  import (
   4  	"context"
   5  	"fmt"
   6  	"sync"
   7  
   8  	"next.orly.dev/pkg/nostr/encoders/event"
   9  	"next.orly.dev/pkg/nostr/encoders/filter"
  10  	"next.orly.dev/pkg/nostr/encoders/hex"
  11  	"next.orly.dev/pkg/nostr/encoders/kind"
  12  	"next.orly.dev/pkg/nostr/encoders/tag"
  13  	"next.orly.dev/pkg/nostr/interfaces/signer"
  14  	"github.com/emersion/go-mls"
  15  	"next.orly.dev/pkg/lol/log"
  16  )
  17  
  18  // RelayConnection abstracts the relay interface so the Marmot client can be
  19  // used with any relay transport (WebSocket, in-process channel, mock).
  20  type RelayConnection interface {
  21  	Publish(ctx context.Context, ev *event.E) error
  22  	Subscribe(ctx context.Context, ff *filter.S) (EventStream, error)
  23  }
  24  
  25  // EventStream delivers events from a subscription. Close stops delivery.
  26  type EventStream interface {
  27  	Events() <-chan *event.E
  28  	Close()
  29  }
  30  
  31  // DMHandler is called when an incoming DM is decrypted.
  32  type DMHandler func(senderPub []byte, plaintext []byte)
  33  
  34  // Client manages Marmot DM conversations. It holds MLS group state for
  35  // active 1:1 conversations and handles the lifecycle of key packages,
  36  // welcomes, and encrypted messages.
  37  type Client struct {
  38  	sign   signer.I
  39  	store  GroupStore
  40  	relay  RelayConnection
  41  	onDM   DMHandler
  42  	kpp    *mls.KeyPairPackage // our current key pair package
  43  	groups map[string]*GroupState
  44  	mu     sync.RWMutex
  45  }
  46  
  47  // NewClient creates a Marmot client. The signer provides identity and
  48  // signing. The store persists group state. The relay handles event transport.
  49  func NewClient(sign signer.I, store GroupStore, relay RelayConnection) (*Client, error) {
  50  	kpp, err := GenerateKeyPackage(sign)
  51  	if err != nil {
  52  		return nil, fmt.Errorf("generate key package: %w", err)
  53  	}
  54  
  55  	c := &Client{
  56  		sign:   sign,
  57  		store:  store,
  58  		relay:  relay,
  59  		kpp:    kpp,
  60  		groups: make(map[string]*GroupState),
  61  	}
  62  
  63  	// Load persisted groups
  64  	ids, err := store.ListGroups()
  65  	if err == nil {
  66  		for _, id := range ids {
  67  			data, err := store.LoadGroup(id)
  68  			if err != nil {
  69  				log.W.F("failed to load group %x: %v", id, err)
  70  				continue
  71  			}
  72  			gs, err := unmarshalGroupState(data)
  73  			if err != nil {
  74  				log.W.F("failed to unmarshal group %x: %v", id, err)
  75  				continue
  76  			}
  77  			// We store the serialized state but can't re-hydrate the
  78  			// mls.Group from bytes with the current go-mls API.
  79  			// For now, groups are re-established on restart via welcome
  80  			// re-exchange. Store the metadata so we know about them.
  81  			c.groups[string(gs.GroupID)] = &GroupState{
  82  				GroupID:  gs.GroupID,
  83  				PeerPub:  gs.PeerPub,
  84  				mlsBytes: gs.MLSState,
  85  			}
  86  		}
  87  	}
  88  
  89  	return c, nil
  90  }
  91  
  92  // OnDM registers a handler for incoming decrypted DMs.
  93  func (c *Client) OnDM(handler DMHandler) {
  94  	c.onDM = handler
  95  }
  96  
  97  // PublishKeyPackage publishes our MLS key package as a kind 443 event so
  98  // peers can create DM groups with us.
  99  func (c *Client) PublishKeyPackage(ctx context.Context) error {
 100  	ev, err := KeyPackageToEvent(c.kpp, c.sign)
 101  	if err != nil {
 102  		return err
 103  	}
 104  	return c.relay.Publish(ctx, ev)
 105  }
 106  
 107  // SendDM sends an encrypted DM to the given recipient. If no group exists,
 108  // it fetches the recipient's key package, creates a group, and sends a
 109  // welcome. Then it encrypts and publishes the message.
 110  func (c *Client) SendDM(ctx context.Context, recipientPub []byte, plaintext []byte) error {
 111  	groupID := DMGroupID(c.sign.Pub(), recipientPub)
 112  
 113  	c.mu.RLock()
 114  	gs, ok := c.groups[string(groupID)]
 115  	c.mu.RUnlock()
 116  
 117  	if !ok || gs.group == nil {
 118  		// Need to establish a new group
 119  		var err error
 120  		gs, err = c.establishGroup(ctx, recipientPub)
 121  		if err != nil {
 122  			return fmt.Errorf("establish group: %w", err)
 123  		}
 124  	}
 125  
 126  	ciphertext, err := gs.Encrypt(plaintext)
 127  	if err != nil {
 128  		return fmt.Errorf("encrypt: %w", err)
 129  	}
 130  
 131  	ev, err := MessageToEvent(groupID, ciphertext, c.sign)
 132  	if err != nil {
 133  		return err
 134  	}
 135  
 136  	return c.relay.Publish(ctx, ev)
 137  }
 138  
 139  // establishGroup fetches the peer's key package and creates a DM group.
 140  func (c *Client) establishGroup(ctx context.Context, peerPub []byte) (*GroupState, error) {
 141  	// Fetch the peer's latest key package (kind 443)
 142  	f := filter.New()
 143  	f.Kinds = kind.NewS(kind.New(KindKeyPackage))
 144  	f.Authors = &tag.T{T: [][]byte{peerPub}}
 145  	limit := uint(1)
 146  	f.Limit = &limit
 147  
 148  	stream, err := c.relay.Subscribe(ctx, filter.NewS(f))
 149  	if err != nil {
 150  		return nil, fmt.Errorf("subscribe for key package: %w", err)
 151  	}
 152  	defer stream.Close()
 153  
 154  	// Wait for one event
 155  	var peerKPEvent *event.E
 156  	select {
 157  	case ev := <-stream.Events():
 158  		peerKPEvent = ev
 159  	case <-ctx.Done():
 160  		return nil, ctx.Err()
 161  	}
 162  
 163  	if peerKPEvent == nil {
 164  		return nil, fmt.Errorf("no key package found for %s", hex.Enc(peerPub))
 165  	}
 166  
 167  	peerKP, err := EventToKeyPackage(peerKPEvent)
 168  	if err != nil {
 169  		return nil, fmt.Errorf("parse peer key package: %w", err)
 170  	}
 171  
 172  	gs, welcome, _, err := CreateDMGroup(c.kpp, peerKP, c.sign.Pub(), peerPub)
 173  	if err != nil {
 174  		return nil, fmt.Errorf("create DM group: %w", err)
 175  	}
 176  
 177  	// Send the welcome as a gift-wrapped event
 178  	wrapEv, err := WelcomeToGiftWrap(welcome, peerPub, c.sign)
 179  	if err != nil {
 180  		return nil, fmt.Errorf("gift wrap welcome: %w", err)
 181  	}
 182  	if err := c.relay.Publish(ctx, wrapEv); err != nil {
 183  		return nil, fmt.Errorf("publish welcome: %w", err)
 184  	}
 185  
 186  	// Store the group
 187  	c.mu.Lock()
 188  	c.groups[string(gs.GroupID)] = gs
 189  	c.mu.Unlock()
 190  
 191  	c.persistGroup(gs)
 192  
 193  	return gs, nil
 194  }
 195  
 196  // HandleEvent processes an incoming event. Call this from the subscription loop.
 197  func (c *Client) HandleEvent(ctx context.Context, ev *event.E) error {
 198  	switch ev.Kind {
 199  	case KindGiftWrap:
 200  		return c.handleWelcome(ctx, ev)
 201  	case KindGroupMessage:
 202  		return c.handleGroupMessage(ctx, ev)
 203  	default:
 204  		return nil
 205  	}
 206  }
 207  
 208  func (c *Client) handleWelcome(ctx context.Context, ev *event.E) error {
 209  	welcome, err := UnwrapWelcome(ev, c.sign)
 210  	if err != nil {
 211  		return fmt.Errorf("unwrap welcome: %w", err)
 212  	}
 213  
 214  	senderPub := ev.Pubkey
 215  	gs, err := JoinDMGroup(welcome, c.kpp, senderPub)
 216  	if err != nil {
 217  		return fmt.Errorf("join DM group: %w", err)
 218  	}
 219  
 220  	// Derive the group ID
 221  	gs.GroupID = DMGroupID(c.sign.Pub(), senderPub)
 222  
 223  	c.mu.Lock()
 224  	c.groups[string(gs.GroupID)] = gs
 225  	c.mu.Unlock()
 226  
 227  	c.persistGroup(gs)
 228  
 229  	log.I.F("joined DM group with %s", hex.Enc(senderPub))
 230  	return nil
 231  }
 232  
 233  func (c *Client) handleGroupMessage(ctx context.Context, ev *event.E) error {
 234  	groupID, ciphertext, err := EventToMessage(ev)
 235  	if err != nil {
 236  		return err
 237  	}
 238  
 239  	c.mu.RLock()
 240  	gs, ok := c.groups[string(groupID)]
 241  	c.mu.RUnlock()
 242  
 243  	if !ok || gs.group == nil {
 244  		return fmt.Errorf("unknown group %x", groupID)
 245  	}
 246  
 247  	plaintext, err := gs.Decrypt(ciphertext)
 248  	if err != nil {
 249  		return fmt.Errorf("decrypt: %w", err)
 250  	}
 251  
 252  	if c.onDM != nil {
 253  		c.onDM(ev.Pubkey, plaintext)
 254  	}
 255  
 256  	return nil
 257  }
 258  
 259  func (c *Client) persistGroup(gs *GroupState) {
 260  	data, err := marshalGroupState(gs)
 261  	if err != nil {
 262  		log.W.F("failed to marshal group state: %v", err)
 263  		return
 264  	}
 265  	if err := c.store.SaveGroup(gs.GroupID, data); err != nil {
 266  		log.W.F("failed to persist group state: %v", err)
 267  	}
 268  }
 269  
 270  // SubscriptionFilter returns the filter for subscribing to events relevant
 271  // to this client (key packages, welcomes, and group messages addressed to us).
 272  func (c *Client) SubscriptionFilter() *filter.S {
 273  	f := filter.New()
 274  	f.Kinds = kind.NewS(
 275  		kind.New(KindGiftWrap),
 276  		kind.New(KindGroupMessage),
 277  	)
 278  	f.Tags = tag.NewS(
 279  		tag.NewFromAny("#p", hex.Enc(c.sign.Pub())),
 280  	)
 281  	return filter.NewS(f)
 282  }
 283