sign_test.go raw

   1  // Copyright 2013-2022 The btcsuite developers
   2  
   3  package musig2
   4  
   5  import (
   6  	"bytes"
   7  	"encoding/json"
   8  	"fmt"
   9  	"os"
  10  	"path"
  11  	"strings"
  12  	"testing"
  13  
  14  	"next.orly.dev/pkg/nostr/crypto/ec"
  15  	"next.orly.dev/pkg/nostr/crypto/ec/secp256k1"
  16  	"next.orly.dev/pkg/nostr/encoders/hex"
  17  	"github.com/stretchr/testify/require"
  18  )
  19  
  20  const (
  21  	signVerifyTestVectorFileName = "sign_verify_vectors.json"
  22  	sigCombineTestVectorFileName = "sig_agg_vectors.json"
  23  )
  24  
  25  type signVerifyValidCase struct {
  26  	Indices       []int  `json:"key_indices"`
  27  	NonceIndices  []int  `json:"nonce_indices"`
  28  	AggNonceIndex int    `json:"aggnonce_index"`
  29  	MsgIndex      int    `json:"msg_index"`
  30  	SignerIndex   int    `json:"signer_index"`
  31  	Expected      string `json:"expected"`
  32  }
  33  
  34  type signErrorCase struct {
  35  	Indices       []int  `json:"key_indices"`
  36  	AggNonceIndex int    `json:"aggnonce_index"`
  37  	MsgIndex      int    `json:"msg_index"`
  38  	SecNonceIndex int    `json:"secnonce_index"`
  39  	Comment       string `json:"comment"`
  40  }
  41  
  42  type verifyFailCase struct {
  43  	Sig          string `json:"sig"`
  44  	Indices      []int  `json:"key_indices"`
  45  	NonceIndices []int  `json:"nonce_indices"`
  46  	MsgIndex     int    `json:"msg_index"`
  47  	SignerIndex  int    `json:"signer_index"`
  48  	Comment      string `json:"comment"`
  49  }
  50  
  51  type verifyErrorCase struct {
  52  	Sig          string `json:"sig"`
  53  	Indices      []int  `json:"key_indices"`
  54  	NonceIndices []int  `json:"nonce_indices"`
  55  	MsgIndex     int    `json:"msg_index"`
  56  	SignerIndex  int    `json:"signer_index"`
  57  	Comment      string `json:"comment"`
  58  }
  59  
  60  type signVerifyTestVectors struct {
  61  	SecKey           string                `json:"sk"`
  62  	PubKeys          []string              `json:"pubkeys"`
  63  	PrivNonces       []string              `json:"secnonces"`
  64  	PubNonces        []string              `json:"pnonces"`
  65  	AggNonces        []string              `json:"aggnonces"`
  66  	Msgs             []string              `json:"msgs"`
  67  	ValidCases       []signVerifyValidCase `json:"valid_test_cases"`
  68  	SignErrorCases   []signErrorCase       `json:"sign_error_test_cases"`
  69  	VerifyFailCases  []verifyFailCase      `json:"verify_fail_test_cases"`
  70  	VerifyErrorCases []verifyErrorCase     `json:"verify_error_test_cases"`
  71  }
  72  
  73  // TestMusig2SignVerify tests that we pass the musig2 verification tests.
  74  func TestMusig2SignVerify(t *testing.T) {
  75  	t.Parallel()
  76  	testVectorPath := path.Join(
  77  		testVectorBaseDir, signVerifyTestVectorFileName,
  78  	)
  79  	testVectorBytes, err := os.ReadFile(testVectorPath)
  80  	require.NoError(t, err)
  81  	var testCases signVerifyTestVectors
  82  	require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
  83  	privKey, _ := btcec.SecKeyFromBytes(mustParseHex(testCases.SecKey))
  84  	for i, testCase := range testCases.ValidCases {
  85  		testCase := testCase
  86  		testName := fmt.Sprintf("valid_case_%v", i)
  87  		t.Run(
  88  			testName, func(t *testing.T) {
  89  				pubKeys, err := keysFromIndices(
  90  					t, testCase.Indices, testCases.PubKeys,
  91  				)
  92  				require.NoError(t, err)
  93  				pubNonces := pubNoncesFromIndices(
  94  					t, testCase.NonceIndices, testCases.PubNonces,
  95  				)
  96  				combinedNonce, err := AggregateNonces(pubNonces)
  97  				require.NoError(t, err)
  98  				var msg [32]byte
  99  				copy(msg[:], mustParseHex(testCases.Msgs[testCase.MsgIndex]))
 100  				var secNonce [SecNonceSize]byte
 101  				copy(secNonce[:], mustParseHex(testCases.PrivNonces[0]))
 102  				partialSig, err := Sign(
 103  					secNonce, privKey, combinedNonce, pubKeys,
 104  					msg,
 105  				)
 106  				var partialSigBytes [32]byte
 107  				partialSig.S.PutBytesUnchecked(partialSigBytes[:])
 108  				require.Equal(
 109  					t, hex.Enc(partialSigBytes[:]),
 110  					hex.Enc(mustParseHex(testCase.Expected)),
 111  				)
 112  			},
 113  		)
 114  	}
 115  	for _, testCase := range testCases.SignErrorCases {
 116  		testCase := testCase
 117  		testName := fmt.Sprintf(
 118  			"invalid_case_%v",
 119  			strings.ToLower(testCase.Comment),
 120  		)
 121  		t.Run(
 122  			testName, func(t *testing.T) {
 123  				pubKeys, err := keysFromIndices(
 124  					t, testCase.Indices, testCases.PubKeys,
 125  				)
 126  				if err != nil {
 127  					require.ErrorIs(t, err, secp256k1.ErrPubKeyNotOnCurve)
 128  					return
 129  				}
 130  				var aggNonce [PubNonceSize]byte
 131  				copy(
 132  					aggNonce[:],
 133  					mustParseHex(
 134  						testCases.AggNonces[testCase.AggNonceIndex],
 135  					),
 136  				)
 137  				var msg [32]byte
 138  				copy(msg[:], mustParseHex(testCases.Msgs[testCase.MsgIndex]))
 139  				var secNonce [SecNonceSize]byte
 140  				copy(
 141  					secNonce[:],
 142  					mustParseHex(
 143  						testCases.PrivNonces[testCase.SecNonceIndex],
 144  					),
 145  				)
 146  				_, err = Sign(
 147  					secNonce, privKey, aggNonce, pubKeys,
 148  					msg,
 149  				)
 150  				require.Error(t, err)
 151  			},
 152  		)
 153  	}
 154  	for _, testCase := range testCases.VerifyFailCases {
 155  		testCase := testCase
 156  		testName := fmt.Sprintf(
 157  			"verify_fail_%v",
 158  			strings.ToLower(testCase.Comment),
 159  		)
 160  		t.Run(
 161  			testName, func(t *testing.T) {
 162  				pubKeys, err := keysFromIndices(
 163  					t, testCase.Indices, testCases.PubKeys,
 164  				)
 165  				require.NoError(t, err)
 166  				pubNonces := pubNoncesFromIndices(
 167  					t, testCase.NonceIndices, testCases.PubNonces,
 168  				)
 169  				combinedNonce, err := AggregateNonces(pubNonces)
 170  				require.NoError(t, err)
 171  				var msg [32]byte
 172  				copy(
 173  					msg[:],
 174  					mustParseHex(testCases.Msgs[testCase.MsgIndex]),
 175  				)
 176  				var secNonce [SecNonceSize]byte
 177  				copy(secNonce[:], mustParseHex(testCases.PrivNonces[0]))
 178  				signerNonce := secNonceToPubNonce(secNonce)
 179  				var partialSig PartialSignature
 180  				err = partialSig.Decode(
 181  					bytes.NewReader(mustParseHex(testCase.Sig)),
 182  				)
 183  				if err != nil && strings.Contains(
 184  					testCase.Comment, "group size",
 185  				) {
 186  					require.ErrorIs(t, err, ErrPartialSigInvalid)
 187  				}
 188  				err = verifyPartialSig(
 189  					&partialSig, signerNonce, combinedNonce,
 190  					pubKeys, privKey.PubKey().SerializeCompressed(),
 191  					msg,
 192  				)
 193  				require.Error(t, err)
 194  			},
 195  		)
 196  	}
 197  
 198  	for _, testCase := range testCases.VerifyErrorCases {
 199  		testCase := testCase
 200  		testName := fmt.Sprintf(
 201  			"verify_error_%v",
 202  			strings.ToLower(testCase.Comment),
 203  		)
 204  		t.Run(
 205  			testName, func(t *testing.T) {
 206  				switch testCase.Comment {
 207  				case "Invalid pubnonce":
 208  					pubNonces := pubNoncesFromIndices(
 209  						t, testCase.NonceIndices, testCases.PubNonces,
 210  					)
 211  					_, err := AggregateNonces(pubNonces)
 212  					require.ErrorIs(t, err, secp256k1.ErrPubKeyNotOnCurve)
 213  
 214  				case "Invalid pubkey":
 215  					_, err := keysFromIndices(
 216  						t, testCase.Indices, testCases.PubKeys,
 217  					)
 218  					require.ErrorIs(t, err, secp256k1.ErrPubKeyNotOnCurve)
 219  
 220  				default:
 221  					t.Fatalf("unhandled case: %v", testCase.Comment)
 222  				}
 223  			},
 224  		)
 225  	}
 226  
 227  }
 228  
 229  type sigCombineValidCase struct {
 230  	AggNonce     string `json:"aggnonce"`
 231  	NonceIndices []int  `json:"nonce_indices"`
 232  	Indices      []int  `json:"key_indices"`
 233  	TweakIndices []int  `json:"tweak_indices"`
 234  	IsXOnly      []bool `json:"is_xonly"`
 235  	PSigIndices  []int  `json:"psig_indices"`
 236  	Expected     string `json:"expected"`
 237  }
 238  
 239  type sigCombineTestVectors struct {
 240  	PubKeys    []string              `json:"pubkeys"`
 241  	PubNonces  []string              `json:"pnonces"`
 242  	Tweaks     []string              `json:"tweaks"`
 243  	Psigs      []string              `json:"psigs"`
 244  	Msg        string                `json:"msg"`
 245  	ValidCases []sigCombineValidCase `json:"valid_test_cases"`
 246  }
 247  
 248  func pSigsFromIndicies(
 249  	t *testing.T, sigs []string,
 250  	indices []int,
 251  ) []*PartialSignature {
 252  	pSigs := make([]*PartialSignature, len(indices))
 253  	for i, idx := range indices {
 254  		var pSig PartialSignature
 255  		err := pSig.Decode(bytes.NewReader(mustParseHex(sigs[idx])))
 256  		require.NoError(t, err)
 257  		pSigs[i] = &pSig
 258  	}
 259  	return pSigs
 260  }
 261  
 262  // TestMusig2SignCombine tests that we pass the musig2 sig combination tests.
 263  func TestMusig2SignCombine(t *testing.T) {
 264  	t.Parallel()
 265  	testVectorPath := path.Join(
 266  		testVectorBaseDir, sigCombineTestVectorFileName,
 267  	)
 268  	testVectorBytes, err := os.ReadFile(testVectorPath)
 269  	require.NoError(t, err)
 270  	var testCases sigCombineTestVectors
 271  	require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
 272  	var msg [32]byte
 273  	copy(msg[:], mustParseHex(testCases.Msg))
 274  	for i, testCase := range testCases.ValidCases {
 275  		testCase := testCase
 276  		testName := fmt.Sprintf("valid_case_%v", i)
 277  		t.Run(
 278  			testName, func(t *testing.T) {
 279  				pubKeys, err := keysFromIndices(
 280  					t, testCase.Indices, testCases.PubKeys,
 281  				)
 282  				require.NoError(t, err)
 283  				pubNonces := pubNoncesFromIndices(
 284  					t, testCase.NonceIndices, testCases.PubNonces,
 285  				)
 286  				partialSigs := pSigsFromIndicies(
 287  					t, testCases.Psigs, testCase.PSigIndices,
 288  				)
 289  				var (
 290  					combineOpts []CombineOption
 291  					keyOpts     []KeyAggOption
 292  				)
 293  				if len(testCase.TweakIndices) > 0 {
 294  					tweaks := tweaksFromIndices(
 295  						t, testCase.TweakIndices,
 296  						testCases.Tweaks, testCase.IsXOnly,
 297  					)
 298  					combineOpts = append(
 299  						combineOpts, WithTweakedCombine(
 300  							msg, pubKeys, tweaks, false,
 301  						),
 302  					)
 303  					keyOpts = append(keyOpts, WithKeyTweaks(tweaks...))
 304  				}
 305  				combinedKey, _, _, err := AggregateKeys(
 306  					pubKeys, false, keyOpts...,
 307  				)
 308  				require.NoError(t, err)
 309  				combinedNonce, err := AggregateNonces(pubNonces)
 310  				require.NoError(t, err)
 311  				finalNonceJ, _, err := computeSigningNonce(
 312  					combinedNonce, combinedKey.FinalKey, msg,
 313  				)
 314  				finalNonceJ.ToAffine()
 315  				finalNonce := btcec.NewPublicKey(
 316  					&finalNonceJ.X, &finalNonceJ.Y,
 317  				)
 318  				combinedSig := CombineSigs(
 319  					finalNonce, partialSigs, combineOpts...,
 320  				)
 321  				require.Equal(
 322  					t,
 323  					strings.ToLower(testCase.Expected),
 324  					hex.Enc(combinedSig.Serialize()),
 325  				)
 326  			},
 327  		)
 328  	}
 329  }
 330