field.go raw
1 package avx
2
3 import "math/bits"
4
5 // Field operations modulo the secp256k1 field prime p.
6 // p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
7 // = 2^256 - 2^32 - 977
8
9 // SetBytes sets a field element from a 32-byte big-endian slice.
10 // Returns true if the value was >= p and was reduced.
11 func (f *FieldElement) SetBytes(b []byte) bool {
12 if len(b) != 32 {
13 panic("field element must be 32 bytes")
14 }
15
16 // Convert big-endian bytes to little-endian limbs
17 f.N[0].Lo = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
18 uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
19 f.N[0].Hi = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
20 uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
21 f.N[1].Lo = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
22 uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
23 f.N[1].Hi = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
24 uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
25
26 // Check overflow and reduce if necessary
27 overflow := f.checkOverflow()
28 if overflow {
29 f.reduce()
30 }
31 return overflow
32 }
33
34 // Bytes returns the field element as a 32-byte big-endian slice.
35 func (f *FieldElement) Bytes() [32]byte {
36 var b [32]byte
37 b[31] = byte(f.N[0].Lo)
38 b[30] = byte(f.N[0].Lo >> 8)
39 b[29] = byte(f.N[0].Lo >> 16)
40 b[28] = byte(f.N[0].Lo >> 24)
41 b[27] = byte(f.N[0].Lo >> 32)
42 b[26] = byte(f.N[0].Lo >> 40)
43 b[25] = byte(f.N[0].Lo >> 48)
44 b[24] = byte(f.N[0].Lo >> 56)
45
46 b[23] = byte(f.N[0].Hi)
47 b[22] = byte(f.N[0].Hi >> 8)
48 b[21] = byte(f.N[0].Hi >> 16)
49 b[20] = byte(f.N[0].Hi >> 24)
50 b[19] = byte(f.N[0].Hi >> 32)
51 b[18] = byte(f.N[0].Hi >> 40)
52 b[17] = byte(f.N[0].Hi >> 48)
53 b[16] = byte(f.N[0].Hi >> 56)
54
55 b[15] = byte(f.N[1].Lo)
56 b[14] = byte(f.N[1].Lo >> 8)
57 b[13] = byte(f.N[1].Lo >> 16)
58 b[12] = byte(f.N[1].Lo >> 24)
59 b[11] = byte(f.N[1].Lo >> 32)
60 b[10] = byte(f.N[1].Lo >> 40)
61 b[9] = byte(f.N[1].Lo >> 48)
62 b[8] = byte(f.N[1].Lo >> 56)
63
64 b[7] = byte(f.N[1].Hi)
65 b[6] = byte(f.N[1].Hi >> 8)
66 b[5] = byte(f.N[1].Hi >> 16)
67 b[4] = byte(f.N[1].Hi >> 24)
68 b[3] = byte(f.N[1].Hi >> 32)
69 b[2] = byte(f.N[1].Hi >> 40)
70 b[1] = byte(f.N[1].Hi >> 48)
71 b[0] = byte(f.N[1].Hi >> 56)
72
73 return b
74 }
75
76 // IsZero returns true if the field element is zero.
77 func (f *FieldElement) IsZero() bool {
78 return f.N[0].IsZero() && f.N[1].IsZero()
79 }
80
81 // IsOne returns true if the field element is one.
82 func (f *FieldElement) IsOne() bool {
83 return f.N[0].Lo == 1 && f.N[0].Hi == 0 && f.N[1].IsZero()
84 }
85
86 // Equal returns true if two field elements are equal.
87 func (f *FieldElement) Equal(other *FieldElement) bool {
88 return f.N[0].Lo == other.N[0].Lo && f.N[0].Hi == other.N[0].Hi &&
89 f.N[1].Lo == other.N[1].Lo && f.N[1].Hi == other.N[1].Hi
90 }
91
92 // IsOdd returns true if the field element is odd.
93 func (f *FieldElement) IsOdd() bool {
94 return f.N[0].Lo&1 == 1
95 }
96
97 // checkOverflow returns true if f >= p.
98 func (f *FieldElement) checkOverflow() bool {
99 // p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
100 // Compare high to low
101 if f.N[1].Hi > FieldP.N[1].Hi {
102 return true
103 }
104 if f.N[1].Hi < FieldP.N[1].Hi {
105 return false
106 }
107 if f.N[1].Lo > FieldP.N[1].Lo {
108 return true
109 }
110 if f.N[1].Lo < FieldP.N[1].Lo {
111 return false
112 }
113 if f.N[0].Hi > FieldP.N[0].Hi {
114 return true
115 }
116 if f.N[0].Hi < FieldP.N[0].Hi {
117 return false
118 }
119 return f.N[0].Lo >= FieldP.N[0].Lo
120 }
121
122 // reduce reduces f modulo p by adding the complement (2^256 - p = 2^32 + 977).
123 func (f *FieldElement) reduce() {
124 // f = f - p = f + (2^256 - p) mod 2^256
125 // 2^256 - p = 0x1000003D1
126 var carry uint64
127 f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, 0x1000003D1, 0)
128 f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, 0, carry)
129 f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
130 f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
131 }
132
133 // Add sets f = a + b mod p.
134 func (f *FieldElement) Add(a, b *FieldElement) *FieldElement {
135 var carry uint64
136 f.N[0].Lo, carry = bits.Add64(a.N[0].Lo, b.N[0].Lo, 0)
137 f.N[0].Hi, carry = bits.Add64(a.N[0].Hi, b.N[0].Hi, carry)
138 f.N[1].Lo, carry = bits.Add64(a.N[1].Lo, b.N[1].Lo, carry)
139 f.N[1].Hi, carry = bits.Add64(a.N[1].Hi, b.N[1].Hi, carry)
140
141 // If there was a carry or if result >= p, reduce
142 if carry != 0 || f.checkOverflow() {
143 f.reduce()
144 }
145 return f
146 }
147
148 // Sub sets f = a - b mod p.
149 func (f *FieldElement) Sub(a, b *FieldElement) *FieldElement {
150 var borrow uint64
151 f.N[0].Lo, borrow = bits.Sub64(a.N[0].Lo, b.N[0].Lo, 0)
152 f.N[0].Hi, borrow = bits.Sub64(a.N[0].Hi, b.N[0].Hi, borrow)
153 f.N[1].Lo, borrow = bits.Sub64(a.N[1].Lo, b.N[1].Lo, borrow)
154 f.N[1].Hi, borrow = bits.Sub64(a.N[1].Hi, b.N[1].Hi, borrow)
155
156 // If there was a borrow, add p back
157 if borrow != 0 {
158 var carry uint64
159 f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, FieldP.N[0].Lo, 0)
160 f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, FieldP.N[0].Hi, carry)
161 f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, FieldP.N[1].Lo, carry)
162 f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, FieldP.N[1].Hi, carry)
163 }
164 return f
165 }
166
167 // Negate sets f = -a mod p.
168 func (f *FieldElement) Negate(a *FieldElement) *FieldElement {
169 if a.IsZero() {
170 *f = FieldZero
171 return f
172 }
173 // f = p - a
174 var borrow uint64
175 f.N[0].Lo, borrow = bits.Sub64(FieldP.N[0].Lo, a.N[0].Lo, 0)
176 f.N[0].Hi, borrow = bits.Sub64(FieldP.N[0].Hi, a.N[0].Hi, borrow)
177 f.N[1].Lo, borrow = bits.Sub64(FieldP.N[1].Lo, a.N[1].Lo, borrow)
178 f.N[1].Hi, _ = bits.Sub64(FieldP.N[1].Hi, a.N[1].Hi, borrow)
179 return f
180 }
181
182 // Mul sets f = a * b mod p.
183 func (f *FieldElement) Mul(a, b *FieldElement) *FieldElement {
184 // Compute 512-bit product
185 var prod [8]uint64
186 fieldMul512(&prod, a, b)
187
188 // Reduce mod p using secp256k1's special structure
189 fieldReduce512(f, &prod)
190 return f
191 }
192
193 // fieldMul512 computes the 512-bit product of two 256-bit field elements.
194 func fieldMul512(prod *[8]uint64, a, b *FieldElement) {
195 aLimbs := [4]uint64{a.N[0].Lo, a.N[0].Hi, a.N[1].Lo, a.N[1].Hi}
196 bLimbs := [4]uint64{b.N[0].Lo, b.N[0].Hi, b.N[1].Lo, b.N[1].Hi}
197
198 // Clear product
199 for i := range prod {
200 prod[i] = 0
201 }
202
203 // Schoolbook multiplication
204 for i := 0; i < 4; i++ {
205 var carry uint64
206 for j := 0; j < 4; j++ {
207 hi, lo := bits.Mul64(aLimbs[i], bLimbs[j])
208 lo, c := bits.Add64(lo, prod[i+j], 0)
209 hi, _ = bits.Add64(hi, 0, c)
210 lo, c = bits.Add64(lo, carry, 0)
211 hi, _ = bits.Add64(hi, 0, c)
212 prod[i+j] = lo
213 carry = hi
214 }
215 prod[i+4] = carry
216 }
217 }
218
219 // fieldReduce512 reduces a 512-bit value mod p using secp256k1's special structure.
220 // p = 2^256 - 2^32 - 977, so 2^256 ≡ 2^32 + 977 (mod p)
221 func fieldReduce512(f *FieldElement, prod *[8]uint64) {
222 // The key insight: if we have a 512-bit number split as H*2^256 + L
223 // then H*2^256 + L ≡ H*(2^32 + 977) + L (mod p)
224
225 // Extract low and high 256-bit parts
226 low := [4]uint64{prod[0], prod[1], prod[2], prod[3]}
227 high := [4]uint64{prod[4], prod[5], prod[6], prod[7]}
228
229 // Compute high * (2^32 + 977) = high * 0x1000003D1
230 // This gives us at most a 289-bit result (256 + 33 bits)
231 const c = uint64(0x1000003D1)
232
233 var reduction [5]uint64
234 var carry uint64
235
236 for i := 0; i < 4; i++ {
237 hi, lo := bits.Mul64(high[i], c)
238 lo, cc := bits.Add64(lo, carry, 0)
239 hi, _ = bits.Add64(hi, 0, cc)
240 reduction[i] = lo
241 carry = hi
242 }
243 reduction[4] = carry
244
245 // Add low + reduction
246 var result [5]uint64
247 carry = 0
248 for i := 0; i < 4; i++ {
249 result[i], carry = bits.Add64(low[i], reduction[i], carry)
250 }
251 result[4] = carry + reduction[4]
252
253 // If result[4] is non-zero, we need to reduce again
254 // result[4] * 2^256 ≡ result[4] * (2^32 + 977) (mod p)
255 if result[4] != 0 {
256 hi, lo := bits.Mul64(result[4], c)
257 result[0], carry = bits.Add64(result[0], lo, 0)
258 result[1], carry = bits.Add64(result[1], hi, carry)
259 result[2], carry = bits.Add64(result[2], 0, carry)
260 result[3], _ = bits.Add64(result[3], 0, carry)
261 result[4] = 0
262 }
263
264 // Store result
265 f.N[0].Lo = result[0]
266 f.N[0].Hi = result[1]
267 f.N[1].Lo = result[2]
268 f.N[1].Hi = result[3]
269
270 // Final reduction if >= p
271 if f.checkOverflow() {
272 f.reduce()
273 }
274 }
275
276 // Sqr sets f = a^2 mod p.
277 func (f *FieldElement) Sqr(a *FieldElement) *FieldElement {
278 // Optimized squaring could save some multiplications, but for now use Mul
279 return f.Mul(a, a)
280 }
281
282 // Inverse sets f = a^(-1) mod p using Fermat's little theorem.
283 // a^(-1) = a^(p-2) mod p
284 func (f *FieldElement) Inverse(a *FieldElement) *FieldElement {
285 // p-2 in bytes (big-endian)
286 // p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
287 // p-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2D
288 pMinus2 := [32]byte{
289 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
290 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
291 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
292 0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFC, 0x2D,
293 }
294
295 var result, base FieldElement
296 result = FieldOne
297 base = *a
298
299 for i := 0; i < 32; i++ {
300 b := pMinus2[31-i]
301 for j := 0; j < 8; j++ {
302 if (b>>j)&1 == 1 {
303 result.Mul(&result, &base)
304 }
305 base.Sqr(&base)
306 }
307 }
308
309 *f = result
310 return f
311 }
312
313 // Sqrt sets f = sqrt(a) mod p if it exists, returns true if successful.
314 // For secp256k1, p ≡ 3 (mod 4), so sqrt(a) = a^((p+1)/4) mod p
315 func (f *FieldElement) Sqrt(a *FieldElement) bool {
316 // (p+1)/4 in bytes
317 // p+1 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC30
318 // (p+1)/4 = 3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFBFFFFF0C
319 pPlus1Div4 := [32]byte{
320 0x3F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
321 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
322 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
323 0xFF, 0xFF, 0xFF, 0xFF, 0xBF, 0xFF, 0xFF, 0x0C,
324 }
325
326 var result, base FieldElement
327 result = FieldOne
328 base = *a
329
330 for i := 0; i < 32; i++ {
331 b := pPlus1Div4[31-i]
332 for j := 0; j < 8; j++ {
333 if (b>>j)&1 == 1 {
334 result.Mul(&result, &base)
335 }
336 base.Sqr(&base)
337 }
338 }
339
340 // Verify: result^2 should equal a
341 var check FieldElement
342 check.Sqr(&result)
343
344 if check.Equal(a) {
345 *f = result
346 return true
347 }
348 return false
349 }
350
351 // MulInt sets f = a * n mod p where n is a small integer.
352 func (f *FieldElement) MulInt(a *FieldElement, n uint64) *FieldElement {
353 if n == 0 {
354 *f = FieldZero
355 return f
356 }
357 if n == 1 {
358 *f = *a
359 return f
360 }
361
362 // Multiply by small integer using proper carry chain
363 // We need to compute a 320-bit result (256 + 64 bits max)
364 var result [5]uint64
365 var carry uint64
366
367 // Multiply each 64-bit limb by n
368 var hi uint64
369 hi, result[0] = bits.Mul64(a.N[0].Lo, n)
370 carry = hi
371
372 hi, result[1] = bits.Mul64(a.N[0].Hi, n)
373 result[1], carry = bits.Add64(result[1], carry, 0)
374 carry = hi + carry // carry can be at most 1 here, so no overflow
375
376 hi, result[2] = bits.Mul64(a.N[1].Lo, n)
377 result[2], carry = bits.Add64(result[2], carry, 0)
378 carry = hi + carry
379
380 hi, result[3] = bits.Mul64(a.N[1].Hi, n)
381 result[3], carry = bits.Add64(result[3], carry, 0)
382 result[4] = hi + carry
383
384 // Store preliminary result
385 f.N[0].Lo = result[0]
386 f.N[0].Hi = result[1]
387 f.N[1].Lo = result[2]
388 f.N[1].Hi = result[3]
389
390 // Reduce overflow
391 if result[4] != 0 {
392 // overflow * 2^256 ≡ overflow * (2^32 + 977) (mod p)
393 hi, lo := bits.Mul64(result[4], 0x1000003D1)
394 f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, lo, 0)
395 f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, hi, carry)
396 f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
397 f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
398 }
399
400 if f.checkOverflow() {
401 f.reduce()
402 }
403 return f
404 }
405
406 // Double sets f = 2*a mod p (optimized addition).
407 func (f *FieldElement) Double(a *FieldElement) *FieldElement {
408 return f.Add(a, a)
409 }
410
411 // Half sets f = a/2 mod p.
412 func (f *FieldElement) Half(a *FieldElement) *FieldElement {
413 // If a is even, just shift right
414 // If a is odd, add p first (which makes it even), then shift right
415 var result FieldElement = *a
416
417 if result.N[0].Lo&1 == 1 {
418 // Add p
419 var carry uint64
420 result.N[0].Lo, carry = bits.Add64(result.N[0].Lo, FieldP.N[0].Lo, 0)
421 result.N[0].Hi, carry = bits.Add64(result.N[0].Hi, FieldP.N[0].Hi, carry)
422 result.N[1].Lo, carry = bits.Add64(result.N[1].Lo, FieldP.N[1].Lo, carry)
423 result.N[1].Hi, _ = bits.Add64(result.N[1].Hi, FieldP.N[1].Hi, carry)
424 }
425
426 // Shift right by 1
427 f.N[0].Lo = (result.N[0].Lo >> 1) | (result.N[0].Hi << 63)
428 f.N[0].Hi = (result.N[0].Hi >> 1) | (result.N[1].Lo << 63)
429 f.N[1].Lo = (result.N[1].Lo >> 1) | (result.N[1].Hi << 63)
430 f.N[1].Hi = result.N[1].Hi >> 1
431
432 return f
433 }
434
435 // CMov conditionally moves b into f if cond is true (constant-time).
436 func (f *FieldElement) CMov(b *FieldElement, cond bool) *FieldElement {
437 mask := uint64(0)
438 if cond {
439 mask = ^uint64(0)
440 }
441 f.N[0].Lo = (f.N[0].Lo &^ mask) | (b.N[0].Lo & mask)
442 f.N[0].Hi = (f.N[0].Hi &^ mask) | (b.N[0].Hi & mask)
443 f.N[1].Lo = (f.N[1].Lo &^ mask) | (b.N[1].Lo & mask)
444 f.N[1].Hi = (f.N[1].Hi &^ mask) | (b.N[1].Hi & mask)
445 return f
446 }
447