avx_test.go raw
1 package avx
2
3 import (
4 "bytes"
5 "crypto/rand"
6 "encoding/hex"
7 "testing"
8 )
9
10 // Test vectors from Bitcoin/secp256k1
11
12 func TestUint128Add(t *testing.T) {
13 tests := []struct {
14 a, b Uint128
15 expect Uint128
16 carry uint64
17 }{
18 {Uint128{0, 0}, Uint128{0, 0}, Uint128{0, 0}, 0},
19 {Uint128{1, 0}, Uint128{1, 0}, Uint128{2, 0}, 0},
20 {Uint128{^uint64(0), 0}, Uint128{1, 0}, Uint128{0, 1}, 0},
21 {Uint128{^uint64(0), ^uint64(0)}, Uint128{1, 0}, Uint128{0, 0}, 1},
22 }
23
24 for i, tt := range tests {
25 result, carry := tt.a.Add(tt.b)
26 if result != tt.expect || carry != tt.carry {
27 t.Errorf("test %d: got (%v, %d), want (%v, %d)", i, result, carry, tt.expect, tt.carry)
28 }
29 }
30 }
31
32 func TestUint128Mul(t *testing.T) {
33 // Test: 2^64 * 2^64 = 2^128
34 a := Uint128{0, 1} // 2^64
35 b := Uint128{0, 1} // 2^64
36 result := a.Mul(b)
37 // Expected: 2^128 = [0, 0, 1, 0]
38 expected := [4]uint64{0, 0, 1, 0}
39 if result != expected {
40 t.Errorf("2^64 * 2^64: got %v, want %v", result, expected)
41 }
42
43 // Test: (2^64 - 1) * (2^64 - 1)
44 a = Uint128{^uint64(0), 0}
45 b = Uint128{^uint64(0), 0}
46 result = a.Mul(b)
47 // (2^64 - 1)^2 = 2^128 - 2^65 + 1
48 // = [1, 0xFFFFFFFFFFFFFFFE, 0, 0]
49 expected = [4]uint64{1, 0xFFFFFFFFFFFFFFFE, 0, 0}
50 if result != expected {
51 t.Errorf("(2^64-1)^2: got %v, want %v", result, expected)
52 }
53 }
54
55 func TestScalarSetBytes(t *testing.T) {
56 // Test with a known scalar
57 bytes32 := make([]byte, 32)
58 bytes32[31] = 1 // scalar = 1
59
60 var s Scalar
61 s.SetBytes(bytes32)
62
63 if !s.IsOne() {
64 t.Errorf("expected scalar to be 1, got %+v", s)
65 }
66
67 // Test zero
68 bytes32 = make([]byte, 32)
69 s.SetBytes(bytes32)
70 if !s.IsZero() {
71 t.Errorf("expected scalar to be 0, got %+v", s)
72 }
73 }
74
75 func TestScalarAddSub(t *testing.T) {
76 var a, b, sum, diff, recovered Scalar
77
78 // a = 1, b = 2
79 a = ScalarOne
80 b.D[0].Lo = 2
81
82 sum.Add(&a, &b)
83 if sum.D[0].Lo != 3 {
84 t.Errorf("1 + 2: expected 3, got %d", sum.D[0].Lo)
85 }
86
87 diff.Sub(&sum, &b)
88 if !diff.Equal(&a) {
89 t.Errorf("(1+2) - 2: expected 1, got %+v", diff)
90 }
91
92 // Test with overflow
93 a = ScalarN
94 a.D[0].Lo-- // n - 1
95 b = ScalarOne
96
97 sum.Add(&a, &b)
98 // n - 1 + 1 = n ≡ 0 (mod n)
99 if !sum.IsZero() {
100 t.Errorf("(n-1) + 1 should be 0 mod n, got %+v", sum)
101 }
102
103 // Test subtraction with borrow
104 a = ScalarZero
105 b = ScalarOne
106 diff.Sub(&a, &b)
107 // 0 - 1 = -1 ≡ n - 1 (mod n)
108 recovered.Add(&diff, &b)
109 if !recovered.IsZero() {
110 t.Errorf("(0-1) + 1 should be 0, got %+v", recovered)
111 }
112 }
113
114 func TestScalarMul(t *testing.T) {
115 var a, b, product Scalar
116
117 // 2 * 3 = 6
118 a.D[0].Lo = 2
119 b.D[0].Lo = 3
120 product.Mul(&a, &b)
121 if product.D[0].Lo != 6 || product.D[0].Hi != 0 || !product.D[1].IsZero() {
122 t.Errorf("2 * 3: expected 6, got %+v", product)
123 }
124
125 // Test with larger values
126 a.D[0].Lo = 0xFFFFFFFFFFFFFFFF
127 a.D[0].Hi = 0
128 b.D[0].Lo = 2
129 product.Mul(&a, &b)
130 // (2^64 - 1) * 2 = 2^65 - 2
131 if product.D[0].Lo != 0xFFFFFFFFFFFFFFFE || product.D[0].Hi != 1 {
132 t.Errorf("(2^64-1) * 2: got %+v", product)
133 }
134 }
135
136 func TestScalarNegate(t *testing.T) {
137 var a, neg, sum Scalar
138
139 a.D[0].Lo = 12345
140 neg.Negate(&a)
141 sum.Add(&a, &neg)
142
143 if !sum.IsZero() {
144 t.Errorf("a + (-a) should be 0, got %+v", sum)
145 }
146 }
147
148 func TestFieldSetBytes(t *testing.T) {
149 bytes32 := make([]byte, 32)
150 bytes32[31] = 1
151
152 var f FieldElement
153 f.SetBytes(bytes32)
154
155 if !f.IsOne() {
156 t.Errorf("expected field element to be 1, got %+v", f)
157 }
158 }
159
160 func TestFieldAddSub(t *testing.T) {
161 var a, b, sum, diff FieldElement
162
163 a.N[0].Lo = 100
164 b.N[0].Lo = 200
165
166 sum.Add(&a, &b)
167 if sum.N[0].Lo != 300 {
168 t.Errorf("100 + 200: expected 300, got %d", sum.N[0].Lo)
169 }
170
171 diff.Sub(&sum, &b)
172 if !diff.Equal(&a) {
173 t.Errorf("(100+200) - 200: expected 100, got %+v", diff)
174 }
175 }
176
177 func TestFieldMul(t *testing.T) {
178 var a, b, product FieldElement
179
180 a.N[0].Lo = 7
181 b.N[0].Lo = 8
182 product.Mul(&a, &b)
183 if product.N[0].Lo != 56 {
184 t.Errorf("7 * 8: expected 56, got %d", product.N[0].Lo)
185 }
186 }
187
188 func TestFieldInverse(t *testing.T) {
189 var a, inv, product FieldElement
190
191 a.N[0].Lo = 7
192 inv.Inverse(&a)
193 product.Mul(&a, &inv)
194
195 if !product.IsOne() {
196 t.Errorf("7 * 7^(-1) should be 1, got %+v", product)
197 }
198 }
199
200 func TestFieldSqrt(t *testing.T) {
201 // Test sqrt(4) = 2
202 var four, root, check FieldElement
203 four.N[0].Lo = 4
204
205 if !root.Sqrt(&four) {
206 t.Fatal("sqrt(4) should exist")
207 }
208
209 check.Sqr(&root)
210 if !check.Equal(&four) {
211 t.Errorf("sqrt(4)^2 should be 4, got %+v", check)
212 }
213 }
214
215 func TestGeneratorOnCurve(t *testing.T) {
216 if !Generator.IsOnCurve() {
217 t.Error("generator point should be on the curve")
218 }
219 }
220
221 func TestPointDouble(t *testing.T) {
222 var g, doubled JacobianPoint
223 var affineResult AffinePoint
224
225 g.FromAffine(&Generator)
226 doubled.Double(&g)
227 doubled.ToAffine(&affineResult)
228
229 if affineResult.Infinity {
230 t.Error("2G should not be infinity")
231 }
232
233 if !affineResult.IsOnCurve() {
234 t.Error("2G should be on the curve")
235 }
236 }
237
238 func TestPointAdd(t *testing.T) {
239 var g, twoG, threeG JacobianPoint
240 var affineResult AffinePoint
241
242 g.FromAffine(&Generator)
243 twoG.Double(&g)
244 threeG.Add(&twoG, &g)
245 threeG.ToAffine(&affineResult)
246
247 if !affineResult.IsOnCurve() {
248 t.Error("3G should be on the curve")
249 }
250
251 // Also test via scalar multiplication
252 var three Scalar
253 three.D[0].Lo = 3
254
255 var expected JacobianPoint
256 expected.ScalarMult(&g, &three)
257 var expectedAffine AffinePoint
258 expected.ToAffine(&expectedAffine)
259
260 if !affineResult.Equal(&expectedAffine) {
261 t.Error("G + 2G should equal 3G")
262 }
263 }
264
265 func TestPointAddInfinity(t *testing.T) {
266 var g, inf, result JacobianPoint
267 var affineResult AffinePoint
268
269 g.FromAffine(&Generator)
270 inf.SetInfinity()
271
272 result.Add(&g, &inf)
273 result.ToAffine(&affineResult)
274
275 if !affineResult.Equal(&Generator) {
276 t.Error("G + O should equal G")
277 }
278
279 result.Add(&inf, &g)
280 result.ToAffine(&affineResult)
281
282 if !affineResult.Equal(&Generator) {
283 t.Error("O + G should equal G")
284 }
285 }
286
287 func TestScalarBaseMult(t *testing.T) {
288 // Test 1*G = G
289 result := BasePointMult(&ScalarOne)
290 if !result.Equal(&Generator) {
291 t.Error("1*G should equal G")
292 }
293
294 // Test 2*G
295 var two Scalar
296 two.D[0].Lo = 2
297 result = BasePointMult(&two)
298
299 var g, twoG JacobianPoint
300 var expected AffinePoint
301 g.FromAffine(&Generator)
302 twoG.Double(&g)
303 twoG.ToAffine(&expected)
304
305 if !result.Equal(&expected) {
306 t.Error("2*G via scalar mult should equal 2*G via doubling")
307 }
308 }
309
310 func TestKnownScalarMult(t *testing.T) {
311 // Test vector: private key and public key from Bitcoin
312 // This is a well-known test vector
313 privKeyHex := "0000000000000000000000000000000000000000000000000000000000000001"
314 expectedXHex := "79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798"
315 expectedYHex := "483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8"
316
317 privKeyBytes, _ := hex.DecodeString(privKeyHex)
318 var k Scalar
319 k.SetBytes(privKeyBytes)
320
321 result := BasePointMult(&k)
322
323 xBytes := result.X.Bytes()
324 yBytes := result.Y.Bytes()
325
326 expectedX, _ := hex.DecodeString(expectedXHex)
327 expectedY, _ := hex.DecodeString(expectedYHex)
328
329 if !bytes.Equal(xBytes[:], expectedX) {
330 t.Errorf("X coordinate mismatch:\ngot: %x\nwant: %x", xBytes, expectedX)
331 }
332 if !bytes.Equal(yBytes[:], expectedY) {
333 t.Errorf("Y coordinate mismatch:\ngot: %x\nwant: %x", yBytes, expectedY)
334 }
335 }
336
337 // Benchmark tests
338
339 func BenchmarkUint128Mul(b *testing.B) {
340 a := Uint128{0x123456789ABCDEF0, 0xFEDCBA9876543210}
341 c := Uint128{0xABCDEF0123456789, 0x9876543210FEDCBA}
342
343 b.ResetTimer()
344 for i := 0; i < b.N; i++ {
345 _ = a.Mul(c)
346 }
347 }
348
349 func BenchmarkScalarAdd(b *testing.B) {
350 var a, c, r Scalar
351 aBytes := make([]byte, 32)
352 cBytes := make([]byte, 32)
353 rand.Read(aBytes)
354 rand.Read(cBytes)
355 a.SetBytes(aBytes)
356 c.SetBytes(cBytes)
357
358 b.ResetTimer()
359 for i := 0; i < b.N; i++ {
360 r.Add(&a, &c)
361 }
362 }
363
364 func BenchmarkScalarMul(b *testing.B) {
365 var a, c, r Scalar
366 aBytes := make([]byte, 32)
367 cBytes := make([]byte, 32)
368 rand.Read(aBytes)
369 rand.Read(cBytes)
370 a.SetBytes(aBytes)
371 c.SetBytes(cBytes)
372
373 b.ResetTimer()
374 for i := 0; i < b.N; i++ {
375 r.Mul(&a, &c)
376 }
377 }
378
379 func BenchmarkFieldAdd(b *testing.B) {
380 var a, c, r FieldElement
381 aBytes := make([]byte, 32)
382 cBytes := make([]byte, 32)
383 rand.Read(aBytes)
384 rand.Read(cBytes)
385 a.SetBytes(aBytes)
386 c.SetBytes(cBytes)
387
388 b.ResetTimer()
389 for i := 0; i < b.N; i++ {
390 r.Add(&a, &c)
391 }
392 }
393
394 func BenchmarkFieldMul(b *testing.B) {
395 var a, c, r FieldElement
396 aBytes := make([]byte, 32)
397 cBytes := make([]byte, 32)
398 rand.Read(aBytes)
399 rand.Read(cBytes)
400 a.SetBytes(aBytes)
401 c.SetBytes(cBytes)
402
403 b.ResetTimer()
404 for i := 0; i < b.N; i++ {
405 r.Mul(&a, &c)
406 }
407 }
408
409 func BenchmarkFieldInverse(b *testing.B) {
410 var a, r FieldElement
411 aBytes := make([]byte, 32)
412 rand.Read(aBytes)
413 a.SetBytes(aBytes)
414
415 b.ResetTimer()
416 for i := 0; i < b.N; i++ {
417 r.Inverse(&a)
418 }
419 }
420
421 func BenchmarkPointDouble(b *testing.B) {
422 var g, r JacobianPoint
423 g.FromAffine(&Generator)
424
425 b.ResetTimer()
426 for i := 0; i < b.N; i++ {
427 r.Double(&g)
428 }
429 }
430
431 func BenchmarkPointAdd(b *testing.B) {
432 var g, twoG, r JacobianPoint
433 g.FromAffine(&Generator)
434 twoG.Double(&g)
435
436 b.ResetTimer()
437 for i := 0; i < b.N; i++ {
438 r.Add(&g, &twoG)
439 }
440 }
441
442 func BenchmarkScalarBaseMult(b *testing.B) {
443 var k Scalar
444 kBytes := make([]byte, 32)
445 rand.Read(kBytes)
446 k.SetBytes(kBytes)
447
448 b.ResetTimer()
449 for i := 0; i < b.N; i++ {
450 _ = BasePointMult(&k)
451 }
452 }
453