scalar.go raw
1 package avx
2
3 import "math/bits"
4
5 // Scalar operations modulo the secp256k1 group order n.
6 // n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
7
8 // SetBytes sets a scalar from a 32-byte big-endian slice.
9 // Returns true if the value was >= n and was reduced.
10 func (s *Scalar) SetBytes(b []byte) bool {
11 if len(b) != 32 {
12 panic("scalar must be 32 bytes")
13 }
14
15 // Convert big-endian bytes to little-endian limbs
16 s.D[0].Lo = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
17 uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
18 s.D[0].Hi = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
19 uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
20 s.D[1].Lo = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
21 uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
22 s.D[1].Hi = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
23 uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
24
25 // Check overflow and reduce if necessary
26 overflow := s.checkOverflow()
27 if overflow {
28 s.reduce()
29 }
30 return overflow
31 }
32
33 // Bytes returns the scalar as a 32-byte big-endian slice.
34 func (s *Scalar) Bytes() [32]byte {
35 var b [32]byte
36 b[31] = byte(s.D[0].Lo)
37 b[30] = byte(s.D[0].Lo >> 8)
38 b[29] = byte(s.D[0].Lo >> 16)
39 b[28] = byte(s.D[0].Lo >> 24)
40 b[27] = byte(s.D[0].Lo >> 32)
41 b[26] = byte(s.D[0].Lo >> 40)
42 b[25] = byte(s.D[0].Lo >> 48)
43 b[24] = byte(s.D[0].Lo >> 56)
44
45 b[23] = byte(s.D[0].Hi)
46 b[22] = byte(s.D[0].Hi >> 8)
47 b[21] = byte(s.D[0].Hi >> 16)
48 b[20] = byte(s.D[0].Hi >> 24)
49 b[19] = byte(s.D[0].Hi >> 32)
50 b[18] = byte(s.D[0].Hi >> 40)
51 b[17] = byte(s.D[0].Hi >> 48)
52 b[16] = byte(s.D[0].Hi >> 56)
53
54 b[15] = byte(s.D[1].Lo)
55 b[14] = byte(s.D[1].Lo >> 8)
56 b[13] = byte(s.D[1].Lo >> 16)
57 b[12] = byte(s.D[1].Lo >> 24)
58 b[11] = byte(s.D[1].Lo >> 32)
59 b[10] = byte(s.D[1].Lo >> 40)
60 b[9] = byte(s.D[1].Lo >> 48)
61 b[8] = byte(s.D[1].Lo >> 56)
62
63 b[7] = byte(s.D[1].Hi)
64 b[6] = byte(s.D[1].Hi >> 8)
65 b[5] = byte(s.D[1].Hi >> 16)
66 b[4] = byte(s.D[1].Hi >> 24)
67 b[3] = byte(s.D[1].Hi >> 32)
68 b[2] = byte(s.D[1].Hi >> 40)
69 b[1] = byte(s.D[1].Hi >> 48)
70 b[0] = byte(s.D[1].Hi >> 56)
71
72 return b
73 }
74
75 // IsZero returns true if the scalar is zero.
76 func (s *Scalar) IsZero() bool {
77 return s.D[0].IsZero() && s.D[1].IsZero()
78 }
79
80 // IsOne returns true if the scalar is one.
81 func (s *Scalar) IsOne() bool {
82 return s.D[0].Lo == 1 && s.D[0].Hi == 0 && s.D[1].IsZero()
83 }
84
85 // Equal returns true if two scalars are equal.
86 func (s *Scalar) Equal(other *Scalar) bool {
87 return s.D[0].Lo == other.D[0].Lo && s.D[0].Hi == other.D[0].Hi &&
88 s.D[1].Lo == other.D[1].Lo && s.D[1].Hi == other.D[1].Hi
89 }
90
91 // checkOverflow returns true if s >= n.
92 func (s *Scalar) checkOverflow() bool {
93 // Compare high to low
94 if s.D[1].Hi > ScalarN.D[1].Hi {
95 return true
96 }
97 if s.D[1].Hi < ScalarN.D[1].Hi {
98 return false
99 }
100 if s.D[1].Lo > ScalarN.D[1].Lo {
101 return true
102 }
103 if s.D[1].Lo < ScalarN.D[1].Lo {
104 return false
105 }
106 if s.D[0].Hi > ScalarN.D[0].Hi {
107 return true
108 }
109 if s.D[0].Hi < ScalarN.D[0].Hi {
110 return false
111 }
112 return s.D[0].Lo >= ScalarN.D[0].Lo
113 }
114
115 // reduce reduces s modulo n by adding the complement (2^256 - n).
116 func (s *Scalar) reduce() {
117 // s = s - n = s + (2^256 - n) mod 2^256
118 var carry uint64
119 s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
120 s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, carry)
121 s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, carry)
122 s.D[1].Hi, _ = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, carry)
123 }
124
125 // Add sets s = a + b mod n.
126 func (s *Scalar) Add(a, b *Scalar) *Scalar {
127 var carry uint64
128 s.D[0].Lo, carry = bits.Add64(a.D[0].Lo, b.D[0].Lo, 0)
129 s.D[0].Hi, carry = bits.Add64(a.D[0].Hi, b.D[0].Hi, carry)
130 s.D[1].Lo, carry = bits.Add64(a.D[1].Lo, b.D[1].Lo, carry)
131 s.D[1].Hi, carry = bits.Add64(a.D[1].Hi, b.D[1].Hi, carry)
132
133 // If there was a carry or if result >= n, reduce
134 if carry != 0 || s.checkOverflow() {
135 s.reduce()
136 }
137 return s
138 }
139
140 // Sub sets s = a - b mod n.
141 func (s *Scalar) Sub(a, b *Scalar) *Scalar {
142 var borrow uint64
143 s.D[0].Lo, borrow = bits.Sub64(a.D[0].Lo, b.D[0].Lo, 0)
144 s.D[0].Hi, borrow = bits.Sub64(a.D[0].Hi, b.D[0].Hi, borrow)
145 s.D[1].Lo, borrow = bits.Sub64(a.D[1].Lo, b.D[1].Lo, borrow)
146 s.D[1].Hi, borrow = bits.Sub64(a.D[1].Hi, b.D[1].Hi, borrow)
147
148 // If there was a borrow, add n back
149 if borrow != 0 {
150 var carry uint64
151 s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, ScalarN.D[0].Lo, 0)
152 s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, ScalarN.D[0].Hi, carry)
153 s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, ScalarN.D[1].Lo, carry)
154 s.D[1].Hi, _ = bits.Add64(s.D[1].Hi, ScalarN.D[1].Hi, carry)
155 }
156 return s
157 }
158
159 // Negate sets s = -a mod n.
160 func (s *Scalar) Negate(a *Scalar) *Scalar {
161 if a.IsZero() {
162 *s = ScalarZero
163 return s
164 }
165 // s = n - a
166 var borrow uint64
167 s.D[0].Lo, borrow = bits.Sub64(ScalarN.D[0].Lo, a.D[0].Lo, 0)
168 s.D[0].Hi, borrow = bits.Sub64(ScalarN.D[0].Hi, a.D[0].Hi, borrow)
169 s.D[1].Lo, borrow = bits.Sub64(ScalarN.D[1].Lo, a.D[1].Lo, borrow)
170 s.D[1].Hi, _ = bits.Sub64(ScalarN.D[1].Hi, a.D[1].Hi, borrow)
171 return s
172 }
173
174 // Mul sets s = a * b mod n.
175 func (s *Scalar) Mul(a, b *Scalar) *Scalar {
176 // Compute 512-bit product
177 var prod [8]uint64
178 scalarMul512(&prod, a, b)
179
180 // Reduce mod n
181 scalarReduce512(s, &prod)
182 return s
183 }
184
185 // scalarMul512 computes the 512-bit product of two 256-bit scalars.
186 // Result is stored in prod[0..7] where prod[0] is the least significant.
187 func scalarMul512(prod *[8]uint64, a, b *Scalar) {
188 // Using schoolbook multiplication with 64-bit limbs
189 // a = a[0] + a[1]*2^64 + a[2]*2^128 + a[3]*2^192
190 // b = b[0] + b[1]*2^64 + b[2]*2^128 + b[3]*2^192
191
192 aLimbs := [4]uint64{a.D[0].Lo, a.D[0].Hi, a.D[1].Lo, a.D[1].Hi}
193 bLimbs := [4]uint64{b.D[0].Lo, b.D[0].Hi, b.D[1].Lo, b.D[1].Hi}
194
195 // Clear product
196 for i := range prod {
197 prod[i] = 0
198 }
199
200 // Schoolbook multiplication
201 for i := 0; i < 4; i++ {
202 var carry uint64
203 for j := 0; j < 4; j++ {
204 hi, lo := bits.Mul64(aLimbs[i], bLimbs[j])
205 lo, c := bits.Add64(lo, prod[i+j], 0)
206 hi, _ = bits.Add64(hi, 0, c)
207 lo, c = bits.Add64(lo, carry, 0)
208 hi, _ = bits.Add64(hi, 0, c)
209 prod[i+j] = lo
210 carry = hi
211 }
212 prod[i+4] = carry
213 }
214 }
215
216 // scalarReduce512 reduces a 512-bit value mod n.
217 func scalarReduce512(s *Scalar, prod *[8]uint64) {
218 // Barrett reduction or simple repeated subtraction
219 // For now, use a simpler approach: extract high 256 bits, multiply by (2^256 mod n), add to low
220
221 // 2^256 mod n = 2^256 - n = ScalarNC (approximately 0x14551231950B75FC4...etc)
222 // This is a simplified reduction - a full implementation would use Barrett reduction
223
224 // Copy low 256 bits to result
225 s.D[0].Lo = prod[0]
226 s.D[0].Hi = prod[1]
227 s.D[1].Lo = prod[2]
228 s.D[1].Hi = prod[3]
229
230 // If high 256 bits are non-zero, we need to reduce
231 if prod[4] != 0 || prod[5] != 0 || prod[6] != 0 || prod[7] != 0 {
232 // high * (2^256 mod n) + low
233 // This is a simplified version - multiply high by NC and add
234 highScalar := Scalar{
235 D: [2]Uint128{
236 {Lo: prod[4], Hi: prod[5]},
237 {Lo: prod[6], Hi: prod[7]},
238 },
239 }
240
241 // Multiply high by NC (which is small: ~2^129)
242 // For correctness, we'd need full multiplication, but NC is small enough
243 // that we can use a simplified approach
244
245 // NC = 0x14551231950B75FC4402DA1732FC9BEBF
246 // NC.D[0] = {Lo: 0x402DA1732FC9BEBF, Hi: 0x4551231950B75FC4}
247 // NC.D[1] = {Lo: 0x1, Hi: 0}
248
249 // Approximate: high * NC ≈ high * 2^129 (since NC ≈ 2^129)
250 // This means we shift high left by 129 bits and add
251
252 // For a correct implementation, compute high * NC properly:
253 var reduction [8]uint64
254 ncLimbs := [4]uint64{ScalarNC.D[0].Lo, ScalarNC.D[0].Hi, ScalarNC.D[1].Lo, ScalarNC.D[1].Hi}
255 highLimbs := [4]uint64{highScalar.D[0].Lo, highScalar.D[0].Hi, highScalar.D[1].Lo, highScalar.D[1].Hi}
256
257 for i := 0; i < 4; i++ {
258 var carry uint64
259 for j := 0; j < 4; j++ {
260 hi, lo := bits.Mul64(highLimbs[i], ncLimbs[j])
261 lo, c := bits.Add64(lo, reduction[i+j], 0)
262 hi, _ = bits.Add64(hi, 0, c)
263 lo, c = bits.Add64(lo, carry, 0)
264 hi, _ = bits.Add64(hi, 0, c)
265 reduction[i+j] = lo
266 carry = hi
267 }
268 if i+4 < 8 {
269 reduction[i+4], _ = bits.Add64(reduction[i+4], carry, 0)
270 }
271 }
272
273 // Add reduction to s
274 var carry uint64
275 s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, reduction[0], 0)
276 s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, reduction[1], carry)
277 s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, reduction[2], carry)
278 s.D[1].Hi, carry = bits.Add64(s.D[1].Hi, reduction[3], carry)
279
280 // Handle any remaining high bits by repeated reduction
281 // If there's a carry, it represents 2^256 which equals NC mod n
282 // If reduction[4..7] are non-zero, we need to reduce those too
283 if carry != 0 || reduction[4] != 0 || reduction[5] != 0 || reduction[6] != 0 || reduction[7] != 0 {
284 // The carry and reduction[4..7] together represent additional multiples of 2^256
285 // Each 2^256 ≡ NC (mod n), so we add (carry + reduction[4..7]) * NC
286
287 // First, handle the carry
288 if carry != 0 {
289 // carry * NC
290 var c uint64
291 s.D[0].Lo, c = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
292 s.D[0].Hi, c = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, c)
293 s.D[1].Lo, c = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, c)
294 s.D[1].Hi, c = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, c)
295
296 // If there's still a carry, add NC again
297 for c != 0 {
298 s.D[0].Lo, c = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
299 s.D[0].Hi, c = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, c)
300 s.D[1].Lo, c = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, c)
301 s.D[1].Hi, c = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, c)
302 }
303 }
304
305 // Handle reduction[4..7] if non-zero
306 if reduction[4] != 0 || reduction[5] != 0 || reduction[6] != 0 || reduction[7] != 0 {
307 // Compute reduction[4..7] * NC and add
308 highScalar2 := Scalar{
309 D: [2]Uint128{
310 {Lo: reduction[4], Hi: reduction[5]},
311 {Lo: reduction[6], Hi: reduction[7]},
312 },
313 }
314
315 var reduction2 [8]uint64
316 high2Limbs := [4]uint64{highScalar2.D[0].Lo, highScalar2.D[0].Hi, highScalar2.D[1].Lo, highScalar2.D[1].Hi}
317
318 for i := 0; i < 4; i++ {
319 var c uint64
320 for j := 0; j < 4; j++ {
321 hi, lo := bits.Mul64(high2Limbs[i], ncLimbs[j])
322 lo, cc := bits.Add64(lo, reduction2[i+j], 0)
323 hi, _ = bits.Add64(hi, 0, cc)
324 lo, cc = bits.Add64(lo, c, 0)
325 hi, _ = bits.Add64(hi, 0, cc)
326 reduction2[i+j] = lo
327 c = hi
328 }
329 if i+4 < 8 {
330 reduction2[i+4], _ = bits.Add64(reduction2[i+4], c, 0)
331 }
332 }
333
334 var c uint64
335 s.D[0].Lo, c = bits.Add64(s.D[0].Lo, reduction2[0], 0)
336 s.D[0].Hi, c = bits.Add64(s.D[0].Hi, reduction2[1], c)
337 s.D[1].Lo, c = bits.Add64(s.D[1].Lo, reduction2[2], c)
338 s.D[1].Hi, c = bits.Add64(s.D[1].Hi, reduction2[3], c)
339
340 // Handle cascading carries
341 for c != 0 || reduction2[4] != 0 || reduction2[5] != 0 || reduction2[6] != 0 || reduction2[7] != 0 {
342 // This case is extremely rare but handle it
343 for s.checkOverflow() {
344 s.reduce()
345 }
346 break
347 }
348 }
349 }
350 }
351
352 // Final reduction if needed
353 if s.checkOverflow() {
354 s.reduce()
355 }
356 }
357
358 // Sqr sets s = a^2 mod n.
359 func (s *Scalar) Sqr(a *Scalar) *Scalar {
360 return s.Mul(a, a)
361 }
362
363 // Inverse sets s = a^(-1) mod n using Fermat's little theorem.
364 // a^(-1) = a^(n-2) mod n
365 func (s *Scalar) Inverse(a *Scalar) *Scalar {
366 // n-2 in binary is used for square-and-multiply
367 // This is a simplified implementation using binary exponentiation
368
369 var result, base Scalar
370 result = ScalarOne
371 base = *a
372
373 // n-2 bytes (big-endian)
374 nMinus2 := [32]byte{
375 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
376 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
377 0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
378 0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x3F,
379 }
380
381 for i := 0; i < 32; i++ {
382 b := nMinus2[31-i]
383 for j := 0; j < 8; j++ {
384 if (b>>j)&1 == 1 {
385 result.Mul(&result, &base)
386 }
387 base.Sqr(&base)
388 }
389 }
390
391 *s = result
392 return s
393 }
394
395 // IsHigh returns true if s > n/2.
396 func (s *Scalar) IsHigh() bool {
397 // Compare with n/2
398 if s.D[1].Hi > ScalarNHalf.D[1].Hi {
399 return true
400 }
401 if s.D[1].Hi < ScalarNHalf.D[1].Hi {
402 return false
403 }
404 if s.D[1].Lo > ScalarNHalf.D[1].Lo {
405 return true
406 }
407 if s.D[1].Lo < ScalarNHalf.D[1].Lo {
408 return false
409 }
410 if s.D[0].Hi > ScalarNHalf.D[0].Hi {
411 return true
412 }
413 if s.D[0].Hi < ScalarNHalf.D[0].Hi {
414 return false
415 }
416 return s.D[0].Lo > ScalarNHalf.D[0].Lo
417 }
418
419 // CondNegate negates s if cond is true.
420 func (s *Scalar) CondNegate(cond bool) *Scalar {
421 if cond {
422 s.Negate(s)
423 }
424 return s
425 }
426