keys_test.go raw

   1  // Copyright 2013-2022 The btcsuite developers
   2  
   3  package musig2
   4  
   5  import (
   6  	"encoding/json"
   7  	"fmt"
   8  	"os"
   9  	"path"
  10  	"strings"
  11  	"testing"
  12  
  13  	"next.orly.dev/pkg/nostr/crypto/ec"
  14  	"next.orly.dev/pkg/nostr/crypto/ec/schnorr"
  15  	"next.orly.dev/pkg/nostr/crypto/ec/secp256k1"
  16  	"next.orly.dev/pkg/nostr/encoders/hex"
  17  
  18  	"github.com/stretchr/testify/require"
  19  )
  20  
  21  const (
  22  	keySortTestVectorFileName  = "key_sort_vectors.json"
  23  	keyAggTestVectorFileName   = "key_agg_vectors.json"
  24  	keyTweakTestVectorFileName = "tweak_vectors.json"
  25  )
  26  
  27  type keySortTestVector struct {
  28  	PubKeys    []string `json:"pubkeys"`
  29  	SortedKeys []string `json:"sorted_pubkeys"`
  30  }
  31  
  32  // TestMusig2KeySort tests that keys are properly sorted according to the
  33  // musig2 test vectors.
  34  func TestMusig2KeySort(t *testing.T) {
  35  	t.Parallel()
  36  	testVectorPath := path.Join(
  37  		testVectorBaseDir, keySortTestVectorFileName,
  38  	)
  39  	testVectorBytes, err := os.ReadFile(testVectorPath)
  40  	require.NoError(t, err)
  41  	var testCase keySortTestVector
  42  	require.NoError(t, json.Unmarshal(testVectorBytes, &testCase))
  43  	keys := make([]*btcec.PublicKey, len(testCase.PubKeys))
  44  	for i, keyStr := range testCase.PubKeys {
  45  		pubKey, err := btcec.ParsePubKey(mustParseHex(keyStr))
  46  		require.NoError(t, err)
  47  		keys[i] = pubKey
  48  	}
  49  	sortedKeys := sortKeys(keys)
  50  	expectedKeys := make([]*btcec.PublicKey, len(testCase.PubKeys))
  51  	for i, keyStr := range testCase.SortedKeys {
  52  		pubKey, err := btcec.ParsePubKey(mustParseHex(keyStr))
  53  		require.NoError(t, err)
  54  		expectedKeys[i] = pubKey
  55  	}
  56  	require.Equal(t, sortedKeys, expectedKeys)
  57  }
  58  
  59  type keyAggValidTest struct {
  60  	Indices  []int  `json:"key_indices"`
  61  	Expected string `json:"expected"`
  62  }
  63  
  64  type keyAggError struct {
  65  	Type     string `json:"type"`
  66  	Signer   int    `json:"signer"`
  67  	Contring string `json:"contrib"`
  68  }
  69  
  70  type keyAggInvalidTest struct {
  71  	Indices      []int  `json:"key_indices"`
  72  	TweakIndices []int  `json:"tweak_indices"`
  73  	IsXOnly      []bool `json:"is_xonly"`
  74  	Comment      string `json:"comment"`
  75  }
  76  
  77  type keyAggTestVectors struct {
  78  	PubKeys      []string            `json:"pubkeys"`
  79  	Tweaks       []string            `json:"tweaks"`
  80  	ValidCases   []keyAggValidTest   `json:"valid_test_cases"`
  81  	InvalidCases []keyAggInvalidTest `json:"error_test_cases"`
  82  }
  83  
  84  func keysFromIndices(
  85  	t *testing.T, indices []int,
  86  	pubKeys []string,
  87  ) ([]*btcec.PublicKey, error) {
  88  	t.Helper()
  89  	inputKeys := make([]*btcec.PublicKey, len(indices))
  90  	for i, keyIdx := range indices {
  91  		var err error
  92  		inputKeys[i], err = btcec.ParsePubKey(
  93  			mustParseHex(pubKeys[keyIdx]),
  94  		)
  95  		if err != nil {
  96  			return nil, err
  97  		}
  98  	}
  99  	return inputKeys, nil
 100  }
 101  
 102  func tweaksFromIndices(
 103  	t *testing.T, indices []int,
 104  	tweaks []string, isXonly []bool,
 105  ) []KeyTweakDesc {
 106  
 107  	t.Helper()
 108  	testTweaks := make([]KeyTweakDesc, len(indices))
 109  	for i, idx := range indices {
 110  		var rawTweak [32]byte
 111  		copy(rawTweak[:], mustParseHex(tweaks[idx]))
 112  		testTweaks[i] = KeyTweakDesc{
 113  			Tweak:   rawTweak,
 114  			IsXOnly: isXonly[i],
 115  		}
 116  	}
 117  	return testTweaks
 118  }
 119  
 120  // TestMuSig2KeyAggTestVectors tests that this implementation of musig2 key
 121  // aggregation lines up with the secp256k1-zkp test vectors.
 122  func TestMuSig2KeyAggTestVectors(t *testing.T) {
 123  	t.Parallel()
 124  	testVectorPath := path.Join(
 125  		testVectorBaseDir, keyAggTestVectorFileName,
 126  	)
 127  	testVectorBytes, err := os.ReadFile(testVectorPath)
 128  	require.NoError(t, err)
 129  	var testCases keyAggTestVectors
 130  	require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
 131  	tweaks := make([][]byte, len(testCases.Tweaks))
 132  	for i := range testCases.Tweaks {
 133  		tweaks[i] = mustParseHex(testCases.Tweaks[i])
 134  	}
 135  	for i, testCase := range testCases.ValidCases {
 136  		testCase := testCase
 137  		// Assemble the set of keys we'll pass in based on their key
 138  		// index. We don't use sorting to ensure we send the keys in
 139  		// the exact same order as the test vectors do.
 140  		inputKeys, err := keysFromIndices(
 141  			t, testCase.Indices, testCases.PubKeys,
 142  		)
 143  		require.NoError(t, err)
 144  		t.Run(
 145  			fmt.Sprintf("test_case=%v", i), func(t *testing.T) {
 146  				uniqueKeyIndex := secondUniqueKeyIndex(inputKeys, false)
 147  				opts := []KeyAggOption{WithUniqueKeyIndex(uniqueKeyIndex)}
 148  
 149  				combinedKey, _, _, err := AggregateKeys(
 150  					inputKeys, false, opts...,
 151  				)
 152  				require.NoError(t, err)
 153  
 154  				require.Equal(
 155  					t, schnorr.SerializePubKey(combinedKey.FinalKey),
 156  					mustParseHex(testCase.Expected),
 157  				)
 158  			},
 159  		)
 160  	}
 161  	for _, testCase := range testCases.InvalidCases {
 162  		testCase := testCase
 163  		testName := fmt.Sprintf(
 164  			"invalid_%v",
 165  			strings.ToLower(testCase.Comment),
 166  		)
 167  		t.Run(
 168  			testName, func(t *testing.T) {
 169  				// For each test, we'll extract the set of input keys
 170  				// as well as the tweaks since this set of cases also
 171  				// exercises error cases related to the set of tweaks.
 172  				inputKeys, err := keysFromIndices(
 173  					t, testCase.Indices, testCases.PubKeys,
 174  				)
 175  				// In this set of test cases, we should only get this
 176  				// for the very first vector.
 177  				if err != nil {
 178  					switch testCase.Comment {
 179  					case "Invalid public key":
 180  						require.ErrorIs(
 181  							t, err,
 182  							secp256k1.ErrPubKeyNotOnCurve,
 183  						)
 184  					case "Public key exceeds field size":
 185  						require.ErrorIs(
 186  							t, err, secp256k1.ErrPubKeyXTooBig,
 187  						)
 188  					case "First byte of public key is not 2 or 3":
 189  						require.ErrorIs(
 190  							t, err,
 191  							secp256k1.ErrPubKeyInvalidFormat,
 192  						)
 193  					default:
 194  						t.Fatalf("uncaught err: %v", err)
 195  					}
 196  					return
 197  				}
 198  				var tweaks []KeyTweakDesc
 199  				if len(testCase.TweakIndices) != 0 {
 200  					tweaks = tweaksFromIndices(
 201  						t, testCase.TweakIndices, testCases.Tweaks,
 202  						testCase.IsXOnly,
 203  					)
 204  				}
 205  				uniqueKeyIndex := secondUniqueKeyIndex(inputKeys, false)
 206  				opts := []KeyAggOption{
 207  					WithUniqueKeyIndex(uniqueKeyIndex),
 208  				}
 209  				if len(tweaks) != 0 {
 210  					opts = append(opts, WithKeyTweaks(tweaks...))
 211  				}
 212  				_, _, _, err = AggregateKeys(
 213  					inputKeys, false, opts...,
 214  				)
 215  				require.Error(t, err)
 216  				switch testCase.Comment {
 217  				case "Tweak is out of range":
 218  					require.ErrorIs(t, err, ErrTweakedKeyOverflows)
 219  				case "Intermediate tweaking result is point at infinity":
 220  					require.ErrorIs(t, err, ErrTweakedKeyIsInfinity)
 221  				default:
 222  					t.Fatalf("uncaught err: %v", err)
 223  				}
 224  			},
 225  		)
 226  	}
 227  }
 228  
 229  type keyTweakInvalidTest struct {
 230  	Indices      []int  `json:"key_indices"`
 231  	NonceIndices []int  `json:"nonce_indices"`
 232  	TweakIndices []int  `json:"tweak_indices"`
 233  	IsXOnly      []bool `json:"is_only"`
 234  	SignerIndex  int    `json:"signer_index"`
 235  	Comment      string `json:"comment"`
 236  }
 237  
 238  type keyTweakValidTest struct {
 239  	Indices      []int  `json:"key_indices"`
 240  	NonceIndices []int  `json:"nonce_indices"`
 241  	TweakIndices []int  `json:"tweak_indices"`
 242  	IsXOnly      []bool `json:"is_xonly"`
 243  	SignerIndex  int    `json:"signer_index"`
 244  	Expected     string `json:"expected"`
 245  	Comment      string `json:"comment"`
 246  }
 247  
 248  type keyTweakVector struct {
 249  	SecKey       string                `json:"sk"`
 250  	PubKeys      []string              `json:"pubkeys"`
 251  	PrivNonce    string                `json:"secnonce"`
 252  	PubNonces    []string              `json:"pnonces"`
 253  	AggNnoce     string                `json:"aggnonce"`
 254  	Tweaks       []string              `json:"tweaks"`
 255  	Msg          string                `json:"msg"`
 256  	ValidCases   []keyTweakValidTest   `json:"valid_test_cases"`
 257  	InvalidCases []keyTweakInvalidTest `json:"error_test_cases"`
 258  }
 259  
 260  func pubNoncesFromIndices(
 261  	t *testing.T, nonceIndices []int,
 262  	pubNonces []string,
 263  ) [][PubNonceSize]byte {
 264  
 265  	nonces := make([][PubNonceSize]byte, len(nonceIndices))
 266  	for i, idx := range nonceIndices {
 267  		var pubNonce [PubNonceSize]byte
 268  		copy(pubNonce[:], mustParseHex(pubNonces[idx]))
 269  		nonces[i] = pubNonce
 270  	}
 271  	return nonces
 272  }
 273  
 274  // TestMuSig2TweakTestVectors tests that we properly handle the various edge
 275  // cases related to tweaking public keys.
 276  func TestMuSig2TweakTestVectors(t *testing.T) {
 277  	t.Parallel()
 278  	testVectorPath := path.Join(
 279  		testVectorBaseDir, keyTweakTestVectorFileName,
 280  	)
 281  	testVectorBytes, err := os.ReadFile(testVectorPath)
 282  	require.NoError(t, err)
 283  	var testCases keyTweakVector
 284  	require.NoError(t, json.Unmarshal(testVectorBytes, &testCases))
 285  	privKey, _ := btcec.SecKeyFromBytes(mustParseHex(testCases.SecKey))
 286  	var msg [32]byte
 287  	copy(msg[:], mustParseHex(testCases.Msg))
 288  	var secNonce [SecNonceSize]byte
 289  	copy(secNonce[:], mustParseHex(testCases.PrivNonce))
 290  	for _, testCase := range testCases.ValidCases {
 291  		testCase := testCase
 292  		testName := fmt.Sprintf(
 293  			"valid_%v",
 294  			strings.ToLower(testCase.Comment),
 295  		)
 296  		t.Run(
 297  			testName, func(t *testing.T) {
 298  				pubKeys, err := keysFromIndices(
 299  					t, testCase.Indices, testCases.PubKeys,
 300  				)
 301  				require.NoError(t, err)
 302  				var tweaks []KeyTweakDesc
 303  				if len(testCase.TweakIndices) != 0 {
 304  					tweaks = tweaksFromIndices(
 305  						t, testCase.TweakIndices,
 306  						testCases.Tweaks, testCase.IsXOnly,
 307  					)
 308  				}
 309  				pubNonces := pubNoncesFromIndices(
 310  					t, testCase.NonceIndices, testCases.PubNonces,
 311  				)
 312  				combinedNonce, err := AggregateNonces(pubNonces)
 313  				require.NoError(t, err)
 314  				var opts []SignOption
 315  				if len(tweaks) != 0 {
 316  					opts = append(opts, WithTweaks(tweaks...))
 317  				}
 318  				partialSig, err := Sign(
 319  					secNonce, privKey, combinedNonce, pubKeys,
 320  					msg, opts...,
 321  				)
 322  				var partialSigBytes [32]byte
 323  				partialSig.S.PutBytesUnchecked(partialSigBytes[:])
 324  				require.Equal(
 325  					t, hex.Enc(partialSigBytes[:]),
 326  					hex.Enc(mustParseHex(testCase.Expected)),
 327  				)
 328  
 329  			},
 330  		)
 331  	}
 332  }
 333