avx_test.go raw
1 //go:build !js && !wasm && !tinygo && !wasm32
2
3 package p256k1
4
5 import (
6 "testing"
7 )
8
9 func TestScalarMulBMI2VsPureGo(t *testing.T) {
10 // BMI2 scalar multiplication has carry propagation bugs that need fixing
11 // Skip for now - the AVX2 path is used instead
12 t.Skip("BMI2 scalar mul has known carry bugs - using AVX2 path")
13 }
14
15 func TestAVX2Integration(t *testing.T) {
16 t.Logf("AVX2 CPU support: %v", HasAVX2CPU())
17 t.Logf("AVX2 enabled: %v", HasAVX2())
18
19 // Test scalar multiplication with AVX2
20 var a, b, productAVX, productGo Scalar
21 a.setInt(12345)
22 b.setInt(67890)
23
24 // Compute with AVX2 enabled
25 SetAVX2Enabled(true)
26 productAVX.mul(&a, &b)
27
28 // Compute with AVX2 disabled
29 SetAVX2Enabled(false)
30 productGo.mulPureGo(&a, &b)
31
32 // Re-enable AVX2
33 SetAVX2Enabled(true)
34
35 if !productAVX.equal(&productGo) {
36 t.Errorf("AVX2 and Go scalar multiplication differ:\n AVX2: %v\n Go: %v",
37 productAVX.d, productGo.d)
38 } else {
39 t.Logf("Scalar multiplication matches: %v", productAVX.d)
40 }
41
42 // Test scalar addition
43 var sumAVX, sumGo Scalar
44 SetAVX2Enabled(true)
45 sumAVX.add(&a, &b)
46
47 SetAVX2Enabled(false)
48 sumGo.addPureGo(&a, &b)
49
50 SetAVX2Enabled(true)
51
52 if !sumAVX.equal(&sumGo) {
53 t.Errorf("AVX2 and Go scalar addition differ:\n AVX2: %v\n Go: %v",
54 sumAVX.d, sumGo.d)
55 } else {
56 t.Logf("Scalar addition matches: %v", sumAVX.d)
57 }
58
59 // Test inverse (which uses mul internally)
60 var inv, product Scalar
61 a.setInt(2)
62
63 SetAVX2Enabled(true)
64 inv.inverse(&a)
65 product.mul(&a, &inv)
66
67 t.Logf("a = %v", a.d)
68 t.Logf("inv(a) = %v", inv.d)
69 t.Logf("a * inv(a) = %v", product.d)
70 t.Logf("isOne = %v", product.isOne())
71
72 if !product.isOne() {
73 // Try with pure Go
74 SetAVX2Enabled(false)
75 var inv2, product2 Scalar
76 inv2.inverse(&a)
77 product2.mul(&a, &inv2)
78 t.Logf("Pure Go: a * inv(a) = %v", product2.d)
79 t.Logf("Pure Go isOne = %v", product2.isOne())
80 SetAVX2Enabled(true)
81
82 t.Errorf("2 * inv(2) should equal 1")
83 }
84 }
85
86 func TestScalarMulAVX2VsPureGo(t *testing.T) {
87 if !HasAVX2CPU() {
88 t.Skip("AVX2 not available")
89 }
90
91 // Test several multiplication cases
92 testCases := []struct {
93 a, b uint
94 }{
95 {2, 3},
96 {12345, 67890},
97 {0xFFFFFFFF, 0xFFFFFFFF},
98 {1, 1},
99 {0, 123},
100 }
101
102 for _, tc := range testCases {
103 var a, b, productAVX, productGo Scalar
104 a.setInt(tc.a)
105 b.setInt(tc.b)
106
107 SetAVX2Enabled(true)
108 scalarMulAVX2(&productAVX, &a, &b)
109
110 productGo.mulPureGo(&a, &b)
111
112 if !productAVX.equal(&productGo) {
113 t.Errorf("Mismatch for %d * %d:\n AVX2: %v\n Go: %v",
114 tc.a, tc.b, productAVX.d, productGo.d)
115 }
116 }
117 }
118
119 func TestScalarMulAVX2Large(t *testing.T) {
120 if !HasAVX2CPU() {
121 t.Skip("AVX2 not available")
122 }
123
124 // Test with the actual inverse of 2
125 var a Scalar
126 a.setInt(2)
127
128 var inv Scalar
129 SetAVX2Enabled(false)
130 inv.inverse(&a)
131 SetAVX2Enabled(true)
132
133 t.Logf("a = %v", a.d)
134 t.Logf("inv(2) = %v", inv.d)
135
136 // Test multiplication of 2 * inv(2)
137 var productAVX, productGo Scalar
138 scalarMulAVX2(&productAVX, &a, &inv)
139
140 SetAVX2Enabled(false)
141 productGo.mulPureGo(&a, &inv)
142 SetAVX2Enabled(true)
143
144 t.Logf("AVX2: 2 * inv(2) = %v", productAVX.d)
145 t.Logf("Go: 2 * inv(2) = %v", productGo.d)
146
147 if !productAVX.equal(&productGo) {
148 t.Errorf("Large number multiplication differs")
149 }
150 }
151
152 func TestInverseAVX2VsGo(t *testing.T) {
153 if !HasAVX2CPU() {
154 t.Skip("AVX2 not available")
155 }
156
157 var a Scalar
158 a.setInt(2)
159
160 // Compute inverse with AVX2
161 var invAVX Scalar
162 SetAVX2Enabled(true)
163 invAVX.inverse(&a)
164
165 // Compute inverse with pure Go
166 var invGo Scalar
167 SetAVX2Enabled(false)
168 invGo.inverse(&a)
169 SetAVX2Enabled(true)
170
171 t.Logf("AVX2 inv(2) = %v", invAVX.d)
172 t.Logf("Go inv(2) = %v", invGo.d)
173
174 if !invAVX.equal(&invGo) {
175 t.Errorf("Inverse differs between AVX2 and Go")
176 }
177 }
178
179 func TestScalarMulAliased(t *testing.T) {
180 if !HasAVX2CPU() {
181 t.Skip("AVX2 not available")
182 }
183
184 // Test aliased multiplication: r.mul(r, &b) and r.mul(&a, r)
185 var a, b Scalar
186 a.setInt(12345)
187 b.setInt(67890)
188
189 // Test r = r * b
190 var rAVX, rGo Scalar
191 rAVX = a
192 rGo = a
193
194 SetAVX2Enabled(true)
195 scalarMulAVX2(&rAVX, &rAVX, &b)
196
197 SetAVX2Enabled(false)
198 rGo.mulPureGo(&rGo, &b)
199 SetAVX2Enabled(true)
200
201 if !rAVX.equal(&rGo) {
202 t.Errorf("r = r * b failed:\n AVX2: %v\n Go: %v", rAVX.d, rGo.d)
203 }
204
205 // Test r = a * r
206 rAVX = b
207 rGo = b
208
209 SetAVX2Enabled(true)
210 scalarMulAVX2(&rAVX, &a, &rAVX)
211
212 SetAVX2Enabled(false)
213 rGo.mulPureGo(&a, &rGo)
214 SetAVX2Enabled(true)
215
216 if !rAVX.equal(&rGo) {
217 t.Errorf("r = a * r failed:\n AVX2: %v\n Go: %v", rAVX.d, rGo.d)
218 }
219
220 // Test squaring: r = r * r
221 rAVX = a
222 rGo = a
223
224 SetAVX2Enabled(true)
225 scalarMulAVX2(&rAVX, &rAVX, &rAVX)
226
227 SetAVX2Enabled(false)
228 rGo.mulPureGo(&rGo, &rGo)
229 SetAVX2Enabled(true)
230
231 if !rAVX.equal(&rGo) {
232 t.Errorf("r = r * r failed:\n AVX2: %v\n Go: %v", rAVX.d, rGo.d)
233 }
234 }
235
236 func TestScalarMulLargeNumbers(t *testing.T) {
237 if !HasAVX2CPU() {
238 t.Skip("AVX2 not available")
239 }
240
241 // Test with large numbers (all limbs non-zero)
242 testCases := []struct {
243 name string
244 a, b Scalar
245 }{
246 {
247 name: "large a * small b",
248 a: Scalar{d: [4]uint64{0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0, 0}},
249 b: Scalar{d: [4]uint64{2, 0, 0, 0}},
250 },
251 {
252 name: "a^2 where a is large",
253 a: Scalar{d: [4]uint64{0x123456789ABCDEF0, 0xFEDCBA9876543210, 0, 0}},
254 b: Scalar{d: [4]uint64{0x123456789ABCDEF0, 0xFEDCBA9876543210, 0, 0}},
255 },
256 {
257 name: "full limbs",
258 a: Scalar{d: [4]uint64{0x123456789ABCDEF0, 0xFEDCBA9876543210, 0x1111111111111111, 0x2222222222222222}},
259 b: Scalar{d: [4]uint64{0x0FEDCBA987654321, 0x123456789ABCDEF0, 0x3333333333333333, 0x4444444444444444}},
260 },
261 }
262
263 for _, tc := range testCases {
264 t.Run(tc.name, func(t *testing.T) {
265 var productAVX, productGo Scalar
266
267 SetAVX2Enabled(true)
268 scalarMulAVX2(&productAVX, &tc.a, &tc.b)
269
270 SetAVX2Enabled(false)
271 productGo.mulPureGo(&tc.a, &tc.b)
272 SetAVX2Enabled(true)
273
274 if !productAVX.equal(&productGo) {
275 t.Errorf("Mismatch:\n a: %v\n b: %v\n AVX2: %v\n Go: %v",
276 tc.a.d, tc.b.d, productAVX.d, productGo.d)
277 }
278 })
279 }
280 }
281