scalar.go raw
1 //go:build !js && !wasm && !tinygo && !wasm32
2
3 package p256k1
4
5 import (
6 "crypto/subtle"
7 "math/bits"
8 "unsafe"
9 )
10
11 // Scalar represents a scalar value modulo the secp256k1 group order.
12 // Uses 4 uint64 limbs to represent a 256-bit scalar.
13 type Scalar struct {
14 d [4]uint64
15 }
16
17 // Scalar constants from the C implementation
18 const (
19 // Limbs of the secp256k1 order n
20 scalarN0 = 0xBFD25E8CD0364141
21 scalarN1 = 0xBAAEDCE6AF48A03B
22 scalarN2 = 0xFFFFFFFFFFFFFFFE
23 scalarN3 = 0xFFFFFFFFFFFFFFFF
24
25 // Limbs of 2^256 minus the secp256k1 order (complement constants)
26 scalarNC0 = 0x402DA1732FC9BEBF // ~scalarN0 + 1
27 scalarNC1 = 0x4551231950B75FC4 // ~scalarN1
28 scalarNC2 = 0x0000000000000001 // 1
29
30 // Limbs of half the secp256k1 order
31 scalarNH0 = 0xDFE92F46681B20A0
32 scalarNH1 = 0x5D576E7357A4501D
33 scalarNH2 = 0xFFFFFFFFFFFFFFFF
34 scalarNH3 = 0x7FFFFFFFFFFFFFFF
35 )
36
37 // Scalar element constants
38 var (
39 // ScalarZero represents the scalar 0
40 ScalarZero = Scalar{d: [4]uint64{0, 0, 0, 0}}
41
42 // ScalarOne represents the scalar 1
43 ScalarOne = Scalar{d: [4]uint64{1, 0, 0, 0}}
44
45 // scalarLambda is the GLV endomorphism constant λ (cube root of unity mod n)
46 // λ^3 ≡ 1 (mod n), and λ^2 + λ + 1 ≡ 0 (mod n)
47 // Value: 0x5363AD4CC05C30E0A5261C028812645A122E22EA20816678DF02967C1B23BD72
48 // From libsecp256k1 scalar_impl.h line 81-84
49 scalarLambda = Scalar{
50 d: [4]uint64{
51 0xDF02967C1B23BD72, // limb 0 (least significant)
52 0x122E22EA20816678, // limb 1
53 0xA5261C028812645A, // limb 2
54 0x5363AD4CC05C30E0, // limb 3 (most significant)
55 },
56 }
57
58 // GLV scalar splitting constants from libsecp256k1 scalar_impl.h lines 142-157
59 // These are used in the splitLambda function to decompose a scalar k
60 // into k1 and k2 such that k1 + k2*λ ≡ k (mod n)
61
62 // scalarMinusB1 = -b1 where b1 is from the GLV basis
63 // Value: 0x00000000000000000000000000000000E4437ED6010E88286F547FA90ABFE4C3
64 scalarMinusB1 = Scalar{
65 d: [4]uint64{
66 0x6F547FA90ABFE4C3, // limb 0
67 0xE4437ED6010E8828, // limb 1
68 0x0000000000000000, // limb 2
69 0x0000000000000000, // limb 3
70 },
71 }
72
73 // scalarMinusB2 = -b2 where b2 is from the GLV basis
74 // Value: 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFE8A280AC50774346DD765CDA83DB1562C
75 scalarMinusB2 = Scalar{
76 d: [4]uint64{
77 0xD765CDA83DB1562C, // limb 0
78 0x8A280AC50774346D, // limb 1
79 0xFFFFFFFFFFFFFFFE, // limb 2
80 0xFFFFFFFFFFFFFFFF, // limb 3
81 },
82 }
83
84 // scalarG1 is a precomputed constant for scalar splitting: g1 = round(2^384 * b2 / n)
85 // Value: 0x3086D221A7D46BCDE86C90E49284EB153DAA8A1471E8CA7FE893209A45DBB031
86 scalarG1 = Scalar{
87 d: [4]uint64{
88 0xE893209A45DBB031, // limb 0
89 0x3DAA8A1471E8CA7F, // limb 1
90 0xE86C90E49284EB15, // limb 2
91 0x3086D221A7D46BCD, // limb 3
92 },
93 }
94
95 // scalarG2 is a precomputed constant for scalar splitting: g2 = round(2^384 * (-b1) / n)
96 // Value: 0xE4437ED6010E88286F547FA90ABFE4C4221208AC9DF506C61571B4AE8AC47F71
97 scalarG2 = Scalar{
98 d: [4]uint64{
99 0x1571B4AE8AC47F71, // limb 0
100 0x221208AC9DF506C6, // limb 1
101 0x6F547FA90ABFE4C4, // limb 2
102 0xE4437ED6010E8828, // limb 3
103 },
104 }
105 )
106
107 // setInt sets a scalar to a small integer value
108 func (r *Scalar) setInt(v uint) {
109 r.d[0] = uint64(v)
110 r.d[1] = 0
111 r.d[2] = 0
112 r.d[3] = 0
113 }
114
115 // setB32 sets a scalar from a 32-byte big-endian array
116 func (r *Scalar) setB32(b []byte) bool {
117 if len(b) != 32 {
118 panic("scalar byte array must be 32 bytes")
119 }
120
121 // Convert from big-endian bytes to uint64 limbs
122 r.d[0] = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
123 uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
124 r.d[1] = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
125 uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
126 r.d[2] = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
127 uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
128 r.d[3] = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
129 uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
130
131 // Check if the scalar overflows the group order
132 overflow := r.checkOverflow()
133 if overflow {
134 r.reduce(1)
135 }
136
137 return overflow
138 }
139
140 // setB32Seckey sets a scalar from a 32-byte secret key, returns true if valid
141 func (r *Scalar) setB32Seckey(b []byte) bool {
142 overflow := r.setB32(b)
143 return !r.isZero() && !overflow
144 }
145
146 // getB32 converts a scalar to a 32-byte big-endian array
147 func (r *Scalar) getB32(b []byte) {
148 if len(b) != 32 {
149 panic("scalar byte array must be 32 bytes")
150 }
151
152 // Convert from uint64 limbs to big-endian bytes
153 b[31] = byte(r.d[0])
154 b[30] = byte(r.d[0] >> 8)
155 b[29] = byte(r.d[0] >> 16)
156 b[28] = byte(r.d[0] >> 24)
157 b[27] = byte(r.d[0] >> 32)
158 b[26] = byte(r.d[0] >> 40)
159 b[25] = byte(r.d[0] >> 48)
160 b[24] = byte(r.d[0] >> 56)
161
162 b[23] = byte(r.d[1])
163 b[22] = byte(r.d[1] >> 8)
164 b[21] = byte(r.d[1] >> 16)
165 b[20] = byte(r.d[1] >> 24)
166 b[19] = byte(r.d[1] >> 32)
167 b[18] = byte(r.d[1] >> 40)
168 b[17] = byte(r.d[1] >> 48)
169 b[16] = byte(r.d[1] >> 56)
170
171 b[15] = byte(r.d[2])
172 b[14] = byte(r.d[2] >> 8)
173 b[13] = byte(r.d[2] >> 16)
174 b[12] = byte(r.d[2] >> 24)
175 b[11] = byte(r.d[2] >> 32)
176 b[10] = byte(r.d[2] >> 40)
177 b[9] = byte(r.d[2] >> 48)
178 b[8] = byte(r.d[2] >> 56)
179
180 b[7] = byte(r.d[3])
181 b[6] = byte(r.d[3] >> 8)
182 b[5] = byte(r.d[3] >> 16)
183 b[4] = byte(r.d[3] >> 24)
184 b[3] = byte(r.d[3] >> 32)
185 b[2] = byte(r.d[3] >> 40)
186 b[1] = byte(r.d[3] >> 48)
187 b[0] = byte(r.d[3] >> 56)
188 }
189
190 // checkOverflow checks if the scalar is >= the group order
191 func (r *Scalar) checkOverflow() bool {
192 yes := 0
193 no := 0
194
195 // Check each limb from most significant to least significant
196 if r.d[3] < scalarN3 {
197 no = 1
198 }
199 if r.d[3] > scalarN3 {
200 yes = 1
201 }
202
203 if r.d[2] < scalarN2 {
204 no |= (yes ^ 1)
205 }
206 if r.d[2] > scalarN2 {
207 yes |= (no ^ 1)
208 }
209
210 if r.d[1] < scalarN1 {
211 no |= (yes ^ 1)
212 }
213 if r.d[1] > scalarN1 {
214 yes |= (no ^ 1)
215 }
216
217 if r.d[0] >= scalarN0 {
218 yes |= (no ^ 1)
219 }
220
221 return yes != 0
222 }
223
224 // reduce reduces the scalar modulo the group order
225 func (r *Scalar) reduce(overflow int) {
226 if overflow < 0 || overflow > 1 {
227 panic("overflow must be 0 or 1")
228 }
229
230 // Use 128-bit arithmetic for the reduction
231 var t uint128
232
233 // d[0] += overflow * scalarNC0
234 t = uint128FromU64(r.d[0])
235 t = t.addU64(uint64(overflow) * scalarNC0)
236 r.d[0] = t.lo()
237 t = t.rshift(64)
238
239 // d[1] += overflow * scalarNC1 + carry
240 t = t.addU64(r.d[1])
241 t = t.addU64(uint64(overflow) * scalarNC1)
242 r.d[1] = t.lo()
243 t = t.rshift(64)
244
245 // d[2] += overflow * scalarNC2 + carry
246 t = t.addU64(r.d[2])
247 t = t.addU64(uint64(overflow) * scalarNC2)
248 r.d[2] = t.lo()
249 t = t.rshift(64)
250
251 // d[3] += carry (scalarNC3 = 0)
252 t = t.addU64(r.d[3])
253 r.d[3] = t.lo()
254 }
255
256 // add adds two scalars: r = a + b, returns overflow
257 func (r *Scalar) add(a, b *Scalar) bool {
258 // Use AVX2 if available (AMD64 only)
259 if HasAVX2() {
260 scalarAddAVX2(r, a, b)
261 return false // AVX2 version handles reduction internally
262 }
263 return r.addPureGo(a, b)
264 }
265
266 // addPureGo is the pure Go implementation of scalar addition
267 func (r *Scalar) addPureGo(a, b *Scalar) bool {
268 var carry uint64
269
270 r.d[0], carry = bits.Add64(a.d[0], b.d[0], 0)
271 r.d[1], carry = bits.Add64(a.d[1], b.d[1], carry)
272 r.d[2], carry = bits.Add64(a.d[2], b.d[2], carry)
273 r.d[3], carry = bits.Add64(a.d[3], b.d[3], carry)
274
275 overflow := carry != 0 || r.checkOverflow()
276 if overflow {
277 r.reduce(1)
278 }
279
280 return overflow
281 }
282
283 // sub subtracts two scalars: r = a - b
284 func (r *Scalar) sub(a, b *Scalar) {
285 // Use AVX2 if available (AMD64 only)
286 if HasAVX2() {
287 scalarSubAVX2(r, a, b)
288 return
289 }
290 r.subPureGo(a, b)
291 }
292
293 // subPureGo is the pure Go implementation of scalar subtraction
294 func (r *Scalar) subPureGo(a, b *Scalar) {
295 // Compute a - b = a + (-b)
296 var negB Scalar
297 negB.negate(b)
298 *r = *a
299 r.addPureGo(r, &negB)
300 }
301
302 // mul multiplies two scalars: r = a * b
303 func (r *Scalar) mul(a, b *Scalar) {
304 // Use AVX2 if available (AMD64 only)
305 if HasAVX2() {
306 scalarMulAVX2(r, a, b)
307 return
308 }
309 r.mulPureGo(a, b)
310 }
311
312 // mulPureGo is the pure Go implementation of scalar multiplication
313 func (r *Scalar) mulPureGo(a, b *Scalar) {
314 // Compute full 512-bit product using all 16 cross products
315 var l [8]uint64
316 r.mul512(l[:], a, b)
317 r.reduce512(l[:])
318 }
319
320 // mul512 computes the 512-bit product of two scalars (from C implementation)
321 func (r *Scalar) mul512(l8 []uint64, a, b *Scalar) {
322 // 160-bit accumulator (c0, c1, c2)
323 var c0, c1 uint64
324 var c2 uint32
325
326 // Helper macros translated from C
327 muladd := func(ai, bi uint64) {
328 hi, lo := bits.Mul64(ai, bi)
329 var carry uint64
330 c0, carry = bits.Add64(c0, lo, 0)
331 c1, carry = bits.Add64(c1, hi, carry)
332 c2 += uint32(carry)
333 }
334
335 muladdFast := func(ai, bi uint64) {
336 hi, lo := bits.Mul64(ai, bi)
337 var carry uint64
338 c0, carry = bits.Add64(c0, lo, 0)
339 c1 += hi + carry
340 }
341
342 extract := func() uint64 {
343 result := c0
344 c0 = c1
345 c1 = uint64(c2)
346 c2 = 0
347 return result
348 }
349
350 extractFast := func() uint64 {
351 result := c0
352 c0 = c1
353 c1 = 0
354 return result
355 }
356
357 // l8[0..7] = a[0..3] * b[0..3] (following C implementation exactly)
358 muladdFast(a.d[0], b.d[0])
359 l8[0] = extractFast()
360
361 muladd(a.d[0], b.d[1])
362 muladd(a.d[1], b.d[0])
363 l8[1] = extract()
364
365 muladd(a.d[0], b.d[2])
366 muladd(a.d[1], b.d[1])
367 muladd(a.d[2], b.d[0])
368 l8[2] = extract()
369
370 muladd(a.d[0], b.d[3])
371 muladd(a.d[1], b.d[2])
372 muladd(a.d[2], b.d[1])
373 muladd(a.d[3], b.d[0])
374 l8[3] = extract()
375
376 muladd(a.d[1], b.d[3])
377 muladd(a.d[2], b.d[2])
378 muladd(a.d[3], b.d[1])
379 l8[4] = extract()
380
381 muladd(a.d[2], b.d[3])
382 muladd(a.d[3], b.d[2])
383 l8[5] = extract()
384
385 muladdFast(a.d[3], b.d[3])
386 l8[6] = extractFast()
387 l8[7] = c0
388 }
389
390 // reduce512 reduces a 512-bit value to 256-bit (from C implementation)
391 func (r *Scalar) reduce512(l []uint64) {
392 // 160-bit accumulator
393 var c0, c1 uint64
394 var c2 uint32
395
396 // Extract upper 256 bits
397 n0, n1, n2, n3 := l[4], l[5], l[6], l[7]
398
399 // Helper macros
400 muladd := func(ai, bi uint64) {
401 hi, lo := bits.Mul64(ai, bi)
402 var carry uint64
403 c0, carry = bits.Add64(c0, lo, 0)
404 c1, carry = bits.Add64(c1, hi, carry)
405 c2 += uint32(carry)
406 }
407
408 muladdFast := func(ai, bi uint64) {
409 hi, lo := bits.Mul64(ai, bi)
410 var carry uint64
411 c0, carry = bits.Add64(c0, lo, 0)
412 c1 += hi + carry
413 }
414
415 sumadd := func(a uint64) {
416 var carry uint64
417 c0, carry = bits.Add64(c0, a, 0)
418 c1, carry = bits.Add64(c1, 0, carry)
419 c2 += uint32(carry)
420 }
421
422 sumaddFast := func(a uint64) {
423 var carry uint64
424 c0, carry = bits.Add64(c0, a, 0)
425 c1 += carry
426 }
427
428 extract := func() uint64 {
429 result := c0
430 c0 = c1
431 c1 = uint64(c2)
432 c2 = 0
433 return result
434 }
435
436 extractFast := func() uint64 {
437 result := c0
438 c0 = c1
439 c1 = 0
440 return result
441 }
442
443 // Reduce 512 bits into 385 bits
444 // m[0..6] = l[0..3] + n[0..3] * SECP256K1_N_C
445 c0 = l[0]
446 c1 = 0
447 c2 = 0
448 muladdFast(n0, scalarNC0)
449 m0 := extractFast()
450
451 sumaddFast(l[1])
452 muladd(n1, scalarNC0)
453 muladd(n0, scalarNC1)
454 m1 := extract()
455
456 sumadd(l[2])
457 muladd(n2, scalarNC0)
458 muladd(n1, scalarNC1)
459 sumadd(n0)
460 m2 := extract()
461
462 sumadd(l[3])
463 muladd(n3, scalarNC0)
464 muladd(n2, scalarNC1)
465 sumadd(n1)
466 m3 := extract()
467
468 muladd(n3, scalarNC1)
469 sumadd(n2)
470 m4 := extract()
471
472 sumaddFast(n3)
473 m5 := extractFast()
474 m6 := uint32(c0)
475
476 // Reduce 385 bits into 258 bits
477 // p[0..4] = m[0..3] + m[4..6] * SECP256K1_N_C
478 c0 = m0
479 c1 = 0
480 c2 = 0
481 muladdFast(m4, scalarNC0)
482 p0 := extractFast()
483
484 sumaddFast(m1)
485 muladd(m5, scalarNC0)
486 muladd(m4, scalarNC1)
487 p1 := extract()
488
489 sumadd(m2)
490 muladd(uint64(m6), scalarNC0)
491 muladd(m5, scalarNC1)
492 sumadd(m4)
493 p2 := extract()
494
495 sumaddFast(m3)
496 muladdFast(uint64(m6), scalarNC1)
497 sumaddFast(m5)
498 p3 := extractFast()
499 p4 := uint32(c0 + uint64(m6))
500
501 // Reduce 258 bits into 256 bits
502 // r[0..3] = p[0..3] + p[4] * SECP256K1_N_C
503 var t uint128
504
505 t = uint128FromU64(p0)
506 t = t.addMul(scalarNC0, uint64(p4))
507 r.d[0] = t.lo()
508 t = t.rshift(64)
509
510 t = t.addU64(p1)
511 t = t.addMul(scalarNC1, uint64(p4))
512 r.d[1] = t.lo()
513 t = t.rshift(64)
514
515 t = t.addU64(p2)
516 t = t.addU64(uint64(p4))
517 r.d[2] = t.lo()
518 t = t.rshift(64)
519
520 t = t.addU64(p3)
521 r.d[3] = t.lo()
522 c := t.hi()
523
524 // Final reduction
525 r.reduce(int(c) + boolToInt(r.checkOverflow()))
526 }
527
528 // negate negates a scalar: r = -a
529 func (r *Scalar) negate(a *Scalar) {
530 // r = n - a where n is the group order
531 var borrow uint64
532
533 r.d[0], borrow = bits.Sub64(scalarN0, a.d[0], 0)
534 r.d[1], borrow = bits.Sub64(scalarN1, a.d[1], borrow)
535 r.d[2], borrow = bits.Sub64(scalarN2, a.d[2], borrow)
536 r.d[3], _ = bits.Sub64(scalarN3, a.d[3], borrow)
537 }
538
539 // inverse computes the modular inverse of a scalar
540 func (r *Scalar) inverse(a *Scalar) {
541 // Use Fermat's little theorem: a^(-1) = a^(n-2) mod n
542 // where n is the group order (which is prime)
543
544 // Use binary exponentiation with n-2
545 var exp Scalar
546 var borrow uint64
547 exp.d[0], borrow = bits.Sub64(scalarN0, 2, 0)
548 exp.d[1], borrow = bits.Sub64(scalarN1, 0, borrow)
549 exp.d[2], borrow = bits.Sub64(scalarN2, 0, borrow)
550 exp.d[3], _ = bits.Sub64(scalarN3, 0, borrow)
551
552 r.exp(a, &exp)
553 }
554
555 // exp computes r = a^b mod n using binary exponentiation
556 func (r *Scalar) exp(a, b *Scalar) {
557 *r = ScalarOne
558 base := *a
559
560 for i := 0; i < 4; i++ {
561 limb := b.d[i]
562 for j := 0; j < 64; j++ {
563 if limb&1 != 0 {
564 r.mul(r, &base)
565 }
566 base.mul(&base, &base)
567 limb >>= 1
568 }
569 }
570 }
571
572 // half computes r = a/2 mod n
573 func (r *Scalar) half(a *Scalar) {
574 *r = *a
575
576 if r.d[0]&1 == 0 {
577 // Even case: simple right shift
578 r.d[0] = (r.d[0] >> 1) | ((r.d[1] & 1) << 63)
579 r.d[1] = (r.d[1] >> 1) | ((r.d[2] & 1) << 63)
580 r.d[2] = (r.d[2] >> 1) | ((r.d[3] & 1) << 63)
581 r.d[3] = r.d[3] >> 1
582 } else {
583 // Odd case: add n then divide by 2
584 var carry uint64
585 r.d[0], carry = bits.Add64(r.d[0], scalarN0, 0)
586 r.d[1], carry = bits.Add64(r.d[1], scalarN1, carry)
587 r.d[2], carry = bits.Add64(r.d[2], scalarN2, carry)
588 r.d[3], _ = bits.Add64(r.d[3], scalarN3, carry)
589
590 // Now divide by 2
591 r.d[0] = (r.d[0] >> 1) | ((r.d[1] & 1) << 63)
592 r.d[1] = (r.d[1] >> 1) | ((r.d[2] & 1) << 63)
593 r.d[2] = (r.d[2] >> 1) | ((r.d[3] & 1) << 63)
594 r.d[3] = r.d[3] >> 1
595 }
596 }
597
598 // isZero returns true if the scalar is zero
599 func (r *Scalar) isZero() bool {
600 return (r.d[0] | r.d[1] | r.d[2] | r.d[3]) == 0
601 }
602
603 // isOne returns true if the scalar is one
604 func (r *Scalar) isOne() bool {
605 return r.d[0] == 1 && r.d[1] == 0 && r.d[2] == 0 && r.d[3] == 0
606 }
607
608 // isEven returns true if the scalar is even
609 func (r *Scalar) isEven() bool {
610 return r.d[0]&1 == 0
611 }
612
613 // isHigh returns true if the scalar is > n/2
614 func (r *Scalar) isHigh() bool {
615 var yes, no int
616
617 if r.d[3] < scalarNH3 {
618 no = 1
619 }
620 if r.d[3] > scalarNH3 {
621 yes = 1
622 }
623
624 if r.d[2] < scalarNH2 {
625 no |= (yes ^ 1)
626 }
627 if r.d[2] > scalarNH2 {
628 yes |= (no ^ 1)
629 }
630
631 if r.d[1] < scalarNH1 {
632 no |= (yes ^ 1)
633 }
634 if r.d[1] > scalarNH1 {
635 yes |= (no ^ 1)
636 }
637
638 if r.d[0] > scalarNH0 {
639 yes |= (no ^ 1)
640 }
641
642 return yes != 0
643 }
644
645 // condNegate conditionally negates the scalar if flag is true
646 func (r *Scalar) condNegate(flag int) {
647 if flag != 0 {
648 var neg Scalar
649 neg.negate(r)
650 *r = neg
651 }
652 }
653
654 // equal returns true if two scalars are equal
655 func (r *Scalar) equal(a *Scalar) bool {
656 return subtle.ConstantTimeCompare(
657 (*[32]byte)(unsafe.Pointer(&r.d[0]))[:32],
658 (*[32]byte)(unsafe.Pointer(&a.d[0]))[:32],
659 ) == 1
660 }
661
662 // getBits extracts count bits starting at offset
663 func (r *Scalar) getBits(offset, count uint) uint32 {
664 if count == 0 || count > 32 {
665 panic("count must be 1-32")
666 }
667 if offset+count > 256 {
668 panic("offset + count must be <= 256")
669 }
670
671 limbIdx := offset / 64
672 bitIdx := offset % 64
673
674 if bitIdx+count <= 64 {
675 // Bits are within a single limb
676 return uint32((r.d[limbIdx] >> bitIdx) & ((1 << count) - 1))
677 } else {
678 // Bits span two limbs
679 lowBits := 64 - bitIdx
680 highBits := count - lowBits
681 low := uint32((r.d[limbIdx] >> bitIdx) & ((1 << lowBits) - 1))
682 high := uint32(r.d[limbIdx+1] & ((1 << highBits) - 1))
683 return low | (high << lowBits)
684 }
685 }
686
687 // cmov conditionally moves a scalar. If flag is true, r = a; otherwise r is unchanged.
688 func (r *Scalar) cmov(a *Scalar, flag int) {
689 mask := uint64(-(int64(flag) & 1))
690 r.d[0] ^= mask & (r.d[0] ^ a.d[0])
691 r.d[1] ^= mask & (r.d[1] ^ a.d[1])
692 r.d[2] ^= mask & (r.d[2] ^ a.d[2])
693 r.d[3] ^= mask & (r.d[3] ^ a.d[3])
694 }
695
696 // clear clears a scalar to prevent leaking sensitive information
697 func (r *Scalar) clear() {
698 memclear(unsafe.Pointer(&r.d[0]), unsafe.Sizeof(r.d))
699 }
700
701 // Helper functions for 128-bit arithmetic (using uint128 from field_mul.go)
702
703 func uint128FromU64(x uint64) uint128 {
704 return uint128{low: x, high: 0}
705 }
706
707 func (x uint128) addU64(y uint64) uint128 {
708 low, carry := bits.Add64(x.low, y, 0)
709 high := x.high + carry
710 return uint128{low: low, high: high}
711 }
712
713 func (x uint128) addMul(a, b uint64) uint128 {
714 hi, lo := bits.Mul64(a, b)
715 low, carry := bits.Add64(x.low, lo, 0)
716 high, _ := bits.Add64(x.high, hi, carry)
717 return uint128{low: low, high: high}
718 }
719
720 // Direct function versions to reduce method call overhead
721 // These are equivalent to the method versions but avoid interface dispatch
722
723 // scalarAdd adds two scalars: r = a + b, returns overflow
724 func scalarAdd(r, a, b *Scalar) bool {
725 var carry uint64
726
727 r.d[0], carry = bits.Add64(a.d[0], b.d[0], 0)
728 r.d[1], carry = bits.Add64(a.d[1], b.d[1], carry)
729 r.d[2], carry = bits.Add64(a.d[2], b.d[2], carry)
730 r.d[3], carry = bits.Add64(a.d[3], b.d[3], carry)
731
732 overflow := carry != 0 || scalarCheckOverflow(r)
733 if overflow {
734 scalarReduce(r, 1)
735 }
736
737 return overflow
738 }
739
740 // scalarMul multiplies two scalars: r = a * b
741 func scalarMul(r, a, b *Scalar) {
742 // Use the method version which has the correct 512-bit reduction
743 r.mulPureGo(a, b)
744 }
745
746 // scalarGetB32 serializes a scalar to 32 bytes in big-endian format
747 func scalarGetB32(bin []byte, a *Scalar) {
748 if len(bin) != 32 {
749 panic("scalar byte array must be 32 bytes")
750 }
751
752 // Convert to big-endian bytes
753 for i := 0; i < 4; i++ {
754 bin[31-8*i] = byte(a.d[i])
755 bin[30-8*i] = byte(a.d[i] >> 8)
756 bin[29-8*i] = byte(a.d[i] >> 16)
757 bin[28-8*i] = byte(a.d[i] >> 24)
758 bin[27-8*i] = byte(a.d[i] >> 32)
759 bin[26-8*i] = byte(a.d[i] >> 40)
760 bin[25-8*i] = byte(a.d[i] >> 48)
761 bin[24-8*i] = byte(a.d[i] >> 56)
762 }
763 }
764
765 // scalarIsZero returns true if the scalar is zero
766 func scalarIsZero(a *Scalar) bool {
767 return a.d[0] == 0 && a.d[1] == 0 && a.d[2] == 0 && a.d[3] == 0
768 }
769
770 // scalarCheckOverflow checks if the scalar is >= the group order
771 func scalarCheckOverflow(r *Scalar) bool {
772 return (r.d[3] > scalarN3) ||
773 (r.d[3] == scalarN3 && r.d[2] > scalarN2) ||
774 (r.d[3] == scalarN3 && r.d[2] == scalarN2 && r.d[1] > scalarN1) ||
775 (r.d[3] == scalarN3 && r.d[2] == scalarN2 && r.d[1] == scalarN1 && r.d[0] >= scalarN0)
776 }
777
778 // scalarReduce reduces the scalar modulo the group order
779 func scalarReduce(r *Scalar, overflow int) {
780 var t Scalar
781 var c uint64
782
783 // Compute r + overflow * N_C
784 t.d[0], c = bits.Add64(r.d[0], uint64(overflow)*scalarNC0, 0)
785 t.d[1], c = bits.Add64(r.d[1], uint64(overflow)*scalarNC1, c)
786 t.d[2], c = bits.Add64(r.d[2], uint64(overflow)*scalarNC2, c)
787 t.d[3], c = bits.Add64(r.d[3], 0, c)
788
789 // Mask to keep only the low 256 bits
790 r.d[0] = t.d[0] & 0xFFFFFFFFFFFFFFFF
791 r.d[1] = t.d[1] & 0xFFFFFFFFFFFFFFFF
792 r.d[2] = t.d[2] & 0xFFFFFFFFFFFFFFFF
793 r.d[3] = t.d[3] & 0xFFFFFFFFFFFFFFFF
794
795 // Ensure result is in range [0, N)
796 if scalarCheckOverflow(r) {
797 scalarReduce(r, 1)
798 }
799 }
800
801 // wNAF converts a scalar to Windowed Non-Adjacent Form representation
802 // wNAF represents the scalar using digits in the range [-(2^(w-1)-1), 2^(w-1)-1]
803 // with the property that non-zero digits are separated by at least w-1 zeros.
804 //
805 // Returns the number of digits in the wNAF representation (at most 257 for 256-bit scalars)
806 // and fills the wnaf array with the digits.
807 func (s *Scalar) wNAF(wnaf *[257]int8, w uint) int {
808 if w < 2 || w > 8 {
809 panic("w must be between 2 and 8")
810 }
811
812 var k Scalar
813 k = *s
814
815 // Note: We do NOT negate the scalar here. The caller is responsible for
816 // ensuring the scalar is in the appropriate form. The ecmultEndoSplit
817 // function already handles sign normalization.
818
819 numBits := 0
820 var carry uint32
821
822 *wnaf = [257]int8{}
823
824 bit := 0
825 for bit < 256 {
826 if k.getBits(uint(bit), 1) == carry {
827 bit++
828 continue
829 }
830
831 window := w
832 if bit+int(window) > 256 {
833 window = uint(256 - bit)
834 }
835
836 word := uint32(k.getBits(uint(bit), window)) + carry
837
838 carry = (word >> (window - 1)) & 1
839 word -= carry << window
840
841 wnaf[bit] = int8(int32(word))
842 numBits = bit + int(window) - 1
843
844 bit += int(window)
845 }
846
847 if carry != 0 {
848 wnaf[256] = int8(carry)
849 numBits = 256
850 }
851
852 return numBits + 1
853 }
854
855 // wNAFSigned converts a scalar to Windowed Non-Adjacent Form representation,
856 // handling sign normalization. If the scalar has its high bit set (is "negative"
857 // in the modular sense), it will be negated and the negated flag will be true.
858 //
859 // Returns the number of digits and whether the scalar was negated.
860 // The caller must negate the result point if negated is true.
861 func (s *Scalar) wNAFSigned(wnaf *[257]int8, w uint) (int, bool) {
862 if w < 2 || w > 8 {
863 panic("w must be between 2 and 8")
864 }
865
866 var k Scalar
867 k = *s
868
869 negated := false
870 if k.getBits(255, 1) == 1 {
871 k.negate(&k)
872 negated = true
873 }
874
875 bits := k.wNAF(wnaf, w)
876 return bits, negated
877 }
878
879 // =============================================================================
880 // GLV Endomorphism Support Functions
881 // =============================================================================
882
883 // caddBit conditionally adds a power of 2 to the scalar
884 // If flag is non-zero, adds 2^bit to r
885 func (r *Scalar) caddBit(bit uint, flag int) {
886 if flag == 0 {
887 return
888 }
889
890 limbIdx := bit >> 6 // bit / 64
891 bitIdx := bit & 0x3F // bit % 64
892 addVal := uint64(1) << bitIdx
893
894 var carry uint64
895 if limbIdx == 0 {
896 r.d[0], carry = bits.Add64(r.d[0], addVal, 0)
897 r.d[1], carry = bits.Add64(r.d[1], 0, carry)
898 r.d[2], carry = bits.Add64(r.d[2], 0, carry)
899 r.d[3], _ = bits.Add64(r.d[3], 0, carry)
900 } else if limbIdx == 1 {
901 r.d[1], carry = bits.Add64(r.d[1], addVal, 0)
902 r.d[2], carry = bits.Add64(r.d[2], 0, carry)
903 r.d[3], _ = bits.Add64(r.d[3], 0, carry)
904 } else if limbIdx == 2 {
905 r.d[2], carry = bits.Add64(r.d[2], addVal, 0)
906 r.d[3], _ = bits.Add64(r.d[3], 0, carry)
907 } else if limbIdx == 3 {
908 r.d[3], _ = bits.Add64(r.d[3], addVal, 0)
909 }
910 }
911
912 // mulShiftVar computes r = round((a * b) >> shift) for shift >= 256
913 // This is used in GLV scalar splitting to compute c1 = round(k * g1 / 2^384)
914 // The rounding is achieved by adding the bit just below the shift position
915 func (r *Scalar) mulShiftVar(a, b *Scalar, shift uint) {
916 if shift < 256 {
917 panic("mulShiftVar requires shift >= 256")
918 }
919
920 // Compute full 512-bit product
921 var l [8]uint64
922 r.mul512(l[:], a, b)
923
924 // Extract bits [shift, shift+256) from the 512-bit product
925 shiftLimbs := shift >> 6 // Number of full 64-bit limbs to skip
926 shiftLow := shift & 0x3F // Bit offset within the limb
927 shiftHigh := 64 - shiftLow // Complementary shift for combining limbs
928
929 // Extract each limb of the result
930 // For shift=384, shiftLimbs=6, shiftLow=0
931 // r.d[0] = l[6], r.d[1] = l[7], r.d[2] = 0, r.d[3] = 0
932
933 if shift < 512 {
934 if shiftLow != 0 {
935 r.d[0] = (l[shiftLimbs] >> shiftLow) | (l[shiftLimbs+1] << shiftHigh)
936 } else {
937 r.d[0] = l[shiftLimbs]
938 }
939 } else {
940 r.d[0] = 0
941 }
942
943 if shift < 448 {
944 if shiftLow != 0 && shift < 384 {
945 r.d[1] = (l[shiftLimbs+1] >> shiftLow) | (l[shiftLimbs+2] << shiftHigh)
946 } else if shiftLow != 0 {
947 r.d[1] = l[shiftLimbs+1] >> shiftLow
948 } else {
949 r.d[1] = l[shiftLimbs+1]
950 }
951 } else {
952 r.d[1] = 0
953 }
954
955 if shift < 384 {
956 if shiftLow != 0 && shift < 320 {
957 r.d[2] = (l[shiftLimbs+2] >> shiftLow) | (l[shiftLimbs+3] << shiftHigh)
958 } else if shiftLow != 0 {
959 r.d[2] = l[shiftLimbs+2] >> shiftLow
960 } else {
961 r.d[2] = l[shiftLimbs+2]
962 }
963 } else {
964 r.d[2] = 0
965 }
966
967 if shift < 320 {
968 r.d[3] = l[shiftLimbs+3] >> shiftLow
969 } else {
970 r.d[3] = 0
971 }
972
973 // Round by adding the bit just below the shift position
974 // This implements round() instead of floor()
975 roundBit := int((l[(shift-1)>>6] >> ((shift - 1) & 0x3F)) & 1)
976 r.caddBit(0, roundBit)
977 }
978
979 // splitLambda decomposes scalar k into k1, k2 such that k1 + k2*λ ≡ k (mod n)
980 // where k1 and k2 are approximately 128 bits each.
981 // This is the core of the GLV endomorphism optimization.
982 //
983 // The algorithm uses precomputed constants g1, g2 to compute:
984 // c1 = round(k * g1 / 2^384)
985 // c2 = round(k * g2 / 2^384)
986 // k2 = c1*(-b1) + c2*(-b2)
987 // k1 = k - k2*λ
988 //
989 // Reference: libsecp256k1 scalar_impl.h:secp256k1_scalar_split_lambda
990 func scalarSplitLambda(r1, r2, k *Scalar) {
991 var c1, c2 Scalar
992
993 // c1 = round(k * g1 / 2^384)
994 c1.mulShiftVar(k, &scalarG1, 384)
995
996 // c2 = round(k * g2 / 2^384)
997 c2.mulShiftVar(k, &scalarG2, 384)
998
999 // c1 = c1 * (-b1)
1000 c1.mul(&c1, &scalarMinusB1)
1001
1002 // c2 = c2 * (-b2)
1003 c2.mul(&c2, &scalarMinusB2)
1004
1005 // r2 = c1 + c2
1006 r2.add(&c1, &c2)
1007
1008 // r1 = r2 * λ
1009 r1.mul(r2, &scalarLambda)
1010
1011 // r1 = -r1
1012 r1.negate(r1)
1013
1014 // r1 = k + (-r2*λ) = k - r2*λ
1015 r1.add(r1, k)
1016 }
1017
1018 // scalarSplit128 splits a scalar into two 128-bit halves
1019 // r1 = k & ((1 << 128) - 1) (low 128 bits)
1020 // r2 = k >> 128 (high 128 bits)
1021 // This is used for generator multiplication optimization
1022 func scalarSplit128(r1, r2, k *Scalar) {
1023 r1.d[0] = k.d[0]
1024 r1.d[1] = k.d[1]
1025 r1.d[2] = 0
1026 r1.d[3] = 0
1027
1028 r2.d[0] = k.d[2]
1029 r2.d[1] = k.d[3]
1030 r2.d[2] = 0
1031 r2.d[3] = 0
1032 }
1033
1034