scalar_amd64.s raw

   1  //go:build amd64
   2  
   3  #include "textflag.h"
   4  
   5  // Constants for scalar reduction
   6  // n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
   7  DATA scalarN<>+0x00(SB)/8, $0xBFD25E8CD0364141
   8  DATA scalarN<>+0x08(SB)/8, $0xBAAEDCE6AF48A03B
   9  DATA scalarN<>+0x10(SB)/8, $0xFFFFFFFFFFFFFFFE
  10  DATA scalarN<>+0x18(SB)/8, $0xFFFFFFFFFFFFFFFF
  11  GLOBL scalarN<>(SB), RODATA|NOPTR, $32
  12  
  13  // 2^256 - n (for reduction)
  14  DATA scalarNC<>+0x00(SB)/8, $0x402DA1732FC9BEBF
  15  DATA scalarNC<>+0x08(SB)/8, $0x4551231950B75FC4
  16  DATA scalarNC<>+0x10(SB)/8, $0x0000000000000001
  17  DATA scalarNC<>+0x18(SB)/8, $0x0000000000000000
  18  GLOBL scalarNC<>(SB), RODATA|NOPTR, $32
  19  
  20  // func hasAVX2() bool
  21  TEXT ·hasAVX2(SB), NOSPLIT, $0-1
  22  	MOVL $7, AX
  23  	MOVL $0, CX
  24  	CPUID
  25  	ANDL $0x20, BX  // Check bit 5 of EBX for AVX2
  26  	SETNE AL
  27  	MOVB AL, ret+0(FP)
  28  	RET
  29  
  30  // func ScalarAddAVX2(r, a, b *Scalar)
  31  // Adds two 256-bit scalars using AVX2 for loading/storing and scalar ADD with carry.
  32  //
  33  // YMM layout: [D[0].Lo, D[0].Hi, D[1].Lo, D[1].Hi] = 4 x 64-bit
  34  TEXT ·ScalarAddAVX2(SB), NOSPLIT, $0-24
  35  	MOVQ r+0(FP), DI
  36  	MOVQ a+8(FP), SI
  37  	MOVQ b+16(FP), DX
  38  
  39  	// Load a and b into registers (scalar loads for carry chain)
  40  	MOVQ 0(SI), AX      // a.D[0].Lo
  41  	MOVQ 8(SI), BX      // a.D[0].Hi
  42  	MOVQ 16(SI), CX     // a.D[1].Lo
  43  	MOVQ 24(SI), R8     // a.D[1].Hi
  44  
  45  	// Add b with carry chain
  46  	ADDQ 0(DX), AX      // a.D[0].Lo + b.D[0].Lo
  47  	ADCQ 8(DX), BX      // a.D[0].Hi + b.D[0].Hi + carry
  48  	ADCQ 16(DX), CX     // a.D[1].Lo + b.D[1].Lo + carry
  49  	ADCQ 24(DX), R8     // a.D[1].Hi + b.D[1].Hi + carry
  50  
  51  	// Save carry flag
  52  	SETCS R9B
  53  
  54  	// Store preliminary result
  55  	MOVQ AX, 0(DI)
  56  	MOVQ BX, 8(DI)
  57  	MOVQ CX, 16(DI)
  58  	MOVQ R8, 24(DI)
  59  
  60  	// Check if we need to reduce (carry set or result >= n)
  61  	TESTB R9B, R9B
  62  	JNZ reduce
  63  
  64  	// Compare with n (from high to low)
  65  	MOVQ $0xFFFFFFFFFFFFFFFF, R10
  66  	CMPQ R8, R10
  67  	JB done
  68  	JA reduce
  69  	MOVQ scalarN<>+0x10(SB), R10
  70  	CMPQ CX, R10
  71  	JB done
  72  	JA reduce
  73  	MOVQ scalarN<>+0x08(SB), R10
  74  	CMPQ BX, R10
  75  	JB done
  76  	JA reduce
  77  	MOVQ scalarN<>+0x00(SB), R10
  78  	CMPQ AX, R10
  79  	JB done
  80  
  81  reduce:
  82  	// Add 2^256 - n (which is equivalent to subtracting n)
  83  	MOVQ 0(DI), AX
  84  	MOVQ 8(DI), BX
  85  	MOVQ 16(DI), CX
  86  	MOVQ 24(DI), R8
  87  
  88  	MOVQ scalarNC<>+0x00(SB), R10
  89  	ADDQ R10, AX
  90  	MOVQ scalarNC<>+0x08(SB), R10
  91  	ADCQ R10, BX
  92  	MOVQ scalarNC<>+0x10(SB), R10
  93  	ADCQ R10, CX
  94  	MOVQ scalarNC<>+0x18(SB), R10
  95  	ADCQ R10, R8
  96  
  97  	MOVQ AX, 0(DI)
  98  	MOVQ BX, 8(DI)
  99  	MOVQ CX, 16(DI)
 100  	MOVQ R8, 24(DI)
 101  
 102  done:
 103  	VZEROUPPER
 104  	RET
 105  
 106  // func ScalarSubAVX2(r, a, b *Scalar)
 107  // Subtracts two 256-bit scalars.
 108  TEXT ·ScalarSubAVX2(SB), NOSPLIT, $0-24
 109  	MOVQ r+0(FP), DI
 110  	MOVQ a+8(FP), SI
 111  	MOVQ b+16(FP), DX
 112  
 113  	// Load a
 114  	MOVQ 0(SI), AX
 115  	MOVQ 8(SI), BX
 116  	MOVQ 16(SI), CX
 117  	MOVQ 24(SI), R8
 118  
 119  	// Subtract b with borrow chain
 120  	SUBQ 0(DX), AX
 121  	SBBQ 8(DX), BX
 122  	SBBQ 16(DX), CX
 123  	SBBQ 24(DX), R8
 124  
 125  	// Save borrow flag
 126  	SETCS R9B
 127  
 128  	// Store preliminary result
 129  	MOVQ AX, 0(DI)
 130  	MOVQ BX, 8(DI)
 131  	MOVQ CX, 16(DI)
 132  	MOVQ R8, 24(DI)
 133  
 134  	// If borrow, add n back
 135  	TESTB R9B, R9B
 136  	JZ done_sub
 137  
 138  	// Add n
 139  	MOVQ scalarN<>+0x00(SB), R10
 140  	ADDQ R10, AX
 141  	MOVQ scalarN<>+0x08(SB), R10
 142  	ADCQ R10, BX
 143  	MOVQ scalarN<>+0x10(SB), R10
 144  	ADCQ R10, CX
 145  	MOVQ scalarN<>+0x18(SB), R10
 146  	ADCQ R10, R8
 147  
 148  	MOVQ AX, 0(DI)
 149  	MOVQ BX, 8(DI)
 150  	MOVQ CX, 16(DI)
 151  	MOVQ R8, 24(DI)
 152  
 153  done_sub:
 154  	VZEROUPPER
 155  	RET
 156  
 157  // func ScalarMulAVX2(r, a, b *Scalar)
 158  // Multiplies two 256-bit scalars and reduces mod n.
 159  // This is a complex operation requiring 512-bit intermediate.
 160  TEXT ·ScalarMulAVX2(SB), NOSPLIT, $64-24
 161  	MOVQ r+0(FP), DI
 162  	MOVQ a+8(FP), SI
 163  	MOVQ b+16(FP), DX
 164  
 165  	// We need to compute a 512-bit product and reduce mod n.
 166  	// For now, use scalar multiplication with MULX (if BMI2 available) or MUL.
 167  
 168  	// Load a limbs
 169  	MOVQ 0(SI), R8      // a0
 170  	MOVQ 8(SI), R9      // a1
 171  	MOVQ 16(SI), R10    // a2
 172  	MOVQ 24(SI), R11    // a3
 173  
 174  	// Store b pointer for later use
 175  	MOVQ DX, R12
 176  
 177  	// Compute 512-bit product using schoolbook multiplication
 178  	// Product stored on stack at SP+0 to SP+56 (8 limbs)
 179  
 180  	// Initialize product to zero
 181  	XORQ AX, AX
 182  	MOVQ AX, 0(SP)
 183  	MOVQ AX, 8(SP)
 184  	MOVQ AX, 16(SP)
 185  	MOVQ AX, 24(SP)
 186  	MOVQ AX, 32(SP)
 187  	MOVQ AX, 40(SP)
 188  	MOVQ AX, 48(SP)
 189  	MOVQ AX, 56(SP)
 190  
 191  	// Multiply a0 * b[0..3]
 192  	MOVQ R8, AX
 193  	MULQ 0(R12)         // a0 * b0
 194  	MOVQ AX, 0(SP)
 195  	MOVQ DX, R13        // carry
 196  
 197  	MOVQ R8, AX
 198  	MULQ 8(R12)         // a0 * b1
 199  	ADDQ R13, AX
 200  	ADCQ $0, DX
 201  	MOVQ AX, 8(SP)
 202  	MOVQ DX, R13
 203  
 204  	MOVQ R8, AX
 205  	MULQ 16(R12)        // a0 * b2
 206  	ADDQ R13, AX
 207  	ADCQ $0, DX
 208  	MOVQ AX, 16(SP)
 209  	MOVQ DX, R13
 210  
 211  	MOVQ R8, AX
 212  	MULQ 24(R12)        // a0 * b3
 213  	ADDQ R13, AX
 214  	ADCQ $0, DX
 215  	MOVQ AX, 24(SP)
 216  	MOVQ DX, 32(SP)
 217  
 218  	// Multiply a1 * b[0..3] and add
 219  	MOVQ R9, AX
 220  	MULQ 0(R12)         // a1 * b0
 221  	ADDQ AX, 8(SP)
 222  	ADCQ DX, 16(SP)
 223  	ADCQ $0, 24(SP)
 224  	ADCQ $0, 32(SP)
 225  
 226  	MOVQ R9, AX
 227  	MULQ 8(R12)         // a1 * b1
 228  	ADDQ AX, 16(SP)
 229  	ADCQ DX, 24(SP)
 230  	ADCQ $0, 32(SP)
 231  
 232  	MOVQ R9, AX
 233  	MULQ 16(R12)        // a1 * b2
 234  	ADDQ AX, 24(SP)
 235  	ADCQ DX, 32(SP)
 236  	ADCQ $0, 40(SP)
 237  
 238  	MOVQ R9, AX
 239  	MULQ 24(R12)        // a1 * b3
 240  	ADDQ AX, 32(SP)
 241  	ADCQ DX, 40(SP)
 242  
 243  	// Multiply a2 * b[0..3] and add
 244  	MOVQ R10, AX
 245  	MULQ 0(R12)         // a2 * b0
 246  	ADDQ AX, 16(SP)
 247  	ADCQ DX, 24(SP)
 248  	ADCQ $0, 32(SP)
 249  	ADCQ $0, 40(SP)
 250  
 251  	MOVQ R10, AX
 252  	MULQ 8(R12)         // a2 * b1
 253  	ADDQ AX, 24(SP)
 254  	ADCQ DX, 32(SP)
 255  	ADCQ $0, 40(SP)
 256  
 257  	MOVQ R10, AX
 258  	MULQ 16(R12)        // a2 * b2
 259  	ADDQ AX, 32(SP)
 260  	ADCQ DX, 40(SP)
 261  	ADCQ $0, 48(SP)
 262  
 263  	MOVQ R10, AX
 264  	MULQ 24(R12)        // a2 * b3
 265  	ADDQ AX, 40(SP)
 266  	ADCQ DX, 48(SP)
 267  
 268  	// Multiply a3 * b[0..3] and add
 269  	MOVQ R11, AX
 270  	MULQ 0(R12)         // a3 * b0
 271  	ADDQ AX, 24(SP)
 272  	ADCQ DX, 32(SP)
 273  	ADCQ $0, 40(SP)
 274  	ADCQ $0, 48(SP)
 275  
 276  	MOVQ R11, AX
 277  	MULQ 8(R12)         // a3 * b1
 278  	ADDQ AX, 32(SP)
 279  	ADCQ DX, 40(SP)
 280  	ADCQ $0, 48(SP)
 281  
 282  	MOVQ R11, AX
 283  	MULQ 16(R12)        // a3 * b2
 284  	ADDQ AX, 40(SP)
 285  	ADCQ DX, 48(SP)
 286  	ADCQ $0, 56(SP)
 287  
 288  	MOVQ R11, AX
 289  	MULQ 24(R12)        // a3 * b3
 290  	ADDQ AX, 48(SP)
 291  	ADCQ DX, 56(SP)
 292  
 293  	// Now we have the 512-bit product in SP+0 to SP+56 (l[0..7])
 294  	// Need to reduce mod n using the bitcoin-core algorithm:
 295  	//
 296  	// Phase 1: 512->385 bits
 297  	//   c0..c4 = l[0..3] + l[4..7] * NC   (where NC = 2^256 - n)
 298  	// Phase 2: 385->258 bits
 299  	//   d0..d4 = c[0..3] + c[4] * NC
 300  	// Phase 3: 258->256 bits
 301  	//   r[0..3] = d[0..3] + d[4] * NC, then final reduce if >= n
 302  	//
 303  	// NC = [0x402DA1732FC9BEBF, 0x4551231950B75FC4, 1, 0]
 304  
 305  	// ========== Phase 1: 512->385 bits ==========
 306  	// Compute c[0..4] = l[0..3] + l[4..7] * NC
 307  	// NC has only 3 significant limbs: NC[0], NC[1], NC[2]=1
 308  
 309  	// Start with c = l[0..3], then add contributions from l[4..7] * NC
 310  	MOVQ 0(SP), R8      // c0 = l0
 311  	MOVQ 8(SP), R9      // c1 = l1
 312  	MOVQ 16(SP), R10    // c2 = l2
 313  	MOVQ 24(SP), R11    // c3 = l3
 314  	XORQ R14, R14       // c4 = 0
 315  	XORQ R15, R15       // c5 for overflow
 316  
 317  	// l4 * NC[0]
 318  	MOVQ 32(SP), AX
 319  	MOVQ scalarNC<>+0x00(SB), R12
 320  	MULQ R12            // DX:AX = l4 * NC[0]
 321  	ADDQ AX, R8
 322  	ADCQ DX, R9
 323  	ADCQ $0, R10
 324  	ADCQ $0, R11
 325  	ADCQ $0, R14
 326  
 327  	// l4 * NC[1]
 328  	MOVQ 32(SP), AX
 329  	MOVQ scalarNC<>+0x08(SB), R12
 330  	MULQ R12            // DX:AX = l4 * NC[1]
 331  	ADDQ AX, R9
 332  	ADCQ DX, R10
 333  	ADCQ $0, R11
 334  	ADCQ $0, R14
 335  
 336  	// l4 * NC[2] (NC[2] = 1)
 337  	MOVQ 32(SP), AX
 338  	ADDQ AX, R10
 339  	ADCQ $0, R11
 340  	ADCQ $0, R14
 341  
 342  	// l5 * NC[0]
 343  	MOVQ 40(SP), AX
 344  	MOVQ scalarNC<>+0x00(SB), R12
 345  	MULQ R12
 346  	ADDQ AX, R9
 347  	ADCQ DX, R10
 348  	ADCQ $0, R11
 349  	ADCQ $0, R14
 350  
 351  	// l5 * NC[1]
 352  	MOVQ 40(SP), AX
 353  	MOVQ scalarNC<>+0x08(SB), R12
 354  	MULQ R12
 355  	ADDQ AX, R10
 356  	ADCQ DX, R11
 357  	ADCQ $0, R14
 358  
 359  	// l5 * NC[2] (NC[2] = 1)
 360  	MOVQ 40(SP), AX
 361  	ADDQ AX, R11
 362  	ADCQ $0, R14
 363  
 364  	// l6 * NC[0]
 365  	MOVQ 48(SP), AX
 366  	MOVQ scalarNC<>+0x00(SB), R12
 367  	MULQ R12
 368  	ADDQ AX, R10
 369  	ADCQ DX, R11
 370  	ADCQ $0, R14
 371  
 372  	// l6 * NC[1]
 373  	MOVQ 48(SP), AX
 374  	MOVQ scalarNC<>+0x08(SB), R12
 375  	MULQ R12
 376  	ADDQ AX, R11
 377  	ADCQ DX, R14
 378  
 379  	// l6 * NC[2] (NC[2] = 1)
 380  	MOVQ 48(SP), AX
 381  	ADDQ AX, R14
 382  	ADCQ $0, R15
 383  
 384  	// l7 * NC[0]
 385  	MOVQ 56(SP), AX
 386  	MOVQ scalarNC<>+0x00(SB), R12
 387  	MULQ R12
 388  	ADDQ AX, R11
 389  	ADCQ DX, R14
 390  	ADCQ $0, R15
 391  
 392  	// l7 * NC[1]
 393  	MOVQ 56(SP), AX
 394  	MOVQ scalarNC<>+0x08(SB), R12
 395  	MULQ R12
 396  	ADDQ AX, R14
 397  	ADCQ DX, R15
 398  
 399  	// l7 * NC[2] (NC[2] = 1)
 400  	MOVQ 56(SP), AX
 401  	ADDQ AX, R15
 402  
 403  	// Now c[0..5] = R8, R9, R10, R11, R14, R15 (~385 bits max)
 404  
 405  	// ========== Phase 2: 385->258 bits ==========
 406  	// Reduce c[4..5] by multiplying by NC and adding to c[0..3]
 407  
 408  	// c4 * NC[0]
 409  	MOVQ R14, AX
 410  	MOVQ scalarNC<>+0x00(SB), R12
 411  	MULQ R12
 412  	ADDQ AX, R8
 413  	ADCQ DX, R9
 414  	ADCQ $0, R10
 415  	ADCQ $0, R11
 416  
 417  	// c4 * NC[1]
 418  	MOVQ R14, AX
 419  	MOVQ scalarNC<>+0x08(SB), R12
 420  	MULQ R12
 421  	ADDQ AX, R9
 422  	ADCQ DX, R10
 423  	ADCQ $0, R11
 424  
 425  	// c4 * NC[2] (NC[2] = 1)
 426  	ADDQ R14, R10
 427  	ADCQ $0, R11
 428  
 429  	// c5 * NC[0]
 430  	MOVQ R15, AX
 431  	MOVQ scalarNC<>+0x00(SB), R12
 432  	MULQ R12
 433  	ADDQ AX, R9
 434  	ADCQ DX, R10
 435  	ADCQ $0, R11
 436  
 437  	// c5 * NC[1]
 438  	MOVQ R15, AX
 439  	MOVQ scalarNC<>+0x08(SB), R12
 440  	MULQ R12
 441  	ADDQ AX, R10
 442  	ADCQ DX, R11
 443  
 444  	// c5 * NC[2] (NC[2] = 1)
 445  	ADDQ R15, R11
 446  	// Capture any final carry into R14
 447  	MOVQ $0, R14
 448  	ADCQ $0, R14
 449  
 450  	// Now we have ~258 bits in R8, R9, R10, R11, R14
 451  
 452  	// ========== Phase 3: 258->256 bits ==========
 453  	// If R14 (the overflow) is non-zero, reduce again
 454  	TESTQ R14, R14
 455  	JZ check_overflow
 456  
 457  	// R14 * NC
 458  	MOVQ R14, AX
 459  	MOVQ scalarNC<>+0x00(SB), R12
 460  	MULQ R12
 461  	ADDQ AX, R8
 462  	ADCQ DX, R9
 463  	ADCQ $0, R10
 464  	ADCQ $0, R11
 465  
 466  	MOVQ R14, AX
 467  	MOVQ scalarNC<>+0x08(SB), R12
 468  	MULQ R12
 469  	ADDQ AX, R9
 470  	ADCQ DX, R10
 471  	ADCQ $0, R11
 472  
 473  	// R14 * NC[2] (NC[2] = 1)
 474  	ADDQ R14, R10
 475  	ADCQ $0, R11
 476  
 477  check_overflow:
 478  	// Check if result >= n and reduce if needed
 479  	MOVQ $0xFFFFFFFFFFFFFFFF, R13
 480  	CMPQ R11, R13
 481  	JB store_result
 482  	JA do_reduce
 483  	MOVQ scalarN<>+0x10(SB), R13
 484  	CMPQ R10, R13
 485  	JB store_result
 486  	JA do_reduce
 487  	MOVQ scalarN<>+0x08(SB), R13
 488  	CMPQ R9, R13
 489  	JB store_result
 490  	JA do_reduce
 491  	MOVQ scalarN<>+0x00(SB), R13
 492  	CMPQ R8, R13
 493  	JB store_result
 494  
 495  do_reduce:
 496  	// Subtract n (add 2^256 - n)
 497  	MOVQ scalarNC<>+0x00(SB), R13
 498  	ADDQ R13, R8
 499  	MOVQ scalarNC<>+0x08(SB), R13
 500  	ADCQ R13, R9
 501  	MOVQ scalarNC<>+0x10(SB), R13
 502  	ADCQ R13, R10
 503  	MOVQ scalarNC<>+0x18(SB), R13
 504  	ADCQ R13, R11
 505  
 506  store_result:
 507  	// Store result
 508  	MOVQ r+0(FP), DI
 509  	MOVQ R8, 0(DI)
 510  	MOVQ R9, 8(DI)
 511  	MOVQ R10, 16(DI)
 512  	MOVQ R11, 24(DI)
 513  
 514  	VZEROUPPER
 515  	RET
 516