gnarl_test.go raw

   1  package gnarl
   2  
   3  import (
   4  	"math/big"
   5  	"testing"
   6  )
   7  
   8  func TestPrimeProperties(t *testing.T) {
   9  	p, q, n := Params()
  10  
  11  	// P should be 216 bits.
  12  	if p.BitLen() != 216 {
  13  		t.Fatalf("P bits = %d, want 216", p.BitLen())
  14  	}
  15  
  16  	// Q should be ~213 bits.
  17  	if q.BitLen() < 212 || q.BitLen() > 214 {
  18  		t.Fatalf("Q bits = %d, want ~213", q.BitLen())
  19  	}
  20  
  21  	// P should be prime.
  22  	if !p.ProbablyPrime(20) {
  23  		t.Fatal("P is not prime")
  24  	}
  25  
  26  	// Q should be prime.
  27  	if !q.ProbablyPrime(20) {
  28  		t.Fatal("Q is not prime")
  29  	}
  30  
  31  	// P ≡ 2 mod 3.
  32  	if new(big.Int).Mod(p, big.NewInt(3)).Int64() != 2 {
  33  		t.Fatalf("P mod 3 = %d, want 2", new(big.Int).Mod(p, big.NewInt(3)).Int64())
  34  	}
  35  
  36  	// P mod 5 ∈ {2, 3}.
  37  	pm5 := new(big.Int).Mod(p, big.NewInt(5)).Int64()
  38  	if pm5 != 2 && pm5 != 3 {
  39  		t.Fatalf("P mod 5 = %d, want 2 or 3", pm5)
  40  	}
  41  
  42  	// N = P + 1.
  43  	if new(big.Int).Add(p, big.NewInt(1)).Cmp(n) != 0 {
  44  		t.Fatal("N != P + 1")
  45  	}
  46  
  47  	// N = 6Q.
  48  	sixQ := new(big.Int).Mul(big.NewInt(6), q)
  49  	if n.Cmp(sixQ) != 0 {
  50  		t.Fatal("N != 6Q")
  51  	}
  52  
  53  	// 5 is QNR mod P.
  54  	pm1 := new(big.Int).Sub(p, big.NewInt(1))
  55  	exp := new(big.Int).Rsh(pm1, 1)
  56  	euler := new(big.Int).Exp(big.NewInt(5), exp, p)
  57  	if euler.Cmp(pm1) != 0 {
  58  		t.Fatal("5 is not QNR mod P")
  59  	}
  60  }
  61  
  62  func TestDeterministicDerivation(t *testing.T) {
  63  	// Running Params twice should give the same values (it's deterministic from seed).
  64  	p1, q1, _ := Params()
  65  	p2, q2, _ := Params()
  66  	if p1.Cmp(p2) != 0 {
  67  		t.Fatal("P is not deterministic")
  68  	}
  69  	if q1.Cmp(q2) != 0 {
  70  		t.Fatal("Q is not deterministic")
  71  	}
  72  }
  73  
  74  func TestFieldArithmeticIdentity(t *testing.T) {
  75  	// 1 * 1 = 1 in Montgomery form.
  76  	var r fe
  77  	montMul(&r, &feOne, &feOne)
  78  	if feEqual(&r, &feOne) != 1 {
  79  		t.Fatal("1 * 1 != 1")
  80  	}
  81  
  82  	// 0 + 1 = 1.
  83  	feAdd(&r, &feZero, &feOne)
  84  	if feEqual(&r, &feOne) != 1 {
  85  		t.Fatal("0 + 1 != 1")
  86  	}
  87  
  88  	// 1 - 1 = 0.
  89  	feSub(&r, &feOne, &feOne)
  90  	if feIsZero(&r) != 1 {
  91  		t.Fatal("1 - 1 != 0")
  92  	}
  93  
  94  	// -0 = 0.
  95  	feNeg(&r, &feZero)
  96  	if feIsZero(&r) != 1 {
  97  		t.Fatal("-0 != 0")
  98  	}
  99  }
 100  
 101  func TestMontgomeryRoundTrip(t *testing.T) {
 102  	// Convert 42 to Montgomery form and back.
 103  	var a, b fe
 104  	a = fe{42, 0, 0, 0}
 105  	feToMont(&b, &a)
 106  	feFromMont(&a, &b)
 107  	if a[0] != 42 || a[1] != 0 || a[2] != 0 || a[3] != 0 {
 108  		t.Fatalf("round-trip failed: got %v", a)
 109  	}
 110  }
 111  
 112  func TestFieldMulCommutativity(t *testing.T) {
 113  	var a, b, ab, ba fe
 114  	feFromSmall(&a, 7)
 115  	feFromSmall(&b, 13)
 116  	montMul(&ab, &a, &b)
 117  	montMul(&ba, &b, &a)
 118  	if feEqual(&ab, &ba) != 1 {
 119  		t.Fatal("a*b != b*a")
 120  	}
 121  
 122  	// Check that 7 * 13 = 91.
 123  	var expected fe
 124  	feFromSmall(&expected, 91)
 125  	if feEqual(&ab, &expected) != 1 {
 126  		t.Fatal("7 * 13 != 91")
 127  	}
 128  }
 129  
 130  func TestFieldInverse(t *testing.T) {
 131  	var a, ainv, r fe
 132  	feFromSmall(&a, 42)
 133  	feInv(&ainv, &a)
 134  	montMul(&r, &a, &ainv)
 135  	if feEqual(&r, &feOne) != 1 {
 136  		t.Fatal("42 * 42^{-1} != 1")
 137  	}
 138  }
 139  
 140  func TestFieldSqrt(t *testing.T) {
 141  	// Compute 7^2 = 49, then sqrt(49) should be ±7.
 142  	var seven, square, root, neg7 fe
 143  	feFromSmall(&seven, 7)
 144  	montSquare(&square, &seven)
 145  
 146  	if !feSqrt(&root, &square) {
 147  		t.Fatal("sqrt(49) failed")
 148  	}
 149  
 150  	// root should be 7 or P-7.
 151  	feNeg(&neg7, &seven)
 152  	if feEqual(&root, &seven) != 1 && feEqual(&root, &neg7) != 1 {
 153  		t.Fatal("sqrt(49) != ±7")
 154  	}
 155  }
 156  
 157  func TestFieldBytes27RoundTrip(t *testing.T) {
 158  	var a, b fe
 159  	feFromSmall(&a, 12345)
 160  	var buf [27]byte
 161  	feToBytes27(buf[:], &a)
 162  
 163  	if !feFromBytes27(&b, buf[:]) {
 164  		t.Fatal("feFromBytes27 rejected valid input")
 165  	}
 166  	if feEqual(&a, &b) != 1 {
 167  		t.Fatal("round-trip failed")
 168  	}
 169  }
 170  
 171  func TestScalarArithmetic(t *testing.T) {
 172  	a := scalar{7, 0, 0, 0}
 173  	b := scalar{13, 0, 0, 0}
 174  
 175  	// 7 + 13 = 20.
 176  	var sum scalar
 177  	scAdd(&sum, &a, &b)
 178  	if sum[0] != 20 || sum[1] != 0 || sum[2] != 0 || sum[3] != 0 {
 179  		t.Fatalf("7 + 13 = %v, want 20", sum)
 180  	}
 181  
 182  	// 13 - 7 = 6.
 183  	var diff scalar
 184  	scSub(&diff, &b, &a)
 185  	if diff[0] != 6 || diff[1] != 0 || diff[2] != 0 || diff[3] != 0 {
 186  		t.Fatalf("13 - 7 = %v, want 6", diff)
 187  	}
 188  
 189  	// 7 * 13 = 91.
 190  	var prod scalar
 191  	scMul(&prod, &a, &b)
 192  	if prod[0] != 91 || prod[1] != 0 || prod[2] != 0 || prod[3] != 0 {
 193  		t.Fatalf("7 * 13 = %v, want 91", prod)
 194  	}
 195  }
 196  
 197  func TestScalarBytes27RoundTrip(t *testing.T) {
 198  	a := scalar{0xdeadbeef, 0xcafebabe, 0x12345678, 0}
 199  	var buf [27]byte
 200  	scToBytes27(buf[:], &a)
 201  
 202  	var b scalar
 203  	scFromBytes27(&b, buf[:])
 204  	if a != b {
 205  		t.Fatalf("round-trip: got %v, want %v", b, a)
 206  	}
 207  }
 208  
 209  func TestGeneratorOrder(t *testing.T) {
 210  	// SchnorrGen^Q should be the identity.
 211  	_, q, _ := Params()
 212  
 213  	var m4gen, result mat4
 214  	tmToMat4(&m4gen, &schnorrGenTM)
 215  	m4PowBig(&result, &m4gen, q)
 216  
 217  	if !m4IsIdentity(&result) {
 218  		t.Fatal("SchnorrGen^Q != I")
 219  	}
 220  }
 221  
 222  func TestGeneratorNotIdentity(t *testing.T) {
 223  	if tmIsIdentity(&schnorrGenTM) {
 224  		t.Fatal("SchnorrGen is identity")
 225  	}
 226  }
 227  
 228  func TestGeneratorDetOne(t *testing.T) {
 229  	// det(SchnorrGen) should be 1.
 230  	var m4gen mat4
 231  	tmToMat4(&m4gen, &schnorrGenTM)
 232  	var det fe
 233  	m4Det(&det, &m4gen)
 234  	if feEqual(&det, &feOne) != 1 {
 235  		t.Fatal("det(SchnorrGen) != 1")
 236  	}
 237  }
 238  
 239  func TestKeyGeneration(t *testing.T) {
 240  	sk, pk, err := GenerateKey()
 241  	if err != nil {
 242  		t.Fatal(err)
 243  	}
 244  
 245  	// Private key should be non-zero.
 246  	if scIsZero(&sk.s) {
 247  		t.Fatal("zero private key")
 248  	}
 249  
 250  	// Public key should not be identity.
 251  	if tmIsIdentity(&pk.tm) {
 252  		t.Fatal("identity public key")
 253  	}
 254  
 255  	// Public key should have det = 1.
 256  	var m4pk mat4
 257  	tmToMat4(&m4pk, &pk.tm)
 258  	var det fe
 259  	m4Det(&det, &m4pk)
 260  	if feEqual(&det, &feOne) != 1 {
 261  		t.Fatal("det(PK) != 1")
 262  	}
 263  }
 264  
 265  func TestKeySerializationRoundTrip(t *testing.T) {
 266  	sk, _, err := GenerateKey()
 267  	if err != nil {
 268  		t.Fatal(err)
 269  	}
 270  
 271  	// Private key round-trip.
 272  	skBytes := sk.Bytes()
 273  	if len(skBytes) != 27 {
 274  		t.Fatalf("sk bytes = %d, want 27", len(skBytes))
 275  	}
 276  	sk2, err := PrivateKeyFromBytes(skBytes)
 277  	if err != nil {
 278  		t.Fatal(err)
 279  	}
 280  	if sk.s != sk2.s {
 281  		t.Logf("sk.s  = %016x %016x %016x %016x", sk.s[3], sk.s[2], sk.s[1], sk.s[0])
 282  		t.Logf("sk2.s = %016x %016x %016x %016x", sk2.s[3], sk2.s[2], sk2.s[1], sk2.s[0])
 283  		t.Logf("bytes = %x", skBytes)
 284  		t.Fatal("private key round-trip failed")
 285  	}
 286  }
 287  
 288  // mockGMid is a stand-in challenge hash for testing without importing pkg/crypto.
 289  func mockGMid(data []byte) [27]byte {
 290  	var result [27]byte
 291  	// FNV-1a-like hash spread across 27 bytes.
 292  	var state uint64 = 14695981039346656037
 293  	for _, b := range data {
 294  		state ^= uint64(b)
 295  		state *= 1099511628211
 296  	}
 297  	for i := range result {
 298  		state ^= uint64(i)
 299  		state *= 1099511628211
 300  		result[i] = byte(state >> 32)
 301  	}
 302  	return result
 303  }
 304  
 305  func TestSignVerifyRoundTrip(t *testing.T) {
 306  	sk, pk, err := GenerateKey()
 307  	if err != nil {
 308  		t.Fatal(err)
 309  	}
 310  
 311  	msg := []byte("test message for gnarl signature")
 312  
 313  	sig, err := Sign(sk, msg, mockGMid)
 314  	if err != nil {
 315  		t.Fatal(err)
 316  	}
 317  
 318  	if !Verify(pk, msg, sig, mockGMid) {
 319  		t.Fatal("valid signature rejected")
 320  	}
 321  }
 322  
 323  func TestSignVerifyTamperedMessage(t *testing.T) {
 324  	sk, pk, err := GenerateKey()
 325  	if err != nil {
 326  		t.Fatal(err)
 327  	}
 328  
 329  	msg := []byte("original message")
 330  	sig, err := Sign(sk, msg, mockGMid)
 331  	if err != nil {
 332  		t.Fatal(err)
 333  	}
 334  
 335  	tampered := []byte("tampered message")
 336  	if Verify(pk, tampered, sig, mockGMid) {
 337  		t.Fatal("tampered message accepted")
 338  	}
 339  }
 340  
 341  func TestSignVerifyWrongKey(t *testing.T) {
 342  	sk, _, err := GenerateKey()
 343  	if err != nil {
 344  		t.Fatal(err)
 345  	}
 346  	_, pk2, err := GenerateKey()
 347  	if err != nil {
 348  		t.Fatal(err)
 349  	}
 350  
 351  	msg := []byte("test message")
 352  	sig, err := Sign(sk, msg, mockGMid)
 353  	if err != nil {
 354  		t.Fatal(err)
 355  	}
 356  
 357  	if Verify(pk2, msg, sig, mockGMid) {
 358  		t.Fatal("wrong key accepted")
 359  	}
 360  }
 361  
 362  func TestSignatureSerialization(t *testing.T) {
 363  	sk, _, err := GenerateKey()
 364  	if err != nil {
 365  		t.Fatal(err)
 366  	}
 367  
 368  	msg := []byte("serialize me")
 369  	sig, err := Sign(sk, msg, mockGMid)
 370  	if err != nil {
 371  		t.Fatal(err)
 372  	}
 373  
 374  	sigBytes := sig.Bytes()
 375  	if len(sigBytes) != 54 {
 376  		t.Fatalf("signature bytes = %d, want 54", len(sigBytes))
 377  	}
 378  
 379  	sig2, err := SignatureFromBytes(sigBytes)
 380  	if err != nil {
 381  		t.Fatal(err)
 382  	}
 383  
 384  	if sig.e != sig2.e {
 385  		t.Fatal("challenge mismatch after round-trip")
 386  	}
 387  	if sig.z != sig2.z {
 388  		t.Fatal("response mismatch after round-trip")
 389  	}
 390  }
 391  
 392  func TestPublicKeyYBytesLength(t *testing.T) {
 393  	_, pk, err := GenerateKey()
 394  	if err != nil {
 395  		t.Fatal(err)
 396  	}
 397  
 398  	yBytes := pk.YBytes()
 399  	if len(yBytes) != 27 {
 400  		t.Fatalf("y-only pubkey = %d bytes, want 27", len(yBytes))
 401  	}
 402  }
 403  
 404  func TestTorusMatrixDetInvariant(t *testing.T) {
 405  	// Generate a few keys and check det = 1 for all intermediate operations.
 406  	for i := 0; i < 5; i++ {
 407  		sk, _, err := GenerateKey()
 408  		if err != nil {
 409  			t.Fatal(err)
 410  		}
 411  
 412  		var pkTM tmat
 413  		tmFixedExp(&pkTM, &sk.s)
 414  
 415  		var m4pk mat4
 416  		tmToMat4(&m4pk, &pkTM)
 417  		var det fe
 418  		m4Det(&det, &m4pk)
 419  		if feEqual(&det, &feOne) != 1 {
 420  			t.Fatalf("iteration %d: det != 1", i)
 421  		}
 422  	}
 423  }
 424  
 425  // --- Benchmarks ---
 426  
 427  func BenchmarkMontMul(b *testing.B) {
 428  	var a, c fe
 429  	feFromSmall(&a, 42)
 430  	feFromSmall(&c, 7)
 431  	b.ResetTimer()
 432  	for i := 0; i < b.N; i++ {
 433  		montMul(&a, &a, &c)
 434  	}
 435  }
 436  
 437  func BenchmarkMontSquare(b *testing.B) {
 438  	var a fe
 439  	feFromSmall(&a, 42)
 440  	b.ResetTimer()
 441  	for i := 0; i < b.N; i++ {
 442  		montSquare(&a, &a)
 443  	}
 444  }
 445  
 446  func BenchmarkTmMul(b *testing.B) {
 447  	sk, _, err := GenerateKey()
 448  	if err != nil {
 449  		b.Fatal(err)
 450  	}
 451  	var pk tmat
 452  	tmFixedExp(&pk, &sk.s)
 453  	b.ResetTimer()
 454  	for i := 0; i < b.N; i++ {
 455  		tmMul(&pk, &pk, &schnorrGenTM)
 456  	}
 457  }
 458  
 459  func BenchmarkTmSquare(b *testing.B) {
 460  	sk, _, err := GenerateKey()
 461  	if err != nil {
 462  		b.Fatal(err)
 463  	}
 464  	var pk tmat
 465  	tmFixedExp(&pk, &sk.s)
 466  	b.ResetTimer()
 467  	for i := 0; i < b.N; i++ {
 468  		tmSquare(&pk, &pk)
 469  	}
 470  }
 471  
 472  func BenchmarkFixedExp(b *testing.B) {
 473  	var s scalar
 474  	scRandom(&s)
 475  	var r tmat
 476  	b.ResetTimer()
 477  	for i := 0; i < b.N; i++ {
 478  		tmFixedExp(&r, &s)
 479  	}
 480  }
 481  
 482  func BenchmarkShamirExp(b *testing.B) {
 483  	sk, pk, err := GenerateKey()
 484  	if err != nil {
 485  		b.Fatal(err)
 486  	}
 487  	var s2 scalar
 488  	scRandom(&s2)
 489  	var r tmat
 490  	b.ResetTimer()
 491  	for i := 0; i < b.N; i++ {
 492  		tmShamirExp(&r, &sk.s, &pk.tm, &s2)
 493  	}
 494  }
 495  
 496  func BenchmarkSign(b *testing.B) {
 497  	sk, _, err := GenerateKey()
 498  	if err != nil {
 499  		b.Fatal(err)
 500  	}
 501  	msg := []byte("benchmark message for signing")
 502  	b.ResetTimer()
 503  	for i := 0; i < b.N; i++ {
 504  		_, err := Sign(sk, msg, mockGMid)
 505  		if err != nil {
 506  			b.Fatal(err)
 507  		}
 508  	}
 509  }
 510  
 511  func BenchmarkVerify(b *testing.B) {
 512  	sk, pk, err := GenerateKey()
 513  	if err != nil {
 514  		b.Fatal(err)
 515  	}
 516  	msg := []byte("benchmark message for verification")
 517  	sig, err := Sign(sk, msg, mockGMid)
 518  	if err != nil {
 519  		b.Fatal(err)
 520  	}
 521  	b.ResetTimer()
 522  	for i := 0; i < b.N; i++ {
 523  		if !Verify(pk, msg, sig, mockGMid) {
 524  			b.Fatal("verification failed")
 525  		}
 526  	}
 527  }
 528