crypto_test.go raw

   1  package marmot
   2  
   3  import (
   4  	"bytes"
   5  	"context"
   6  	"sync"
   7  	"testing"
   8  	"time"
   9  
  10  	"github.com/stretchr/testify/assert"
  11  	"github.com/stretchr/testify/require"
  12  	"git.smesh.lol/orly/pkg/nostr/encoders/event"
  13  	"git.smesh.lol/orly/pkg/nostr/encoders/hex"
  14  	"git.smesh.lol/orly/pkg/nostr/interfaces/signer/p8k"
  15  )
  16  
  17  // simulatedExtension creates a ProxyCrypto backed by real keys, with a
  18  // goroutine that intercepts sendFn calls and resolves them using LocalCrypto.
  19  // This simulates the browser extension path without needing a browser.
  20  func simulatedExtension(t *testing.T, sign *p8k.Signer) (*ProxyCrypto, func()) {
  21  	t.Helper()
  22  	local := &LocalCrypto{Sign: sign}
  23  
  24  	type cryptoReq struct {
  25  		op, peerHex, data string
  26  		id                int
  27  	}
  28  	ch := make(chan cryptoReq, 16)
  29  
  30  	proxy := NewProxyCrypto(sign.Pub(), func(op, peerHex, data string, id int) {
  31  		ch <- cryptoReq{op, peerHex, data, id}
  32  	})
  33  
  34  	ctx, cancel := context.WithCancel(context.Background())
  35  	go func() {
  36  		for {
  37  			select {
  38  			case req := <-ch:
  39  				var result string
  40  				var errMsg string
  41  				switch req.op {
  42  				case "signEvent":
  43  					var ev event.E
  44  					if err := ev.UnmarshalJSON([]byte(req.data)); err != nil {
  45  						errMsg = err.Error()
  46  					} else if err := local.SignEvent(&ev); err != nil {
  47  						errMsg = err.Error()
  48  					} else {
  49  						b, _ := ev.MarshalJSON()
  50  						result = string(b)
  51  					}
  52  				case "nip44Encrypt":
  53  					peer, _ := hex.Dec(req.peerHex)
  54  					ct, err := local.Nip44Encrypt(peer, []byte(req.data))
  55  					if err != nil {
  56  						errMsg = err.Error()
  57  					} else {
  58  						result = ct
  59  					}
  60  				case "nip44Decrypt":
  61  					peer, _ := hex.Dec(req.peerHex)
  62  					pt, err := local.Nip44Decrypt(peer, req.data)
  63  					if err != nil {
  64  						errMsg = err.Error()
  65  					} else {
  66  						result = pt
  67  					}
  68  				default:
  69  					errMsg = "unknown op: " + req.op
  70  				}
  71  				proxy.Resolve(req.id, result, errMsg)
  72  			case <-ctx.Done():
  73  				return
  74  			}
  75  		}
  76  	}()
  77  
  78  	return proxy, cancel
  79  }
  80  
  81  func TestProxyCrypto_BasicRequestResolve(t *testing.T) {
  82  	var proxy *ProxyCrypto
  83  	proxy = NewProxyCrypto([]byte("testpub"), func(op, peerHex, data string, id int) {
  84  		go func() { proxy.Resolve(id, "result-"+data, "") }()
  85  	})
  86  	// Use the low-level request directly.
  87  	result, err := proxy.request("test", "", "hello")
  88  	require.NoError(t, err)
  89  	assert.Equal(t, "result-hello", result)
  90  }
  91  
  92  func TestProxyCrypto_Close(t *testing.T) {
  93  	proxy := NewProxyCrypto([]byte("testpub"), func(op, peerHex, data string, id int) {
  94  		// Don't resolve — let Close() unblock it.
  95  	})
  96  	done := make(chan error, 1)
  97  	go func() {
  98  		_, err := proxy.request("test", "", "blocked")
  99  		done <- err
 100  	}()
 101  	time.Sleep(10 * time.Millisecond) // let goroutine start
 102  	proxy.Close()
 103  	err := <-done
 104  	require.Error(t, err)
 105  	assert.Contains(t, err.Error(), "connection closed")
 106  }
 107  
 108  func TestProxyCrypto_ConcurrentRequests(t *testing.T) {
 109  	var mu sync.Mutex
 110  	pending := make(map[int]string)
 111  
 112  	proxy := NewProxyCrypto([]byte("testpub"), func(op, peerHex, data string, id int) {
 113  		mu.Lock()
 114  		pending[id] = data
 115  		mu.Unlock()
 116  	})
 117  
 118  	const n = 5
 119  	results := make([]string, n)
 120  	errs := make([]error, n)
 121  	var wg sync.WaitGroup
 122  	wg.Add(n)
 123  
 124  	for i := 0; i < n; i++ {
 125  		go func(idx int) {
 126  			defer wg.Done()
 127  			results[idx], errs[idx] = proxy.request("test", "", "req-"+string(rune('A'+idx)))
 128  		}(i)
 129  	}
 130  
 131  	// Wait for all requests to register.
 132  	time.Sleep(20 * time.Millisecond)
 133  
 134  	// Resolve all pending requests.
 135  	mu.Lock()
 136  	for id, data := range pending {
 137  		proxy.Resolve(id, "ok-"+data, "")
 138  	}
 139  	mu.Unlock()
 140  
 141  	wg.Wait()
 142  	for i := 0; i < n; i++ {
 143  		require.NoError(t, errs[i], "request %d", i)
 144  		assert.Contains(t, results[i], "ok-req-", "request %d", i)
 145  	}
 146  }
 147  
 148  func TestProxyCrypto_ResolveUnknownID(t *testing.T) {
 149  	proxy := NewProxyCrypto([]byte("testpub"), func(op, peerHex, data string, id int) {})
 150  	// Should not panic.
 151  	proxy.Resolve(999, "whatever", "")
 152  	proxy.Resolve(-1, "", "error")
 153  }
 154  
 155  func TestProxyCrypto_CloseIdempotent(t *testing.T) {
 156  	proxy := NewProxyCrypto([]byte("testpub"), func(op, peerHex, data string, id int) {})
 157  	proxy.Close()
 158  	proxy.Close() // should not panic
 159  }
 160  
 161  func TestProxyCrypto_ResolveAfterClose(t *testing.T) {
 162  	proxy := NewProxyCrypto([]byte("testpub"), func(op, peerHex, data string, id int) {})
 163  
 164  	done := make(chan error, 1)
 165  	go func() {
 166  		_, err := proxy.request("test", "", "data")
 167  		done <- err
 168  	}()
 169  	time.Sleep(10 * time.Millisecond)
 170  	proxy.Close()
 171  	err := <-done
 172  	require.Error(t, err)
 173  
 174  	// Late resolve after close — should not panic.
 175  	proxy.Resolve(0, "late", "")
 176  }
 177  
 178  func TestProxyCrypto_SignEventRoundTrip(t *testing.T) {
 179  	sign := generateSigner(t)
 180  	proxy, cancel := simulatedExtension(t, sign)
 181  	defer cancel()
 182  
 183  	ev := event.New()
 184  	ev.CreatedAt = time.Now().Unix()
 185  	ev.Kind = 1
 186  	ev.Content = []byte("hello proxy")
 187  
 188  	err := proxy.SignEvent(ev)
 189  	require.NoError(t, err)
 190  
 191  	// Event should have valid ID and sig from the simulated extension.
 192  	assert.True(t, len(ev.ID) == 32, "ID should be 32 bytes")
 193  	assert.True(t, len(ev.Sig) == 64, "Sig should be 64 bytes")
 194  	assert.True(t, bytes.Equal(ev.Pubkey, sign.Pub()), "pubkey should match signer")
 195  
 196  	// Verify the signature independently.
 197  	ok, err := sign.Verify(ev.ID, ev.Sig)
 198  	require.NoError(t, err)
 199  	assert.True(t, ok, "signature should verify")
 200  }
 201  
 202  func TestProxyCrypto_Nip44EncryptDecryptRoundTrip(t *testing.T) {
 203  	alice := generateSigner(t)
 204  	bob := generateSigner(t)
 205  
 206  	aliceProxy, cancel := simulatedExtension(t, alice)
 207  	defer cancel()
 208  	bobLocal := &LocalCrypto{Sign: bob}
 209  
 210  	// Alice encrypts for Bob via proxy.
 211  	ct, err := aliceProxy.Nip44Encrypt(bob.Pub(), []byte("secret message"))
 212  	require.NoError(t, err)
 213  	assert.NotEmpty(t, ct)
 214  
 215  	// Bob decrypts with local crypto.
 216  	pt, err := bobLocal.Nip44Decrypt(alice.Pub(), ct)
 217  	require.NoError(t, err)
 218  	assert.Equal(t, "secret message", pt)
 219  }
 220  
 221  func TestProxyCrypto_WelcomeGiftWrapViaProxy(t *testing.T) {
 222  	alice := generateSigner(t)
 223  	bob := generateSigner(t)
 224  
 225  	aliceProxy, cancel := simulatedExtension(t, alice)
 226  	defer cancel()
 227  	bobLocal := &LocalCrypto{Sign: bob}
 228  
 229  	// Generate key packages.
 230  	aliceKPP, err := GenerateKeyPackage(aliceProxy)
 231  	require.NoError(t, err)
 232  	bobKPP, err := GenerateKeyPackage(bobLocal)
 233  	require.NoError(t, err)
 234  
 235  	// Alice creates group with Bob's KP.
 236  	_, welcome, _, err := CreateDMGroup(aliceKPP, &bobKPP.Public, alice.Pub(), bob.Pub(), nil)
 237  	require.NoError(t, err)
 238  
 239  	// Alice gift-wraps the welcome via ProxyCrypto.
 240  	wrapEv, err := WelcomeToGiftWrap(welcome, bob.Pub(), aliceProxy, nil, nil)
 241  	require.NoError(t, err)
 242  	assert.Equal(t, KindGiftWrap, wrapEv.Kind)
 243  
 244  	// Bob unwraps with LocalCrypto.
 245  	unwrapped, err := UnwrapWelcome(wrapEv, bobLocal)
 246  	require.NoError(t, err)
 247  
 248  	// Sender should be Alice's real pubkey from the seal layer.
 249  	assert.True(t, bytes.Equal(unwrapped.SenderPub, alice.Pub()),
 250  		"sender should be Alice, got %s", hex.Enc(unwrapped.SenderPub))
 251  	assert.NotNil(t, unwrapped.Welcome)
 252  }
 253  
 254  func TestProxyCrypto_ClientE2E(t *testing.T) {
 255  	relay := newMockRelay()
 256  	alice := generateSigner(t)
 257  	bob := generateSigner(t)
 258  
 259  	// Alice uses ProxyCrypto (simulated extension).
 260  	aliceProxy, cancelProxy := simulatedExtension(t, alice)
 261  	defer cancelProxy()
 262  
 263  	aliceClient, err := NewClient(aliceProxy, NewMemoryGroupStore(), relay)
 264  	require.NoError(t, err)
 265  
 266  	// Bob uses LocalCrypto.
 267  	bobClient, err := NewClient(&LocalCrypto{Sign: bob}, NewMemoryGroupStore(), relay)
 268  	require.NoError(t, err)
 269  
 270  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
 271  	defer cancel()
 272  
 273  	// Bob publishes key package.
 274  	require.NoError(t, bobClient.PublishKeyPackage(ctx))
 275  
 276  	// Bob listens for events.
 277  	received := make(chan struct {
 278  		sender    []byte
 279  		plaintext []byte
 280  	}, 1)
 281  	bobClient.OnDM(func(senderPub []byte, plaintext []byte) {
 282  		received <- struct {
 283  			sender    []byte
 284  			plaintext []byte
 285  		}{senderPub, plaintext}
 286  	})
 287  
 288  	go func() {
 289  		stream, err := relay.Subscribe(ctx, bobClient.SubscriptionFilters())
 290  		if err != nil {
 291  			return
 292  		}
 293  		defer stream.Close()
 294  		for {
 295  			select {
 296  			case ev := <-stream.Events():
 297  				_ = bobClient.HandleEvent(ctx, ev)
 298  			case <-ctx.Done():
 299  				return
 300  			}
 301  		}
 302  	}()
 303  
 304  	time.Sleep(50 * time.Millisecond)
 305  
 306  	// Alice sends DM to Bob via ProxyCrypto.
 307  	err = aliceClient.SendDM(ctx, bob.Pub(), []byte("hello from proxy"))
 308  	require.NoError(t, err)
 309  
 310  	select {
 311  	case msg := <-received:
 312  		assert.Equal(t, []byte("hello from proxy"), msg.plaintext)
 313  		assert.True(t, bytes.Equal(msg.sender, alice.Pub()),
 314  			"sender should be Alice's real pubkey, got %s", hex.Enc(msg.sender))
 315  	case <-time.After(5 * time.Second):
 316  		t.Fatal("timeout waiting for DM")
 317  	}
 318  }
 319