field.go raw

   1  package avx
   2  
   3  import "math/bits"
   4  
   5  // Field operations modulo the secp256k1 field prime p.
   6  // p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
   7  //   = 2^256 - 2^32 - 977
   8  
   9  // SetBytes sets a field element from a 32-byte big-endian slice.
  10  // Returns true if the value was >= p and was reduced.
  11  func (f *FieldElement) SetBytes(b []byte) bool {
  12  	if len(b) != 32 {
  13  		panic("field element must be 32 bytes")
  14  	}
  15  
  16  	// Convert big-endian bytes to little-endian limbs
  17  	f.N[0].Lo = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
  18  		uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
  19  	f.N[0].Hi = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
  20  		uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
  21  	f.N[1].Lo = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
  22  		uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
  23  	f.N[1].Hi = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
  24  		uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
  25  
  26  	// Check overflow and reduce if necessary
  27  	overflow := f.checkOverflow()
  28  	if overflow {
  29  		f.reduce()
  30  	}
  31  	return overflow
  32  }
  33  
  34  // Bytes returns the field element as a 32-byte big-endian slice.
  35  func (f *FieldElement) Bytes() [32]byte {
  36  	var b [32]byte
  37  	b[31] = byte(f.N[0].Lo)
  38  	b[30] = byte(f.N[0].Lo >> 8)
  39  	b[29] = byte(f.N[0].Lo >> 16)
  40  	b[28] = byte(f.N[0].Lo >> 24)
  41  	b[27] = byte(f.N[0].Lo >> 32)
  42  	b[26] = byte(f.N[0].Lo >> 40)
  43  	b[25] = byte(f.N[0].Lo >> 48)
  44  	b[24] = byte(f.N[0].Lo >> 56)
  45  
  46  	b[23] = byte(f.N[0].Hi)
  47  	b[22] = byte(f.N[0].Hi >> 8)
  48  	b[21] = byte(f.N[0].Hi >> 16)
  49  	b[20] = byte(f.N[0].Hi >> 24)
  50  	b[19] = byte(f.N[0].Hi >> 32)
  51  	b[18] = byte(f.N[0].Hi >> 40)
  52  	b[17] = byte(f.N[0].Hi >> 48)
  53  	b[16] = byte(f.N[0].Hi >> 56)
  54  
  55  	b[15] = byte(f.N[1].Lo)
  56  	b[14] = byte(f.N[1].Lo >> 8)
  57  	b[13] = byte(f.N[1].Lo >> 16)
  58  	b[12] = byte(f.N[1].Lo >> 24)
  59  	b[11] = byte(f.N[1].Lo >> 32)
  60  	b[10] = byte(f.N[1].Lo >> 40)
  61  	b[9] = byte(f.N[1].Lo >> 48)
  62  	b[8] = byte(f.N[1].Lo >> 56)
  63  
  64  	b[7] = byte(f.N[1].Hi)
  65  	b[6] = byte(f.N[1].Hi >> 8)
  66  	b[5] = byte(f.N[1].Hi >> 16)
  67  	b[4] = byte(f.N[1].Hi >> 24)
  68  	b[3] = byte(f.N[1].Hi >> 32)
  69  	b[2] = byte(f.N[1].Hi >> 40)
  70  	b[1] = byte(f.N[1].Hi >> 48)
  71  	b[0] = byte(f.N[1].Hi >> 56)
  72  
  73  	return b
  74  }
  75  
  76  // IsZero returns true if the field element is zero.
  77  func (f *FieldElement) IsZero() bool {
  78  	return f.N[0].IsZero() && f.N[1].IsZero()
  79  }
  80  
  81  // IsOne returns true if the field element is one.
  82  func (f *FieldElement) IsOne() bool {
  83  	return f.N[0].Lo == 1 && f.N[0].Hi == 0 && f.N[1].IsZero()
  84  }
  85  
  86  // Equal returns true if two field elements are equal.
  87  func (f *FieldElement) Equal(other *FieldElement) bool {
  88  	return f.N[0].Lo == other.N[0].Lo && f.N[0].Hi == other.N[0].Hi &&
  89  		f.N[1].Lo == other.N[1].Lo && f.N[1].Hi == other.N[1].Hi
  90  }
  91  
  92  // IsOdd returns true if the field element is odd.
  93  func (f *FieldElement) IsOdd() bool {
  94  	return f.N[0].Lo&1 == 1
  95  }
  96  
  97  // checkOverflow returns true if f >= p.
  98  func (f *FieldElement) checkOverflow() bool {
  99  	// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
 100  	// Compare high to low
 101  	if f.N[1].Hi > FieldP.N[1].Hi {
 102  		return true
 103  	}
 104  	if f.N[1].Hi < FieldP.N[1].Hi {
 105  		return false
 106  	}
 107  	if f.N[1].Lo > FieldP.N[1].Lo {
 108  		return true
 109  	}
 110  	if f.N[1].Lo < FieldP.N[1].Lo {
 111  		return false
 112  	}
 113  	if f.N[0].Hi > FieldP.N[0].Hi {
 114  		return true
 115  	}
 116  	if f.N[0].Hi < FieldP.N[0].Hi {
 117  		return false
 118  	}
 119  	return f.N[0].Lo >= FieldP.N[0].Lo
 120  }
 121  
 122  // reduce reduces f modulo p by adding the complement (2^256 - p = 2^32 + 977).
 123  func (f *FieldElement) reduce() {
 124  	// f = f - p = f + (2^256 - p) mod 2^256
 125  	// 2^256 - p = 0x1000003D1
 126  	var carry uint64
 127  	f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, 0x1000003D1, 0)
 128  	f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, 0, carry)
 129  	f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
 130  	f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
 131  }
 132  
 133  // Add sets f = a + b mod p.
 134  func (f *FieldElement) Add(a, b *FieldElement) *FieldElement {
 135  	var carry uint64
 136  	f.N[0].Lo, carry = bits.Add64(a.N[0].Lo, b.N[0].Lo, 0)
 137  	f.N[0].Hi, carry = bits.Add64(a.N[0].Hi, b.N[0].Hi, carry)
 138  	f.N[1].Lo, carry = bits.Add64(a.N[1].Lo, b.N[1].Lo, carry)
 139  	f.N[1].Hi, carry = bits.Add64(a.N[1].Hi, b.N[1].Hi, carry)
 140  
 141  	// If there was a carry or if result >= p, reduce
 142  	if carry != 0 || f.checkOverflow() {
 143  		f.reduce()
 144  	}
 145  	return f
 146  }
 147  
 148  // Sub sets f = a - b mod p.
 149  func (f *FieldElement) Sub(a, b *FieldElement) *FieldElement {
 150  	var borrow uint64
 151  	f.N[0].Lo, borrow = bits.Sub64(a.N[0].Lo, b.N[0].Lo, 0)
 152  	f.N[0].Hi, borrow = bits.Sub64(a.N[0].Hi, b.N[0].Hi, borrow)
 153  	f.N[1].Lo, borrow = bits.Sub64(a.N[1].Lo, b.N[1].Lo, borrow)
 154  	f.N[1].Hi, borrow = bits.Sub64(a.N[1].Hi, b.N[1].Hi, borrow)
 155  
 156  	// If there was a borrow, add p back
 157  	if borrow != 0 {
 158  		var carry uint64
 159  		f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, FieldP.N[0].Lo, 0)
 160  		f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, FieldP.N[0].Hi, carry)
 161  		f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, FieldP.N[1].Lo, carry)
 162  		f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, FieldP.N[1].Hi, carry)
 163  	}
 164  	return f
 165  }
 166  
 167  // Negate sets f = -a mod p.
 168  func (f *FieldElement) Negate(a *FieldElement) *FieldElement {
 169  	if a.IsZero() {
 170  		*f = FieldZero
 171  		return f
 172  	}
 173  	// f = p - a
 174  	var borrow uint64
 175  	f.N[0].Lo, borrow = bits.Sub64(FieldP.N[0].Lo, a.N[0].Lo, 0)
 176  	f.N[0].Hi, borrow = bits.Sub64(FieldP.N[0].Hi, a.N[0].Hi, borrow)
 177  	f.N[1].Lo, borrow = bits.Sub64(FieldP.N[1].Lo, a.N[1].Lo, borrow)
 178  	f.N[1].Hi, _ = bits.Sub64(FieldP.N[1].Hi, a.N[1].Hi, borrow)
 179  	return f
 180  }
 181  
 182  // Mul sets f = a * b mod p.
 183  func (f *FieldElement) Mul(a, b *FieldElement) *FieldElement {
 184  	// Compute 512-bit product
 185  	var prod [8]uint64
 186  	fieldMul512(&prod, a, b)
 187  
 188  	// Reduce mod p using secp256k1's special structure
 189  	fieldReduce512(f, &prod)
 190  	return f
 191  }
 192  
 193  // fieldMul512 computes the 512-bit product of two 256-bit field elements.
 194  func fieldMul512(prod *[8]uint64, a, b *FieldElement) {
 195  	aLimbs := [4]uint64{a.N[0].Lo, a.N[0].Hi, a.N[1].Lo, a.N[1].Hi}
 196  	bLimbs := [4]uint64{b.N[0].Lo, b.N[0].Hi, b.N[1].Lo, b.N[1].Hi}
 197  
 198  	// Clear product
 199  	for i := range prod {
 200  		prod[i] = 0
 201  	}
 202  
 203  	// Schoolbook multiplication
 204  	for i := 0; i < 4; i++ {
 205  		var carry uint64
 206  		for j := 0; j < 4; j++ {
 207  			hi, lo := bits.Mul64(aLimbs[i], bLimbs[j])
 208  			lo, c := bits.Add64(lo, prod[i+j], 0)
 209  			hi, _ = bits.Add64(hi, 0, c)
 210  			lo, c = bits.Add64(lo, carry, 0)
 211  			hi, _ = bits.Add64(hi, 0, c)
 212  			prod[i+j] = lo
 213  			carry = hi
 214  		}
 215  		prod[i+4] = carry
 216  	}
 217  }
 218  
 219  // fieldReduce512 reduces a 512-bit value mod p using secp256k1's special structure.
 220  // p = 2^256 - 2^32 - 977, so 2^256 ≡ 2^32 + 977 (mod p)
 221  func fieldReduce512(f *FieldElement, prod *[8]uint64) {
 222  	// The key insight: if we have a 512-bit number split as H*2^256 + L
 223  	// then H*2^256 + L ≡ H*(2^32 + 977) + L (mod p)
 224  
 225  	// Extract low and high 256-bit parts
 226  	low := [4]uint64{prod[0], prod[1], prod[2], prod[3]}
 227  	high := [4]uint64{prod[4], prod[5], prod[6], prod[7]}
 228  
 229  	// Compute high * (2^32 + 977) = high * 0x1000003D1
 230  	// This gives us at most a 289-bit result (256 + 33 bits)
 231  	const c = uint64(0x1000003D1)
 232  
 233  	var reduction [5]uint64
 234  	var carry uint64
 235  
 236  	for i := 0; i < 4; i++ {
 237  		hi, lo := bits.Mul64(high[i], c)
 238  		lo, cc := bits.Add64(lo, carry, 0)
 239  		hi, _ = bits.Add64(hi, 0, cc)
 240  		reduction[i] = lo
 241  		carry = hi
 242  	}
 243  	reduction[4] = carry
 244  
 245  	// Add low + reduction
 246  	var result [5]uint64
 247  	carry = 0
 248  	for i := 0; i < 4; i++ {
 249  		result[i], carry = bits.Add64(low[i], reduction[i], carry)
 250  	}
 251  	result[4] = carry + reduction[4]
 252  
 253  	// If result[4] is non-zero, we need to reduce again
 254  	// result[4] * 2^256 ≡ result[4] * (2^32 + 977) (mod p)
 255  	if result[4] != 0 {
 256  		hi, lo := bits.Mul64(result[4], c)
 257  		result[0], carry = bits.Add64(result[0], lo, 0)
 258  		result[1], carry = bits.Add64(result[1], hi, carry)
 259  		result[2], carry = bits.Add64(result[2], 0, carry)
 260  		result[3], _ = bits.Add64(result[3], 0, carry)
 261  		result[4] = 0
 262  	}
 263  
 264  	// Store result
 265  	f.N[0].Lo = result[0]
 266  	f.N[0].Hi = result[1]
 267  	f.N[1].Lo = result[2]
 268  	f.N[1].Hi = result[3]
 269  
 270  	// Final reduction if >= p
 271  	if f.checkOverflow() {
 272  		f.reduce()
 273  	}
 274  }
 275  
 276  // Sqr sets f = a^2 mod p.
 277  func (f *FieldElement) Sqr(a *FieldElement) *FieldElement {
 278  	// Optimized squaring could save some multiplications, but for now use Mul
 279  	return f.Mul(a, a)
 280  }
 281  
 282  // Inverse sets f = a^(-1) mod p using Fermat's little theorem.
 283  // a^(-1) = a^(p-2) mod p
 284  func (f *FieldElement) Inverse(a *FieldElement) *FieldElement {
 285  	// p-2 in bytes (big-endian)
 286  	// p = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
 287  	// p-2 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2D
 288  	pMinus2 := [32]byte{
 289  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
 290  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
 291  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
 292  		0xFF, 0xFF, 0xFF, 0xFE, 0xFF, 0xFF, 0xFC, 0x2D,
 293  	}
 294  
 295  	var result, base FieldElement
 296  	result = FieldOne
 297  	base = *a
 298  
 299  	for i := 0; i < 32; i++ {
 300  		b := pMinus2[31-i]
 301  		for j := 0; j < 8; j++ {
 302  			if (b>>j)&1 == 1 {
 303  				result.Mul(&result, &base)
 304  			}
 305  			base.Sqr(&base)
 306  		}
 307  	}
 308  
 309  	*f = result
 310  	return f
 311  }
 312  
 313  // Sqrt sets f = sqrt(a) mod p if it exists, returns true if successful.
 314  // For secp256k1, p ≡ 3 (mod 4), so sqrt(a) = a^((p+1)/4) mod p
 315  func (f *FieldElement) Sqrt(a *FieldElement) bool {
 316  	// (p+1)/4 in bytes
 317  	// p+1 = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC30
 318  	// (p+1)/4 = 3FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFBFFFFF0C
 319  	pPlus1Div4 := [32]byte{
 320  		0x3F, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
 321  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
 322  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
 323  		0xFF, 0xFF, 0xFF, 0xFF, 0xBF, 0xFF, 0xFF, 0x0C,
 324  	}
 325  
 326  	var result, base FieldElement
 327  	result = FieldOne
 328  	base = *a
 329  
 330  	for i := 0; i < 32; i++ {
 331  		b := pPlus1Div4[31-i]
 332  		for j := 0; j < 8; j++ {
 333  			if (b>>j)&1 == 1 {
 334  				result.Mul(&result, &base)
 335  			}
 336  			base.Sqr(&base)
 337  		}
 338  	}
 339  
 340  	// Verify: result^2 should equal a
 341  	var check FieldElement
 342  	check.Sqr(&result)
 343  
 344  	if check.Equal(a) {
 345  		*f = result
 346  		return true
 347  	}
 348  	return false
 349  }
 350  
 351  // MulInt sets f = a * n mod p where n is a small integer.
 352  func (f *FieldElement) MulInt(a *FieldElement, n uint64) *FieldElement {
 353  	if n == 0 {
 354  		*f = FieldZero
 355  		return f
 356  	}
 357  	if n == 1 {
 358  		*f = *a
 359  		return f
 360  	}
 361  
 362  	// Multiply by small integer using proper carry chain
 363  	// We need to compute a 320-bit result (256 + 64 bits max)
 364  	var result [5]uint64
 365  	var carry uint64
 366  
 367  	// Multiply each 64-bit limb by n
 368  	var hi uint64
 369  	hi, result[0] = bits.Mul64(a.N[0].Lo, n)
 370  	carry = hi
 371  
 372  	hi, result[1] = bits.Mul64(a.N[0].Hi, n)
 373  	result[1], carry = bits.Add64(result[1], carry, 0)
 374  	carry = hi + carry // carry can be at most 1 here, so no overflow
 375  
 376  	hi, result[2] = bits.Mul64(a.N[1].Lo, n)
 377  	result[2], carry = bits.Add64(result[2], carry, 0)
 378  	carry = hi + carry
 379  
 380  	hi, result[3] = bits.Mul64(a.N[1].Hi, n)
 381  	result[3], carry = bits.Add64(result[3], carry, 0)
 382  	result[4] = hi + carry
 383  
 384  	// Store preliminary result
 385  	f.N[0].Lo = result[0]
 386  	f.N[0].Hi = result[1]
 387  	f.N[1].Lo = result[2]
 388  	f.N[1].Hi = result[3]
 389  
 390  	// Reduce overflow
 391  	if result[4] != 0 {
 392  		// overflow * 2^256 ≡ overflow * (2^32 + 977) (mod p)
 393  		hi, lo := bits.Mul64(result[4], 0x1000003D1)
 394  		f.N[0].Lo, carry = bits.Add64(f.N[0].Lo, lo, 0)
 395  		f.N[0].Hi, carry = bits.Add64(f.N[0].Hi, hi, carry)
 396  		f.N[1].Lo, carry = bits.Add64(f.N[1].Lo, 0, carry)
 397  		f.N[1].Hi, _ = bits.Add64(f.N[1].Hi, 0, carry)
 398  	}
 399  
 400  	if f.checkOverflow() {
 401  		f.reduce()
 402  	}
 403  	return f
 404  }
 405  
 406  // Double sets f = 2*a mod p (optimized addition).
 407  func (f *FieldElement) Double(a *FieldElement) *FieldElement {
 408  	return f.Add(a, a)
 409  }
 410  
 411  // Half sets f = a/2 mod p.
 412  func (f *FieldElement) Half(a *FieldElement) *FieldElement {
 413  	// If a is even, just shift right
 414  	// If a is odd, add p first (which makes it even), then shift right
 415  	var result FieldElement = *a
 416  
 417  	if result.N[0].Lo&1 == 1 {
 418  		// Add p
 419  		var carry uint64
 420  		result.N[0].Lo, carry = bits.Add64(result.N[0].Lo, FieldP.N[0].Lo, 0)
 421  		result.N[0].Hi, carry = bits.Add64(result.N[0].Hi, FieldP.N[0].Hi, carry)
 422  		result.N[1].Lo, carry = bits.Add64(result.N[1].Lo, FieldP.N[1].Lo, carry)
 423  		result.N[1].Hi, _ = bits.Add64(result.N[1].Hi, FieldP.N[1].Hi, carry)
 424  	}
 425  
 426  	// Shift right by 1
 427  	f.N[0].Lo = (result.N[0].Lo >> 1) | (result.N[0].Hi << 63)
 428  	f.N[0].Hi = (result.N[0].Hi >> 1) | (result.N[1].Lo << 63)
 429  	f.N[1].Lo = (result.N[1].Lo >> 1) | (result.N[1].Hi << 63)
 430  	f.N[1].Hi = result.N[1].Hi >> 1
 431  
 432  	return f
 433  }
 434  
 435  // CMov conditionally moves b into f if cond is true (constant-time).
 436  func (f *FieldElement) CMov(b *FieldElement, cond bool) *FieldElement {
 437  	mask := uint64(0)
 438  	if cond {
 439  		mask = ^uint64(0)
 440  	}
 441  	f.N[0].Lo = (f.N[0].Lo &^ mask) | (b.N[0].Lo & mask)
 442  	f.N[0].Hi = (f.N[0].Hi &^ mask) | (b.N[0].Hi & mask)
 443  	f.N[1].Lo = (f.N[1].Lo &^ mask) | (b.N[1].Lo & mask)
 444  	f.N[1].Hi = (f.N[1].Hi &^ mask) | (b.N[1].Hi & mask)
 445  	return f
 446  }
 447