payment_test.go raw

   1  package bridge
   2  
   3  import (
   4  	"context"
   5  	"encoding/json"
   6  	"fmt"
   7  	"testing"
   8  	"time"
   9  )
  10  
  11  // mockNWC implements NWCRequester for testing.
  12  type mockNWC struct {
  13  	responses map[string]any    // method -> response to marshal into result
  14  	errors    map[string]error  // method -> error to return
  15  	callCount map[string]int    // method -> number of calls
  16  }
  17  
  18  func newMockNWC() *mockNWC {
  19  	return &mockNWC{
  20  		responses: make(map[string]any),
  21  		errors:    make(map[string]error),
  22  		callCount: make(map[string]int),
  23  	}
  24  }
  25  
  26  func (m *mockNWC) Request(ctx context.Context, method string, params, result any) error {
  27  	m.callCount[method]++
  28  	if err, ok := m.errors[method]; ok {
  29  		return err
  30  	}
  31  	if resp, ok := m.responses[method]; ok {
  32  		data, _ := json.Marshal(resp)
  33  		return json.Unmarshal(data, result)
  34  	}
  35  	return nil
  36  }
  37  
  38  func TestNewPaymentProcessorWithClient(t *testing.T) {
  39  	mock := newMockNWC()
  40  	pp := NewPaymentProcessorWithClient(mock, 2100)
  41  	if pp == nil {
  42  		t.Fatal("expected non-nil PaymentProcessor")
  43  	}
  44  	if pp.monthlyPriceMSats != 2100000 {
  45  		t.Errorf("monthlyPriceMSats = %d, want 2100000", pp.monthlyPriceMSats)
  46  	}
  47  }
  48  
  49  func TestCreateSubscriptionInvoice(t *testing.T) {
  50  	mock := newMockNWC()
  51  	mock.responses["make_invoice"] = map[string]any{
  52  		"invoice":      "lnbc21000n1...",
  53  		"payment_hash": "abc123",
  54  		"amount":       2100000,
  55  		"description":  "test",
  56  	}
  57  
  58  	pp := NewPaymentProcessorWithClient(mock, 2100)
  59  	inv, err := pp.CreateSubscriptionInvoice(context.Background())
  60  	if err != nil {
  61  		t.Fatalf("unexpected error: %v", err)
  62  	}
  63  	if inv.Bolt11 != "lnbc21000n1..." {
  64  		t.Errorf("Bolt11 = %q, want lnbc21000n1...", inv.Bolt11)
  65  	}
  66  	if inv.PaymentHash != "abc123" {
  67  		t.Errorf("PaymentHash = %q, want abc123", inv.PaymentHash)
  68  	}
  69  	if mock.callCount["make_invoice"] != 1 {
  70  		t.Errorf("callCount = %d, want 1", mock.callCount["make_invoice"])
  71  	}
  72  }
  73  
  74  func TestCreateSubscriptionInvoice_Error(t *testing.T) {
  75  	mock := newMockNWC()
  76  	mock.errors["make_invoice"] = fmt.Errorf("wallet offline")
  77  
  78  	pp := NewPaymentProcessorWithClient(mock, 2100)
  79  	_, err := pp.CreateSubscriptionInvoice(context.Background())
  80  	if err == nil {
  81  		t.Fatal("expected error")
  82  	}
  83  	if !contains(err.Error(), "make_invoice") {
  84  		t.Errorf("error = %q, want to contain 'make_invoice'", err)
  85  	}
  86  }
  87  
  88  func TestLookupInvoice_Paid(t *testing.T) {
  89  	mock := newMockNWC()
  90  	mock.responses["lookup_invoice"] = map[string]any{
  91  		"invoice":      "lnbc...",
  92  		"payment_hash": "abc123",
  93  		"amount":       2100000,
  94  		"settled_at":   1700000000,
  95  		"preimage":     "deadbeef",
  96  	}
  97  
  98  	pp := NewPaymentProcessorWithClient(mock, 2100)
  99  	status, err := pp.LookupInvoice(context.Background(), "abc123")
 100  	if err != nil {
 101  		t.Fatalf("unexpected error: %v", err)
 102  	}
 103  	if !status.IsPaid {
 104  		t.Error("expected IsPaid=true")
 105  	}
 106  	if status.Preimage != "deadbeef" {
 107  		t.Errorf("Preimage = %q, want deadbeef", status.Preimage)
 108  	}
 109  }
 110  
 111  func TestLookupInvoice_Unpaid(t *testing.T) {
 112  	mock := newMockNWC()
 113  	mock.responses["lookup_invoice"] = map[string]any{
 114  		"invoice":      "lnbc...",
 115  		"payment_hash": "abc123",
 116  		"amount":       2100000,
 117  	}
 118  
 119  	pp := NewPaymentProcessorWithClient(mock, 2100)
 120  	status, err := pp.LookupInvoice(context.Background(), "abc123")
 121  	if err != nil {
 122  		t.Fatalf("unexpected error: %v", err)
 123  	}
 124  	if status.IsPaid {
 125  		t.Error("expected IsPaid=false")
 126  	}
 127  }
 128  
 129  func TestLookupInvoice_Error(t *testing.T) {
 130  	mock := newMockNWC()
 131  	mock.errors["lookup_invoice"] = fmt.Errorf("not found")
 132  
 133  	pp := NewPaymentProcessorWithClient(mock, 2100)
 134  	_, err := pp.LookupInvoice(context.Background(), "abc123")
 135  	if err == nil {
 136  		t.Fatal("expected error")
 137  	}
 138  }
 139  
 140  func TestLookupInvoice_PaidByPreimageOnly(t *testing.T) {
 141  	mock := newMockNWC()
 142  	mock.responses["lookup_invoice"] = map[string]any{
 143  		"invoice":      "lnbc...",
 144  		"payment_hash": "abc123",
 145  		"preimage":     "some_preimage",
 146  	}
 147  
 148  	pp := NewPaymentProcessorWithClient(mock, 2100)
 149  	status, err := pp.LookupInvoice(context.Background(), "abc123")
 150  	if err != nil {
 151  		t.Fatalf("unexpected error: %v", err)
 152  	}
 153  	if !status.IsPaid {
 154  		t.Error("expected IsPaid=true when preimage present")
 155  	}
 156  }
 157  
 158  func TestWaitForPayment_ImmediatelyPaid(t *testing.T) {
 159  	mock := newMockNWC()
 160  	mock.responses["lookup_invoice"] = map[string]any{
 161  		"payment_hash": "abc123",
 162  		"settled_at":   1700000000,
 163  	}
 164  
 165  	pp := NewPaymentProcessorWithClient(mock, 2100)
 166  	status, err := pp.WaitForPayment(context.Background(), "abc123", 10*time.Millisecond)
 167  	if err != nil {
 168  		t.Fatalf("unexpected error: %v", err)
 169  	}
 170  	if !status.IsPaid {
 171  		t.Error("expected IsPaid=true")
 172  	}
 173  }
 174  
 175  func TestWaitForPayment_Cancelled(t *testing.T) {
 176  	mock := newMockNWC()
 177  	mock.responses["lookup_invoice"] = map[string]any{
 178  		"payment_hash": "abc123",
 179  	}
 180  
 181  	pp := NewPaymentProcessorWithClient(mock, 2100)
 182  	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Millisecond)
 183  	defer cancel()
 184  
 185  	_, err := pp.WaitForPayment(ctx, "abc123", 10*time.Millisecond)
 186  	if err == nil {
 187  		t.Fatal("expected error from cancelled context")
 188  	}
 189  }
 190  
 191  func TestWaitForPayment_DefaultInterval(t *testing.T) {
 192  	mock := newMockNWC()
 193  	mock.responses["lookup_invoice"] = map[string]any{
 194  		"payment_hash": "abc123",
 195  		"settled_at":   1700000000,
 196  	}
 197  
 198  	pp := NewPaymentProcessorWithClient(mock, 2100)
 199  	// Pass 0 interval to trigger default
 200  	ctx, cancel := context.WithTimeout(context.Background(), 6*time.Second)
 201  	defer cancel()
 202  
 203  	status, err := pp.WaitForPayment(ctx, "abc123", 0)
 204  	if err != nil {
 205  		t.Fatalf("unexpected error: %v", err)
 206  	}
 207  	if !status.IsPaid {
 208  		t.Error("expected IsPaid=true")
 209  	}
 210  }
 211  
 212  func TestWaitForPayment_TransientErrors(t *testing.T) {
 213  	callCount := 0
 214  	mock := &mockNWC{
 215  		responses: make(map[string]any),
 216  		errors:    make(map[string]error),
 217  		callCount: make(map[string]int),
 218  	}
 219  	// Override Request to return errors first, then success
 220  	originalRequest := mock.Request
 221  	_ = originalRequest
 222  	transientMock := &transientNWC{
 223  		failCount: 2,
 224  		successResp: map[string]any{
 225  			"payment_hash": "abc123",
 226  			"settled_at":   1700000000,
 227  		},
 228  		calls: &callCount,
 229  	}
 230  
 231  	pp := NewPaymentProcessorWithClient(transientMock, 2100)
 232  	status, err := pp.WaitForPayment(context.Background(), "abc123", 10*time.Millisecond)
 233  	if err != nil {
 234  		t.Fatalf("unexpected error: %v", err)
 235  	}
 236  	if !status.IsPaid {
 237  		t.Error("expected IsPaid=true after transient errors")
 238  	}
 239  	if callCount < 3 {
 240  		t.Errorf("expected at least 3 calls, got %d", callCount)
 241  	}
 242  }
 243  
 244  type transientNWC struct {
 245  	failCount   int
 246  	successResp map[string]any
 247  	calls       *int
 248  }
 249  
 250  func (m *transientNWC) Request(ctx context.Context, method string, params, result any) error {
 251  	*m.calls++
 252  	if *m.calls <= m.failCount {
 253  		return fmt.Errorf("transient error %d", *m.calls)
 254  	}
 255  	data, _ := json.Marshal(m.successResp)
 256  	return json.Unmarshal(data, result)
 257  }
 258  
 259  func contains(s, substr string) bool {
 260  	return len(s) >= len(substr) && (s == substr || len(s) > 0 && containsHelper(s, substr))
 261  }
 262  
 263  func containsHelper(s, substr string) bool {
 264  	for i := 0; i <= len(s)-len(substr); i++ {
 265  		if s[i:i+len(substr)] == substr {
 266  			return true
 267  		}
 268  	}
 269  	return false
 270  }
 271