scalar.go raw

   1  package avx
   2  
   3  import "math/bits"
   4  
   5  // Scalar operations modulo the secp256k1 group order n.
   6  // n = FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
   7  
   8  // SetBytes sets a scalar from a 32-byte big-endian slice.
   9  // Returns true if the value was >= n and was reduced.
  10  func (s *Scalar) SetBytes(b []byte) bool {
  11  	if len(b) != 32 {
  12  		panic("scalar must be 32 bytes")
  13  	}
  14  
  15  	// Convert big-endian bytes to little-endian limbs
  16  	s.D[0].Lo = uint64(b[31]) | uint64(b[30])<<8 | uint64(b[29])<<16 | uint64(b[28])<<24 |
  17  		uint64(b[27])<<32 | uint64(b[26])<<40 | uint64(b[25])<<48 | uint64(b[24])<<56
  18  	s.D[0].Hi = uint64(b[23]) | uint64(b[22])<<8 | uint64(b[21])<<16 | uint64(b[20])<<24 |
  19  		uint64(b[19])<<32 | uint64(b[18])<<40 | uint64(b[17])<<48 | uint64(b[16])<<56
  20  	s.D[1].Lo = uint64(b[15]) | uint64(b[14])<<8 | uint64(b[13])<<16 | uint64(b[12])<<24 |
  21  		uint64(b[11])<<32 | uint64(b[10])<<40 | uint64(b[9])<<48 | uint64(b[8])<<56
  22  	s.D[1].Hi = uint64(b[7]) | uint64(b[6])<<8 | uint64(b[5])<<16 | uint64(b[4])<<24 |
  23  		uint64(b[3])<<32 | uint64(b[2])<<40 | uint64(b[1])<<48 | uint64(b[0])<<56
  24  
  25  	// Check overflow and reduce if necessary
  26  	overflow := s.checkOverflow()
  27  	if overflow {
  28  		s.reduce()
  29  	}
  30  	return overflow
  31  }
  32  
  33  // Bytes returns the scalar as a 32-byte big-endian slice.
  34  func (s *Scalar) Bytes() [32]byte {
  35  	var b [32]byte
  36  	b[31] = byte(s.D[0].Lo)
  37  	b[30] = byte(s.D[0].Lo >> 8)
  38  	b[29] = byte(s.D[0].Lo >> 16)
  39  	b[28] = byte(s.D[0].Lo >> 24)
  40  	b[27] = byte(s.D[0].Lo >> 32)
  41  	b[26] = byte(s.D[0].Lo >> 40)
  42  	b[25] = byte(s.D[0].Lo >> 48)
  43  	b[24] = byte(s.D[0].Lo >> 56)
  44  
  45  	b[23] = byte(s.D[0].Hi)
  46  	b[22] = byte(s.D[0].Hi >> 8)
  47  	b[21] = byte(s.D[0].Hi >> 16)
  48  	b[20] = byte(s.D[0].Hi >> 24)
  49  	b[19] = byte(s.D[0].Hi >> 32)
  50  	b[18] = byte(s.D[0].Hi >> 40)
  51  	b[17] = byte(s.D[0].Hi >> 48)
  52  	b[16] = byte(s.D[0].Hi >> 56)
  53  
  54  	b[15] = byte(s.D[1].Lo)
  55  	b[14] = byte(s.D[1].Lo >> 8)
  56  	b[13] = byte(s.D[1].Lo >> 16)
  57  	b[12] = byte(s.D[1].Lo >> 24)
  58  	b[11] = byte(s.D[1].Lo >> 32)
  59  	b[10] = byte(s.D[1].Lo >> 40)
  60  	b[9] = byte(s.D[1].Lo >> 48)
  61  	b[8] = byte(s.D[1].Lo >> 56)
  62  
  63  	b[7] = byte(s.D[1].Hi)
  64  	b[6] = byte(s.D[1].Hi >> 8)
  65  	b[5] = byte(s.D[1].Hi >> 16)
  66  	b[4] = byte(s.D[1].Hi >> 24)
  67  	b[3] = byte(s.D[1].Hi >> 32)
  68  	b[2] = byte(s.D[1].Hi >> 40)
  69  	b[1] = byte(s.D[1].Hi >> 48)
  70  	b[0] = byte(s.D[1].Hi >> 56)
  71  
  72  	return b
  73  }
  74  
  75  // IsZero returns true if the scalar is zero.
  76  func (s *Scalar) IsZero() bool {
  77  	return s.D[0].IsZero() && s.D[1].IsZero()
  78  }
  79  
  80  // IsOne returns true if the scalar is one.
  81  func (s *Scalar) IsOne() bool {
  82  	return s.D[0].Lo == 1 && s.D[0].Hi == 0 && s.D[1].IsZero()
  83  }
  84  
  85  // Equal returns true if two scalars are equal.
  86  func (s *Scalar) Equal(other *Scalar) bool {
  87  	return s.D[0].Lo == other.D[0].Lo && s.D[0].Hi == other.D[0].Hi &&
  88  		s.D[1].Lo == other.D[1].Lo && s.D[1].Hi == other.D[1].Hi
  89  }
  90  
  91  // checkOverflow returns true if s >= n.
  92  func (s *Scalar) checkOverflow() bool {
  93  	// Compare high to low
  94  	if s.D[1].Hi > ScalarN.D[1].Hi {
  95  		return true
  96  	}
  97  	if s.D[1].Hi < ScalarN.D[1].Hi {
  98  		return false
  99  	}
 100  	if s.D[1].Lo > ScalarN.D[1].Lo {
 101  		return true
 102  	}
 103  	if s.D[1].Lo < ScalarN.D[1].Lo {
 104  		return false
 105  	}
 106  	if s.D[0].Hi > ScalarN.D[0].Hi {
 107  		return true
 108  	}
 109  	if s.D[0].Hi < ScalarN.D[0].Hi {
 110  		return false
 111  	}
 112  	return s.D[0].Lo >= ScalarN.D[0].Lo
 113  }
 114  
 115  // reduce reduces s modulo n by adding the complement (2^256 - n).
 116  func (s *Scalar) reduce() {
 117  	// s = s - n = s + (2^256 - n) mod 2^256
 118  	var carry uint64
 119  	s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
 120  	s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, carry)
 121  	s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, carry)
 122  	s.D[1].Hi, _ = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, carry)
 123  }
 124  
 125  // Add sets s = a + b mod n.
 126  func (s *Scalar) Add(a, b *Scalar) *Scalar {
 127  	var carry uint64
 128  	s.D[0].Lo, carry = bits.Add64(a.D[0].Lo, b.D[0].Lo, 0)
 129  	s.D[0].Hi, carry = bits.Add64(a.D[0].Hi, b.D[0].Hi, carry)
 130  	s.D[1].Lo, carry = bits.Add64(a.D[1].Lo, b.D[1].Lo, carry)
 131  	s.D[1].Hi, carry = bits.Add64(a.D[1].Hi, b.D[1].Hi, carry)
 132  
 133  	// If there was a carry or if result >= n, reduce
 134  	if carry != 0 || s.checkOverflow() {
 135  		s.reduce()
 136  	}
 137  	return s
 138  }
 139  
 140  // Sub sets s = a - b mod n.
 141  func (s *Scalar) Sub(a, b *Scalar) *Scalar {
 142  	var borrow uint64
 143  	s.D[0].Lo, borrow = bits.Sub64(a.D[0].Lo, b.D[0].Lo, 0)
 144  	s.D[0].Hi, borrow = bits.Sub64(a.D[0].Hi, b.D[0].Hi, borrow)
 145  	s.D[1].Lo, borrow = bits.Sub64(a.D[1].Lo, b.D[1].Lo, borrow)
 146  	s.D[1].Hi, borrow = bits.Sub64(a.D[1].Hi, b.D[1].Hi, borrow)
 147  
 148  	// If there was a borrow, add n back
 149  	if borrow != 0 {
 150  		var carry uint64
 151  		s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, ScalarN.D[0].Lo, 0)
 152  		s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, ScalarN.D[0].Hi, carry)
 153  		s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, ScalarN.D[1].Lo, carry)
 154  		s.D[1].Hi, _ = bits.Add64(s.D[1].Hi, ScalarN.D[1].Hi, carry)
 155  	}
 156  	return s
 157  }
 158  
 159  // Negate sets s = -a mod n.
 160  func (s *Scalar) Negate(a *Scalar) *Scalar {
 161  	if a.IsZero() {
 162  		*s = ScalarZero
 163  		return s
 164  	}
 165  	// s = n - a
 166  	var borrow uint64
 167  	s.D[0].Lo, borrow = bits.Sub64(ScalarN.D[0].Lo, a.D[0].Lo, 0)
 168  	s.D[0].Hi, borrow = bits.Sub64(ScalarN.D[0].Hi, a.D[0].Hi, borrow)
 169  	s.D[1].Lo, borrow = bits.Sub64(ScalarN.D[1].Lo, a.D[1].Lo, borrow)
 170  	s.D[1].Hi, _ = bits.Sub64(ScalarN.D[1].Hi, a.D[1].Hi, borrow)
 171  	return s
 172  }
 173  
 174  // Mul sets s = a * b mod n.
 175  func (s *Scalar) Mul(a, b *Scalar) *Scalar {
 176  	// Compute 512-bit product
 177  	var prod [8]uint64
 178  	scalarMul512(&prod, a, b)
 179  
 180  	// Reduce mod n
 181  	scalarReduce512(s, &prod)
 182  	return s
 183  }
 184  
 185  // scalarMul512 computes the 512-bit product of two 256-bit scalars.
 186  // Result is stored in prod[0..7] where prod[0] is the least significant.
 187  func scalarMul512(prod *[8]uint64, a, b *Scalar) {
 188  	// Using schoolbook multiplication with 64-bit limbs
 189  	// a = a[0] + a[1]*2^64 + a[2]*2^128 + a[3]*2^192
 190  	// b = b[0] + b[1]*2^64 + b[2]*2^128 + b[3]*2^192
 191  
 192  	aLimbs := [4]uint64{a.D[0].Lo, a.D[0].Hi, a.D[1].Lo, a.D[1].Hi}
 193  	bLimbs := [4]uint64{b.D[0].Lo, b.D[0].Hi, b.D[1].Lo, b.D[1].Hi}
 194  
 195  	// Clear product
 196  	for i := range prod {
 197  		prod[i] = 0
 198  	}
 199  
 200  	// Schoolbook multiplication
 201  	for i := 0; i < 4; i++ {
 202  		var carry uint64
 203  		for j := 0; j < 4; j++ {
 204  			hi, lo := bits.Mul64(aLimbs[i], bLimbs[j])
 205  			lo, c := bits.Add64(lo, prod[i+j], 0)
 206  			hi, _ = bits.Add64(hi, 0, c)
 207  			lo, c = bits.Add64(lo, carry, 0)
 208  			hi, _ = bits.Add64(hi, 0, c)
 209  			prod[i+j] = lo
 210  			carry = hi
 211  		}
 212  		prod[i+4] = carry
 213  	}
 214  }
 215  
 216  // scalarReduce512 reduces a 512-bit value mod n.
 217  func scalarReduce512(s *Scalar, prod *[8]uint64) {
 218  	// Barrett reduction or simple repeated subtraction
 219  	// For now, use a simpler approach: extract high 256 bits, multiply by (2^256 mod n), add to low
 220  
 221  	// 2^256 mod n = 2^256 - n = ScalarNC (approximately 0x14551231950B75FC4...etc)
 222  	// This is a simplified reduction - a full implementation would use Barrett reduction
 223  
 224  	// Copy low 256 bits to result
 225  	s.D[0].Lo = prod[0]
 226  	s.D[0].Hi = prod[1]
 227  	s.D[1].Lo = prod[2]
 228  	s.D[1].Hi = prod[3]
 229  
 230  	// If high 256 bits are non-zero, we need to reduce
 231  	if prod[4] != 0 || prod[5] != 0 || prod[6] != 0 || prod[7] != 0 {
 232  		// high * (2^256 mod n) + low
 233  		// This is a simplified version - multiply high by NC and add
 234  		highScalar := Scalar{
 235  			D: [2]Uint128{
 236  				{Lo: prod[4], Hi: prod[5]},
 237  				{Lo: prod[6], Hi: prod[7]},
 238  			},
 239  		}
 240  
 241  		// Multiply high by NC (which is small: ~2^129)
 242  		// For correctness, we'd need full multiplication, but NC is small enough
 243  		// that we can use a simplified approach
 244  
 245  		// NC = 0x14551231950B75FC4402DA1732FC9BEBF
 246  		// NC.D[0] = {Lo: 0x402DA1732FC9BEBF, Hi: 0x4551231950B75FC4}
 247  		// NC.D[1] = {Lo: 0x1, Hi: 0}
 248  
 249  		// Approximate: high * NC ≈ high * 2^129 (since NC ≈ 2^129)
 250  		// This means we shift high left by 129 bits and add
 251  
 252  		// For a correct implementation, compute high * NC properly:
 253  		var reduction [8]uint64
 254  		ncLimbs := [4]uint64{ScalarNC.D[0].Lo, ScalarNC.D[0].Hi, ScalarNC.D[1].Lo, ScalarNC.D[1].Hi}
 255  		highLimbs := [4]uint64{highScalar.D[0].Lo, highScalar.D[0].Hi, highScalar.D[1].Lo, highScalar.D[1].Hi}
 256  
 257  		for i := 0; i < 4; i++ {
 258  			var carry uint64
 259  			for j := 0; j < 4; j++ {
 260  				hi, lo := bits.Mul64(highLimbs[i], ncLimbs[j])
 261  				lo, c := bits.Add64(lo, reduction[i+j], 0)
 262  				hi, _ = bits.Add64(hi, 0, c)
 263  				lo, c = bits.Add64(lo, carry, 0)
 264  				hi, _ = bits.Add64(hi, 0, c)
 265  				reduction[i+j] = lo
 266  				carry = hi
 267  			}
 268  			if i+4 < 8 {
 269  				reduction[i+4], _ = bits.Add64(reduction[i+4], carry, 0)
 270  			}
 271  		}
 272  
 273  		// Add reduction to s
 274  		var carry uint64
 275  		s.D[0].Lo, carry = bits.Add64(s.D[0].Lo, reduction[0], 0)
 276  		s.D[0].Hi, carry = bits.Add64(s.D[0].Hi, reduction[1], carry)
 277  		s.D[1].Lo, carry = bits.Add64(s.D[1].Lo, reduction[2], carry)
 278  		s.D[1].Hi, carry = bits.Add64(s.D[1].Hi, reduction[3], carry)
 279  
 280  		// Handle any remaining high bits by repeated reduction
 281  		// If there's a carry, it represents 2^256 which equals NC mod n
 282  		// If reduction[4..7] are non-zero, we need to reduce those too
 283  		if carry != 0 || reduction[4] != 0 || reduction[5] != 0 || reduction[6] != 0 || reduction[7] != 0 {
 284  			// The carry and reduction[4..7] together represent additional multiples of 2^256
 285  			// Each 2^256 ≡ NC (mod n), so we add (carry + reduction[4..7]) * NC
 286  
 287  			// First, handle the carry
 288  			if carry != 0 {
 289  				// carry * NC
 290  				var c uint64
 291  				s.D[0].Lo, c = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
 292  				s.D[0].Hi, c = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, c)
 293  				s.D[1].Lo, c = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, c)
 294  				s.D[1].Hi, c = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, c)
 295  
 296  				// If there's still a carry, add NC again
 297  				for c != 0 {
 298  					s.D[0].Lo, c = bits.Add64(s.D[0].Lo, ScalarNC.D[0].Lo, 0)
 299  					s.D[0].Hi, c = bits.Add64(s.D[0].Hi, ScalarNC.D[0].Hi, c)
 300  					s.D[1].Lo, c = bits.Add64(s.D[1].Lo, ScalarNC.D[1].Lo, c)
 301  					s.D[1].Hi, c = bits.Add64(s.D[1].Hi, ScalarNC.D[1].Hi, c)
 302  				}
 303  			}
 304  
 305  			// Handle reduction[4..7] if non-zero
 306  			if reduction[4] != 0 || reduction[5] != 0 || reduction[6] != 0 || reduction[7] != 0 {
 307  				// Compute reduction[4..7] * NC and add
 308  				highScalar2 := Scalar{
 309  					D: [2]Uint128{
 310  						{Lo: reduction[4], Hi: reduction[5]},
 311  						{Lo: reduction[6], Hi: reduction[7]},
 312  					},
 313  				}
 314  
 315  				var reduction2 [8]uint64
 316  				high2Limbs := [4]uint64{highScalar2.D[0].Lo, highScalar2.D[0].Hi, highScalar2.D[1].Lo, highScalar2.D[1].Hi}
 317  
 318  				for i := 0; i < 4; i++ {
 319  					var c uint64
 320  					for j := 0; j < 4; j++ {
 321  						hi, lo := bits.Mul64(high2Limbs[i], ncLimbs[j])
 322  						lo, cc := bits.Add64(lo, reduction2[i+j], 0)
 323  						hi, _ = bits.Add64(hi, 0, cc)
 324  						lo, cc = bits.Add64(lo, c, 0)
 325  						hi, _ = bits.Add64(hi, 0, cc)
 326  						reduction2[i+j] = lo
 327  						c = hi
 328  					}
 329  					if i+4 < 8 {
 330  						reduction2[i+4], _ = bits.Add64(reduction2[i+4], c, 0)
 331  					}
 332  				}
 333  
 334  				var c uint64
 335  				s.D[0].Lo, c = bits.Add64(s.D[0].Lo, reduction2[0], 0)
 336  				s.D[0].Hi, c = bits.Add64(s.D[0].Hi, reduction2[1], c)
 337  				s.D[1].Lo, c = bits.Add64(s.D[1].Lo, reduction2[2], c)
 338  				s.D[1].Hi, c = bits.Add64(s.D[1].Hi, reduction2[3], c)
 339  
 340  				// Handle cascading carries
 341  				for c != 0 || reduction2[4] != 0 || reduction2[5] != 0 || reduction2[6] != 0 || reduction2[7] != 0 {
 342  					// This case is extremely rare but handle it
 343  					for s.checkOverflow() {
 344  						s.reduce()
 345  					}
 346  					break
 347  				}
 348  			}
 349  		}
 350  	}
 351  
 352  	// Final reduction if needed
 353  	if s.checkOverflow() {
 354  		s.reduce()
 355  	}
 356  }
 357  
 358  // Sqr sets s = a^2 mod n.
 359  func (s *Scalar) Sqr(a *Scalar) *Scalar {
 360  	return s.Mul(a, a)
 361  }
 362  
 363  // Inverse sets s = a^(-1) mod n using Fermat's little theorem.
 364  // a^(-1) = a^(n-2) mod n
 365  func (s *Scalar) Inverse(a *Scalar) *Scalar {
 366  	// n-2 in binary is used for square-and-multiply
 367  	// This is a simplified implementation using binary exponentiation
 368  
 369  	var result, base Scalar
 370  	result = ScalarOne
 371  	base = *a
 372  
 373  	// n-2 bytes (big-endian)
 374  	nMinus2 := [32]byte{
 375  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
 376  		0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE,
 377  		0xBA, 0xAE, 0xDC, 0xE6, 0xAF, 0x48, 0xA0, 0x3B,
 378  		0xBF, 0xD2, 0x5E, 0x8C, 0xD0, 0x36, 0x41, 0x3F,
 379  	}
 380  
 381  	for i := 0; i < 32; i++ {
 382  		b := nMinus2[31-i]
 383  		for j := 0; j < 8; j++ {
 384  			if (b>>j)&1 == 1 {
 385  				result.Mul(&result, &base)
 386  			}
 387  			base.Sqr(&base)
 388  		}
 389  	}
 390  
 391  	*s = result
 392  	return s
 393  }
 394  
 395  // IsHigh returns true if s > n/2.
 396  func (s *Scalar) IsHigh() bool {
 397  	// Compare with n/2
 398  	if s.D[1].Hi > ScalarNHalf.D[1].Hi {
 399  		return true
 400  	}
 401  	if s.D[1].Hi < ScalarNHalf.D[1].Hi {
 402  		return false
 403  	}
 404  	if s.D[1].Lo > ScalarNHalf.D[1].Lo {
 405  		return true
 406  	}
 407  	if s.D[1].Lo < ScalarNHalf.D[1].Lo {
 408  		return false
 409  	}
 410  	if s.D[0].Hi > ScalarNHalf.D[0].Hi {
 411  		return true
 412  	}
 413  	if s.D[0].Hi < ScalarNHalf.D[0].Hi {
 414  		return false
 415  	}
 416  	return s.D[0].Lo > ScalarNHalf.D[0].Lo
 417  }
 418  
 419  // CondNegate negates s if cond is true.
 420  func (s *Scalar) CondNegate(cond bool) *Scalar {
 421  	if cond {
 422  		s.Negate(s)
 423  	}
 424  	return s
 425  }
 426