marmot_test.go raw

   1  package marmot
   2  
   3  import (
   4  	"context"
   5  	"sync"
   6  	"testing"
   7  	"time"
   8  
   9  	"next.orly.dev/pkg/nostr/encoders/event"
  10  	"next.orly.dev/pkg/nostr/encoders/filter"
  11  	"next.orly.dev/pkg/nostr/interfaces/signer/p8k"
  12  	"github.com/stretchr/testify/assert"
  13  	"github.com/stretchr/testify/require"
  14  )
  15  
  16  // mockRelay is a test relay that routes events between connected clients.
  17  // It buffers all published events and replays matching ones to new subscribers.
  18  type mockRelay struct {
  19  	mu          sync.Mutex
  20  	events      []*event.E
  21  	subscribers []chan *event.E
  22  }
  23  
  24  func newMockRelay() *mockRelay {
  25  	return &mockRelay{}
  26  }
  27  
  28  func (r *mockRelay) Publish(ctx context.Context, ev *event.E) error {
  29  	r.mu.Lock()
  30  	r.events = append(r.events, ev.Clone())
  31  	subs := make([]chan *event.E, len(r.subscribers))
  32  	copy(subs, r.subscribers)
  33  	r.mu.Unlock()
  34  
  35  	for _, ch := range subs {
  36  		select {
  37  		case ch <- ev.Clone():
  38  		default:
  39  		}
  40  	}
  41  	return nil
  42  }
  43  
  44  func (r *mockRelay) Subscribe(ctx context.Context, ff *filter.S) (EventStream, error) {
  45  	ch := make(chan *event.E, 64)
  46  
  47  	r.mu.Lock()
  48  	// Replay existing events that match the filter
  49  	for _, ev := range r.events {
  50  		if ff.Match(ev) {
  51  			select {
  52  			case ch <- ev.Clone():
  53  			default:
  54  			}
  55  		}
  56  	}
  57  	r.subscribers = append(r.subscribers, ch)
  58  	r.mu.Unlock()
  59  
  60  	return &mockStream{ch: ch}, nil
  61  }
  62  
  63  type mockStream struct {
  64  	ch   chan *event.E
  65  	once sync.Once
  66  }
  67  
  68  func (s *mockStream) Events() <-chan *event.E { return s.ch }
  69  func (s *mockStream) Close() {
  70  	s.once.Do(func() {
  71  		// Don't close — just drain. Closing a channel that's being
  72  		// written to by Publish would panic.
  73  	})
  74  }
  75  
  76  func generateSigner(t *testing.T) *p8k.Signer {
  77  	t.Helper()
  78  	s, err := p8k.New()
  79  	require.NoError(t, err)
  80  	require.NoError(t, s.Generate())
  81  	return s
  82  }
  83  
  84  func TestDMGroupID(t *testing.T) {
  85  	pubA := []byte("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa1")
  86  	pubB := []byte("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb1")
  87  
  88  	// Same result regardless of order
  89  	id1 := DMGroupID(pubA, pubB)
  90  	id2 := DMGroupID(pubB, pubA)
  91  	assert.Equal(t, id1, id2, "group ID should be order-independent")
  92  	assert.Len(t, id1, 32, "group ID should be 32 bytes (SHA256)")
  93  }
  94  
  95  func TestKeyPackageRoundTrip(t *testing.T) {
  96  	sign := generateSigner(t)
  97  
  98  	kpp, err := GenerateKeyPackage(sign)
  99  	require.NoError(t, err)
 100  	require.NotNil(t, kpp)
 101  
 102  	ev, err := KeyPackageToEvent(kpp, sign)
 103  	require.NoError(t, err)
 104  	assert.Equal(t, KindKeyPackage, ev.Kind)
 105  
 106  	kp, err := EventToKeyPackage(ev)
 107  	require.NoError(t, err)
 108  	require.NotNil(t, kp)
 109  
 110  	// The key package bytes should round-trip
 111  	assert.Equal(t, kpp.Public.Bytes(), kp.Bytes())
 112  }
 113  
 114  func TestGroupCreateJoinEncryptDecrypt(t *testing.T) {
 115  	alice := generateSigner(t)
 116  	bob := generateSigner(t)
 117  
 118  	aliceKPP, err := GenerateKeyPackage(alice)
 119  	require.NoError(t, err)
 120  	bobKPP, err := GenerateKeyPackage(bob)
 121  	require.NoError(t, err)
 122  
 123  	// Alice creates a group and invites Bob
 124  	gs, welcome, _, err := CreateDMGroup(aliceKPP, &bobKPP.Public, alice.Pub(), bob.Pub())
 125  	require.NoError(t, err)
 126  	require.NotNil(t, gs)
 127  	require.NotNil(t, welcome)
 128  
 129  	// Bob joins the group from the welcome
 130  	bobGS, err := JoinDMGroup(welcome, bobKPP, alice.Pub())
 131  	require.NoError(t, err)
 132  	require.NotNil(t, bobGS)
 133  
 134  	// Alice sends a message
 135  	plaintext := []byte("hello from alice")
 136  	ciphertext, err := gs.Encrypt(plaintext)
 137  	require.NoError(t, err)
 138  	require.NotEmpty(t, ciphertext)
 139  
 140  	// Bob decrypts
 141  	decrypted, err := bobGS.Decrypt(ciphertext)
 142  	require.NoError(t, err)
 143  	assert.Equal(t, plaintext, decrypted)
 144  
 145  	// Bob sends a reply
 146  	reply := []byte("hello from bob")
 147  	replyCT, err := bobGS.Encrypt(reply)
 148  	require.NoError(t, err)
 149  
 150  	// Alice decrypts
 151  	decryptedReply, err := gs.Decrypt(replyCT)
 152  	require.NoError(t, err)
 153  	assert.Equal(t, reply, decryptedReply)
 154  }
 155  
 156  func TestWelcomeGiftWrapRoundTrip(t *testing.T) {
 157  	alice := generateSigner(t)
 158  	bob := generateSigner(t)
 159  
 160  	aliceKPP, err := GenerateKeyPackage(alice)
 161  	require.NoError(t, err)
 162  	bobKPP, err := GenerateKeyPackage(bob)
 163  	require.NoError(t, err)
 164  
 165  	_, welcome, _, err := CreateDMGroup(aliceKPP, &bobKPP.Public, alice.Pub(), bob.Pub())
 166  	require.NoError(t, err)
 167  
 168  	// Alice wraps the welcome for Bob
 169  	wrapEv, err := WelcomeToGiftWrap(welcome, bob.Pub(), alice)
 170  	require.NoError(t, err)
 171  	assert.Equal(t, KindGiftWrap, wrapEv.Kind)
 172  
 173  	// Bob unwraps
 174  	unwrapped, err := UnwrapWelcome(wrapEv, bob)
 175  	require.NoError(t, err)
 176  	require.NotNil(t, unwrapped)
 177  
 178  	// Bob should be able to join the group from the unwrapped welcome
 179  	bobGS, err := JoinDMGroup(unwrapped, bobKPP, alice.Pub())
 180  	require.NoError(t, err)
 181  	require.NotNil(t, bobGS)
 182  }
 183  
 184  func TestMessageEventRoundTrip(t *testing.T) {
 185  	sign := generateSigner(t)
 186  	groupID := []byte("test-group-id-32-bytes-long!!!!!")
 187  	ciphertext := []byte("encrypted-payload-here")
 188  
 189  	ev, err := MessageToEvent(groupID, ciphertext, sign)
 190  	require.NoError(t, err)
 191  	assert.Equal(t, KindGroupMessage, ev.Kind)
 192  
 193  	gid, ct, err := EventToMessage(ev)
 194  	require.NoError(t, err)
 195  	assert.Equal(t, groupID, gid)
 196  	assert.Equal(t, ciphertext, ct)
 197  }
 198  
 199  func TestGroupStore(t *testing.T) {
 200  	store := NewMemoryGroupStore()
 201  
 202  	groupID := []byte("test-group-id")
 203  	state := []byte(`{"test": true}`)
 204  
 205  	// Save
 206  	require.NoError(t, store.SaveGroup(groupID, state))
 207  
 208  	// Load
 209  	loaded, err := store.LoadGroup(groupID)
 210  	require.NoError(t, err)
 211  	assert.Equal(t, state, loaded)
 212  
 213  	// List
 214  	ids, err := store.ListGroups()
 215  	require.NoError(t, err)
 216  	assert.Len(t, ids, 1)
 217  
 218  	// Delete
 219  	require.NoError(t, store.DeleteGroup(groupID))
 220  	_, err = store.LoadGroup(groupID)
 221  	assert.Error(t, err)
 222  }
 223  
 224  func TestFileGroupStore(t *testing.T) {
 225  	dir := t.TempDir()
 226  	store, err := NewFileGroupStore(dir)
 227  	require.NoError(t, err)
 228  
 229  	groupID := []byte{0xde, 0xad, 0xbe, 0xef}
 230  	state := []byte(`{"group": "data"}`)
 231  
 232  	require.NoError(t, store.SaveGroup(groupID, state))
 233  
 234  	loaded, err := store.LoadGroup(groupID)
 235  	require.NoError(t, err)
 236  	assert.Equal(t, state, loaded)
 237  
 238  	ids, err := store.ListGroups()
 239  	require.NoError(t, err)
 240  	assert.Len(t, ids, 1)
 241  	assert.Equal(t, groupID, ids[0])
 242  }
 243  
 244  func TestClientSendReceiveDM(t *testing.T) {
 245  	relay := newMockRelay()
 246  	alice := generateSigner(t)
 247  	bob := generateSigner(t)
 248  
 249  	aliceStore := NewMemoryGroupStore()
 250  	bobStore := NewMemoryGroupStore()
 251  
 252  	aliceClient, err := NewClient(alice, aliceStore, relay)
 253  	require.NoError(t, err)
 254  	bobClient, err := NewClient(bob, bobStore, relay)
 255  	require.NoError(t, err)
 256  
 257  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 258  	defer cancel()
 259  
 260  	// Bob publishes his key package so Alice can find it
 261  	require.NoError(t, bobClient.PublishKeyPackage(ctx))
 262  
 263  	// Set up Bob's DM handler
 264  	received := make(chan []byte, 1)
 265  	bobClient.OnDM(func(senderPub []byte, plaintext []byte) {
 266  		received <- plaintext
 267  	})
 268  
 269  	// Start Bob listening for events
 270  	go func() {
 271  		stream, err := relay.Subscribe(ctx, bobClient.SubscriptionFilter())
 272  		if err != nil {
 273  			return
 274  		}
 275  		defer stream.Close()
 276  		for ev := range stream.Events() {
 277  			_ = bobClient.HandleEvent(ctx, ev)
 278  		}
 279  	}()
 280  
 281  	// Give Bob a moment to subscribe
 282  	time.Sleep(50 * time.Millisecond)
 283  
 284  	// Alice sends a DM to Bob
 285  	err = aliceClient.SendDM(ctx, bob.Pub(), []byte("hello bob"))
 286  	require.NoError(t, err)
 287  
 288  	// Wait for Bob to receive and decrypt
 289  	select {
 290  	case msg := <-received:
 291  		assert.Equal(t, []byte("hello bob"), msg)
 292  	case <-time.After(5 * time.Second):
 293  		t.Fatal("timeout waiting for DM")
 294  	}
 295  }
 296