avx_test.go raw

   1  package avx
   2  
   3  import (
   4  	"bytes"
   5  	"crypto/rand"
   6  	"encoding/hex"
   7  	"testing"
   8  )
   9  
  10  // Test vectors from Bitcoin/secp256k1
  11  
  12  func TestUint128Add(t *testing.T) {
  13  	tests := []struct {
  14  		a, b   Uint128
  15  		expect Uint128
  16  		carry  uint64
  17  	}{
  18  		{Uint128{0, 0}, Uint128{0, 0}, Uint128{0, 0}, 0},
  19  		{Uint128{1, 0}, Uint128{1, 0}, Uint128{2, 0}, 0},
  20  		{Uint128{^uint64(0), 0}, Uint128{1, 0}, Uint128{0, 1}, 0},
  21  		{Uint128{^uint64(0), ^uint64(0)}, Uint128{1, 0}, Uint128{0, 0}, 1},
  22  	}
  23  
  24  	for i, tt := range tests {
  25  		result, carry := tt.a.Add(tt.b)
  26  		if result != tt.expect || carry != tt.carry {
  27  			t.Errorf("test %d: got (%v, %d), want (%v, %d)", i, result, carry, tt.expect, tt.carry)
  28  		}
  29  	}
  30  }
  31  
  32  func TestUint128Mul(t *testing.T) {
  33  	// Test: 2^64 * 2^64 = 2^128
  34  	a := Uint128{0, 1} // 2^64
  35  	b := Uint128{0, 1} // 2^64
  36  	result := a.Mul(b)
  37  	// Expected: 2^128 = [0, 0, 1, 0]
  38  	expected := [4]uint64{0, 0, 1, 0}
  39  	if result != expected {
  40  		t.Errorf("2^64 * 2^64: got %v, want %v", result, expected)
  41  	}
  42  
  43  	// Test: (2^64 - 1) * (2^64 - 1)
  44  	a = Uint128{^uint64(0), 0}
  45  	b = Uint128{^uint64(0), 0}
  46  	result = a.Mul(b)
  47  	// (2^64 - 1)^2 = 2^128 - 2^65 + 1
  48  	// = [1, 0xFFFFFFFFFFFFFFFE, 0, 0]
  49  	expected = [4]uint64{1, 0xFFFFFFFFFFFFFFFE, 0, 0}
  50  	if result != expected {
  51  		t.Errorf("(2^64-1)^2: got %v, want %v", result, expected)
  52  	}
  53  }
  54  
  55  func TestScalarSetBytes(t *testing.T) {
  56  	// Test with a known scalar
  57  	bytes32 := make([]byte, 32)
  58  	bytes32[31] = 1 // scalar = 1
  59  
  60  	var s Scalar
  61  	s.SetBytes(bytes32)
  62  
  63  	if !s.IsOne() {
  64  		t.Errorf("expected scalar to be 1, got %+v", s)
  65  	}
  66  
  67  	// Test zero
  68  	bytes32 = make([]byte, 32)
  69  	s.SetBytes(bytes32)
  70  	if !s.IsZero() {
  71  		t.Errorf("expected scalar to be 0, got %+v", s)
  72  	}
  73  }
  74  
  75  func TestScalarAddSub(t *testing.T) {
  76  	var a, b, sum, diff, recovered Scalar
  77  
  78  	// a = 1, b = 2
  79  	a = ScalarOne
  80  	b.D[0].Lo = 2
  81  
  82  	sum.Add(&a, &b)
  83  	if sum.D[0].Lo != 3 {
  84  		t.Errorf("1 + 2: expected 3, got %d", sum.D[0].Lo)
  85  	}
  86  
  87  	diff.Sub(&sum, &b)
  88  	if !diff.Equal(&a) {
  89  		t.Errorf("(1+2) - 2: expected 1, got %+v", diff)
  90  	}
  91  
  92  	// Test with overflow
  93  	a = ScalarN
  94  	a.D[0].Lo-- // n - 1
  95  	b = ScalarOne
  96  
  97  	sum.Add(&a, &b)
  98  	// n - 1 + 1 = n ≡ 0 (mod n)
  99  	if !sum.IsZero() {
 100  		t.Errorf("(n-1) + 1 should be 0 mod n, got %+v", sum)
 101  	}
 102  
 103  	// Test subtraction with borrow
 104  	a = ScalarZero
 105  	b = ScalarOne
 106  	diff.Sub(&a, &b)
 107  	// 0 - 1 = -1 ≡ n - 1 (mod n)
 108  	recovered.Add(&diff, &b)
 109  	if !recovered.IsZero() {
 110  		t.Errorf("(0-1) + 1 should be 0, got %+v", recovered)
 111  	}
 112  }
 113  
 114  func TestScalarMul(t *testing.T) {
 115  	var a, b, product Scalar
 116  
 117  	// 2 * 3 = 6
 118  	a.D[0].Lo = 2
 119  	b.D[0].Lo = 3
 120  	product.Mul(&a, &b)
 121  	if product.D[0].Lo != 6 || product.D[0].Hi != 0 || !product.D[1].IsZero() {
 122  		t.Errorf("2 * 3: expected 6, got %+v", product)
 123  	}
 124  
 125  	// Test with larger values
 126  	a.D[0].Lo = 0xFFFFFFFFFFFFFFFF
 127  	a.D[0].Hi = 0
 128  	b.D[0].Lo = 2
 129  	product.Mul(&a, &b)
 130  	// (2^64 - 1) * 2 = 2^65 - 2
 131  	if product.D[0].Lo != 0xFFFFFFFFFFFFFFFE || product.D[0].Hi != 1 {
 132  		t.Errorf("(2^64-1) * 2: got %+v", product)
 133  	}
 134  }
 135  
 136  func TestScalarNegate(t *testing.T) {
 137  	var a, neg, sum Scalar
 138  
 139  	a.D[0].Lo = 12345
 140  	neg.Negate(&a)
 141  	sum.Add(&a, &neg)
 142  
 143  	if !sum.IsZero() {
 144  		t.Errorf("a + (-a) should be 0, got %+v", sum)
 145  	}
 146  }
 147  
 148  func TestFieldSetBytes(t *testing.T) {
 149  	bytes32 := make([]byte, 32)
 150  	bytes32[31] = 1
 151  
 152  	var f FieldElement
 153  	f.SetBytes(bytes32)
 154  
 155  	if !f.IsOne() {
 156  		t.Errorf("expected field element to be 1, got %+v", f)
 157  	}
 158  }
 159  
 160  func TestFieldAddSub(t *testing.T) {
 161  	var a, b, sum, diff FieldElement
 162  
 163  	a.N[0].Lo = 100
 164  	b.N[0].Lo = 200
 165  
 166  	sum.Add(&a, &b)
 167  	if sum.N[0].Lo != 300 {
 168  		t.Errorf("100 + 200: expected 300, got %d", sum.N[0].Lo)
 169  	}
 170  
 171  	diff.Sub(&sum, &b)
 172  	if !diff.Equal(&a) {
 173  		t.Errorf("(100+200) - 200: expected 100, got %+v", diff)
 174  	}
 175  }
 176  
 177  func TestFieldMul(t *testing.T) {
 178  	var a, b, product FieldElement
 179  
 180  	a.N[0].Lo = 7
 181  	b.N[0].Lo = 8
 182  	product.Mul(&a, &b)
 183  	if product.N[0].Lo != 56 {
 184  		t.Errorf("7 * 8: expected 56, got %d", product.N[0].Lo)
 185  	}
 186  }
 187  
 188  func TestFieldInverse(t *testing.T) {
 189  	var a, inv, product FieldElement
 190  
 191  	a.N[0].Lo = 7
 192  	inv.Inverse(&a)
 193  	product.Mul(&a, &inv)
 194  
 195  	if !product.IsOne() {
 196  		t.Errorf("7 * 7^(-1) should be 1, got %+v", product)
 197  	}
 198  }
 199  
 200  func TestFieldSqrt(t *testing.T) {
 201  	// Test sqrt(4) = 2
 202  	var four, root, check FieldElement
 203  	four.N[0].Lo = 4
 204  
 205  	if !root.Sqrt(&four) {
 206  		t.Fatal("sqrt(4) should exist")
 207  	}
 208  
 209  	check.Sqr(&root)
 210  	if !check.Equal(&four) {
 211  		t.Errorf("sqrt(4)^2 should be 4, got %+v", check)
 212  	}
 213  }
 214  
 215  func TestGeneratorOnCurve(t *testing.T) {
 216  	if !Generator.IsOnCurve() {
 217  		t.Error("generator point should be on the curve")
 218  	}
 219  }
 220  
 221  func TestPointDouble(t *testing.T) {
 222  	var g, doubled JacobianPoint
 223  	var affineResult AffinePoint
 224  
 225  	g.FromAffine(&Generator)
 226  	doubled.Double(&g)
 227  	doubled.ToAffine(&affineResult)
 228  
 229  	if affineResult.Infinity {
 230  		t.Error("2G should not be infinity")
 231  	}
 232  
 233  	if !affineResult.IsOnCurve() {
 234  		t.Error("2G should be on the curve")
 235  	}
 236  }
 237  
 238  func TestPointAdd(t *testing.T) {
 239  	var g, twoG, threeG JacobianPoint
 240  	var affineResult AffinePoint
 241  
 242  	g.FromAffine(&Generator)
 243  	twoG.Double(&g)
 244  	threeG.Add(&twoG, &g)
 245  	threeG.ToAffine(&affineResult)
 246  
 247  	if !affineResult.IsOnCurve() {
 248  		t.Error("3G should be on the curve")
 249  	}
 250  
 251  	// Also test via scalar multiplication
 252  	var three Scalar
 253  	three.D[0].Lo = 3
 254  
 255  	var expected JacobianPoint
 256  	expected.ScalarMult(&g, &three)
 257  	var expectedAffine AffinePoint
 258  	expected.ToAffine(&expectedAffine)
 259  
 260  	if !affineResult.Equal(&expectedAffine) {
 261  		t.Error("G + 2G should equal 3G")
 262  	}
 263  }
 264  
 265  func TestPointAddInfinity(t *testing.T) {
 266  	var g, inf, result JacobianPoint
 267  	var affineResult AffinePoint
 268  
 269  	g.FromAffine(&Generator)
 270  	inf.SetInfinity()
 271  
 272  	result.Add(&g, &inf)
 273  	result.ToAffine(&affineResult)
 274  
 275  	if !affineResult.Equal(&Generator) {
 276  		t.Error("G + O should equal G")
 277  	}
 278  
 279  	result.Add(&inf, &g)
 280  	result.ToAffine(&affineResult)
 281  
 282  	if !affineResult.Equal(&Generator) {
 283  		t.Error("O + G should equal G")
 284  	}
 285  }
 286  
 287  func TestScalarBaseMult(t *testing.T) {
 288  	// Test 1*G = G
 289  	result := BasePointMult(&ScalarOne)
 290  	if !result.Equal(&Generator) {
 291  		t.Error("1*G should equal G")
 292  	}
 293  
 294  	// Test 2*G
 295  	var two Scalar
 296  	two.D[0].Lo = 2
 297  	result = BasePointMult(&two)
 298  
 299  	var g, twoG JacobianPoint
 300  	var expected AffinePoint
 301  	g.FromAffine(&Generator)
 302  	twoG.Double(&g)
 303  	twoG.ToAffine(&expected)
 304  
 305  	if !result.Equal(&expected) {
 306  		t.Error("2*G via scalar mult should equal 2*G via doubling")
 307  	}
 308  }
 309  
 310  func TestKnownScalarMult(t *testing.T) {
 311  	// Test vector: private key and public key from Bitcoin
 312  	// This is a well-known test vector
 313  	privKeyHex := "0000000000000000000000000000000000000000000000000000000000000001"
 314  	expectedXHex := "79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798"
 315  	expectedYHex := "483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8"
 316  
 317  	privKeyBytes, _ := hex.DecodeString(privKeyHex)
 318  	var k Scalar
 319  	k.SetBytes(privKeyBytes)
 320  
 321  	result := BasePointMult(&k)
 322  
 323  	xBytes := result.X.Bytes()
 324  	yBytes := result.Y.Bytes()
 325  
 326  	expectedX, _ := hex.DecodeString(expectedXHex)
 327  	expectedY, _ := hex.DecodeString(expectedYHex)
 328  
 329  	if !bytes.Equal(xBytes[:], expectedX) {
 330  		t.Errorf("X coordinate mismatch:\ngot:  %x\nwant: %x", xBytes, expectedX)
 331  	}
 332  	if !bytes.Equal(yBytes[:], expectedY) {
 333  		t.Errorf("Y coordinate mismatch:\ngot:  %x\nwant: %x", yBytes, expectedY)
 334  	}
 335  }
 336  
 337  // Benchmark tests
 338  
 339  func BenchmarkUint128Mul(b *testing.B) {
 340  	a := Uint128{0x123456789ABCDEF0, 0xFEDCBA9876543210}
 341  	c := Uint128{0xABCDEF0123456789, 0x9876543210FEDCBA}
 342  
 343  	b.ResetTimer()
 344  	for i := 0; i < b.N; i++ {
 345  		_ = a.Mul(c)
 346  	}
 347  }
 348  
 349  func BenchmarkScalarAdd(b *testing.B) {
 350  	var a, c, r Scalar
 351  	aBytes := make([]byte, 32)
 352  	cBytes := make([]byte, 32)
 353  	rand.Read(aBytes)
 354  	rand.Read(cBytes)
 355  	a.SetBytes(aBytes)
 356  	c.SetBytes(cBytes)
 357  
 358  	b.ResetTimer()
 359  	for i := 0; i < b.N; i++ {
 360  		r.Add(&a, &c)
 361  	}
 362  }
 363  
 364  func BenchmarkScalarMul(b *testing.B) {
 365  	var a, c, r Scalar
 366  	aBytes := make([]byte, 32)
 367  	cBytes := make([]byte, 32)
 368  	rand.Read(aBytes)
 369  	rand.Read(cBytes)
 370  	a.SetBytes(aBytes)
 371  	c.SetBytes(cBytes)
 372  
 373  	b.ResetTimer()
 374  	for i := 0; i < b.N; i++ {
 375  		r.Mul(&a, &c)
 376  	}
 377  }
 378  
 379  func BenchmarkFieldAdd(b *testing.B) {
 380  	var a, c, r FieldElement
 381  	aBytes := make([]byte, 32)
 382  	cBytes := make([]byte, 32)
 383  	rand.Read(aBytes)
 384  	rand.Read(cBytes)
 385  	a.SetBytes(aBytes)
 386  	c.SetBytes(cBytes)
 387  
 388  	b.ResetTimer()
 389  	for i := 0; i < b.N; i++ {
 390  		r.Add(&a, &c)
 391  	}
 392  }
 393  
 394  func BenchmarkFieldMul(b *testing.B) {
 395  	var a, c, r FieldElement
 396  	aBytes := make([]byte, 32)
 397  	cBytes := make([]byte, 32)
 398  	rand.Read(aBytes)
 399  	rand.Read(cBytes)
 400  	a.SetBytes(aBytes)
 401  	c.SetBytes(cBytes)
 402  
 403  	b.ResetTimer()
 404  	for i := 0; i < b.N; i++ {
 405  		r.Mul(&a, &c)
 406  	}
 407  }
 408  
 409  func BenchmarkFieldInverse(b *testing.B) {
 410  	var a, r FieldElement
 411  	aBytes := make([]byte, 32)
 412  	rand.Read(aBytes)
 413  	a.SetBytes(aBytes)
 414  
 415  	b.ResetTimer()
 416  	for i := 0; i < b.N; i++ {
 417  		r.Inverse(&a)
 418  	}
 419  }
 420  
 421  func BenchmarkPointDouble(b *testing.B) {
 422  	var g, r JacobianPoint
 423  	g.FromAffine(&Generator)
 424  
 425  	b.ResetTimer()
 426  	for i := 0; i < b.N; i++ {
 427  		r.Double(&g)
 428  	}
 429  }
 430  
 431  func BenchmarkPointAdd(b *testing.B) {
 432  	var g, twoG, r JacobianPoint
 433  	g.FromAffine(&Generator)
 434  	twoG.Double(&g)
 435  
 436  	b.ResetTimer()
 437  	for i := 0; i < b.N; i++ {
 438  		r.Add(&g, &twoG)
 439  	}
 440  }
 441  
 442  func BenchmarkScalarBaseMult(b *testing.B) {
 443  	var k Scalar
 444  	kBytes := make([]byte, 32)
 445  	rand.Read(kBytes)
 446  	k.SetBytes(kBytes)
 447  
 448  	b.ResetTimer()
 449  	for i := 0; i < b.N; i++ {
 450  		_ = BasePointMult(&k)
 451  	}
 452  }
 453