scalar_amd64.s raw
1 //go:build amd64
2
3 #include "textflag.h"
4
5 // Constants for scalar reduction
6 // n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
7 DATA scalarN<>+0x00(SB)/8, $0xBFD25E8CD0364141
8 DATA scalarN<>+0x08(SB)/8, $0xBAAEDCE6AF48A03B
9 DATA scalarN<>+0x10(SB)/8, $0xFFFFFFFFFFFFFFFE
10 DATA scalarN<>+0x18(SB)/8, $0xFFFFFFFFFFFFFFFF
11 GLOBL scalarN<>(SB), RODATA|NOPTR, $32
12
13 // 2^256 - n (for reduction)
14 DATA scalarNC<>+0x00(SB)/8, $0x402DA1732FC9BEBF
15 DATA scalarNC<>+0x08(SB)/8, $0x4551231950B75FC4
16 DATA scalarNC<>+0x10(SB)/8, $0x0000000000000001
17 DATA scalarNC<>+0x18(SB)/8, $0x0000000000000000
18 GLOBL scalarNC<>(SB), RODATA|NOPTR, $32
19
20 // func hasAVX2() bool
21 TEXT ·hasAVX2(SB), NOSPLIT, $0-1
22 MOVL $7, AX
23 MOVL $0, CX
24 CPUID
25 ANDL $0x20, BX // Check bit 5 of EBX for AVX2
26 SETNE AL
27 MOVB AL, ret+0(FP)
28 RET
29
30 // func ScalarAddAVX2(r, a, b *Scalar)
31 // Adds two 256-bit scalars using AVX2 for loading/storing and scalar ADD with carry.
32 //
33 // YMM layout: [D[0].Lo, D[0].Hi, D[1].Lo, D[1].Hi] = 4 x 64-bit
34 TEXT ·ScalarAddAVX2(SB), NOSPLIT, $0-24
35 MOVQ r+0(FP), DI
36 MOVQ a+8(FP), SI
37 MOVQ b+16(FP), DX
38
39 // Load a and b into registers (scalar loads for carry chain)
40 MOVQ 0(SI), AX // a.D[0].Lo
41 MOVQ 8(SI), BX // a.D[0].Hi
42 MOVQ 16(SI), CX // a.D[1].Lo
43 MOVQ 24(SI), R8 // a.D[1].Hi
44
45 // Add b with carry chain
46 ADDQ 0(DX), AX // a.D[0].Lo + b.D[0].Lo
47 ADCQ 8(DX), BX // a.D[0].Hi + b.D[0].Hi + carry
48 ADCQ 16(DX), CX // a.D[1].Lo + b.D[1].Lo + carry
49 ADCQ 24(DX), R8 // a.D[1].Hi + b.D[1].Hi + carry
50
51 // Save carry flag
52 SETCS R9B
53
54 // Store preliminary result
55 MOVQ AX, 0(DI)
56 MOVQ BX, 8(DI)
57 MOVQ CX, 16(DI)
58 MOVQ R8, 24(DI)
59
60 // Check if we need to reduce (carry set or result >= n)
61 TESTB R9B, R9B
62 JNZ reduce
63
64 // Compare with n (from high to low)
65 MOVQ $0xFFFFFFFFFFFFFFFF, R10
66 CMPQ R8, R10
67 JB done
68 JA reduce
69 MOVQ scalarN<>+0x10(SB), R10
70 CMPQ CX, R10
71 JB done
72 JA reduce
73 MOVQ scalarN<>+0x08(SB), R10
74 CMPQ BX, R10
75 JB done
76 JA reduce
77 MOVQ scalarN<>+0x00(SB), R10
78 CMPQ AX, R10
79 JB done
80
81 reduce:
82 // Add 2^256 - n (which is equivalent to subtracting n)
83 MOVQ 0(DI), AX
84 MOVQ 8(DI), BX
85 MOVQ 16(DI), CX
86 MOVQ 24(DI), R8
87
88 MOVQ scalarNC<>+0x00(SB), R10
89 ADDQ R10, AX
90 MOVQ scalarNC<>+0x08(SB), R10
91 ADCQ R10, BX
92 MOVQ scalarNC<>+0x10(SB), R10
93 ADCQ R10, CX
94 MOVQ scalarNC<>+0x18(SB), R10
95 ADCQ R10, R8
96
97 MOVQ AX, 0(DI)
98 MOVQ BX, 8(DI)
99 MOVQ CX, 16(DI)
100 MOVQ R8, 24(DI)
101
102 done:
103 VZEROUPPER
104 RET
105
106 // func ScalarSubAVX2(r, a, b *Scalar)
107 // Subtracts two 256-bit scalars.
108 TEXT ·ScalarSubAVX2(SB), NOSPLIT, $0-24
109 MOVQ r+0(FP), DI
110 MOVQ a+8(FP), SI
111 MOVQ b+16(FP), DX
112
113 // Load a
114 MOVQ 0(SI), AX
115 MOVQ 8(SI), BX
116 MOVQ 16(SI), CX
117 MOVQ 24(SI), R8
118
119 // Subtract b with borrow chain
120 SUBQ 0(DX), AX
121 SBBQ 8(DX), BX
122 SBBQ 16(DX), CX
123 SBBQ 24(DX), R8
124
125 // Save borrow flag
126 SETCS R9B
127
128 // Store preliminary result
129 MOVQ AX, 0(DI)
130 MOVQ BX, 8(DI)
131 MOVQ CX, 16(DI)
132 MOVQ R8, 24(DI)
133
134 // If borrow, add n back
135 TESTB R9B, R9B
136 JZ done_sub
137
138 // Add n
139 MOVQ scalarN<>+0x00(SB), R10
140 ADDQ R10, AX
141 MOVQ scalarN<>+0x08(SB), R10
142 ADCQ R10, BX
143 MOVQ scalarN<>+0x10(SB), R10
144 ADCQ R10, CX
145 MOVQ scalarN<>+0x18(SB), R10
146 ADCQ R10, R8
147
148 MOVQ AX, 0(DI)
149 MOVQ BX, 8(DI)
150 MOVQ CX, 16(DI)
151 MOVQ R8, 24(DI)
152
153 done_sub:
154 VZEROUPPER
155 RET
156
157 // func ScalarMulAVX2(r, a, b *Scalar)
158 // Multiplies two 256-bit scalars and reduces mod n.
159 // This is a complex operation requiring 512-bit intermediate.
160 TEXT ·ScalarMulAVX2(SB), NOSPLIT, $64-24
161 MOVQ r+0(FP), DI
162 MOVQ a+8(FP), SI
163 MOVQ b+16(FP), DX
164
165 // We need to compute a 512-bit product and reduce mod n.
166 // For now, use scalar multiplication with MULX (if BMI2 available) or MUL.
167
168 // Load a limbs
169 MOVQ 0(SI), R8 // a0
170 MOVQ 8(SI), R9 // a1
171 MOVQ 16(SI), R10 // a2
172 MOVQ 24(SI), R11 // a3
173
174 // Store b pointer for later use
175 MOVQ DX, R12
176
177 // Compute 512-bit product using schoolbook multiplication
178 // Product stored on stack at SP+0 to SP+56 (8 limbs)
179
180 // Initialize product to zero
181 XORQ AX, AX
182 MOVQ AX, 0(SP)
183 MOVQ AX, 8(SP)
184 MOVQ AX, 16(SP)
185 MOVQ AX, 24(SP)
186 MOVQ AX, 32(SP)
187 MOVQ AX, 40(SP)
188 MOVQ AX, 48(SP)
189 MOVQ AX, 56(SP)
190
191 // Multiply a0 * b[0..3]
192 MOVQ R8, AX
193 MULQ 0(R12) // a0 * b0
194 MOVQ AX, 0(SP)
195 MOVQ DX, R13 // carry
196
197 MOVQ R8, AX
198 MULQ 8(R12) // a0 * b1
199 ADDQ R13, AX
200 ADCQ $0, DX
201 MOVQ AX, 8(SP)
202 MOVQ DX, R13
203
204 MOVQ R8, AX
205 MULQ 16(R12) // a0 * b2
206 ADDQ R13, AX
207 ADCQ $0, DX
208 MOVQ AX, 16(SP)
209 MOVQ DX, R13
210
211 MOVQ R8, AX
212 MULQ 24(R12) // a0 * b3
213 ADDQ R13, AX
214 ADCQ $0, DX
215 MOVQ AX, 24(SP)
216 MOVQ DX, 32(SP)
217
218 // Multiply a1 * b[0..3] and add
219 MOVQ R9, AX
220 MULQ 0(R12) // a1 * b0
221 ADDQ AX, 8(SP)
222 ADCQ DX, 16(SP)
223 ADCQ $0, 24(SP)
224 ADCQ $0, 32(SP)
225
226 MOVQ R9, AX
227 MULQ 8(R12) // a1 * b1
228 ADDQ AX, 16(SP)
229 ADCQ DX, 24(SP)
230 ADCQ $0, 32(SP)
231
232 MOVQ R9, AX
233 MULQ 16(R12) // a1 * b2
234 ADDQ AX, 24(SP)
235 ADCQ DX, 32(SP)
236 ADCQ $0, 40(SP)
237
238 MOVQ R9, AX
239 MULQ 24(R12) // a1 * b3
240 ADDQ AX, 32(SP)
241 ADCQ DX, 40(SP)
242
243 // Multiply a2 * b[0..3] and add
244 MOVQ R10, AX
245 MULQ 0(R12) // a2 * b0
246 ADDQ AX, 16(SP)
247 ADCQ DX, 24(SP)
248 ADCQ $0, 32(SP)
249 ADCQ $0, 40(SP)
250
251 MOVQ R10, AX
252 MULQ 8(R12) // a2 * b1
253 ADDQ AX, 24(SP)
254 ADCQ DX, 32(SP)
255 ADCQ $0, 40(SP)
256
257 MOVQ R10, AX
258 MULQ 16(R12) // a2 * b2
259 ADDQ AX, 32(SP)
260 ADCQ DX, 40(SP)
261 ADCQ $0, 48(SP)
262
263 MOVQ R10, AX
264 MULQ 24(R12) // a2 * b3
265 ADDQ AX, 40(SP)
266 ADCQ DX, 48(SP)
267
268 // Multiply a3 * b[0..3] and add
269 MOVQ R11, AX
270 MULQ 0(R12) // a3 * b0
271 ADDQ AX, 24(SP)
272 ADCQ DX, 32(SP)
273 ADCQ $0, 40(SP)
274 ADCQ $0, 48(SP)
275
276 MOVQ R11, AX
277 MULQ 8(R12) // a3 * b1
278 ADDQ AX, 32(SP)
279 ADCQ DX, 40(SP)
280 ADCQ $0, 48(SP)
281
282 MOVQ R11, AX
283 MULQ 16(R12) // a3 * b2
284 ADDQ AX, 40(SP)
285 ADCQ DX, 48(SP)
286 ADCQ $0, 56(SP)
287
288 MOVQ R11, AX
289 MULQ 24(R12) // a3 * b3
290 ADDQ AX, 48(SP)
291 ADCQ DX, 56(SP)
292
293 // Now we have the 512-bit product in SP+0 to SP+56 (l[0..7])
294 // Need to reduce mod n using the bitcoin-core algorithm:
295 //
296 // Phase 1: 512->385 bits
297 // c0..c4 = l[0..3] + l[4..7] * NC (where NC = 2^256 - n)
298 // Phase 2: 385->258 bits
299 // d0..d4 = c[0..3] + c[4] * NC
300 // Phase 3: 258->256 bits
301 // r[0..3] = d[0..3] + d[4] * NC, then final reduce if >= n
302 //
303 // NC = [0x402DA1732FC9BEBF, 0x4551231950B75FC4, 1, 0]
304
305 // ========== Phase 1: 512->385 bits ==========
306 // Compute c[0..4] = l[0..3] + l[4..7] * NC
307 // NC has only 3 significant limbs: NC[0], NC[1], NC[2]=1
308
309 // Start with c = l[0..3], then add contributions from l[4..7] * NC
310 MOVQ 0(SP), R8 // c0 = l0
311 MOVQ 8(SP), R9 // c1 = l1
312 MOVQ 16(SP), R10 // c2 = l2
313 MOVQ 24(SP), R11 // c3 = l3
314 XORQ R14, R14 // c4 = 0
315 XORQ R15, R15 // c5 for overflow
316
317 // l4 * NC[0]
318 MOVQ 32(SP), AX
319 MOVQ scalarNC<>+0x00(SB), R12
320 MULQ R12 // DX:AX = l4 * NC[0]
321 ADDQ AX, R8
322 ADCQ DX, R9
323 ADCQ $0, R10
324 ADCQ $0, R11
325 ADCQ $0, R14
326
327 // l4 * NC[1]
328 MOVQ 32(SP), AX
329 MOVQ scalarNC<>+0x08(SB), R12
330 MULQ R12 // DX:AX = l4 * NC[1]
331 ADDQ AX, R9
332 ADCQ DX, R10
333 ADCQ $0, R11
334 ADCQ $0, R14
335
336 // l4 * NC[2] (NC[2] = 1)
337 MOVQ 32(SP), AX
338 ADDQ AX, R10
339 ADCQ $0, R11
340 ADCQ $0, R14
341
342 // l5 * NC[0]
343 MOVQ 40(SP), AX
344 MOVQ scalarNC<>+0x00(SB), R12
345 MULQ R12
346 ADDQ AX, R9
347 ADCQ DX, R10
348 ADCQ $0, R11
349 ADCQ $0, R14
350
351 // l5 * NC[1]
352 MOVQ 40(SP), AX
353 MOVQ scalarNC<>+0x08(SB), R12
354 MULQ R12
355 ADDQ AX, R10
356 ADCQ DX, R11
357 ADCQ $0, R14
358
359 // l5 * NC[2] (NC[2] = 1)
360 MOVQ 40(SP), AX
361 ADDQ AX, R11
362 ADCQ $0, R14
363
364 // l6 * NC[0]
365 MOVQ 48(SP), AX
366 MOVQ scalarNC<>+0x00(SB), R12
367 MULQ R12
368 ADDQ AX, R10
369 ADCQ DX, R11
370 ADCQ $0, R14
371
372 // l6 * NC[1]
373 MOVQ 48(SP), AX
374 MOVQ scalarNC<>+0x08(SB), R12
375 MULQ R12
376 ADDQ AX, R11
377 ADCQ DX, R14
378
379 // l6 * NC[2] (NC[2] = 1)
380 MOVQ 48(SP), AX
381 ADDQ AX, R14
382 ADCQ $0, R15
383
384 // l7 * NC[0]
385 MOVQ 56(SP), AX
386 MOVQ scalarNC<>+0x00(SB), R12
387 MULQ R12
388 ADDQ AX, R11
389 ADCQ DX, R14
390 ADCQ $0, R15
391
392 // l7 * NC[1]
393 MOVQ 56(SP), AX
394 MOVQ scalarNC<>+0x08(SB), R12
395 MULQ R12
396 ADDQ AX, R14
397 ADCQ DX, R15
398
399 // l7 * NC[2] (NC[2] = 1)
400 MOVQ 56(SP), AX
401 ADDQ AX, R15
402
403 // Now c[0..5] = R8, R9, R10, R11, R14, R15 (~385 bits max)
404
405 // ========== Phase 2: 385->258 bits ==========
406 // Reduce c[4..5] by multiplying by NC and adding to c[0..3]
407
408 // c4 * NC[0]
409 MOVQ R14, AX
410 MOVQ scalarNC<>+0x00(SB), R12
411 MULQ R12
412 ADDQ AX, R8
413 ADCQ DX, R9
414 ADCQ $0, R10
415 ADCQ $0, R11
416
417 // c4 * NC[1]
418 MOVQ R14, AX
419 MOVQ scalarNC<>+0x08(SB), R12
420 MULQ R12
421 ADDQ AX, R9
422 ADCQ DX, R10
423 ADCQ $0, R11
424
425 // c4 * NC[2] (NC[2] = 1)
426 ADDQ R14, R10
427 ADCQ $0, R11
428
429 // c5 * NC[0]
430 MOVQ R15, AX
431 MOVQ scalarNC<>+0x00(SB), R12
432 MULQ R12
433 ADDQ AX, R9
434 ADCQ DX, R10
435 ADCQ $0, R11
436
437 // c5 * NC[1]
438 MOVQ R15, AX
439 MOVQ scalarNC<>+0x08(SB), R12
440 MULQ R12
441 ADDQ AX, R10
442 ADCQ DX, R11
443
444 // c5 * NC[2] (NC[2] = 1)
445 ADDQ R15, R11
446 // Capture any final carry into R14
447 MOVQ $0, R14
448 ADCQ $0, R14
449
450 // Now we have ~258 bits in R8, R9, R10, R11, R14
451
452 // ========== Phase 3: 258->256 bits ==========
453 // If R14 (the overflow) is non-zero, reduce again
454 TESTQ R14, R14
455 JZ check_overflow
456
457 // R14 * NC
458 MOVQ R14, AX
459 MOVQ scalarNC<>+0x00(SB), R12
460 MULQ R12
461 ADDQ AX, R8
462 ADCQ DX, R9
463 ADCQ $0, R10
464 ADCQ $0, R11
465
466 MOVQ R14, AX
467 MOVQ scalarNC<>+0x08(SB), R12
468 MULQ R12
469 ADDQ AX, R9
470 ADCQ DX, R10
471 ADCQ $0, R11
472
473 // R14 * NC[2] (NC[2] = 1)
474 ADDQ R14, R10
475 ADCQ $0, R11
476
477 check_overflow:
478 // Check if result >= n and reduce if needed
479 MOVQ $0xFFFFFFFFFFFFFFFF, R13
480 CMPQ R11, R13
481 JB store_result
482 JA do_reduce
483 MOVQ scalarN<>+0x10(SB), R13
484 CMPQ R10, R13
485 JB store_result
486 JA do_reduce
487 MOVQ scalarN<>+0x08(SB), R13
488 CMPQ R9, R13
489 JB store_result
490 JA do_reduce
491 MOVQ scalarN<>+0x00(SB), R13
492 CMPQ R8, R13
493 JB store_result
494
495 do_reduce:
496 // Subtract n (add 2^256 - n)
497 MOVQ scalarNC<>+0x00(SB), R13
498 ADDQ R13, R8
499 MOVQ scalarNC<>+0x08(SB), R13
500 ADCQ R13, R9
501 MOVQ scalarNC<>+0x10(SB), R13
502 ADCQ R13, R10
503 MOVQ scalarNC<>+0x18(SB), R13
504 ADCQ R13, R11
505
506 store_result:
507 // Store result
508 MOVQ r+0(FP), DI
509 MOVQ R8, 0(DI)
510 MOVQ R9, 8(DI)
511 MOVQ R10, 16(DI)
512 MOVQ R11, 24(DI)
513
514 VZEROUPPER
515 RET
516