scalar_32bit.go raw
1 //go:build js || wasm || tinygo || wasm32
2
3 // Copyright (c) 2024 mleku
4 // Adapted from github.com/decred/dcrd/dcrec/secp256k1/v4
5 // Copyright (c) 2020-2024 The Decred developers
6
7 package p256k1
8
9 import (
10 "crypto/subtle"
11 "unsafe"
12 )
13
14 // Scalar represents a scalar value modulo the secp256k1 group order.
15 // This implementation uses 8 uint32 limbs in base 2^32, optimized for 32-bit platforms.
16 type Scalar struct {
17 n [8]uint32
18 }
19
20 // Scalar constants in 8x32 representation
21 const (
22 // Order words (from least to most significant)
23 orderWord0 uint32 = 0xd0364141
24 orderWord1 uint32 = 0xbfd25e8c
25 orderWord2 uint32 = 0xaf48a03b
26 orderWord3 uint32 = 0xbaaedce6
27 orderWord4 uint32 = 0xfffffffe
28 orderWord5 uint32 = 0xffffffff
29 orderWord6 uint32 = 0xffffffff
30 orderWord7 uint32 = 0xffffffff
31
32 // Two's complement of order (for reduction)
33 orderCompWord0 uint32 = 0x2fc9bebf // ~orderWord0 + 1
34 orderCompWord1 uint32 = 0x402da173 // ~orderWord1
35 orderCompWord2 uint32 = 0x50b75fc4 // ~orderWord2
36 orderCompWord3 uint32 = 0x45512319 // ~orderWord3
37
38 // Half order words
39 halfOrderWord0 uint32 = 0x681b20a0
40 halfOrderWord1 uint32 = 0xdfe92f46
41 halfOrderWord2 uint32 = 0x57a4501d
42 halfOrderWord3 uint32 = 0x5d576e73
43 halfOrderWord4 uint32 = 0xffffffff
44 halfOrderWord5 uint32 = 0xffffffff
45 halfOrderWord6 uint32 = 0xffffffff
46 halfOrderWord7 uint32 = 0x7fffffff
47
48 uint32Mask = 0xffffffff
49 )
50
51 // Scalar element constants
52 var (
53 ScalarZero = Scalar{n: [8]uint32{0, 0, 0, 0, 0, 0, 0, 0}}
54 ScalarOne = Scalar{n: [8]uint32{1, 0, 0, 0, 0, 0, 0, 0}}
55
56 // GLV constants in 8x32 representation
57 scalarLambda = Scalar{
58 n: [8]uint32{
59 0x1b23bd72, 0xdf02967c, 0x20816678, 0x122e22ea,
60 0x8812645a, 0xa5261c02, 0xc05c30e0, 0x5363ad4c,
61 },
62 }
63
64 scalarMinusB1 = Scalar{
65 n: [8]uint32{
66 0x0abfe4c3, 0x6f547fa9, 0x010e8828, 0xe4437ed6,
67 0x00000000, 0x00000000, 0x00000000, 0x00000000,
68 },
69 }
70
71 scalarMinusB2 = Scalar{
72 n: [8]uint32{
73 0x3db1562c, 0xd765cda8, 0x0774346d, 0x8a280ac5,
74 0xfffffffe, 0xffffffff, 0xffffffff, 0xffffffff,
75 },
76 }
77
78 scalarG1 = Scalar{
79 n: [8]uint32{
80 0x45dbb031, 0xe893209a, 0x71e8ca7f, 0x3daa8a14,
81 0x9284eb15, 0xe86c90e4, 0xa7d46bcd, 0x3086d221,
82 },
83 }
84
85 scalarG2 = Scalar{
86 n: [8]uint32{
87 0x8ac47f71, 0x1571b4ae, 0x9df506c6, 0x221208ac,
88 0x0abfe4c4, 0x6f547fa9, 0x010e8828, 0xe4437ed6,
89 },
90 }
91 )
92
93 // setInt sets a scalar to a small integer value
94 func (s *Scalar) setInt(v uint) {
95 s.n[0] = uint32(v)
96 for i := 1; i < 8; i++ {
97 s.n[i] = 0
98 }
99 }
100
101 // setB32 sets a scalar from a 32-byte big-endian array
102 func (s *Scalar) setB32(b []byte) bool {
103 if len(b) != 32 {
104 panic("scalar byte array must be 32 bytes")
105 }
106
107 s.n[0] = uint32(b[31]) | uint32(b[30])<<8 | uint32(b[29])<<16 | uint32(b[28])<<24
108 s.n[1] = uint32(b[27]) | uint32(b[26])<<8 | uint32(b[25])<<16 | uint32(b[24])<<24
109 s.n[2] = uint32(b[23]) | uint32(b[22])<<8 | uint32(b[21])<<16 | uint32(b[20])<<24
110 s.n[3] = uint32(b[19]) | uint32(b[18])<<8 | uint32(b[17])<<16 | uint32(b[16])<<24
111 s.n[4] = uint32(b[15]) | uint32(b[14])<<8 | uint32(b[13])<<16 | uint32(b[12])<<24
112 s.n[5] = uint32(b[11]) | uint32(b[10])<<8 | uint32(b[9])<<16 | uint32(b[8])<<24
113 s.n[6] = uint32(b[7]) | uint32(b[6])<<8 | uint32(b[5])<<16 | uint32(b[4])<<24
114 s.n[7] = uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
115
116 overflow := s.overflows()
117 s.reduce256(overflow)
118 return overflow != 0
119 }
120
121 // setB32Seckey sets a scalar from a 32-byte secret key, returns true if valid
122 func (s *Scalar) setB32Seckey(b []byte) bool {
123 overflow := s.setB32(b)
124 return !s.isZero() && !overflow
125 }
126
127 // getB32 converts a scalar to a 32-byte big-endian array
128 func (s *Scalar) getB32(b []byte) {
129 if len(b) != 32 {
130 panic("scalar byte array must be 32 bytes")
131 }
132
133 b[31] = byte(s.n[0])
134 b[30] = byte(s.n[0] >> 8)
135 b[29] = byte(s.n[0] >> 16)
136 b[28] = byte(s.n[0] >> 24)
137 b[27] = byte(s.n[1])
138 b[26] = byte(s.n[1] >> 8)
139 b[25] = byte(s.n[1] >> 16)
140 b[24] = byte(s.n[1] >> 24)
141 b[23] = byte(s.n[2])
142 b[22] = byte(s.n[2] >> 8)
143 b[21] = byte(s.n[2] >> 16)
144 b[20] = byte(s.n[2] >> 24)
145 b[19] = byte(s.n[3])
146 b[18] = byte(s.n[3] >> 8)
147 b[17] = byte(s.n[3] >> 16)
148 b[16] = byte(s.n[3] >> 24)
149 b[15] = byte(s.n[4])
150 b[14] = byte(s.n[4] >> 8)
151 b[13] = byte(s.n[4] >> 16)
152 b[12] = byte(s.n[4] >> 24)
153 b[11] = byte(s.n[5])
154 b[10] = byte(s.n[5] >> 8)
155 b[9] = byte(s.n[5] >> 16)
156 b[8] = byte(s.n[5] >> 24)
157 b[7] = byte(s.n[6])
158 b[6] = byte(s.n[6] >> 8)
159 b[5] = byte(s.n[6] >> 16)
160 b[4] = byte(s.n[6] >> 24)
161 b[3] = byte(s.n[7])
162 b[2] = byte(s.n[7] >> 8)
163 b[1] = byte(s.n[7] >> 16)
164 b[0] = byte(s.n[7] >> 24)
165 }
166
167 // overflows determines if the scalar >= order
168 func (s *Scalar) overflows() uint32 {
169 highWordsEqual := constantTimeEq32(s.n[7], orderWord7)
170 highWordsEqual &= constantTimeEq32(s.n[6], orderWord6)
171 highWordsEqual &= constantTimeEq32(s.n[5], orderWord5)
172 overflow := highWordsEqual & constantTimeGreater32(s.n[4], orderWord4)
173 highWordsEqual &= constantTimeEq32(s.n[4], orderWord4)
174 overflow |= highWordsEqual & constantTimeGreater32(s.n[3], orderWord3)
175 highWordsEqual &= constantTimeEq32(s.n[3], orderWord3)
176 overflow |= highWordsEqual & constantTimeGreater32(s.n[2], orderWord2)
177 highWordsEqual &= constantTimeEq32(s.n[2], orderWord2)
178 overflow |= highWordsEqual & constantTimeGreater32(s.n[1], orderWord1)
179 highWordsEqual &= constantTimeEq32(s.n[1], orderWord1)
180 overflow |= highWordsEqual & constantTimeGreaterOrEq32(s.n[0], orderWord0)
181 return overflow
182 }
183
184 // reduce256 reduces the scalar modulo the order
185 func (s *Scalar) reduce256(overflows uint32) {
186 overflows64 := uint64(overflows)
187 c := uint64(s.n[0]) + overflows64*uint64(orderCompWord0)
188 s.n[0] = uint32(c & uint32Mask)
189 c = (c >> 32) + uint64(s.n[1]) + overflows64*uint64(orderCompWord1)
190 s.n[1] = uint32(c & uint32Mask)
191 c = (c >> 32) + uint64(s.n[2]) + overflows64*uint64(orderCompWord2)
192 s.n[2] = uint32(c & uint32Mask)
193 c = (c >> 32) + uint64(s.n[3]) + overflows64*uint64(orderCompWord3)
194 s.n[3] = uint32(c & uint32Mask)
195 c = (c >> 32) + uint64(s.n[4]) + overflows64
196 s.n[4] = uint32(c & uint32Mask)
197 c = (c >> 32) + uint64(s.n[5])
198 s.n[5] = uint32(c & uint32Mask)
199 c = (c >> 32) + uint64(s.n[6])
200 s.n[6] = uint32(c & uint32Mask)
201 c = (c >> 32) + uint64(s.n[7])
202 s.n[7] = uint32(c & uint32Mask)
203 }
204
205 // checkOverflow checks if the scalar overflows
206 func (s *Scalar) checkOverflow() bool {
207 return s.overflows() != 0
208 }
209
210 // reduce reduces the scalar modulo the order
211 func (s *Scalar) reduce(overflow int) {
212 s.reduce256(uint32(overflow))
213 }
214
215 // add adds two scalars: r = a + b
216 func (s *Scalar) add(a, b *Scalar) bool {
217 c := uint64(a.n[0]) + uint64(b.n[0])
218 s.n[0] = uint32(c & uint32Mask)
219 c = (c >> 32) + uint64(a.n[1]) + uint64(b.n[1])
220 s.n[1] = uint32(c & uint32Mask)
221 c = (c >> 32) + uint64(a.n[2]) + uint64(b.n[2])
222 s.n[2] = uint32(c & uint32Mask)
223 c = (c >> 32) + uint64(a.n[3]) + uint64(b.n[3])
224 s.n[3] = uint32(c & uint32Mask)
225 c = (c >> 32) + uint64(a.n[4]) + uint64(b.n[4])
226 s.n[4] = uint32(c & uint32Mask)
227 c = (c >> 32) + uint64(a.n[5]) + uint64(b.n[5])
228 s.n[5] = uint32(c & uint32Mask)
229 c = (c >> 32) + uint64(a.n[6]) + uint64(b.n[6])
230 s.n[6] = uint32(c & uint32Mask)
231 c = (c >> 32) + uint64(a.n[7]) + uint64(b.n[7])
232 s.n[7] = uint32(c & uint32Mask)
233
234 s.reduce256(uint32(c>>32) + s.overflows())
235 return false
236 }
237
238 // addPureGo is an alias for add in 32-bit mode
239 func (s *Scalar) addPureGo(a, b *Scalar) bool {
240 return s.add(a, b)
241 }
242
243 // sub subtracts two scalars: r = a - b
244 func (s *Scalar) sub(a, b *Scalar) {
245 var negB Scalar
246 negB.negate(b)
247 s.add(a, &negB)
248 }
249
250 // subPureGo is an alias for sub in 32-bit mode
251 func (s *Scalar) subPureGo(a, b *Scalar) {
252 s.sub(a, b)
253 }
254
255 // negate negates a scalar
256 func (s *Scalar) negate(a *Scalar) {
257 bits := a.n[0] | a.n[1] | a.n[2] | a.n[3] | a.n[4] | a.n[5] | a.n[6] | a.n[7]
258 mask := uint64(uint32Mask * constantTimeNotEq32(bits, 0))
259 c := uint64(orderWord0) + (uint64(^a.n[0]) + 1)
260 s.n[0] = uint32(c & mask)
261 c = (c >> 32) + uint64(orderWord1) + uint64(^a.n[1])
262 s.n[1] = uint32(c & mask)
263 c = (c >> 32) + uint64(orderWord2) + uint64(^a.n[2])
264 s.n[2] = uint32(c & mask)
265 c = (c >> 32) + uint64(orderWord3) + uint64(^a.n[3])
266 s.n[3] = uint32(c & mask)
267 c = (c >> 32) + uint64(orderWord4) + uint64(^a.n[4])
268 s.n[4] = uint32(c & mask)
269 c = (c >> 32) + uint64(orderWord5) + uint64(^a.n[5])
270 s.n[5] = uint32(c & mask)
271 c = (c >> 32) + uint64(orderWord6) + uint64(^a.n[6])
272 s.n[6] = uint32(c & mask)
273 c = (c >> 32) + uint64(orderWord7) + uint64(^a.n[7])
274 s.n[7] = uint32(c & mask)
275 }
276
277 // mul multiplies two scalars: r = a * b
278 func (s *Scalar) mul(a, b *Scalar) {
279 s.mulPureGo(a, b)
280 }
281
282 // mulPureGo performs multiplication using 32-bit arithmetic
283 func (s *Scalar) mulPureGo(a, b *Scalar) {
284 // Compute 512-bit product then reduce
285 var l [16]uint64
286
287 // Full 512-bit multiplication (using 64-bit intermediates for 32x32->64)
288 for i := 0; i < 8; i++ {
289 var c uint64
290 for j := 0; j < 8; j++ {
291 c += l[i+j] + uint64(a.n[i])*uint64(b.n[j])
292 l[i+j] = c & uint32Mask
293 c >>= 32
294 }
295 l[i+8] = c
296 }
297
298 // Reduce 512 bits to 256 bits modulo order
299 s.reduce512_32(l[:])
300 }
301
302 // reduce512_32 reduces a 512-bit value modulo the order (32-bit version)
303 func (s *Scalar) reduce512_32(l []uint64) {
304 // First reduction: 512 -> 385 bits
305 var m [13]uint64
306 var c uint64
307
308 c = l[0] + l[8]*uint64(orderCompWord0)
309 m[0] = c & uint32Mask
310 c >>= 32
311 c += l[1] + l[8]*uint64(orderCompWord1) + l[9]*uint64(orderCompWord0)
312 m[1] = c & uint32Mask
313 c >>= 32
314 c += l[2] + l[8]*uint64(orderCompWord2) + l[9]*uint64(orderCompWord1) + l[10]*uint64(orderCompWord0)
315 m[2] = c & uint32Mask
316 c >>= 32
317 c += l[3] + l[8]*uint64(orderCompWord3) + l[9]*uint64(orderCompWord2) + l[10]*uint64(orderCompWord1) + l[11]*uint64(orderCompWord0)
318 m[3] = c & uint32Mask
319 c >>= 32
320 c += l[4] + l[8] + l[9]*uint64(orderCompWord3) + l[10]*uint64(orderCompWord2) + l[11]*uint64(orderCompWord1) + l[12]*uint64(orderCompWord0)
321 m[4] = c & uint32Mask
322 c >>= 32
323 c += l[5] + l[9] + l[10]*uint64(orderCompWord3) + l[11]*uint64(orderCompWord2) + l[12]*uint64(orderCompWord1) + l[13]*uint64(orderCompWord0)
324 m[5] = c & uint32Mask
325 c >>= 32
326 c += l[6] + l[10] + l[11]*uint64(orderCompWord3) + l[12]*uint64(orderCompWord2) + l[13]*uint64(orderCompWord1) + l[14]*uint64(orderCompWord0)
327 m[6] = c & uint32Mask
328 c >>= 32
329 c += l[7] + l[11] + l[12]*uint64(orderCompWord3) + l[13]*uint64(orderCompWord2) + l[14]*uint64(orderCompWord1) + l[15]*uint64(orderCompWord0)
330 m[7] = c & uint32Mask
331 c >>= 32
332 c += l[12] + l[13]*uint64(orderCompWord3) + l[14]*uint64(orderCompWord2) + l[15]*uint64(orderCompWord1)
333 m[8] = c & uint32Mask
334 c >>= 32
335 c += l[13] + l[14]*uint64(orderCompWord3) + l[15]*uint64(orderCompWord2)
336 m[9] = c & uint32Mask
337 c >>= 32
338 c += l[14] + l[15]*uint64(orderCompWord3)
339 m[10] = c & uint32Mask
340 c >>= 32
341 c += l[15]
342 m[11] = c & uint32Mask
343 c >>= 32
344 m[12] = c
345
346 // Second reduction: 385 -> 258 bits
347 var p [9]uint64
348 c = m[0] + m[8]*uint64(orderCompWord0)
349 p[0] = c & uint32Mask
350 c >>= 32
351 c += m[1] + m[8]*uint64(orderCompWord1) + m[9]*uint64(orderCompWord0)
352 p[1] = c & uint32Mask
353 c >>= 32
354 c += m[2] + m[8]*uint64(orderCompWord2) + m[9]*uint64(orderCompWord1) + m[10]*uint64(orderCompWord0)
355 p[2] = c & uint32Mask
356 c >>= 32
357 c += m[3] + m[8]*uint64(orderCompWord3) + m[9]*uint64(orderCompWord2) + m[10]*uint64(orderCompWord1) + m[11]*uint64(orderCompWord0)
358 p[3] = c & uint32Mask
359 c >>= 32
360 c += m[4] + m[8] + m[9]*uint64(orderCompWord3) + m[10]*uint64(orderCompWord2) + m[11]*uint64(orderCompWord1) + m[12]*uint64(orderCompWord0)
361 p[4] = c & uint32Mask
362 c >>= 32
363 c += m[5] + m[9] + m[10]*uint64(orderCompWord3) + m[11]*uint64(orderCompWord2) + m[12]*uint64(orderCompWord1)
364 p[5] = c & uint32Mask
365 c >>= 32
366 c += m[6] + m[10] + m[11]*uint64(orderCompWord3) + m[12]*uint64(orderCompWord2)
367 p[6] = c & uint32Mask
368 c >>= 32
369 c += m[7] + m[11] + m[12]*uint64(orderCompWord3)
370 p[7] = c & uint32Mask
371 c >>= 32
372 p[8] = c + m[12]
373
374 // Final reduction: 258 -> 256 bits
375 c = p[0] + p[8]*uint64(orderCompWord0)
376 s.n[0] = uint32(c & uint32Mask)
377 c >>= 32
378 c += p[1] + p[8]*uint64(orderCompWord1)
379 s.n[1] = uint32(c & uint32Mask)
380 c >>= 32
381 c += p[2] + p[8]*uint64(orderCompWord2)
382 s.n[2] = uint32(c & uint32Mask)
383 c >>= 32
384 c += p[3] + p[8]*uint64(orderCompWord3)
385 s.n[3] = uint32(c & uint32Mask)
386 c >>= 32
387 c += p[4] + p[8]
388 s.n[4] = uint32(c & uint32Mask)
389 c >>= 32
390 c += p[5]
391 s.n[5] = uint32(c & uint32Mask)
392 c >>= 32
393 c += p[6]
394 s.n[6] = uint32(c & uint32Mask)
395 c >>= 32
396 c += p[7]
397 s.n[7] = uint32(c & uint32Mask)
398
399 s.reduce256(uint32(c>>32) + s.overflows())
400 }
401
402 // inverse computes the modular inverse
403 func (s *Scalar) inverse(a *Scalar) {
404 // Use Fermat's little theorem: a^(-1) = a^(n-2) mod n
405 var exp Scalar
406 exp.n[0] = orderWord0 - 2
407 exp.n[1] = orderWord1
408 exp.n[2] = orderWord2
409 exp.n[3] = orderWord3
410 exp.n[4] = orderWord4
411 exp.n[5] = orderWord5
412 exp.n[6] = orderWord6
413 exp.n[7] = orderWord7
414
415 s.exp(a, &exp)
416 }
417
418 // exp computes s = a^b mod n
419 func (s *Scalar) exp(a, b *Scalar) {
420 *s = ScalarOne
421 base := *a
422
423 for i := 0; i < 8; i++ {
424 limb := b.n[i]
425 for j := 0; j < 32; j++ {
426 if limb&1 != 0 {
427 s.mul(s, &base)
428 }
429 base.mul(&base, &base)
430 limb >>= 1
431 }
432 }
433 }
434
435 // half computes s = a/2 mod n
436 func (s *Scalar) half(a *Scalar) {
437 *s = *a
438 if s.n[0]&1 == 0 {
439 // Even: simple right shift
440 for i := 0; i < 7; i++ {
441 s.n[i] = (s.n[i] >> 1) | ((s.n[i+1] & 1) << 31)
442 }
443 s.n[7] >>= 1
444 } else {
445 // Odd: add n then divide by 2
446 var c uint64
447 c = uint64(s.n[0]) + uint64(orderWord0)
448 s.n[0] = uint32(c)
449 c = (c >> 32) + uint64(s.n[1]) + uint64(orderWord1)
450 s.n[1] = uint32(c)
451 c = (c >> 32) + uint64(s.n[2]) + uint64(orderWord2)
452 s.n[2] = uint32(c)
453 c = (c >> 32) + uint64(s.n[3]) + uint64(orderWord3)
454 s.n[3] = uint32(c)
455 c = (c >> 32) + uint64(s.n[4]) + uint64(orderWord4)
456 s.n[4] = uint32(c)
457 c = (c >> 32) + uint64(s.n[5]) + uint64(orderWord5)
458 s.n[5] = uint32(c)
459 c = (c >> 32) + uint64(s.n[6]) + uint64(orderWord6)
460 s.n[6] = uint32(c)
461 c = (c >> 32) + uint64(s.n[7]) + uint64(orderWord7)
462 s.n[7] = uint32(c)
463
464 // Divide by 2
465 for i := 0; i < 7; i++ {
466 s.n[i] = (s.n[i] >> 1) | ((s.n[i+1] & 1) << 31)
467 }
468 s.n[7] >>= 1
469 }
470 }
471
472 // isZero returns true if the scalar is zero
473 func (s *Scalar) isZero() bool {
474 bits := s.n[0] | s.n[1] | s.n[2] | s.n[3] | s.n[4] | s.n[5] | s.n[6] | s.n[7]
475 return bits == 0
476 }
477
478 // isOne returns true if the scalar is one
479 func (s *Scalar) isOne() bool {
480 return s.n[0] == 1 && s.n[1] == 0 && s.n[2] == 0 && s.n[3] == 0 &&
481 s.n[4] == 0 && s.n[5] == 0 && s.n[6] == 0 && s.n[7] == 0
482 }
483
484 // isEven returns true if the scalar is even
485 func (s *Scalar) isEven() bool {
486 return s.n[0]&1 == 0
487 }
488
489 // isHigh returns true if the scalar is > n/2
490 func (s *Scalar) isHigh() bool {
491 result := constantTimeGreater32(s.n[7], halfOrderWord7)
492 highWordsEqual := constantTimeEq32(s.n[7], halfOrderWord7)
493 highWordsEqual &= constantTimeEq32(s.n[6], halfOrderWord6)
494 highWordsEqual &= constantTimeEq32(s.n[5], halfOrderWord5)
495 highWordsEqual &= constantTimeEq32(s.n[4], halfOrderWord4)
496 result |= highWordsEqual & constantTimeGreater32(s.n[3], halfOrderWord3)
497 highWordsEqual &= constantTimeEq32(s.n[3], halfOrderWord3)
498 result |= highWordsEqual & constantTimeGreater32(s.n[2], halfOrderWord2)
499 highWordsEqual &= constantTimeEq32(s.n[2], halfOrderWord2)
500 result |= highWordsEqual & constantTimeGreater32(s.n[1], halfOrderWord1)
501 highWordsEqual &= constantTimeEq32(s.n[1], halfOrderWord1)
502 result |= highWordsEqual & constantTimeGreater32(s.n[0], halfOrderWord0)
503 return result != 0
504 }
505
506 // condNegate conditionally negates the scalar
507 func (s *Scalar) condNegate(flag int) {
508 if flag != 0 {
509 var neg Scalar
510 neg.negate(s)
511 *s = neg
512 }
513 }
514
515 // equal returns true if two scalars are equal
516 func (s *Scalar) equal(a *Scalar) bool {
517 return subtle.ConstantTimeCompare(
518 (*[32]byte)(unsafe.Pointer(&s.n[0]))[:32],
519 (*[32]byte)(unsafe.Pointer(&a.n[0]))[:32],
520 ) == 1
521 }
522
523 // getBits extracts count bits starting at offset
524 func (s *Scalar) getBits(offset, count uint) uint32 {
525 if count == 0 || count > 32 {
526 panic("count must be 1-32")
527 }
528 if offset+count > 256 {
529 panic("offset + count must be <= 256")
530 }
531
532 limbIdx := offset / 32
533 bitIdx := offset % 32
534
535 if bitIdx+count <= 32 {
536 return (s.n[limbIdx] >> bitIdx) & ((1 << count) - 1)
537 }
538 lowBits := 32 - bitIdx
539 highBits := count - lowBits
540 low := (s.n[limbIdx] >> bitIdx) & ((1 << lowBits) - 1)
541 high := s.n[limbIdx+1] & ((1 << highBits) - 1)
542 return low | (high << lowBits)
543 }
544
545 // cmov conditionally moves a scalar
546 func (s *Scalar) cmov(a *Scalar, flag int) {
547 mask := uint32(-(int32(flag) & 1))
548 for i := 0; i < 8; i++ {
549 s.n[i] ^= mask & (s.n[i] ^ a.n[i])
550 }
551 }
552
553 // clear clears a scalar
554 func (s *Scalar) clear() {
555 for i := 0; i < 8; i++ {
556 s.n[i] = 0
557 }
558 }
559
560 // wNAF converts a scalar to wNAF representation
561 func (s *Scalar) wNAF(wnaf *[257]int8, w uint) int {
562 if w < 2 || w > 8 {
563 panic("w must be between 2 and 8")
564 }
565
566 var k Scalar
567 k = *s
568
569 numBits := 0
570 var carry uint32
571
572 *wnaf = [257]int8{}
573
574 bit := 0
575 for bit < 256 {
576 if k.getBits(uint(bit), 1) == carry {
577 bit++
578 continue
579 }
580
581 window := w
582 if bit+int(window) > 256 {
583 window = uint(256 - bit)
584 }
585
586 word := k.getBits(uint(bit), window) + carry
587 carry = (word >> (window - 1)) & 1
588 word -= carry << window
589
590 wnaf[bit] = int8(int32(word))
591 numBits = bit + int(window) - 1
592
593 bit += int(window)
594 }
595
596 if carry != 0 {
597 wnaf[256] = int8(carry)
598 numBits = 256
599 }
600
601 return numBits + 1
602 }
603
604 // wNAFSigned converts a scalar to wNAF representation with sign handling
605 func (s *Scalar) wNAFSigned(wnaf *[257]int8, w uint) (int, bool) {
606 if w < 2 || w > 8 {
607 panic("w must be between 2 and 8")
608 }
609
610 var k Scalar
611 k = *s
612
613 negated := false
614 if k.getBits(255, 1) == 1 {
615 k.negate(&k)
616 negated = true
617 }
618
619 bits := k.wNAF(wnaf, w)
620 return bits, negated
621 }
622
623 // caddBit conditionally adds a power of 2
624 func (s *Scalar) caddBit(bit uint, flag int) {
625 if flag == 0 {
626 return
627 }
628
629 limbIdx := bit >> 5
630 bitIdx := bit & 0x1F
631 addVal := uint32(1) << bitIdx
632
633 var c uint64
634 for i := limbIdx; i < 8; i++ {
635 if i == limbIdx {
636 c = uint64(s.n[i]) + uint64(addVal)
637 } else {
638 c = uint64(s.n[i]) + (c >> 32)
639 }
640 s.n[i] = uint32(c)
641 if c>>32 == 0 {
642 break
643 }
644 }
645 }
646
647 // mulShiftVar computes r = round((a * b) >> shift)
648 func (s *Scalar) mulShiftVar(a, b *Scalar, shift uint) {
649 if shift < 256 {
650 panic("mulShiftVar requires shift >= 256")
651 }
652
653 // Compute full 512-bit product
654 var l [16]uint64
655 for i := 0; i < 8; i++ {
656 var c uint64
657 for j := 0; j < 8; j++ {
658 c += l[i+j] + uint64(a.n[i])*uint64(b.n[j])
659 l[i+j] = c & uint32Mask
660 c >>= 32
661 }
662 l[i+8] = c
663 }
664
665 // Extract bits [shift, shift+256)
666 shiftLimbs := shift >> 5
667 shiftLow := shift & 0x1F
668 shiftHigh := 32 - shiftLow
669
670 for i := 0; i < 8; i++ {
671 srcIdx := shiftLimbs + uint(i)
672 if srcIdx < 16 {
673 if shiftLow != 0 && srcIdx+1 < 16 {
674 s.n[i] = uint32((l[srcIdx] >> shiftLow) | (l[srcIdx+1] << shiftHigh))
675 } else {
676 s.n[i] = uint32(l[srcIdx] >> shiftLow)
677 }
678 } else {
679 s.n[i] = 0
680 }
681 }
682
683 // Round by adding bit just below shift
684 roundBit := int((l[(shift-1)>>5] >> ((shift - 1) & 0x1F)) & 1)
685 s.caddBit(0, roundBit)
686 }
687
688 // Constant-time helper functions
689 func constantTimeNotEq32(a, b uint32) uint32 {
690 return ^constantTimeEq32(a, b) & 1
691 }
692
693 func constantTimeGreaterOrEq32(a, b uint32) uint32 {
694 return uint32((uint64(a) - uint64(b) - 1) >> 63) ^ 1
695 }
696
697 // scalarSplitLambda decomposes k into k1, k2 for GLV
698 func scalarSplitLambda(r1, r2, k *Scalar) {
699 var c1, c2 Scalar
700
701 c1.mulShiftVar(k, &scalarG1, 384)
702 c2.mulShiftVar(k, &scalarG2, 384)
703
704 c1.mul(&c1, &scalarMinusB1)
705 c2.mul(&c2, &scalarMinusB2)
706
707 r2.add(&c1, &c2)
708 r1.mul(r2, &scalarLambda)
709 r1.negate(r1)
710 r1.add(r1, k)
711 }
712
713 // scalarSplit128 splits a scalar into two 128-bit halves
714 func scalarSplit128(r1, r2, k *Scalar) {
715 r1.n[0] = k.n[0]
716 r1.n[1] = k.n[1]
717 r1.n[2] = k.n[2]
718 r1.n[3] = k.n[3]
719 r1.n[4] = 0
720 r1.n[5] = 0
721 r1.n[6] = 0
722 r1.n[7] = 0
723
724 r2.n[0] = k.n[4]
725 r2.n[1] = k.n[5]
726 r2.n[2] = k.n[6]
727 r2.n[3] = k.n[7]
728 r2.n[4] = 0
729 r2.n[5] = 0
730 r2.n[6] = 0
731 r2.n[7] = 0
732 }
733
734 // Direct function versions for compatibility
735 func scalarAdd(r, a, b *Scalar) bool { return r.add(a, b) }
736 func scalarMul(r, a, b *Scalar) { r.mul(a, b) }
737 func scalarGetB32(bin []byte, a *Scalar) { a.getB32(bin) }
738 func scalarIsZero(a *Scalar) bool { return a.isZero() }
739 func scalarCheckOverflow(r *Scalar) bool { return r.checkOverflow() }
740 func scalarReduce(r *Scalar, overflow int) { r.reduce(overflow) }
741
742 // Stubs for AVX2 functions (not available on WASM) - these forward to pure Go
743 func scalarAddAVX2(r, a, b *Scalar) { r.add(a, b) }
744 func scalarSubAVX2(r, a, b *Scalar) { r.sub(a, b) }
745 func scalarMulAVX2(r, a, b *Scalar) { r.mul(a, b) }
746
747 // Compatibility constants for verify.go (these map to our 8x32 representation)
748 const (
749 scalarN0 = uint64(orderWord0) | uint64(orderWord1)<<32
750 scalarN1 = uint64(orderWord2) | uint64(orderWord3)<<32
751 scalarN2 = uint64(orderWord4) | uint64(orderWord5)<<32
752 scalarN3 = uint64(orderWord6) | uint64(orderWord7)<<32
753 )
754
755 // d returns the scalar limbs as a 4-element uint64 array for compatibility
756 // This converts from 8x32 to 4x64 representation
757 func (s *Scalar) d() [4]uint64 {
758 return [4]uint64{
759 uint64(s.n[0]) | uint64(s.n[1])<<32,
760 uint64(s.n[2]) | uint64(s.n[3])<<32,
761 uint64(s.n[4]) | uint64(s.n[5])<<32,
762 uint64(s.n[6]) | uint64(s.n[7])<<32,
763 }
764 }
765