scalar_amd64_bmi2.s raw

   1  //go:build amd64 && !purego
   2  
   3  #include "textflag.h"
   4  
   5  // Scalar multiplication for secp256k1 using BMI2 MULX instruction.
   6  // This is faster than traditional MUL because MULX doesn't clobber flags,
   7  // allowing better instruction scheduling with carry chains.
   8  //
   9  // Stack layout (64 bytes):
  10  //   SP+0:  l4 (saved)
  11  //   SP+8:  l5 (saved)
  12  //   SP+16: l6 (saved)
  13  //   SP+24: l7 (saved)
  14  //   SP+32: m0
  15  //   SP+40: m1
  16  //   SP+48: m2
  17  //   SP+56: m3
  18  //
  19  // func scalarMulBMI2(r, a, b *Scalar)
  20  TEXT ·scalarMulBMI2(SB), NOSPLIT, $64-24
  21      MOVQ r+0(FP), DI      // result pointer
  22      MOVQ a+8(FP), SI      // a pointer
  23      MOVQ b+16(FP), CX     // b pointer
  24  
  25      // Load a[0..3]
  26      MOVQ 0(SI), R8        // a0
  27      MOVQ 8(SI), R9        // a1
  28      MOVQ 16(SI), R10      // a2
  29      MOVQ 24(SI), R11      // a3
  30  
  31      // We'll compute the 512-bit product column by column using MULX
  32      // MULX puts DX as the implicit multiplier, result goes to specified registers
  33  
  34      // Column 0: a0*b0
  35      MOVQ 0(CX), DX        // b0 into DX for MULX
  36      MULXQ R8, R12, R13    // a0*b0 -> R13:R12 (hi:lo)
  37  
  38      // Column 1: a0*b1 + a1*b0
  39      MOVQ 8(CX), DX        // b1
  40      MULXQ R8, AX, BX      // a0*b1 -> BX:AX
  41      ADDQ AX, R13
  42      ADCQ $0, BX
  43  
  44      MOVQ 0(CX), DX        // b0
  45      MULXQ R9, AX, R14     // a1*b0 -> R14:AX
  46      ADDQ AX, R13
  47      ADCQ BX, R14
  48      MOVQ $0, R15
  49      ADCQ $0, R15
  50  
  51      // Column 2: a0*b2 + a1*b1 + a2*b0
  52      MOVQ 16(CX), DX       // b2
  53      MULXQ R8, AX, BX      // a0*b2 -> BX:AX
  54      ADDQ AX, R14
  55      ADCQ BX, R15
  56  
  57      MOVQ 8(CX), DX        // b1
  58      MULXQ R9, AX, BX      // a1*b1 -> BX:AX
  59      ADDQ AX, R14
  60      ADCQ BX, R15
  61      MOVQ $0, BP
  62      ADCQ $0, BP
  63  
  64      MOVQ 0(CX), DX        // b0
  65      MULXQ R10, AX, BX     // a2*b0 -> BX:AX
  66      ADDQ AX, R14
  67      ADCQ BX, R15
  68      ADCQ $0, BP
  69  
  70      // Column 3: a0*b3 + a1*b2 + a2*b1 + a3*b0
  71      // Save R12-R14 (columns 0-2), use them for column 3+
  72      MOVQ R12, 0(DI)       // Save l0
  73      MOVQ R13, 8(DI)       // Save l1
  74      MOVQ R14, 16(DI)      // Save l2
  75  
  76      // Now R12, R13, R14 are free
  77      MOVQ R15, R12         // l3 accumulator low
  78      MOVQ BP, R13          // l3 accumulator high
  79      XORQ R14, R14         // l4 accumulator
  80  
  81      MOVQ 24(CX), DX       // b3
  82      MULXQ R8, AX, BX      // a0*b3 -> BX:AX
  83      ADDQ AX, R12
  84      ADCQ BX, R13
  85      ADCQ $0, R14
  86  
  87      MOVQ 16(CX), DX       // b2
  88      MULXQ R9, AX, BX      // a1*b2 -> BX:AX
  89      ADDQ AX, R12
  90      ADCQ BX, R13
  91      ADCQ $0, R14
  92  
  93      MOVQ 8(CX), DX        // b1
  94      MULXQ R10, AX, BX     // a2*b1 -> BX:AX
  95      ADDQ AX, R12
  96      ADCQ BX, R13
  97      ADCQ $0, R14
  98  
  99      MOVQ 0(CX), DX        // b0
 100      MULXQ R11, AX, BX     // a3*b0 -> BX:AX
 101      ADDQ AX, R12
 102      ADCQ BX, R13
 103      ADCQ $0, R14
 104  
 105      MOVQ R12, 24(DI)      // Save l3
 106  
 107      // Column 4: a1*b3 + a2*b2 + a3*b1
 108      MOVQ R13, R12         // l4 accumulator low
 109      MOVQ R14, R13         // l4 accumulator high
 110      XORQ R14, R14
 111  
 112      MOVQ 24(CX), DX       // b3
 113      MULXQ R9, AX, BX      // a1*b3 -> BX:AX
 114      ADDQ AX, R12
 115      ADCQ BX, R13
 116      ADCQ $0, R14
 117  
 118      MOVQ 16(CX), DX       // b2
 119      MULXQ R10, AX, BX     // a2*b2 -> BX:AX
 120      ADDQ AX, R12
 121      ADCQ BX, R13
 122      ADCQ $0, R14
 123  
 124      MOVQ 8(CX), DX        // b1
 125      MULXQ R11, AX, BX     // a3*b1 -> BX:AX
 126      ADDQ AX, R12
 127      ADCQ BX, R13
 128      ADCQ $0, R14
 129  
 130      // l4 is in R12, carry in R13:R14
 131  
 132      // Column 5: a2*b3 + a3*b2
 133      MOVQ R13, R15         // l5 accumulator low
 134      MOVQ R14, BP          // l5 accumulator high
 135      XORQ R8, R8           // reuse R8 for l6
 136  
 137      MOVQ 24(CX), DX       // b3
 138      MULXQ R10, AX, BX     // a2*b3 -> BX:AX
 139      ADDQ AX, R15
 140      ADCQ BX, BP
 141      ADCQ $0, R8
 142  
 143      MOVQ 16(CX), DX       // b2
 144      MULXQ R11, AX, BX     // a3*b2 -> BX:AX
 145      ADDQ AX, R15
 146      ADCQ BX, BP
 147      ADCQ $0, R8
 148  
 149      // Column 6: a3*b3
 150      MOVQ BP, R9           // l6 accumulator low
 151      MOVQ R8, R10          // l6 accumulator high (will be l7)
 152  
 153      MOVQ 24(CX), DX       // b3
 154      MULXQ R11, AX, BX     // a3*b3 -> BX:AX
 155      ADDQ AX, R9
 156      ADCQ BX, R10
 157  
 158      // Now we have:
 159      // l[0..3] in memory at DI
 160      // l[4] = R12
 161      // l[5] = R15
 162      // l[6] = R9
 163      // l[7] = R10
 164  
 165      // Save l4-l7 to stack for reduction phase
 166      MOVQ R12, 0(SP)       // l4
 167      MOVQ R15, 8(SP)       // l5
 168      MOVQ R9, 16(SP)       // l6
 169      MOVQ R10, 24(SP)      // l7
 170  
 171      // === Reduction modulo scalar order n ===
 172      // n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
 173      // NC = 2^256 - n = { 0x402DA1732FC9BEBF, 0x4551231950B75FC4, 1, 0 }
 174      //
 175      // Phase 1: Reduce 512 bits to 385 bits
 176      // m[0..6] = l[0..3] + l[4..7] * NC
 177  
 178      // Load constants
 179      MOVQ $0x402DA1732FC9BEBF, R8  // NC0
 180      MOVQ $0x4551231950B75FC4, R11 // NC1
 181  
 182      // === m0 ===
 183      // c0 = l[0], c1 = 0
 184      // muladd_fast(l4, NC0)
 185      // m0 = extract_fast()
 186      MOVQ 0(DI), R13       // c0 = l0
 187      XORQ R14, R14         // c1 = 0
 188  
 189      MOVQ R12, DX          // l4
 190      MULXQ R8, AX, BX      // l4 * NC0 -> BX:AX
 191      ADDQ AX, R13          // c0 += lo
 192      ADCQ BX, R14          // c1 += hi + carry
 193  
 194      // m0 = c0, shift accum
 195      MOVQ R13, 32(SP)      // m0
 196      MOVQ R14, R13         // c0 = c1
 197      XORQ R14, R14         // c1 = 0
 198      XORQ BP, BP           // c2 = 0
 199  
 200      // === m1 ===
 201      // sumadd_fast(l[1])
 202      // muladd(l5, NC0)
 203      // muladd(l4, NC1)
 204      ADDQ 8(DI), R13       // c0 += l1
 205      ADCQ $0, R14
 206  
 207      MOVQ R15, DX          // l5
 208      MULXQ R8, AX, BX      // l5 * NC0 -> BX:AX
 209      ADDQ AX, R13
 210      ADCQ BX, R14
 211      ADCQ $0, BP
 212  
 213      MOVQ R12, DX          // l4
 214      MULXQ R11, AX, BX     // l4 * NC1 -> BX:AX
 215      ADDQ AX, R13
 216      ADCQ BX, R14
 217      ADCQ $0, BP
 218  
 219      // m1 = c0, shift
 220      MOVQ R13, 40(SP)      // m1
 221      MOVQ R14, R13
 222      MOVQ BP, R14
 223      XORQ BP, BP
 224  
 225      // === m2 ===
 226      // sumadd(l[2])
 227      // muladd(l6, NC0)
 228      // muladd(l5, NC1)
 229      // sumadd(l4)  (NC2 = 1)
 230      ADDQ 16(DI), R13      // c0 += l2
 231      ADCQ $0, R14
 232      ADCQ $0, BP
 233  
 234      MOVQ 16(SP), DX       // l6
 235      MULXQ R8, AX, BX      // l6 * NC0 -> BX:AX
 236      ADDQ AX, R13
 237      ADCQ BX, R14
 238      ADCQ $0, BP
 239  
 240      MOVQ R15, DX          // l5
 241      MULXQ R11, AX, BX     // l5 * NC1 -> BX:AX
 242      ADDQ AX, R13
 243      ADCQ BX, R14
 244      ADCQ $0, BP
 245  
 246      ADDQ R12, R13         // c0 += l4 (l4 * NC2 = l4 * 1)
 247      ADCQ $0, R14
 248      ADCQ $0, BP
 249  
 250      // m2 = c0
 251      MOVQ R13, 48(SP)      // m2
 252      MOVQ R14, R13
 253      MOVQ BP, R14
 254      XORQ BP, BP
 255  
 256      // === m3 ===
 257      // sumadd(l[3])
 258      // muladd(l7, NC0)
 259      // muladd(l6, NC1)
 260      // sumadd(l5)
 261      ADDQ 24(DI), R13      // c0 += l3
 262      ADCQ $0, R14
 263      ADCQ $0, BP
 264  
 265      MOVQ 24(SP), DX       // l7
 266      MULXQ R8, AX, BX      // l7 * NC0 -> BX:AX
 267      ADDQ AX, R13
 268      ADCQ BX, R14
 269      ADCQ $0, BP
 270  
 271      MOVQ 16(SP), DX       // l6
 272      MULXQ R11, AX, BX     // l6 * NC1 -> BX:AX
 273      ADDQ AX, R13
 274      ADCQ BX, R14
 275      ADCQ $0, BP
 276  
 277      ADDQ R15, R13         // c0 += l5
 278      ADCQ $0, R14
 279      ADCQ $0, BP
 280  
 281      // m3 = c0
 282      MOVQ R13, 56(SP)      // m3
 283      MOVQ R14, R13
 284      MOVQ BP, R14
 285  
 286      // === m4 ===
 287      // muladd(l7, NC1)
 288      // sumadd(l6)
 289      MOVQ 24(SP), DX       // l7
 290      MULXQ R11, AX, BX     // l7 * NC1 -> BX:AX
 291      ADDQ AX, R13
 292      ADCQ BX, R14
 293  
 294      ADDQ 16(SP), R13      // c0 += l6
 295      ADCQ $0, R14
 296  
 297      // m4 in R13
 298      MOVQ R13, R12         // m4 = c0
 299      MOVQ R14, R13         // c0 = c1
 300  
 301      // === m5 ===
 302      // sumadd_fast(l7)
 303      ADDQ 24(SP), R13      // c0 += l7
 304      MOVQ $0, R9
 305      ADCQ $0, R9
 306      // m5 in R13
 307      MOVQ R13, R15         // m5
 308  
 309      // === m6 ===
 310      // m6 = carry (should be small, often 0)
 311      // R9 already has the carry
 312  
 313      // Phase 2: Reduce 385 bits to 258 bits
 314      // p[0..4] = m[0..3] + m[4..6] * NC
 315  
 316      // === p0 ===
 317      MOVQ 32(SP), R13      // c0 = m0
 318      XORQ R14, R14         // c1 = 0
 319  
 320      MOVQ R12, DX          // m4
 321      MULXQ R8, AX, BX      // m4 * NC0 -> BX:AX
 322      ADDQ AX, R13
 323      ADCQ BX, R14
 324  
 325      MOVQ R13, 0(DI)       // p0 = c0
 326      MOVQ R14, R13
 327      XORQ R14, R14
 328      XORQ BP, BP
 329  
 330      // === p1 ===
 331      ADDQ 40(SP), R13      // c0 += m1
 332  
 333      MOVQ R15, DX          // m5
 334      MULXQ R8, AX, BX      // m5 * NC0 -> BX:AX
 335      ADDQ AX, R13
 336      ADCQ BX, R14
 337      ADCQ $0, BP
 338  
 339      MOVQ R12, DX          // m4
 340      MULXQ R11, AX, BX     // m4 * NC1 -> BX:AX
 341      ADDQ AX, R13
 342      ADCQ BX, R14
 343      ADCQ $0, BP
 344  
 345      MOVQ R13, 8(DI)       // p1
 346      MOVQ R14, R13
 347      MOVQ BP, R14
 348      XORQ BP, BP
 349  
 350      // === p2 ===
 351      ADDQ 48(SP), R13      // c0 += m2
 352      ADCQ $0, R14
 353      ADCQ $0, BP
 354  
 355      MOVQ R9, DX           // m6
 356      MULXQ R8, AX, BX      // m6 * NC0 -> BX:AX
 357      ADDQ AX, R13
 358      ADCQ BX, R14
 359      ADCQ $0, BP
 360  
 361      MOVQ R15, DX          // m5
 362      MULXQ R11, AX, BX     // m5 * NC1 -> BX:AX
 363      ADDQ AX, R13
 364      ADCQ BX, R14
 365      ADCQ $0, BP
 366  
 367      ADDQ R12, R13         // c0 += m4
 368      ADCQ $0, R14
 369      ADCQ $0, BP
 370  
 371      MOVQ R13, 16(DI)      // p2
 372      MOVQ R14, R13
 373      MOVQ BP, R14
 374  
 375      // === p3 ===
 376      ADDQ 56(SP), R13      // c0 += m3
 377  
 378      MOVQ R9, DX           // m6
 379      MULXQ R11, AX, BX     // m6 * NC1 -> BX:AX
 380      ADDQ AX, R13
 381      ADCQ BX, R14
 382  
 383      ADDQ R15, R13         // c0 += m5
 384      ADCQ $0, R14
 385  
 386      MOVQ R13, 24(DI)      // p3
 387      // p4 = c1 + m6
 388      ADDQ R14, R9          // p4 = R9
 389  
 390      // Phase 3: Reduce 258 bits to 256 bits
 391      // r[0..3] = p[0..3] + p[4] * NC
 392  
 393      // r0 = p0 + p4 * NC0
 394      MOVQ R9, DX           // p4
 395      MULXQ R8, AX, BX      // p4 * NC0 -> BX:AX
 396      ADDQ 0(DI), AX        // AX = p0 + lo
 397      ADCQ $0, BX           // BX = hi + carry
 398      MOVQ AX, R12          // r0
 399      MOVQ BX, R13          // carry
 400  
 401      // r1 = p1 + p4 * NC1 + carry
 402      MOVQ R9, DX           // p4
 403      MULXQ R11, AX, BX     // p4 * NC1 -> BX:AX
 404      ADDQ R13, AX          // AX += carry
 405      ADCQ $0, BX
 406      ADDQ 8(DI), AX        // AX += p1
 407      ADCQ $0, BX
 408      MOVQ AX, R14          // r1
 409      MOVQ BX, R13          // carry
 410  
 411      // r2 = p2 + p4 + carry (NC2 = 1)
 412      MOVQ 16(DI), AX
 413      ADDQ R13, AX          // AX = p2 + carry
 414      MOVQ $0, R13
 415      ADCQ $0, R13
 416      ADDQ R9, AX           // AX += p4
 417      ADCQ $0, R13
 418      MOVQ AX, R15          // r2
 419  
 420      // r3 = p3 + carry
 421      MOVQ 24(DI), AX
 422      ADDQ R13, AX
 423      MOVQ $0, R10
 424      ADCQ $0, R10          // final carry
 425      MOVQ AX, BP           // r3
 426  
 427      // Check if we need to reduce (carry or result >= n)
 428      TESTQ R10, R10
 429      JNZ bmi2_do_final_reduce
 430  
 431      // Compare with n (from high to low)
 432      MOVQ $0xFFFFFFFFFFFFFFFF, R13
 433      CMPQ BP, R13
 434      JB bmi2_store_result
 435      JA bmi2_do_final_reduce
 436      MOVQ $0xFFFFFFFFFFFFFFFE, R13
 437      CMPQ R15, R13
 438      JB bmi2_store_result
 439      JA bmi2_do_final_reduce
 440      MOVQ $0xBAAEDCE6AF48A03B, R13
 441      CMPQ R14, R13
 442      JB bmi2_store_result
 443      JA bmi2_do_final_reduce
 444      MOVQ $0xBFD25E8CD0364141, R13
 445      CMPQ R12, R13
 446      JB bmi2_store_result
 447  
 448  bmi2_do_final_reduce:
 449      // Add 2^256 - n
 450      ADDQ R8, R12          // r0 += NC0
 451      ADCQ R11, R14         // r1 += NC1
 452      ADCQ $1, R15          // r2 += NC2 = 1
 453      ADCQ $0, BP           // r3 += 0
 454  
 455  bmi2_store_result:
 456      // Store result
 457      MOVQ R12, 0(DI)
 458      MOVQ R14, 8(DI)
 459      MOVQ R15, 16(DI)
 460      MOVQ BP, 24(DI)
 461  
 462      RET
 463