nonces_test.go raw

   1  // Copyright 2013-2022 The btcsuite developers
   2  
   3  package musig2
   4  
   5  import (
   6  	"encoding/json"
   7  	"fmt"
   8  	"os"
   9  	"path"
  10  	"testing"
  11  
  12  	"next.orly.dev/pkg/nostr/encoders/hex"
  13  	"next.orly.dev/pkg/nostr/utils"
  14  	"github.com/stretchr/testify/require"
  15  )
  16  
  17  type nonceGenTestCase struct {
  18  	Rand     string  `json:"rand_"`
  19  	Sk       string  `json:"sk"`
  20  	AggPk    string  `json:"aggpk"`
  21  	Msg      *string `json:"msg"`
  22  	ExtraIn  string  `json:"extra_in"`
  23  	Pk       string  `json:"pk"`
  24  	Expected string  `json:"expected"`
  25  }
  26  
  27  type nonceGenTestCases struct {
  28  	TestCases []nonceGenTestCase `json:"test_cases"`
  29  }
  30  
  31  const (
  32  	nonceGenTestVectorsFileName = "nonce_gen_vectors.json"
  33  	nonceAggTestVectorsFileName = "nonce_agg_vectors.json"
  34  )
  35  
  36  // TestMusig2NonceGenTestVectors tests the nonce generation function with the
  37  // testvectors defined in the Musig2 BIP.
  38  func TestMusig2NonceGenTestVectors(t *testing.T) {
  39  	t.Parallel()
  40  	testVectorPath := path.Join(
  41  		testVectorBaseDir, nonceGenTestVectorsFileName,
  42  	)
  43  	testVectorBytes, err := os.ReadFile(testVectorPath)
  44  	require.NoError(t, err)
  45  	var testCases nonceGenTestCases
  46  	require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
  47  	for i, testCase := range testCases.TestCases {
  48  		testCase := testCase
  49  		customOpts := nonceGenOpts{
  50  			randReader:  &memsetRandReader{i: 0},
  51  			secretKey:   mustParseHex(testCase.Sk),
  52  			combinedKey: mustParseHex(testCase.AggPk),
  53  			auxInput:    mustParseHex(testCase.ExtraIn),
  54  			publicKey:   mustParseHex(testCase.Pk),
  55  		}
  56  		if testCase.Msg != nil {
  57  			customOpts.msg = mustParseHex(*testCase.Msg)
  58  		}
  59  		t.Run(
  60  			fmt.Sprintf("test_case=%v", i), func(t *testing.T) {
  61  				nonce, err := GenNonces(withCustomOptions(customOpts))
  62  				if err != nil {
  63  					t.Fatalf("err gen nonce aux bytes %v", err)
  64  				}
  65  				expectedBytes, _ := hex.Dec(testCase.Expected)
  66  				if !utils.FastEqual(nonce.SecNonce[:], expectedBytes) {
  67  					t.Fatalf(
  68  						"nonces don't match: expected %x, got %x",
  69  						expectedBytes, nonce.SecNonce[:],
  70  					)
  71  				}
  72  			},
  73  		)
  74  	}
  75  }
  76  
  77  type nonceAggError struct {
  78  	Type    string `json:"type"`
  79  	Signer  int    `json:"signer"`
  80  	Contrib string `json:"contrib"`
  81  }
  82  
  83  type nonceAggValidCase struct {
  84  	Indices  []int  `json:"pnonce_indices"`
  85  	Expected string `json:"expected"`
  86  	Comment  string `json:"comment"`
  87  }
  88  
  89  type nonceAggInvalidCase struct {
  90  	Indices     []int         `json:"pnonce_indices"`
  91  	Error       nonceAggError `json:"error"`
  92  	Comment     string        `json:"comment"`
  93  	ExpectedErr string        `json:"btcec_err"`
  94  }
  95  
  96  type nonceAggTestCases struct {
  97  	Nonces       []string              `json:"pnonces"`
  98  	ValidCases   []nonceAggValidCase   `json:"valid_test_cases"`
  99  	InvalidCases []nonceAggInvalidCase `json:"error_test_cases"`
 100  }
 101  
 102  // TestMusig2AggregateNoncesTestVectors tests that the musig2 implementation
 103  // passes the nonce aggregration test vectors for musig2 1.0.
 104  func TestMusig2AggregateNoncesTestVectors(t *testing.T) {
 105  	t.Parallel()
 106  	testVectorPath := path.Join(
 107  		testVectorBaseDir, nonceAggTestVectorsFileName,
 108  	)
 109  	testVectorBytes, err := os.ReadFile(testVectorPath)
 110  	require.NoError(t, err)
 111  	var testCases nonceAggTestCases
 112  	require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
 113  	nonces := make([][PubNonceSize]byte, len(testCases.Nonces))
 114  	for i := range testCases.Nonces {
 115  		var nonce [PubNonceSize]byte
 116  		copy(nonce[:], mustParseHex(testCases.Nonces[i]))
 117  		nonces[i] = nonce
 118  	}
 119  	for i, testCase := range testCases.ValidCases {
 120  		testCase := testCase
 121  		var testNonces [][PubNonceSize]byte
 122  		for _, idx := range testCase.Indices {
 123  			testNonces = append(testNonces, nonces[idx])
 124  		}
 125  		t.Run(
 126  			fmt.Sprintf("valid_case=%v", i), func(t *testing.T) {
 127  				aggregatedNonce, err := AggregateNonces(testNonces)
 128  				require.NoError(t, err)
 129  				var expectedNonce [PubNonceSize]byte
 130  				copy(expectedNonce[:], mustParseHex(testCase.Expected))
 131  				require.Equal(t, aggregatedNonce[:], expectedNonce[:])
 132  			},
 133  		)
 134  	}
 135  
 136  	for i, testCase := range testCases.InvalidCases {
 137  		var testNonces [][PubNonceSize]byte
 138  		for _, idx := range testCase.Indices {
 139  			testNonces = append(testNonces, nonces[idx])
 140  		}
 141  		t.Run(
 142  			fmt.Sprintf("invalid_case=%v", i), func(t *testing.T) {
 143  				_, err := AggregateNonces(testNonces)
 144  				require.True(t, err != nil)
 145  				require.Equal(t, testCase.ExpectedErr, err.Error())
 146  			},
 147  		)
 148  	}
 149  }
 150