scalar.go raw

   1  //go:build !js && !wasm && !tinygo && !wasm32
   2  
   3  package p256k1
   4  
   5  import (
   6  	"crypto/subtle"
   7  	"math/bits"
   8  	"unsafe"
   9  )
  10  
  11  // Scalar represents a scalar value modulo the secp256k1 group order.
  12  // Uses 4 uint64 limbs to represent a 256-bit scalar.
  13  type Scalar struct {
  14  	d [4]uint64
  15  }
  16  
  17  // Scalar constants from the C implementation
  18  const (
  19  	// Limbs of the secp256k1 order n
  20  	scalarN0 = 0xBFD25E8CD0364141
  21  	scalarN1 = 0xBAAEDCE6AF48A03B
  22  	scalarN2 = 0xFFFFFFFFFFFFFFFE
  23  	scalarN3 = 0xFFFFFFFFFFFFFFFF
  24  
  25  	// Limbs of 2^256 minus the secp256k1 order (complement constants)
  26  	scalarNC0 = 0x402DA1732FC9BEBF // ~scalarN0 + 1
  27  	scalarNC1 = 0x4551231950B75FC4 // ~scalarN1
  28  	scalarNC2 = 0x0000000000000001 // 1
  29  
  30  	// Limbs of half the secp256k1 order
  31  	scalarNH0 = 0xDFE92F46681B20A0
  32  	scalarNH1 = 0x5D576E7357A4501D
  33  	scalarNH2 = 0xFFFFFFFFFFFFFFFF
  34  	scalarNH3 = 0x7FFFFFFFFFFFFFFF
  35  )
  36  
  37  // Scalar element constants
  38  var (
  39  	// ScalarZero represents the scalar 0
  40  	ScalarZero = Scalar{d: [4]uint64{0, 0, 0, 0}}
  41  
  42  	// ScalarOne represents the scalar 1
  43  	ScalarOne = Scalar{d: [4]uint64{1, 0, 0, 0}}
  44  
  45  	// scalarLambda is the GLV endomorphism constant λ (cube root of unity mod n)
  46  	// λ^3 ≡ 1 (mod n), and λ^2 + λ + 1 ≡ 0 (mod n)
  47  	// Value: 0x5363AD4CC05C30E0A5261C028812645A122E22EA20816678DF02967C1B23BD72
  48  	// From libsecp256k1 scalar_impl.h line 81-84
  49  	scalarLambda = Scalar{
  50  		d: [4]uint64{
  51  			0xDF02967C1B23BD72, // limb 0 (least significant)
  52  			0x122E22EA20816678, // limb 1
  53  			0xA5261C028812645A, // limb 2
  54  			0x5363AD4CC05C30E0, // limb 3 (most significant)
  55  		},
  56  	}
  57  
  58  	// GLV scalar splitting constants from libsecp256k1 scalar_impl.h lines 142-157
  59  	// These are used in the splitLambda function to decompose a scalar k
  60  	// into k1 and k2 such that k1 + k2*λ ≡ k (mod n)
  61  
  62  	// scalarMinusB1 = -b1 where b1 is from the GLV basis
  63  	// Value: 0x00000000000000000000000000000000E4437ED6010E88286F547FA90ABFE4C3
  64  	scalarMinusB1 = Scalar{
  65  		d: [4]uint64{
  66  			0x6F547FA90ABFE4C3, // limb 0
  67  			0xE4437ED6010E8828, // limb 1
  68  			0x0000000000000000, // limb 2
  69  			0x0000000000000000, // limb 3
  70  		},
  71  	}
  72  
  73  	// scalarMinusB2 = -b2 where b2 is from the GLV basis
  74  	// Value: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE8A280AC50774346DD765CDA83DB1562C
  75  	scalarMinusB2 = Scalar{
  76  		d: [4]uint64{
  77  			0xD765CDA83DB1562C, // limb 0
  78  			0x8A280AC50774346D, // limb 1
  79  			0xFFFFFFFFFFFFFFFE, // limb 2
  80  			0xFFFFFFFFFFFFFFFF, // limb 3
  81  		},
  82  	}
  83  
  84  	// scalarG1 is a precomputed constant for scalar splitting: g1 = round(2^384 * b2 / n)
  85  	// Value: 0x3086D221A7D46BCDE86C90E49284EB153DAA8A1471E8CA7FE893209A45DBB031
  86  	scalarG1 = Scalar{
  87  		d: [4]uint64{
  88  			0xE893209A45DBB031, // limb 0
  89  			0x3DAA8A1471E8CA7F, // limb 1
  90  			0xE86C90E49284EB15, // limb 2
  91  			0x3086D221A7D46BCD, // limb 3
  92  		},
  93  	}
  94  
  95  	// scalarG2 is a precomputed constant for scalar splitting: g2 = round(2^384 * (-b1) / n)
  96  	// Value: 0xE4437ED6010E88286F547FA90ABFE4C4221208AC9DF506C61571B4AE8AC47F71
  97  	scalarG2 = Scalar{
  98  		d: [4]uint64{
  99  			0x1571B4AE8AC47F71, // limb 0
 100  			0x221208AC9DF506C6, // limb 1
 101  			0x6F547FA90ABFE4C4, // limb 2
 102  			0xE4437ED6010E8828, // limb 3
 103  		},
 104  	}
 105  )
 106  
 107  // setInt sets a scalar to a small integer value
 108  func (r *Scalar) setInt(v uint) {
 109  	r.d[0] = uint64(v)
 110  	r.d[1] = 0
 111  	r.d[2] = 0
 112  	r.d[3] = 0
 113  }
 114  
 115  // setB32 sets a scalar from a 32-byte big-endian array
 116  func (r *Scalar) setB32(b []byte) bool {
 117  	if len(b) != 32 {
 118  		panic("scalar byte array must be 32 bytes")
 119  	}
 120  
 121  	// Convert from big-endian bytes to uint64 limbs
 122  	r.d[0] = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
 123  		uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
 124  	r.d[1] = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
 125  		uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
 126  	r.d[2] = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
 127  		uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
 128  	r.d[3] = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
 129  		uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
 130  
 131  	// Check if the scalar overflows the group order
 132  	overflow := r.checkOverflow()
 133  	if overflow {
 134  		r.reduce(1)
 135  	}
 136  
 137  	return overflow
 138  }
 139  
 140  // setB32Seckey sets a scalar from a 32-byte secret key, returns true if valid
 141  func (r *Scalar) setB32Seckey(b []byte) bool {
 142  	overflow := r.setB32(b)
 143  	return !r.isZero() && !overflow
 144  }
 145  
 146  // getB32 converts a scalar to a 32-byte big-endian array
 147  func (r *Scalar) getB32(b []byte) {
 148  	if len(b) != 32 {
 149  		panic("scalar byte array must be 32 bytes")
 150  	}
 151  
 152  	// Convert from uint64 limbs to big-endian bytes
 153  	b[31] = byte(r.d[0])
 154  	b[30] = byte(r.d[0] >> 8)
 155  	b[29] = byte(r.d[0] >> 16)
 156  	b[28] = byte(r.d[0] >> 24)
 157  	b[27] = byte(r.d[0] >> 32)
 158  	b[26] = byte(r.d[0] >> 40)
 159  	b[25] = byte(r.d[0] >> 48)
 160  	b[24] = byte(r.d[0] >> 56)
 161  
 162  	b[23] = byte(r.d[1])
 163  	b[22] = byte(r.d[1] >> 8)
 164  	b[21] = byte(r.d[1] >> 16)
 165  	b[20] = byte(r.d[1] >> 24)
 166  	b[19] = byte(r.d[1] >> 32)
 167  	b[18] = byte(r.d[1] >> 40)
 168  	b[17] = byte(r.d[1] >> 48)
 169  	b[16] = byte(r.d[1] >> 56)
 170  
 171  	b[15] = byte(r.d[2])
 172  	b[14] = byte(r.d[2] >> 8)
 173  	b[13] = byte(r.d[2] >> 16)
 174  	b[12] = byte(r.d[2] >> 24)
 175  	b[11] = byte(r.d[2] >> 32)
 176  	b[10] = byte(r.d[2] >> 40)
 177  	b[9] = byte(r.d[2] >> 48)
 178  	b[8] = byte(r.d[2] >> 56)
 179  
 180  	b[7] = byte(r.d[3])
 181  	b[6] = byte(r.d[3] >> 8)
 182  	b[5] = byte(r.d[3] >> 16)
 183  	b[4] = byte(r.d[3] >> 24)
 184  	b[3] = byte(r.d[3] >> 32)
 185  	b[2] = byte(r.d[3] >> 40)
 186  	b[1] = byte(r.d[3] >> 48)
 187  	b[0] = byte(r.d[3] >> 56)
 188  }
 189  
 190  // checkOverflow checks if the scalar is >= the group order
 191  func (r *Scalar) checkOverflow() bool {
 192  	yes := 0
 193  	no := 0
 194  
 195  	// Check each limb from most significant to least significant
 196  	if r.d[3] < scalarN3 {
 197  		no = 1
 198  	}
 199  	if r.d[3] > scalarN3 {
 200  		yes = 1
 201  	}
 202  
 203  	if r.d[2] < scalarN2 {
 204  		no |= (yes ^ 1)
 205  	}
 206  	if r.d[2] > scalarN2 {
 207  		yes |= (no ^ 1)
 208  	}
 209  
 210  	if r.d[1] < scalarN1 {
 211  		no |= (yes ^ 1)
 212  	}
 213  	if r.d[1] > scalarN1 {
 214  		yes |= (no ^ 1)
 215  	}
 216  
 217  	if r.d[0] >= scalarN0 {
 218  		yes |= (no ^ 1)
 219  	}
 220  
 221  	return yes != 0
 222  }
 223  
 224  // reduce reduces the scalar modulo the group order
 225  func (r *Scalar) reduce(overflow int) {
 226  	if overflow < 0 || overflow > 1 {
 227  		panic("overflow must be 0 or 1")
 228  	}
 229  
 230  	// Use 128-bit arithmetic for the reduction
 231  	var t uint128
 232  
 233  	// d[0] += overflow * scalarNC0
 234  	t = uint128FromU64(r.d[0])
 235  	t = t.addU64(uint64(overflow) * scalarNC0)
 236  	r.d[0] = t.lo()
 237  	t = t.rshift(64)
 238  
 239  	// d[1] += overflow * scalarNC1 + carry
 240  	t = t.addU64(r.d[1])
 241  	t = t.addU64(uint64(overflow) * scalarNC1)
 242  	r.d[1] = t.lo()
 243  	t = t.rshift(64)
 244  
 245  	// d[2] += overflow * scalarNC2 + carry
 246  	t = t.addU64(r.d[2])
 247  	t = t.addU64(uint64(overflow) * scalarNC2)
 248  	r.d[2] = t.lo()
 249  	t = t.rshift(64)
 250  
 251  	// d[3] += carry (scalarNC3 = 0)
 252  	t = t.addU64(r.d[3])
 253  	r.d[3] = t.lo()
 254  }
 255  
 256  // add adds two scalars: r = a + b, returns overflow
 257  func (r *Scalar) add(a, b *Scalar) bool {
 258  	// Use AVX2 if available (AMD64 only)
 259  	if HasAVX2() {
 260  		scalarAddAVX2(r, a, b)
 261  		return false // AVX2 version handles reduction internally
 262  	}
 263  	return r.addPureGo(a, b)
 264  }
 265  
 266  // addPureGo is the pure Go implementation of scalar addition
 267  func (r *Scalar) addPureGo(a, b *Scalar) bool {
 268  	var carry uint64
 269  
 270  	r.d[0], carry = bits.Add64(a.d[0], b.d[0], 0)
 271  	r.d[1], carry = bits.Add64(a.d[1], b.d[1], carry)
 272  	r.d[2], carry = bits.Add64(a.d[2], b.d[2], carry)
 273  	r.d[3], carry = bits.Add64(a.d[3], b.d[3], carry)
 274  
 275  	overflow := carry != 0 || r.checkOverflow()
 276  	if overflow {
 277  		r.reduce(1)
 278  	}
 279  
 280  	return overflow
 281  }
 282  
 283  // sub subtracts two scalars: r = a - b
 284  func (r *Scalar) sub(a, b *Scalar) {
 285  	// Use AVX2 if available (AMD64 only)
 286  	if HasAVX2() {
 287  		scalarSubAVX2(r, a, b)
 288  		return
 289  	}
 290  	r.subPureGo(a, b)
 291  }
 292  
 293  // subPureGo is the pure Go implementation of scalar subtraction
 294  func (r *Scalar) subPureGo(a, b *Scalar) {
 295  	// Compute a - b = a + (-b)
 296  	var negB Scalar
 297  	negB.negate(b)
 298  	*r = *a
 299  	r.addPureGo(r, &negB)
 300  }
 301  
 302  // mul multiplies two scalars: r = a * b
 303  func (r *Scalar) mul(a, b *Scalar) {
 304  	// Use AVX2 if available (AMD64 only)
 305  	if HasAVX2() {
 306  		scalarMulAVX2(r, a, b)
 307  		return
 308  	}
 309  	r.mulPureGo(a, b)
 310  }
 311  
 312  // mulPureGo is the pure Go implementation of scalar multiplication
 313  func (r *Scalar) mulPureGo(a, b *Scalar) {
 314  	// Compute full 512-bit product using all 16 cross products
 315  	var l [8]uint64
 316  	r.mul512(l[:], a, b)
 317  	r.reduce512(l[:])
 318  }
 319  
 320  // mul512 computes the 512-bit product of two scalars (from C implementation)
 321  func (r *Scalar) mul512(l8 []uint64, a, b *Scalar) {
 322  	// 160-bit accumulator (c0, c1, c2)
 323  	var c0, c1 uint64
 324  	var c2 uint32
 325  
 326  	// Helper macros translated from C
 327  	muladd := func(ai, bi uint64) {
 328  		hi, lo := bits.Mul64(ai, bi)
 329  		var carry uint64
 330  		c0, carry = bits.Add64(c0, lo, 0)
 331  		c1, carry = bits.Add64(c1, hi, carry)
 332  		c2 += uint32(carry)
 333  	}
 334  
 335  	muladdFast := func(ai, bi uint64) {
 336  		hi, lo := bits.Mul64(ai, bi)
 337  		var carry uint64
 338  		c0, carry = bits.Add64(c0, lo, 0)
 339  		c1 += hi + carry
 340  	}
 341  
 342  	extract := func() uint64 {
 343  		result := c0
 344  		c0 = c1
 345  		c1 = uint64(c2)
 346  		c2 = 0
 347  		return result
 348  	}
 349  
 350  	extractFast := func() uint64 {
 351  		result := c0
 352  		c0 = c1
 353  		c1 = 0
 354  		return result
 355  	}
 356  
 357  	// l8[0..7] = a[0..3] * b[0..3] (following C implementation exactly)
 358  	muladdFast(a.d[0], b.d[0])
 359  	l8[0] = extractFast()
 360  
 361  	muladd(a.d[0], b.d[1])
 362  	muladd(a.d[1], b.d[0])
 363  	l8[1] = extract()
 364  
 365  	muladd(a.d[0], b.d[2])
 366  	muladd(a.d[1], b.d[1])
 367  	muladd(a.d[2], b.d[0])
 368  	l8[2] = extract()
 369  
 370  	muladd(a.d[0], b.d[3])
 371  	muladd(a.d[1], b.d[2])
 372  	muladd(a.d[2], b.d[1])
 373  	muladd(a.d[3], b.d[0])
 374  	l8[3] = extract()
 375  
 376  	muladd(a.d[1], b.d[3])
 377  	muladd(a.d[2], b.d[2])
 378  	muladd(a.d[3], b.d[1])
 379  	l8[4] = extract()
 380  
 381  	muladd(a.d[2], b.d[3])
 382  	muladd(a.d[3], b.d[2])
 383  	l8[5] = extract()
 384  
 385  	muladdFast(a.d[3], b.d[3])
 386  	l8[6] = extractFast()
 387  	l8[7] = c0
 388  }
 389  
 390  // reduce512 reduces a 512-bit value to 256-bit (from C implementation)
 391  func (r *Scalar) reduce512(l []uint64) {
 392  	// 160-bit accumulator
 393  	var c0, c1 uint64
 394  	var c2 uint32
 395  
 396  	// Extract upper 256 bits
 397  	n0, n1, n2, n3 := l[4], l[5], l[6], l[7]
 398  
 399  	// Helper macros
 400  	muladd := func(ai, bi uint64) {
 401  		hi, lo := bits.Mul64(ai, bi)
 402  		var carry uint64
 403  		c0, carry = bits.Add64(c0, lo, 0)
 404  		c1, carry = bits.Add64(c1, hi, carry)
 405  		c2 += uint32(carry)
 406  	}
 407  
 408  	muladdFast := func(ai, bi uint64) {
 409  		hi, lo := bits.Mul64(ai, bi)
 410  		var carry uint64
 411  		c0, carry = bits.Add64(c0, lo, 0)
 412  		c1 += hi + carry
 413  	}
 414  
 415  	sumadd := func(a uint64) {
 416  		var carry uint64
 417  		c0, carry = bits.Add64(c0, a, 0)
 418  		c1, carry = bits.Add64(c1, 0, carry)
 419  		c2 += uint32(carry)
 420  	}
 421  
 422  	sumaddFast := func(a uint64) {
 423  		var carry uint64
 424  		c0, carry = bits.Add64(c0, a, 0)
 425  		c1 += carry
 426  	}
 427  
 428  	extract := func() uint64 {
 429  		result := c0
 430  		c0 = c1
 431  		c1 = uint64(c2)
 432  		c2 = 0
 433  		return result
 434  	}
 435  
 436  	extractFast := func() uint64 {
 437  		result := c0
 438  		c0 = c1
 439  		c1 = 0
 440  		return result
 441  	}
 442  
 443  	// Reduce 512 bits into 385 bits
 444  	// m[0..6] = l[0..3] + n[0..3] * SECP256K1_N_C
 445  	c0 = l[0]
 446  	c1 = 0
 447  	c2 = 0
 448  	muladdFast(n0, scalarNC0)
 449  	m0 := extractFast()
 450  
 451  	sumaddFast(l[1])
 452  	muladd(n1, scalarNC0)
 453  	muladd(n0, scalarNC1)
 454  	m1 := extract()
 455  
 456  	sumadd(l[2])
 457  	muladd(n2, scalarNC0)
 458  	muladd(n1, scalarNC1)
 459  	sumadd(n0)
 460  	m2 := extract()
 461  
 462  	sumadd(l[3])
 463  	muladd(n3, scalarNC0)
 464  	muladd(n2, scalarNC1)
 465  	sumadd(n1)
 466  	m3 := extract()
 467  
 468  	muladd(n3, scalarNC1)
 469  	sumadd(n2)
 470  	m4 := extract()
 471  
 472  	sumaddFast(n3)
 473  	m5 := extractFast()
 474  	m6 := uint32(c0)
 475  
 476  	// Reduce 385 bits into 258 bits
 477  	// p[0..4] = m[0..3] + m[4..6] * SECP256K1_N_C
 478  	c0 = m0
 479  	c1 = 0
 480  	c2 = 0
 481  	muladdFast(m4, scalarNC0)
 482  	p0 := extractFast()
 483  
 484  	sumaddFast(m1)
 485  	muladd(m5, scalarNC0)
 486  	muladd(m4, scalarNC1)
 487  	p1 := extract()
 488  
 489  	sumadd(m2)
 490  	muladd(uint64(m6), scalarNC0)
 491  	muladd(m5, scalarNC1)
 492  	sumadd(m4)
 493  	p2 := extract()
 494  
 495  	sumaddFast(m3)
 496  	muladdFast(uint64(m6), scalarNC1)
 497  	sumaddFast(m5)
 498  	p3 := extractFast()
 499  	p4 := uint32(c0 + uint64(m6))
 500  
 501  	// Reduce 258 bits into 256 bits
 502  	// r[0..3] = p[0..3] + p[4] * SECP256K1_N_C
 503  	var t uint128
 504  
 505  	t = uint128FromU64(p0)
 506  	t = t.addMul(scalarNC0, uint64(p4))
 507  	r.d[0] = t.lo()
 508  	t = t.rshift(64)
 509  
 510  	t = t.addU64(p1)
 511  	t = t.addMul(scalarNC1, uint64(p4))
 512  	r.d[1] = t.lo()
 513  	t = t.rshift(64)
 514  
 515  	t = t.addU64(p2)
 516  	t = t.addU64(uint64(p4))
 517  	r.d[2] = t.lo()
 518  	t = t.rshift(64)
 519  
 520  	t = t.addU64(p3)
 521  	r.d[3] = t.lo()
 522  	c := t.hi()
 523  
 524  	// Final reduction
 525  	r.reduce(int(c) + boolToInt(r.checkOverflow()))
 526  }
 527  
 528  // negate negates a scalar: r = -a
 529  func (r *Scalar) negate(a *Scalar) {
 530  	// r = n - a where n is the group order
 531  	var borrow uint64
 532  
 533  	r.d[0], borrow = bits.Sub64(scalarN0, a.d[0], 0)
 534  	r.d[1], borrow = bits.Sub64(scalarN1, a.d[1], borrow)
 535  	r.d[2], borrow = bits.Sub64(scalarN2, a.d[2], borrow)
 536  	r.d[3], _ = bits.Sub64(scalarN3, a.d[3], borrow)
 537  }
 538  
 539  // inverse computes the modular inverse of a scalar
 540  func (r *Scalar) inverse(a *Scalar) {
 541  	// Use Fermat's little theorem: a^(-1) = a^(n-2) mod n
 542  	// where n is the group order (which is prime)
 543  
 544  	// Use binary exponentiation with n-2
 545  	var exp Scalar
 546  	var borrow uint64
 547  	exp.d[0], borrow = bits.Sub64(scalarN0, 2, 0)
 548  	exp.d[1], borrow = bits.Sub64(scalarN1, 0, borrow)
 549  	exp.d[2], borrow = bits.Sub64(scalarN2, 0, borrow)
 550  	exp.d[3], _ = bits.Sub64(scalarN3, 0, borrow)
 551  
 552  	r.exp(a, &exp)
 553  }
 554  
 555  // exp computes r = a^b mod n using binary exponentiation
 556  func (r *Scalar) exp(a, b *Scalar) {
 557  	*r = ScalarOne
 558  	base := *a
 559  
 560  	for i := 0; i < 4; i++ {
 561  		limb := b.d[i]
 562  		for j := 0; j < 64; j++ {
 563  			if limb&1 != 0 {
 564  				r.mul(r, &base)
 565  			}
 566  			base.mul(&base, &base)
 567  			limb >>= 1
 568  		}
 569  	}
 570  }
 571  
 572  // half computes r = a/2 mod n
 573  func (r *Scalar) half(a *Scalar) {
 574  	*r = *a
 575  
 576  	if r.d[0]&1 == 0 {
 577  		// Even case: simple right shift
 578  		r.d[0] = (r.d[0] >> 1) | ((r.d[1] & 1) << 63)
 579  		r.d[1] = (r.d[1] >> 1) | ((r.d[2] & 1) << 63)
 580  		r.d[2] = (r.d[2] >> 1) | ((r.d[3] & 1) << 63)
 581  		r.d[3] = r.d[3] >> 1
 582  	} else {
 583  		// Odd case: add n then divide by 2
 584  		var carry uint64
 585  		r.d[0], carry = bits.Add64(r.d[0], scalarN0, 0)
 586  		r.d[1], carry = bits.Add64(r.d[1], scalarN1, carry)
 587  		r.d[2], carry = bits.Add64(r.d[2], scalarN2, carry)
 588  		r.d[3], _ = bits.Add64(r.d[3], scalarN3, carry)
 589  
 590  		// Now divide by 2
 591  		r.d[0] = (r.d[0] >> 1) | ((r.d[1] & 1) << 63)
 592  		r.d[1] = (r.d[1] >> 1) | ((r.d[2] & 1) << 63)
 593  		r.d[2] = (r.d[2] >> 1) | ((r.d[3] & 1) << 63)
 594  		r.d[3] = r.d[3] >> 1
 595  	}
 596  }
 597  
 598  // isZero returns true if the scalar is zero
 599  func (r *Scalar) isZero() bool {
 600  	return (r.d[0] | r.d[1] | r.d[2] | r.d[3]) == 0
 601  }
 602  
 603  // isOne returns true if the scalar is one
 604  func (r *Scalar) isOne() bool {
 605  	return r.d[0] == 1 && r.d[1] == 0 && r.d[2] == 0 && r.d[3] == 0
 606  }
 607  
 608  // isEven returns true if the scalar is even
 609  func (r *Scalar) isEven() bool {
 610  	return r.d[0]&1 == 0
 611  }
 612  
 613  // isHigh returns true if the scalar is > n/2
 614  func (r *Scalar) isHigh() bool {
 615  	var yes, no int
 616  
 617  	if r.d[3] < scalarNH3 {
 618  		no = 1
 619  	}
 620  	if r.d[3] > scalarNH3 {
 621  		yes = 1
 622  	}
 623  
 624  	if r.d[2] < scalarNH2 {
 625  		no |= (yes ^ 1)
 626  	}
 627  	if r.d[2] > scalarNH2 {
 628  		yes |= (no ^ 1)
 629  	}
 630  
 631  	if r.d[1] < scalarNH1 {
 632  		no |= (yes ^ 1)
 633  	}
 634  	if r.d[1] > scalarNH1 {
 635  		yes |= (no ^ 1)
 636  	}
 637  
 638  	if r.d[0] > scalarNH0 {
 639  		yes |= (no ^ 1)
 640  	}
 641  
 642  	return yes != 0
 643  }
 644  
 645  // condNegate conditionally negates the scalar if flag is true
 646  func (r *Scalar) condNegate(flag int) {
 647  	if flag != 0 {
 648  		var neg Scalar
 649  		neg.negate(r)
 650  		*r = neg
 651  	}
 652  }
 653  
 654  // equal returns true if two scalars are equal
 655  func (r *Scalar) equal(a *Scalar) bool {
 656  	return subtle.ConstantTimeCompare(
 657  		(*[32]byte)(unsafe.Pointer(&r.d[0]))[:32],
 658  		(*[32]byte)(unsafe.Pointer(&a.d[0]))[:32],
 659  	) == 1
 660  }
 661  
 662  // getBits extracts count bits starting at offset
 663  func (r *Scalar) getBits(offset, count uint) uint32 {
 664  	if count == 0 || count > 32 {
 665  		panic("count must be 1-32")
 666  	}
 667  	if offset+count > 256 {
 668  		panic("offset + count must be <= 256")
 669  	}
 670  
 671  	limbIdx := offset / 64
 672  	bitIdx := offset % 64
 673  
 674  	if bitIdx+count <= 64 {
 675  		// Bits are within a single limb
 676  		return uint32((r.d[limbIdx] >> bitIdx) & ((1 << count) - 1))
 677  	} else {
 678  		// Bits span two limbs
 679  		lowBits := 64 - bitIdx
 680  		highBits := count - lowBits
 681  		low := uint32((r.d[limbIdx] >> bitIdx) & ((1 << lowBits) - 1))
 682  		high := uint32(r.d[limbIdx+1] & ((1 << highBits) - 1))
 683  		return low | (high << lowBits)
 684  	}
 685  }
 686  
 687  // cmov conditionally moves a scalar. If flag is true, r = a; otherwise r is unchanged.
 688  func (r *Scalar) cmov(a *Scalar, flag int) {
 689  	mask := uint64(-(int64(flag) & 1))
 690  	r.d[0] ^= mask & (r.d[0] ^ a.d[0])
 691  	r.d[1] ^= mask & (r.d[1] ^ a.d[1])
 692  	r.d[2] ^= mask & (r.d[2] ^ a.d[2])
 693  	r.d[3] ^= mask & (r.d[3] ^ a.d[3])
 694  }
 695  
 696  // clear clears a scalar to prevent leaking sensitive information
 697  func (r *Scalar) clear() {
 698  	memclear(unsafe.Pointer(&r.d[0]), unsafe.Sizeof(r.d))
 699  }
 700  
 701  // Helper functions for 128-bit arithmetic (using uint128 from field_mul.go)
 702  
 703  func uint128FromU64(x uint64) uint128 {
 704  	return uint128{low: x, high: 0}
 705  }
 706  
 707  func (x uint128) addU64(y uint64) uint128 {
 708  	low, carry := bits.Add64(x.low, y, 0)
 709  	high := x.high + carry
 710  	return uint128{low: low, high: high}
 711  }
 712  
 713  func (x uint128) addMul(a, b uint64) uint128 {
 714  	hi, lo := bits.Mul64(a, b)
 715  	low, carry := bits.Add64(x.low, lo, 0)
 716  	high, _ := bits.Add64(x.high, hi, carry)
 717  	return uint128{low: low, high: high}
 718  }
 719  
 720  // Direct function versions to reduce method call overhead
 721  // These are equivalent to the method versions but avoid interface dispatch
 722  
 723  // scalarAdd adds two scalars: r = a + b, returns overflow
 724  func scalarAdd(r, a, b *Scalar) bool {
 725  	var carry uint64
 726  
 727  	r.d[0], carry = bits.Add64(a.d[0], b.d[0], 0)
 728  	r.d[1], carry = bits.Add64(a.d[1], b.d[1], carry)
 729  	r.d[2], carry = bits.Add64(a.d[2], b.d[2], carry)
 730  	r.d[3], carry = bits.Add64(a.d[3], b.d[3], carry)
 731  
 732  	overflow := carry != 0 || scalarCheckOverflow(r)
 733  	if overflow {
 734  		scalarReduce(r, 1)
 735  	}
 736  
 737  	return overflow
 738  }
 739  
 740  // scalarMul multiplies two scalars: r = a * b
 741  func scalarMul(r, a, b *Scalar) {
 742  	// Use the method version which has the correct 512-bit reduction
 743  	r.mulPureGo(a, b)
 744  }
 745  
 746  // scalarGetB32 serializes a scalar to 32 bytes in big-endian format
 747  func scalarGetB32(bin []byte, a *Scalar) {
 748  	if len(bin) != 32 {
 749  		panic("scalar byte array must be 32 bytes")
 750  	}
 751  
 752  	// Convert to big-endian bytes
 753  	for i := 0; i < 4; i++ {
 754  		bin[31-8*i] = byte(a.d[i])
 755  		bin[30-8*i] = byte(a.d[i] >> 8)
 756  		bin[29-8*i] = byte(a.d[i] >> 16)
 757  		bin[28-8*i] = byte(a.d[i] >> 24)
 758  		bin[27-8*i] = byte(a.d[i] >> 32)
 759  		bin[26-8*i] = byte(a.d[i] >> 40)
 760  		bin[25-8*i] = byte(a.d[i] >> 48)
 761  		bin[24-8*i] = byte(a.d[i] >> 56)
 762  	}
 763  }
 764  
 765  // scalarIsZero returns true if the scalar is zero
 766  func scalarIsZero(a *Scalar) bool {
 767  	return a.d[0] == 0 && a.d[1] == 0 && a.d[2] == 0 && a.d[3] == 0
 768  }
 769  
 770  // scalarCheckOverflow checks if the scalar is >= the group order
 771  func scalarCheckOverflow(r *Scalar) bool {
 772  	return (r.d[3] > scalarN3) ||
 773  		(r.d[3] == scalarN3 && r.d[2] > scalarN2) ||
 774  		(r.d[3] == scalarN3 && r.d[2] == scalarN2 && r.d[1] > scalarN1) ||
 775  		(r.d[3] == scalarN3 && r.d[2] == scalarN2 && r.d[1] == scalarN1 && r.d[0] >= scalarN0)
 776  }
 777  
 778  // scalarReduce reduces the scalar modulo the group order
 779  func scalarReduce(r *Scalar, overflow int) {
 780  	var t Scalar
 781  	var c uint64
 782  
 783  	// Compute r + overflow * N_C
 784  	t.d[0], c = bits.Add64(r.d[0], uint64(overflow)*scalarNC0, 0)
 785  	t.d[1], c = bits.Add64(r.d[1], uint64(overflow)*scalarNC1, c)
 786  	t.d[2], c = bits.Add64(r.d[2], uint64(overflow)*scalarNC2, c)
 787  	t.d[3], c = bits.Add64(r.d[3], 0, c)
 788  
 789  	// Mask to keep only the low 256 bits
 790  	r.d[0] = t.d[0] & 0xFFFFFFFFFFFFFFFF
 791  	r.d[1] = t.d[1] & 0xFFFFFFFFFFFFFFFF
 792  	r.d[2] = t.d[2] & 0xFFFFFFFFFFFFFFFF
 793  	r.d[3] = t.d[3] & 0xFFFFFFFFFFFFFFFF
 794  
 795  	// Ensure result is in range [0, N)
 796  	if scalarCheckOverflow(r) {
 797  		scalarReduce(r, 1)
 798  	}
 799  }
 800  
 801  // wNAF converts a scalar to Windowed Non-Adjacent Form representation
 802  // wNAF represents the scalar using digits in the range [-(2^(w-1)-1), 2^(w-1)-1]
 803  // with the property that non-zero digits are separated by at least w-1 zeros.
 804  //
 805  // Returns the number of digits in the wNAF representation (at most 257 for 256-bit scalars)
 806  // and fills the wnaf array with the digits.
 807  func (s *Scalar) wNAF(wnaf *[257]int8, w uint) int {
 808  	if w < 2 || w > 8 {
 809  		panic("w must be between 2 and 8")
 810  	}
 811  
 812  	var k Scalar
 813  	k = *s
 814  
 815  	// Note: We do NOT negate the scalar here. The caller is responsible for
 816  	// ensuring the scalar is in the appropriate form. The ecmultEndoSplit
 817  	// function already handles sign normalization.
 818  
 819  	numBits := 0
 820  	var carry uint32
 821  
 822  	*wnaf = [257]int8{}
 823  
 824  	bit := 0
 825  	for bit < 256 {
 826  		if k.getBits(uint(bit), 1) == carry {
 827  			bit++
 828  			continue
 829  		}
 830  
 831  		window := w
 832  		if bit+int(window) > 256 {
 833  			window = uint(256 - bit)
 834  		}
 835  
 836  		word := uint32(k.getBits(uint(bit), window)) + carry
 837  
 838  		carry = (word >> (window - 1)) & 1
 839  		word -= carry << window
 840  
 841  		wnaf[bit] = int8(int32(word))
 842  		numBits = bit + int(window) - 1
 843  
 844  		bit += int(window)
 845  	}
 846  
 847  	if carry != 0 {
 848  		wnaf[256] = int8(carry)
 849  		numBits = 256
 850  	}
 851  
 852  	return numBits + 1
 853  }
 854  
 855  // wNAFSigned converts a scalar to Windowed Non-Adjacent Form representation,
 856  // handling sign normalization. If the scalar has its high bit set (is "negative"
 857  // in the modular sense), it will be negated and the negated flag will be true.
 858  //
 859  // Returns the number of digits and whether the scalar was negated.
 860  // The caller must negate the result point if negated is true.
 861  func (s *Scalar) wNAFSigned(wnaf *[257]int8, w uint) (int, bool) {
 862  	if w < 2 || w > 8 {
 863  		panic("w must be between 2 and 8")
 864  	}
 865  
 866  	var k Scalar
 867  	k = *s
 868  
 869  	negated := false
 870  	if k.getBits(255, 1) == 1 {
 871  		k.negate(&k)
 872  		negated = true
 873  	}
 874  
 875  	bits := k.wNAF(wnaf, w)
 876  	return bits, negated
 877  }
 878  
 879  // =============================================================================
 880  // GLV Endomorphism Support Functions
 881  // =============================================================================
 882  
 883  // caddBit conditionally adds a power of 2 to the scalar
 884  // If flag is non-zero, adds 2^bit to r
 885  func (r *Scalar) caddBit(bit uint, flag int) {
 886  	if flag == 0 {
 887  		return
 888  	}
 889  
 890  	limbIdx := bit >> 6        // bit / 64
 891  	bitIdx := bit & 0x3F       // bit % 64
 892  	addVal := uint64(1) << bitIdx
 893  
 894  	var carry uint64
 895  	if limbIdx == 0 {
 896  		r.d[0], carry = bits.Add64(r.d[0], addVal, 0)
 897  		r.d[1], carry = bits.Add64(r.d[1], 0, carry)
 898  		r.d[2], carry = bits.Add64(r.d[2], 0, carry)
 899  		r.d[3], _ = bits.Add64(r.d[3], 0, carry)
 900  	} else if limbIdx == 1 {
 901  		r.d[1], carry = bits.Add64(r.d[1], addVal, 0)
 902  		r.d[2], carry = bits.Add64(r.d[2], 0, carry)
 903  		r.d[3], _ = bits.Add64(r.d[3], 0, carry)
 904  	} else if limbIdx == 2 {
 905  		r.d[2], carry = bits.Add64(r.d[2], addVal, 0)
 906  		r.d[3], _ = bits.Add64(r.d[3], 0, carry)
 907  	} else if limbIdx == 3 {
 908  		r.d[3], _ = bits.Add64(r.d[3], addVal, 0)
 909  	}
 910  }
 911  
 912  // mulShiftVar computes r = round((a * b) >> shift) for shift >= 256
 913  // This is used in GLV scalar splitting to compute c1 = round(k * g1 / 2^384)
 914  // The rounding is achieved by adding the bit just below the shift position
 915  func (r *Scalar) mulShiftVar(a, b *Scalar, shift uint) {
 916  	if shift < 256 {
 917  		panic("mulShiftVar requires shift >= 256")
 918  	}
 919  
 920  	// Compute full 512-bit product
 921  	var l [8]uint64
 922  	r.mul512(l[:], a, b)
 923  
 924  	// Extract bits [shift, shift+256) from the 512-bit product
 925  	shiftLimbs := shift >> 6      // Number of full 64-bit limbs to skip
 926  	shiftLow := shift & 0x3F      // Bit offset within the limb
 927  	shiftHigh := 64 - shiftLow    // Complementary shift for combining limbs
 928  
 929  	// Extract each limb of the result
 930  	// For shift=384, shiftLimbs=6, shiftLow=0
 931  	// r.d[0] = l[6], r.d[1] = l[7], r.d[2] = 0, r.d[3] = 0
 932  
 933  	if shift < 512 {
 934  		if shiftLow != 0 {
 935  			r.d[0] = (l[shiftLimbs] >> shiftLow) | (l[shiftLimbs+1] << shiftHigh)
 936  		} else {
 937  			r.d[0] = l[shiftLimbs]
 938  		}
 939  	} else {
 940  		r.d[0] = 0
 941  	}
 942  
 943  	if shift < 448 {
 944  		if shiftLow != 0 && shift < 384 {
 945  			r.d[1] = (l[shiftLimbs+1] >> shiftLow) | (l[shiftLimbs+2] << shiftHigh)
 946  		} else if shiftLow != 0 {
 947  			r.d[1] = l[shiftLimbs+1] >> shiftLow
 948  		} else {
 949  			r.d[1] = l[shiftLimbs+1]
 950  		}
 951  	} else {
 952  		r.d[1] = 0
 953  	}
 954  
 955  	if shift < 384 {
 956  		if shiftLow != 0 && shift < 320 {
 957  			r.d[2] = (l[shiftLimbs+2] >> shiftLow) | (l[shiftLimbs+3] << shiftHigh)
 958  		} else if shiftLow != 0 {
 959  			r.d[2] = l[shiftLimbs+2] >> shiftLow
 960  		} else {
 961  			r.d[2] = l[shiftLimbs+2]
 962  		}
 963  	} else {
 964  		r.d[2] = 0
 965  	}
 966  
 967  	if shift < 320 {
 968  		r.d[3] = l[shiftLimbs+3] >> shiftLow
 969  	} else {
 970  		r.d[3] = 0
 971  	}
 972  
 973  	// Round by adding the bit just below the shift position
 974  	// This implements round() instead of floor()
 975  	roundBit := int((l[(shift-1)>>6] >> ((shift - 1) & 0x3F)) & 1)
 976  	r.caddBit(0, roundBit)
 977  }
 978  
 979  // splitLambda decomposes scalar k into k1, k2 such that k1 + k2*λ ≡ k (mod n)
 980  // where k1 and k2 are approximately 128 bits each.
 981  // This is the core of the GLV endomorphism optimization.
 982  //
 983  // The algorithm uses precomputed constants g1, g2 to compute:
 984  //   c1 = round(k * g1 / 2^384)
 985  //   c2 = round(k * g2 / 2^384)
 986  //   k2 = c1*(-b1) + c2*(-b2)
 987  //   k1 = k - k2*λ
 988  //
 989  // Reference: libsecp256k1 scalar_impl.h:secp256k1_scalar_split_lambda
 990  func scalarSplitLambda(r1, r2, k *Scalar) {
 991  	var c1, c2 Scalar
 992  
 993  	// c1 = round(k * g1 / 2^384)
 994  	c1.mulShiftVar(k, &scalarG1, 384)
 995  
 996  	// c2 = round(k * g2 / 2^384)
 997  	c2.mulShiftVar(k, &scalarG2, 384)
 998  
 999  	// c1 = c1 * (-b1)
1000  	c1.mul(&c1, &scalarMinusB1)
1001  
1002  	// c2 = c2 * (-b2)
1003  	c2.mul(&c2, &scalarMinusB2)
1004  
1005  	// r2 = c1 + c2
1006  	r2.add(&c1, &c2)
1007  
1008  	// r1 = r2 * λ
1009  	r1.mul(r2, &scalarLambda)
1010  
1011  	// r1 = -r1
1012  	r1.negate(r1)
1013  
1014  	// r1 = k + (-r2*λ) = k - r2*λ
1015  	r1.add(r1, k)
1016  }
1017  
1018  // scalarSplit128 splits a scalar into two 128-bit halves
1019  // r1 = k & ((1 << 128) - 1)  (low 128 bits)
1020  // r2 = k >> 128               (high 128 bits)
1021  // This is used for generator multiplication optimization
1022  func scalarSplit128(r1, r2, k *Scalar) {
1023  	r1.d[0] = k.d[0]
1024  	r1.d[1] = k.d[1]
1025  	r1.d[2] = 0
1026  	r1.d[3] = 0
1027  
1028  	r2.d[0] = k.d[2]
1029  	r2.d[1] = k.d[3]
1030  	r2.d[2] = 0
1031  	r2.d[3] = 0
1032  }
1033  
1034