scalar_32bit.go raw

   1  //go:build js || wasm || tinygo || wasm32
   2  
   3  // Copyright (c) 2024 mleku
   4  // Adapted from github.com/decred/dcrd/dcrec/secp256k1/v4
   5  // Copyright (c) 2020-2024 The Decred developers
   6  
   7  package p256k1
   8  
   9  import (
  10  	"crypto/subtle"
  11  	"unsafe"
  12  )
  13  
  14  // Scalar represents a scalar value modulo the secp256k1 group order.
  15  // This implementation uses 8 uint32 limbs in base 2^32, optimized for 32-bit platforms.
  16  type Scalar struct {
  17  	n [8]uint32
  18  }
  19  
  20  // Scalar constants in 8x32 representation
  21  const (
  22  	// Order words (from least to most significant)
  23  	orderWord0 uint32 = 0xd0364141
  24  	orderWord1 uint32 = 0xbfd25e8c
  25  	orderWord2 uint32 = 0xaf48a03b
  26  	orderWord3 uint32 = 0xbaaedce6
  27  	orderWord4 uint32 = 0xfffffffe
  28  	orderWord5 uint32 = 0xffffffff
  29  	orderWord6 uint32 = 0xffffffff
  30  	orderWord7 uint32 = 0xffffffff
  31  
  32  	// Two's complement of order (for reduction)
  33  	orderCompWord0 uint32 = 0x2fc9bebf // ~orderWord0 + 1
  34  	orderCompWord1 uint32 = 0x402da173 // ~orderWord1
  35  	orderCompWord2 uint32 = 0x50b75fc4 // ~orderWord2
  36  	orderCompWord3 uint32 = 0x45512319 // ~orderWord3
  37  
  38  	// Half order words
  39  	halfOrderWord0 uint32 = 0x681b20a0
  40  	halfOrderWord1 uint32 = 0xdfe92f46
  41  	halfOrderWord2 uint32 = 0x57a4501d
  42  	halfOrderWord3 uint32 = 0x5d576e73
  43  	halfOrderWord4 uint32 = 0xffffffff
  44  	halfOrderWord5 uint32 = 0xffffffff
  45  	halfOrderWord6 uint32 = 0xffffffff
  46  	halfOrderWord7 uint32 = 0x7fffffff
  47  
  48  	uint32Mask = 0xffffffff
  49  )
  50  
  51  // Scalar element constants
  52  var (
  53  	ScalarZero = Scalar{n: [8]uint32{0, 0, 0, 0, 0, 0, 0, 0}}
  54  	ScalarOne  = Scalar{n: [8]uint32{1, 0, 0, 0, 0, 0, 0, 0}}
  55  
  56  	// GLV constants in 8x32 representation
  57  	scalarLambda = Scalar{
  58  		n: [8]uint32{
  59  			0x1b23bd72, 0xdf02967c, 0x20816678, 0x122e22ea,
  60  			0x8812645a, 0xa5261c02, 0xc05c30e0, 0x5363ad4c,
  61  		},
  62  	}
  63  
  64  	scalarMinusB1 = Scalar{
  65  		n: [8]uint32{
  66  			0x0abfe4c3, 0x6f547fa9, 0x010e8828, 0xe4437ed6,
  67  			0x00000000, 0x00000000, 0x00000000, 0x00000000,
  68  		},
  69  	}
  70  
  71  	scalarMinusB2 = Scalar{
  72  		n: [8]uint32{
  73  			0x3db1562c, 0xd765cda8, 0x0774346d, 0x8a280ac5,
  74  			0xfffffffe, 0xffffffff, 0xffffffff, 0xffffffff,
  75  		},
  76  	}
  77  
  78  	scalarG1 = Scalar{
  79  		n: [8]uint32{
  80  			0x45dbb031, 0xe893209a, 0x71e8ca7f, 0x3daa8a14,
  81  			0x9284eb15, 0xe86c90e4, 0xa7d46bcd, 0x3086d221,
  82  		},
  83  	}
  84  
  85  	scalarG2 = Scalar{
  86  		n: [8]uint32{
  87  			0x8ac47f71, 0x1571b4ae, 0x9df506c6, 0x221208ac,
  88  			0x0abfe4c4, 0x6f547fa9, 0x010e8828, 0xe4437ed6,
  89  		},
  90  	}
  91  )
  92  
  93  // setInt sets a scalar to a small integer value
  94  func (s *Scalar) setInt(v uint) {
  95  	s.n[0] = uint32(v)
  96  	for i := 1; i < 8; i++ {
  97  		s.n[i] = 0
  98  	}
  99  }
 100  
 101  // setB32 sets a scalar from a 32-byte big-endian array
 102  func (s *Scalar) setB32(b []byte) bool {
 103  	if len(b) != 32 {
 104  		panic("scalar byte array must be 32 bytes")
 105  	}
 106  
 107  	s.n[0] = uint32(b[31]) | uint32(b[30])<<8 | uint32(b[29])<<16 | uint32(b[28])<<24
 108  	s.n[1] = uint32(b[27]) | uint32(b[26])<<8 | uint32(b[25])<<16 | uint32(b[24])<<24
 109  	s.n[2] = uint32(b[23]) | uint32(b[22])<<8 | uint32(b[21])<<16 | uint32(b[20])<<24
 110  	s.n[3] = uint32(b[19]) | uint32(b[18])<<8 | uint32(b[17])<<16 | uint32(b[16])<<24
 111  	s.n[4] = uint32(b[15]) | uint32(b[14])<<8 | uint32(b[13])<<16 | uint32(b[12])<<24
 112  	s.n[5] = uint32(b[11]) | uint32(b[10])<<8 | uint32(b[9])<<16 | uint32(b[8])<<24
 113  	s.n[6] = uint32(b[7]) | uint32(b[6])<<8 | uint32(b[5])<<16 | uint32(b[4])<<24
 114  	s.n[7] = uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24
 115  
 116  	overflow := s.overflows()
 117  	s.reduce256(overflow)
 118  	return overflow != 0
 119  }
 120  
 121  // setB32Seckey sets a scalar from a 32-byte secret key, returns true if valid
 122  func (s *Scalar) setB32Seckey(b []byte) bool {
 123  	overflow := s.setB32(b)
 124  	return !s.isZero() && !overflow
 125  }
 126  
 127  // getB32 converts a scalar to a 32-byte big-endian array
 128  func (s *Scalar) getB32(b []byte) {
 129  	if len(b) != 32 {
 130  		panic("scalar byte array must be 32 bytes")
 131  	}
 132  
 133  	b[31] = byte(s.n[0])
 134  	b[30] = byte(s.n[0] >> 8)
 135  	b[29] = byte(s.n[0] >> 16)
 136  	b[28] = byte(s.n[0] >> 24)
 137  	b[27] = byte(s.n[1])
 138  	b[26] = byte(s.n[1] >> 8)
 139  	b[25] = byte(s.n[1] >> 16)
 140  	b[24] = byte(s.n[1] >> 24)
 141  	b[23] = byte(s.n[2])
 142  	b[22] = byte(s.n[2] >> 8)
 143  	b[21] = byte(s.n[2] >> 16)
 144  	b[20] = byte(s.n[2] >> 24)
 145  	b[19] = byte(s.n[3])
 146  	b[18] = byte(s.n[3] >> 8)
 147  	b[17] = byte(s.n[3] >> 16)
 148  	b[16] = byte(s.n[3] >> 24)
 149  	b[15] = byte(s.n[4])
 150  	b[14] = byte(s.n[4] >> 8)
 151  	b[13] = byte(s.n[4] >> 16)
 152  	b[12] = byte(s.n[4] >> 24)
 153  	b[11] = byte(s.n[5])
 154  	b[10] = byte(s.n[5] >> 8)
 155  	b[9] = byte(s.n[5] >> 16)
 156  	b[8] = byte(s.n[5] >> 24)
 157  	b[7] = byte(s.n[6])
 158  	b[6] = byte(s.n[6] >> 8)
 159  	b[5] = byte(s.n[6] >> 16)
 160  	b[4] = byte(s.n[6] >> 24)
 161  	b[3] = byte(s.n[7])
 162  	b[2] = byte(s.n[7] >> 8)
 163  	b[1] = byte(s.n[7] >> 16)
 164  	b[0] = byte(s.n[7] >> 24)
 165  }
 166  
 167  // overflows determines if the scalar >= order
 168  func (s *Scalar) overflows() uint32 {
 169  	highWordsEqual := constantTimeEq32(s.n[7], orderWord7)
 170  	highWordsEqual &= constantTimeEq32(s.n[6], orderWord6)
 171  	highWordsEqual &= constantTimeEq32(s.n[5], orderWord5)
 172  	overflow := highWordsEqual & constantTimeGreater32(s.n[4], orderWord4)
 173  	highWordsEqual &= constantTimeEq32(s.n[4], orderWord4)
 174  	overflow |= highWordsEqual & constantTimeGreater32(s.n[3], orderWord3)
 175  	highWordsEqual &= constantTimeEq32(s.n[3], orderWord3)
 176  	overflow |= highWordsEqual & constantTimeGreater32(s.n[2], orderWord2)
 177  	highWordsEqual &= constantTimeEq32(s.n[2], orderWord2)
 178  	overflow |= highWordsEqual & constantTimeGreater32(s.n[1], orderWord1)
 179  	highWordsEqual &= constantTimeEq32(s.n[1], orderWord1)
 180  	overflow |= highWordsEqual & constantTimeGreaterOrEq32(s.n[0], orderWord0)
 181  	return overflow
 182  }
 183  
 184  // reduce256 reduces the scalar modulo the order
 185  func (s *Scalar) reduce256(overflows uint32) {
 186  	overflows64 := uint64(overflows)
 187  	c := uint64(s.n[0]) + overflows64*uint64(orderCompWord0)
 188  	s.n[0] = uint32(c & uint32Mask)
 189  	c = (c >> 32) + uint64(s.n[1]) + overflows64*uint64(orderCompWord1)
 190  	s.n[1] = uint32(c & uint32Mask)
 191  	c = (c >> 32) + uint64(s.n[2]) + overflows64*uint64(orderCompWord2)
 192  	s.n[2] = uint32(c & uint32Mask)
 193  	c = (c >> 32) + uint64(s.n[3]) + overflows64*uint64(orderCompWord3)
 194  	s.n[3] = uint32(c & uint32Mask)
 195  	c = (c >> 32) + uint64(s.n[4]) + overflows64
 196  	s.n[4] = uint32(c & uint32Mask)
 197  	c = (c >> 32) + uint64(s.n[5])
 198  	s.n[5] = uint32(c & uint32Mask)
 199  	c = (c >> 32) + uint64(s.n[6])
 200  	s.n[6] = uint32(c & uint32Mask)
 201  	c = (c >> 32) + uint64(s.n[7])
 202  	s.n[7] = uint32(c & uint32Mask)
 203  }
 204  
 205  // checkOverflow checks if the scalar overflows
 206  func (s *Scalar) checkOverflow() bool {
 207  	return s.overflows() != 0
 208  }
 209  
 210  // reduce reduces the scalar modulo the order
 211  func (s *Scalar) reduce(overflow int) {
 212  	s.reduce256(uint32(overflow))
 213  }
 214  
 215  // add adds two scalars: r = a + b
 216  func (s *Scalar) add(a, b *Scalar) bool {
 217  	c := uint64(a.n[0]) + uint64(b.n[0])
 218  	s.n[0] = uint32(c & uint32Mask)
 219  	c = (c >> 32) + uint64(a.n[1]) + uint64(b.n[1])
 220  	s.n[1] = uint32(c & uint32Mask)
 221  	c = (c >> 32) + uint64(a.n[2]) + uint64(b.n[2])
 222  	s.n[2] = uint32(c & uint32Mask)
 223  	c = (c >> 32) + uint64(a.n[3]) + uint64(b.n[3])
 224  	s.n[3] = uint32(c & uint32Mask)
 225  	c = (c >> 32) + uint64(a.n[4]) + uint64(b.n[4])
 226  	s.n[4] = uint32(c & uint32Mask)
 227  	c = (c >> 32) + uint64(a.n[5]) + uint64(b.n[5])
 228  	s.n[5] = uint32(c & uint32Mask)
 229  	c = (c >> 32) + uint64(a.n[6]) + uint64(b.n[6])
 230  	s.n[6] = uint32(c & uint32Mask)
 231  	c = (c >> 32) + uint64(a.n[7]) + uint64(b.n[7])
 232  	s.n[7] = uint32(c & uint32Mask)
 233  
 234  	s.reduce256(uint32(c>>32) + s.overflows())
 235  	return false
 236  }
 237  
 238  // addPureGo is an alias for add in 32-bit mode
 239  func (s *Scalar) addPureGo(a, b *Scalar) bool {
 240  	return s.add(a, b)
 241  }
 242  
 243  // sub subtracts two scalars: r = a - b
 244  func (s *Scalar) sub(a, b *Scalar) {
 245  	var negB Scalar
 246  	negB.negate(b)
 247  	s.add(a, &negB)
 248  }
 249  
 250  // subPureGo is an alias for sub in 32-bit mode
 251  func (s *Scalar) subPureGo(a, b *Scalar) {
 252  	s.sub(a, b)
 253  }
 254  
 255  // negate negates a scalar
 256  func (s *Scalar) negate(a *Scalar) {
 257  	bits := a.n[0] | a.n[1] | a.n[2] | a.n[3] | a.n[4] | a.n[5] | a.n[6] | a.n[7]
 258  	mask := uint64(uint32Mask * constantTimeNotEq32(bits, 0))
 259  	c := uint64(orderWord0) + (uint64(^a.n[0]) + 1)
 260  	s.n[0] = uint32(c & mask)
 261  	c = (c >> 32) + uint64(orderWord1) + uint64(^a.n[1])
 262  	s.n[1] = uint32(c & mask)
 263  	c = (c >> 32) + uint64(orderWord2) + uint64(^a.n[2])
 264  	s.n[2] = uint32(c & mask)
 265  	c = (c >> 32) + uint64(orderWord3) + uint64(^a.n[3])
 266  	s.n[3] = uint32(c & mask)
 267  	c = (c >> 32) + uint64(orderWord4) + uint64(^a.n[4])
 268  	s.n[4] = uint32(c & mask)
 269  	c = (c >> 32) + uint64(orderWord5) + uint64(^a.n[5])
 270  	s.n[5] = uint32(c & mask)
 271  	c = (c >> 32) + uint64(orderWord6) + uint64(^a.n[6])
 272  	s.n[6] = uint32(c & mask)
 273  	c = (c >> 32) + uint64(orderWord7) + uint64(^a.n[7])
 274  	s.n[7] = uint32(c & mask)
 275  }
 276  
 277  // mul multiplies two scalars: r = a * b
 278  func (s *Scalar) mul(a, b *Scalar) {
 279  	s.mulPureGo(a, b)
 280  }
 281  
 282  // mulPureGo performs multiplication using 32-bit arithmetic
 283  func (s *Scalar) mulPureGo(a, b *Scalar) {
 284  	// Compute 512-bit product then reduce
 285  	var l [16]uint64
 286  
 287  	// Full 512-bit multiplication (using 64-bit intermediates for 32x32->64)
 288  	for i := 0; i < 8; i++ {
 289  		var c uint64
 290  		for j := 0; j < 8; j++ {
 291  			c += l[i+j] + uint64(a.n[i])*uint64(b.n[j])
 292  			l[i+j] = c & uint32Mask
 293  			c >>= 32
 294  		}
 295  		l[i+8] = c
 296  	}
 297  
 298  	// Reduce 512 bits to 256 bits modulo order
 299  	s.reduce512_32(l[:])
 300  }
 301  
 302  // reduce512_32 reduces a 512-bit value modulo the order (32-bit version)
 303  func (s *Scalar) reduce512_32(l []uint64) {
 304  	// First reduction: 512 -> 385 bits
 305  	var m [13]uint64
 306  	var c uint64
 307  
 308  	c = l[0] + l[8]*uint64(orderCompWord0)
 309  	m[0] = c & uint32Mask
 310  	c >>= 32
 311  	c += l[1] + l[8]*uint64(orderCompWord1) + l[9]*uint64(orderCompWord0)
 312  	m[1] = c & uint32Mask
 313  	c >>= 32
 314  	c += l[2] + l[8]*uint64(orderCompWord2) + l[9]*uint64(orderCompWord1) + l[10]*uint64(orderCompWord0)
 315  	m[2] = c & uint32Mask
 316  	c >>= 32
 317  	c += l[3] + l[8]*uint64(orderCompWord3) + l[9]*uint64(orderCompWord2) + l[10]*uint64(orderCompWord1) + l[11]*uint64(orderCompWord0)
 318  	m[3] = c & uint32Mask
 319  	c >>= 32
 320  	c += l[4] + l[8] + l[9]*uint64(orderCompWord3) + l[10]*uint64(orderCompWord2) + l[11]*uint64(orderCompWord1) + l[12]*uint64(orderCompWord0)
 321  	m[4] = c & uint32Mask
 322  	c >>= 32
 323  	c += l[5] + l[9] + l[10]*uint64(orderCompWord3) + l[11]*uint64(orderCompWord2) + l[12]*uint64(orderCompWord1) + l[13]*uint64(orderCompWord0)
 324  	m[5] = c & uint32Mask
 325  	c >>= 32
 326  	c += l[6] + l[10] + l[11]*uint64(orderCompWord3) + l[12]*uint64(orderCompWord2) + l[13]*uint64(orderCompWord1) + l[14]*uint64(orderCompWord0)
 327  	m[6] = c & uint32Mask
 328  	c >>= 32
 329  	c += l[7] + l[11] + l[12]*uint64(orderCompWord3) + l[13]*uint64(orderCompWord2) + l[14]*uint64(orderCompWord1) + l[15]*uint64(orderCompWord0)
 330  	m[7] = c & uint32Mask
 331  	c >>= 32
 332  	c += l[12] + l[13]*uint64(orderCompWord3) + l[14]*uint64(orderCompWord2) + l[15]*uint64(orderCompWord1)
 333  	m[8] = c & uint32Mask
 334  	c >>= 32
 335  	c += l[13] + l[14]*uint64(orderCompWord3) + l[15]*uint64(orderCompWord2)
 336  	m[9] = c & uint32Mask
 337  	c >>= 32
 338  	c += l[14] + l[15]*uint64(orderCompWord3)
 339  	m[10] = c & uint32Mask
 340  	c >>= 32
 341  	c += l[15]
 342  	m[11] = c & uint32Mask
 343  	c >>= 32
 344  	m[12] = c
 345  
 346  	// Second reduction: 385 -> 258 bits
 347  	var p [9]uint64
 348  	c = m[0] + m[8]*uint64(orderCompWord0)
 349  	p[0] = c & uint32Mask
 350  	c >>= 32
 351  	c += m[1] + m[8]*uint64(orderCompWord1) + m[9]*uint64(orderCompWord0)
 352  	p[1] = c & uint32Mask
 353  	c >>= 32
 354  	c += m[2] + m[8]*uint64(orderCompWord2) + m[9]*uint64(orderCompWord1) + m[10]*uint64(orderCompWord0)
 355  	p[2] = c & uint32Mask
 356  	c >>= 32
 357  	c += m[3] + m[8]*uint64(orderCompWord3) + m[9]*uint64(orderCompWord2) + m[10]*uint64(orderCompWord1) + m[11]*uint64(orderCompWord0)
 358  	p[3] = c & uint32Mask
 359  	c >>= 32
 360  	c += m[4] + m[8] + m[9]*uint64(orderCompWord3) + m[10]*uint64(orderCompWord2) + m[11]*uint64(orderCompWord1) + m[12]*uint64(orderCompWord0)
 361  	p[4] = c & uint32Mask
 362  	c >>= 32
 363  	c += m[5] + m[9] + m[10]*uint64(orderCompWord3) + m[11]*uint64(orderCompWord2) + m[12]*uint64(orderCompWord1)
 364  	p[5] = c & uint32Mask
 365  	c >>= 32
 366  	c += m[6] + m[10] + m[11]*uint64(orderCompWord3) + m[12]*uint64(orderCompWord2)
 367  	p[6] = c & uint32Mask
 368  	c >>= 32
 369  	c += m[7] + m[11] + m[12]*uint64(orderCompWord3)
 370  	p[7] = c & uint32Mask
 371  	c >>= 32
 372  	p[8] = c + m[12]
 373  
 374  	// Final reduction: 258 -> 256 bits
 375  	c = p[0] + p[8]*uint64(orderCompWord0)
 376  	s.n[0] = uint32(c & uint32Mask)
 377  	c >>= 32
 378  	c += p[1] + p[8]*uint64(orderCompWord1)
 379  	s.n[1] = uint32(c & uint32Mask)
 380  	c >>= 32
 381  	c += p[2] + p[8]*uint64(orderCompWord2)
 382  	s.n[2] = uint32(c & uint32Mask)
 383  	c >>= 32
 384  	c += p[3] + p[8]*uint64(orderCompWord3)
 385  	s.n[3] = uint32(c & uint32Mask)
 386  	c >>= 32
 387  	c += p[4] + p[8]
 388  	s.n[4] = uint32(c & uint32Mask)
 389  	c >>= 32
 390  	c += p[5]
 391  	s.n[5] = uint32(c & uint32Mask)
 392  	c >>= 32
 393  	c += p[6]
 394  	s.n[6] = uint32(c & uint32Mask)
 395  	c >>= 32
 396  	c += p[7]
 397  	s.n[7] = uint32(c & uint32Mask)
 398  
 399  	s.reduce256(uint32(c>>32) + s.overflows())
 400  }
 401  
 402  // inverse computes the modular inverse
 403  func (s *Scalar) inverse(a *Scalar) {
 404  	// Use Fermat's little theorem: a^(-1) = a^(n-2) mod n
 405  	var exp Scalar
 406  	exp.n[0] = orderWord0 - 2
 407  	exp.n[1] = orderWord1
 408  	exp.n[2] = orderWord2
 409  	exp.n[3] = orderWord3
 410  	exp.n[4] = orderWord4
 411  	exp.n[5] = orderWord5
 412  	exp.n[6] = orderWord6
 413  	exp.n[7] = orderWord7
 414  
 415  	s.exp(a, &exp)
 416  }
 417  
 418  // exp computes s = a^b mod n
 419  func (s *Scalar) exp(a, b *Scalar) {
 420  	*s = ScalarOne
 421  	base := *a
 422  
 423  	for i := 0; i < 8; i++ {
 424  		limb := b.n[i]
 425  		for j := 0; j < 32; j++ {
 426  			if limb&1 != 0 {
 427  				s.mul(s, &base)
 428  			}
 429  			base.mul(&base, &base)
 430  			limb >>= 1
 431  		}
 432  	}
 433  }
 434  
 435  // half computes s = a/2 mod n
 436  func (s *Scalar) half(a *Scalar) {
 437  	*s = *a
 438  	if s.n[0]&1 == 0 {
 439  		// Even: simple right shift
 440  		for i := 0; i < 7; i++ {
 441  			s.n[i] = (s.n[i] >> 1) | ((s.n[i+1] & 1) << 31)
 442  		}
 443  		s.n[7] >>= 1
 444  	} else {
 445  		// Odd: add n then divide by 2
 446  		var c uint64
 447  		c = uint64(s.n[0]) + uint64(orderWord0)
 448  		s.n[0] = uint32(c)
 449  		c = (c >> 32) + uint64(s.n[1]) + uint64(orderWord1)
 450  		s.n[1] = uint32(c)
 451  		c = (c >> 32) + uint64(s.n[2]) + uint64(orderWord2)
 452  		s.n[2] = uint32(c)
 453  		c = (c >> 32) + uint64(s.n[3]) + uint64(orderWord3)
 454  		s.n[3] = uint32(c)
 455  		c = (c >> 32) + uint64(s.n[4]) + uint64(orderWord4)
 456  		s.n[4] = uint32(c)
 457  		c = (c >> 32) + uint64(s.n[5]) + uint64(orderWord5)
 458  		s.n[5] = uint32(c)
 459  		c = (c >> 32) + uint64(s.n[6]) + uint64(orderWord6)
 460  		s.n[6] = uint32(c)
 461  		c = (c >> 32) + uint64(s.n[7]) + uint64(orderWord7)
 462  		s.n[7] = uint32(c)
 463  
 464  		// Divide by 2
 465  		for i := 0; i < 7; i++ {
 466  			s.n[i] = (s.n[i] >> 1) | ((s.n[i+1] & 1) << 31)
 467  		}
 468  		s.n[7] >>= 1
 469  	}
 470  }
 471  
 472  // isZero returns true if the scalar is zero
 473  func (s *Scalar) isZero() bool {
 474  	bits := s.n[0] | s.n[1] | s.n[2] | s.n[3] | s.n[4] | s.n[5] | s.n[6] | s.n[7]
 475  	return bits == 0
 476  }
 477  
 478  // isOne returns true if the scalar is one
 479  func (s *Scalar) isOne() bool {
 480  	return s.n[0] == 1 && s.n[1] == 0 && s.n[2] == 0 && s.n[3] == 0 &&
 481  		s.n[4] == 0 && s.n[5] == 0 && s.n[6] == 0 && s.n[7] == 0
 482  }
 483  
 484  // isEven returns true if the scalar is even
 485  func (s *Scalar) isEven() bool {
 486  	return s.n[0]&1 == 0
 487  }
 488  
 489  // isHigh returns true if the scalar is > n/2
 490  func (s *Scalar) isHigh() bool {
 491  	result := constantTimeGreater32(s.n[7], halfOrderWord7)
 492  	highWordsEqual := constantTimeEq32(s.n[7], halfOrderWord7)
 493  	highWordsEqual &= constantTimeEq32(s.n[6], halfOrderWord6)
 494  	highWordsEqual &= constantTimeEq32(s.n[5], halfOrderWord5)
 495  	highWordsEqual &= constantTimeEq32(s.n[4], halfOrderWord4)
 496  	result |= highWordsEqual & constantTimeGreater32(s.n[3], halfOrderWord3)
 497  	highWordsEqual &= constantTimeEq32(s.n[3], halfOrderWord3)
 498  	result |= highWordsEqual & constantTimeGreater32(s.n[2], halfOrderWord2)
 499  	highWordsEqual &= constantTimeEq32(s.n[2], halfOrderWord2)
 500  	result |= highWordsEqual & constantTimeGreater32(s.n[1], halfOrderWord1)
 501  	highWordsEqual &= constantTimeEq32(s.n[1], halfOrderWord1)
 502  	result |= highWordsEqual & constantTimeGreater32(s.n[0], halfOrderWord0)
 503  	return result != 0
 504  }
 505  
 506  // condNegate conditionally negates the scalar
 507  func (s *Scalar) condNegate(flag int) {
 508  	if flag != 0 {
 509  		var neg Scalar
 510  		neg.negate(s)
 511  		*s = neg
 512  	}
 513  }
 514  
 515  // equal returns true if two scalars are equal
 516  func (s *Scalar) equal(a *Scalar) bool {
 517  	return subtle.ConstantTimeCompare(
 518  		(*[32]byte)(unsafe.Pointer(&s.n[0]))[:32],
 519  		(*[32]byte)(unsafe.Pointer(&a.n[0]))[:32],
 520  	) == 1
 521  }
 522  
 523  // getBits extracts count bits starting at offset
 524  func (s *Scalar) getBits(offset, count uint) uint32 {
 525  	if count == 0 || count > 32 {
 526  		panic("count must be 1-32")
 527  	}
 528  	if offset+count > 256 {
 529  		panic("offset + count must be <= 256")
 530  	}
 531  
 532  	limbIdx := offset / 32
 533  	bitIdx := offset % 32
 534  
 535  	if bitIdx+count <= 32 {
 536  		return (s.n[limbIdx] >> bitIdx) & ((1 << count) - 1)
 537  	}
 538  	lowBits := 32 - bitIdx
 539  	highBits := count - lowBits
 540  	low := (s.n[limbIdx] >> bitIdx) & ((1 << lowBits) - 1)
 541  	high := s.n[limbIdx+1] & ((1 << highBits) - 1)
 542  	return low | (high << lowBits)
 543  }
 544  
 545  // cmov conditionally moves a scalar
 546  func (s *Scalar) cmov(a *Scalar, flag int) {
 547  	mask := uint32(-(int32(flag) & 1))
 548  	for i := 0; i < 8; i++ {
 549  		s.n[i] ^= mask & (s.n[i] ^ a.n[i])
 550  	}
 551  }
 552  
 553  // clear clears a scalar
 554  func (s *Scalar) clear() {
 555  	for i := 0; i < 8; i++ {
 556  		s.n[i] = 0
 557  	}
 558  }
 559  
 560  // wNAF converts a scalar to wNAF representation
 561  func (s *Scalar) wNAF(wnaf *[257]int8, w uint) int {
 562  	if w < 2 || w > 8 {
 563  		panic("w must be between 2 and 8")
 564  	}
 565  
 566  	var k Scalar
 567  	k = *s
 568  
 569  	numBits := 0
 570  	var carry uint32
 571  
 572  	*wnaf = [257]int8{}
 573  
 574  	bit := 0
 575  	for bit < 256 {
 576  		if k.getBits(uint(bit), 1) == carry {
 577  			bit++
 578  			continue
 579  		}
 580  
 581  		window := w
 582  		if bit+int(window) > 256 {
 583  			window = uint(256 - bit)
 584  		}
 585  
 586  		word := k.getBits(uint(bit), window) + carry
 587  		carry = (word >> (window - 1)) & 1
 588  		word -= carry << window
 589  
 590  		wnaf[bit] = int8(int32(word))
 591  		numBits = bit + int(window) - 1
 592  
 593  		bit += int(window)
 594  	}
 595  
 596  	if carry != 0 {
 597  		wnaf[256] = int8(carry)
 598  		numBits = 256
 599  	}
 600  
 601  	return numBits + 1
 602  }
 603  
 604  // wNAFSigned converts a scalar to wNAF representation with sign handling
 605  func (s *Scalar) wNAFSigned(wnaf *[257]int8, w uint) (int, bool) {
 606  	if w < 2 || w > 8 {
 607  		panic("w must be between 2 and 8")
 608  	}
 609  
 610  	var k Scalar
 611  	k = *s
 612  
 613  	negated := false
 614  	if k.getBits(255, 1) == 1 {
 615  		k.negate(&k)
 616  		negated = true
 617  	}
 618  
 619  	bits := k.wNAF(wnaf, w)
 620  	return bits, negated
 621  }
 622  
 623  // caddBit conditionally adds a power of 2
 624  func (s *Scalar) caddBit(bit uint, flag int) {
 625  	if flag == 0 {
 626  		return
 627  	}
 628  
 629  	limbIdx := bit >> 5
 630  	bitIdx := bit & 0x1F
 631  	addVal := uint32(1) << bitIdx
 632  
 633  	var c uint64
 634  	for i := limbIdx; i < 8; i++ {
 635  		if i == limbIdx {
 636  			c = uint64(s.n[i]) + uint64(addVal)
 637  		} else {
 638  			c = uint64(s.n[i]) + (c >> 32)
 639  		}
 640  		s.n[i] = uint32(c)
 641  		if c>>32 == 0 {
 642  			break
 643  		}
 644  	}
 645  }
 646  
 647  // mulShiftVar computes r = round((a * b) >> shift)
 648  func (s *Scalar) mulShiftVar(a, b *Scalar, shift uint) {
 649  	if shift < 256 {
 650  		panic("mulShiftVar requires shift >= 256")
 651  	}
 652  
 653  	// Compute full 512-bit product
 654  	var l [16]uint64
 655  	for i := 0; i < 8; i++ {
 656  		var c uint64
 657  		for j := 0; j < 8; j++ {
 658  			c += l[i+j] + uint64(a.n[i])*uint64(b.n[j])
 659  			l[i+j] = c & uint32Mask
 660  			c >>= 32
 661  		}
 662  		l[i+8] = c
 663  	}
 664  
 665  	// Extract bits [shift, shift+256)
 666  	shiftLimbs := shift >> 5
 667  	shiftLow := shift & 0x1F
 668  	shiftHigh := 32 - shiftLow
 669  
 670  	for i := 0; i < 8; i++ {
 671  		srcIdx := shiftLimbs + uint(i)
 672  		if srcIdx < 16 {
 673  			if shiftLow != 0 && srcIdx+1 < 16 {
 674  				s.n[i] = uint32((l[srcIdx] >> shiftLow) | (l[srcIdx+1] << shiftHigh))
 675  			} else {
 676  				s.n[i] = uint32(l[srcIdx] >> shiftLow)
 677  			}
 678  		} else {
 679  			s.n[i] = 0
 680  		}
 681  	}
 682  
 683  	// Round by adding bit just below shift
 684  	roundBit := int((l[(shift-1)>>5] >> ((shift - 1) & 0x1F)) & 1)
 685  	s.caddBit(0, roundBit)
 686  }
 687  
 688  // Constant-time helper functions
 689  func constantTimeNotEq32(a, b uint32) uint32 {
 690  	return ^constantTimeEq32(a, b) & 1
 691  }
 692  
 693  func constantTimeGreaterOrEq32(a, b uint32) uint32 {
 694  	return uint32((uint64(a) - uint64(b) - 1) >> 63) ^ 1
 695  }
 696  
 697  // scalarSplitLambda decomposes k into k1, k2 for GLV
 698  func scalarSplitLambda(r1, r2, k *Scalar) {
 699  	var c1, c2 Scalar
 700  
 701  	c1.mulShiftVar(k, &scalarG1, 384)
 702  	c2.mulShiftVar(k, &scalarG2, 384)
 703  
 704  	c1.mul(&c1, &scalarMinusB1)
 705  	c2.mul(&c2, &scalarMinusB2)
 706  
 707  	r2.add(&c1, &c2)
 708  	r1.mul(r2, &scalarLambda)
 709  	r1.negate(r1)
 710  	r1.add(r1, k)
 711  }
 712  
 713  // scalarSplit128 splits a scalar into two 128-bit halves
 714  func scalarSplit128(r1, r2, k *Scalar) {
 715  	r1.n[0] = k.n[0]
 716  	r1.n[1] = k.n[1]
 717  	r1.n[2] = k.n[2]
 718  	r1.n[3] = k.n[3]
 719  	r1.n[4] = 0
 720  	r1.n[5] = 0
 721  	r1.n[6] = 0
 722  	r1.n[7] = 0
 723  
 724  	r2.n[0] = k.n[4]
 725  	r2.n[1] = k.n[5]
 726  	r2.n[2] = k.n[6]
 727  	r2.n[3] = k.n[7]
 728  	r2.n[4] = 0
 729  	r2.n[5] = 0
 730  	r2.n[6] = 0
 731  	r2.n[7] = 0
 732  }
 733  
 734  // Direct function versions for compatibility
 735  func scalarAdd(r, a, b *Scalar) bool  { return r.add(a, b) }
 736  func scalarMul(r, a, b *Scalar)       { r.mul(a, b) }
 737  func scalarGetB32(bin []byte, a *Scalar) { a.getB32(bin) }
 738  func scalarIsZero(a *Scalar) bool     { return a.isZero() }
 739  func scalarCheckOverflow(r *Scalar) bool { return r.checkOverflow() }
 740  func scalarReduce(r *Scalar, overflow int) { r.reduce(overflow) }
 741  
 742  // Stubs for AVX2 functions (not available on WASM) - these forward to pure Go
 743  func scalarAddAVX2(r, a, b *Scalar) { r.add(a, b) }
 744  func scalarSubAVX2(r, a, b *Scalar) { r.sub(a, b) }
 745  func scalarMulAVX2(r, a, b *Scalar) { r.mul(a, b) }
 746  
 747  // Compatibility constants for verify.go (these map to our 8x32 representation)
 748  const (
 749  	scalarN0 = uint64(orderWord0) | uint64(orderWord1)<<32
 750  	scalarN1 = uint64(orderWord2) | uint64(orderWord3)<<32
 751  	scalarN2 = uint64(orderWord4) | uint64(orderWord5)<<32
 752  	scalarN3 = uint64(orderWord6) | uint64(orderWord7)<<32
 753  )
 754  
 755  // d returns the scalar limbs as a 4-element uint64 array for compatibility
 756  // This converts from 8x32 to 4x64 representation
 757  func (s *Scalar) d() [4]uint64 {
 758  	return [4]uint64{
 759  		uint64(s.n[0]) | uint64(s.n[1])<<32,
 760  		uint64(s.n[2]) | uint64(s.n[3])<<32,
 761  		uint64(s.n[4]) | uint64(s.n[5])<<32,
 762  		uint64(s.n[6]) | uint64(s.n[7])<<32,
 763  	}
 764  }
 765