nat.mx raw

   1  // Copyright 2009 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // This file implements unsigned multi-precision integers (natural
   6  // numbers). They are the building blocks for the implementation
   7  // of signed integers, rationals, and floating-point numbers.
   8  //
   9  // Caution: This implementation relies on the function "alias"
  10  //          which assumes that (nat) slice capacities are never
  11  //          changed (no 3-operand slice expressions). If that
  12  //          changes, alias needs to be updated for correctness.
  13  
  14  package big
  15  
  16  import (
  17  	"internal/byteorder"
  18  	"math/bits"
  19  	"math/rand"
  20  	"slices"
  21  	"sync"
  22  )
  23  
  24  // An unsigned integer x of the form
  25  //
  26  //	x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0]
  27  //
  28  // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n,
  29  // with the digits x[i] as the slice elements.
  30  //
  31  // A number is normalized if the slice contains no leading 0 digits.
  32  // During arithmetic operations, denormalized values may occur but are
  33  // always normalized before returning the final result. The normalized
  34  // representation of 0 is the empty or nil slice (length = 0).
  35  type nat []Word
  36  
  37  var (
  38  	natOne  = nat{1}
  39  	natTwo  = nat{2}
  40  	natFive = nat{5}
  41  	natTen  = nat{10}
  42  )
  43  
  44  func (z nat) String() string {
  45  	return "0x" + string(z.itoa(false, 16))
  46  }
  47  
  48  func (z nat) norm() nat {
  49  	i := len(z)
  50  	for i > 0 && z[i-1] == 0 {
  51  		i--
  52  	}
  53  	return z[0:i]
  54  }
  55  
  56  func (z nat) make(n int) nat {
  57  	if n <= cap(z) {
  58  		return z[:n] // reuse z
  59  	}
  60  	if n == 1 {
  61  		// Most nats start small and stay that way; don't over-allocate.
  62  		return make(nat, 1)
  63  	}
  64  	// Choosing a good value for e has significant performance impact
  65  	// because it increases the chance that a value can be reused.
  66  	const e = 4 // extra capacity
  67  	return make(nat, n, n+e)
  68  }
  69  
  70  func (z nat) setWord(x Word) nat {
  71  	if x == 0 {
  72  		return z[:0]
  73  	}
  74  	z = z.make(1)
  75  	z[0] = x
  76  	return z
  77  }
  78  
  79  func (z nat) setUint64(x uint64) nat {
  80  	// single-word value
  81  	if w := Word(x); uint64(w) == x {
  82  		return z.setWord(w)
  83  	}
  84  	// 2-word value
  85  	z = z.make(2)
  86  	z[1] = Word(x >> 32)
  87  	z[0] = Word(x)
  88  	return z
  89  }
  90  
  91  func (z nat) set(x nat) nat {
  92  	z = z.make(len(x))
  93  	copy(z, x)
  94  	return z
  95  }
  96  
  97  func (z nat) add(x, y nat) nat {
  98  	m := len(x)
  99  	n := len(y)
 100  
 101  	switch {
 102  	case m < n:
 103  		return z.add(y, x)
 104  	case m == 0:
 105  		// n == 0 because m >= n; result is 0
 106  		return z[:0]
 107  	case n == 0:
 108  		// result is x
 109  		return z.set(x)
 110  	}
 111  	// m > 0
 112  
 113  	z = z.make(m + 1)
 114  	c := addVV(z[:n], x[:n], y[:n])
 115  	if m > n {
 116  		c = addVW(z[n:m], x[n:], c)
 117  	}
 118  	z[m] = c
 119  
 120  	return z.norm()
 121  }
 122  
 123  func (z nat) sub(x, y nat) nat {
 124  	m := len(x)
 125  	n := len(y)
 126  
 127  	switch {
 128  	case m < n:
 129  		panic("underflow")
 130  	case m == 0:
 131  		// n == 0 because m >= n; result is 0
 132  		return z[:0]
 133  	case n == 0:
 134  		// result is x
 135  		return z.set(x)
 136  	}
 137  	// m > 0
 138  
 139  	z = z.make(m)
 140  	c := subVV(z[:n], x[:n], y[:n])
 141  	if m > n {
 142  		c = subVW(z[n:], x[n:], c)
 143  	}
 144  	if c != 0 {
 145  		panic("underflow")
 146  	}
 147  
 148  	return z.norm()
 149  }
 150  
 151  func (x nat) cmp(y nat) (r int) {
 152  	m := len(x)
 153  	n := len(y)
 154  	if m != n || m == 0 {
 155  		switch {
 156  		case m < n:
 157  			r = -1
 158  		case m > n:
 159  			r = 1
 160  		}
 161  		return
 162  	}
 163  
 164  	i := m - 1
 165  	for i > 0 && x[i] == y[i] {
 166  		i--
 167  	}
 168  
 169  	switch {
 170  	case x[i] < y[i]:
 171  		r = -1
 172  	case x[i] > y[i]:
 173  		r = 1
 174  	}
 175  	return
 176  }
 177  
 178  // montgomery computes z mod m = x*y*2**(-n*_W) mod m,
 179  // assuming k = -1/m mod 2**_W.
 180  // z is used for storing the result which is returned;
 181  // z must not alias x, y or m.
 182  // See Gueron, "Efficient Software Implementations of Modular Exponentiation".
 183  // https://eprint.iacr.org/2011/239.pdf
 184  // In the terminology of that paper, this is an "Almost Montgomery Multiplication":
 185  // x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result
 186  // z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m.
 187  func (z nat) montgomery(x, y, m nat, k Word, n int) nat {
 188  	// This code assumes x, y, m are all the same length, n.
 189  	// (required by addMulVVW and the for loop).
 190  	// It also assumes that x, y are already reduced mod m,
 191  	// or else the result will not be properly reduced.
 192  	if len(x) != n || len(y) != n || len(m) != n {
 193  		panic("math/big: mismatched montgomery number lengths")
 194  	}
 195  	z = z.make(n * 2)
 196  	clear(z)
 197  	var c Word
 198  	for i := 0; i < n; i++ {
 199  		d := y[i]
 200  		c2 := addMulVVWW(z[i:n+i], z[i:n+i], x, d, 0)
 201  		t := z[i] * k
 202  		c3 := addMulVVWW(z[i:n+i], z[i:n+i], m, t, 0)
 203  		cx := c + c2
 204  		cy := cx + c3
 205  		z[n+i] = cy
 206  		if cx < c2 || cy < c3 {
 207  			c = 1
 208  		} else {
 209  			c = 0
 210  		}
 211  	}
 212  	if c != 0 {
 213  		subVV(z[:n], z[n:], m)
 214  	} else {
 215  		copy(z[:n], z[n:])
 216  	}
 217  	return z[:n]
 218  }
 219  
 220  // alias reports whether x and y share the same base array.
 221  //
 222  // Note: alias assumes that the capacity of underlying arrays
 223  // is never changed for nat values; i.e. that there are
 224  // no 3-operand slice expressions in this code (or worse,
 225  // reflect-based operations to the same effect).
 226  func alias(x, y nat) bool {
 227  	return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1]
 228  }
 229  
 230  // addTo implements z += x; z must be long enough.
 231  // (we don't use nat.add because we need z to stay the same
 232  // slice, and we don't need to normalize z after each addition)
 233  func addTo(z, x nat) {
 234  	if n := len(x); n > 0 {
 235  		if c := addVV(z[:n], z[:n], x[:n]); c != 0 {
 236  			if n < len(z) {
 237  				addVW(z[n:], z[n:], c)
 238  			}
 239  		}
 240  	}
 241  }
 242  
 243  // mulRange computes the product of all the unsigned integers in the
 244  // range [a, b] inclusively. If a > b (empty range), the result is 1.
 245  // The caller may pass stk == nil to request that mulRange obtain and release one itself.
 246  func (z nat) mulRange(stk *stack, a, b uint64) nat {
 247  	switch {
 248  	case a == 0:
 249  		// cut long ranges short (optimization)
 250  		return z.setUint64(0)
 251  	case a > b:
 252  		return z.setUint64(1)
 253  	case a == b:
 254  		return z.setUint64(a)
 255  	case a+1 == b:
 256  		return z.mul(stk, nat(nil).setUint64(a), nat(nil).setUint64(b))
 257  	}
 258  
 259  	if stk == nil {
 260  		stk = getStack()
 261  		defer stk.free()
 262  	}
 263  
 264  	m := a + (b-a)/2 // avoid overflow
 265  	return z.mul(stk, nat(nil).mulRange(stk, a, m), nat(nil).mulRange(stk, m+1, b))
 266  }
 267  
 268  // A stack provides temporary storage for complex calculations
 269  // such as multiplication and division.
 270  // The stack is a simple slice of words, extended as needed
 271  // to hold all the temporary storage for a calculation.
 272  // In general, if a function takes a *stack, it expects a non-nil *stack.
 273  // However, certain functions may allow passing a nil *stack instead,
 274  // so that they can handle trivial stack-free cases without forcing the
 275  // caller to obtain and free a stack that will be unused. These functions
 276  // document that they accept a nil *stack in their doc comments.
 277  type stack struct {
 278  	w []Word
 279  }
 280  
 281  var stackPool sync.Pool
 282  
 283  // getStack returns a temporary stack.
 284  // The caller must call [stack.free] to give up use of the stack when finished.
 285  func getStack() *stack {
 286  	s, _ := stackPool.Get().(*stack)
 287  	if s == nil {
 288  		s = &stack{}
 289  	}
 290  	return s
 291  }
 292  
 293  // free returns the stack for use by another calculation.
 294  func (s *stack) free() {
 295  	s.w = s.w[:0]
 296  	stackPool.Put(s)
 297  }
 298  
 299  // save returns the current stack pointer.
 300  // A future call to restore with the same value
 301  // frees any temporaries allocated on the stack after the call to save.
 302  func (s *stack) save() int {
 303  	return len(s.w)
 304  }
 305  
 306  // restore restores the stack pointer to n.
 307  // It is almost always invoked as
 308  //
 309  //	defer stk.restore(stk.save())
 310  //
 311  // which makes sure to pop any temporaries allocated in the current function
 312  // from the stack before returning.
 313  func (s *stack) restore(n int) {
 314  	s.w = s.w[:n]
 315  }
 316  
 317  // nat returns a nat of n words, allocated on the stack.
 318  func (s *stack) nat(n int) nat {
 319  	nr := (n + 3) &^ 3 // round up to multiple of 4
 320  	off := len(s.w)
 321  	s.w = slices.Grow(s.w, nr)
 322  	s.w = s.w[:off+nr]
 323  	x := s.w[off : off+n : off+n]
 324  	if n > 0 {
 325  		x[0] = 0xfedcb
 326  	}
 327  	return x
 328  }
 329  
 330  // bitLen returns the length of x in bits.
 331  // Unlike most methods, it works even if x is not normalized.
 332  func (x nat) bitLen() int {
 333  	// This function is used in cryptographic operations. It must not leak
 334  	// anything but the Int's sign and bit size through side-channels. Any
 335  	// changes must be reviewed by a security expert.
 336  	if i := len(x) - 1; i >= 0 {
 337  		// bits.Len uses a lookup table for the low-order bits on some
 338  		// architectures. Neutralize any input-dependent behavior by setting all
 339  		// bits after the first one bit.
 340  		top := uint(x[i])
 341  		top |= top >> 1
 342  		top |= top >> 2
 343  		top |= top >> 4
 344  		top |= top >> 8
 345  		top |= top >> 16
 346  		top |= top >> 16 >> 16 // ">> 32" doesn't compile on 32-bit architectures
 347  		return i*_W + bits.Len(top)
 348  	}
 349  	return 0
 350  }
 351  
 352  // trailingZeroBits returns the number of consecutive least significant zero
 353  // bits of x.
 354  func (x nat) trailingZeroBits() uint {
 355  	if len(x) == 0 {
 356  		return 0
 357  	}
 358  	var i uint
 359  	for x[i] == 0 {
 360  		i++
 361  	}
 362  	// x[i] != 0
 363  	return i*_W + uint(bits.TrailingZeros(uint(x[i])))
 364  }
 365  
 366  // isPow2 returns i, true when x == 2**i and 0, false otherwise.
 367  func (x nat) isPow2() (uint, bool) {
 368  	var i uint
 369  	for x[i] == 0 {
 370  		i++
 371  	}
 372  	if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 {
 373  		return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true
 374  	}
 375  	return 0, false
 376  }
 377  
 378  func same(x, y nat) bool {
 379  	return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0]
 380  }
 381  
 382  // z = x << s
 383  func (z nat) lsh(x nat, s uint) nat {
 384  	if s == 0 {
 385  		if same(z, x) {
 386  			return z
 387  		}
 388  		if !alias(z, x) {
 389  			return z.set(x)
 390  		}
 391  	}
 392  
 393  	m := len(x)
 394  	if m == 0 {
 395  		return z[:0]
 396  	}
 397  	// m > 0
 398  
 399  	n := m + int(s/_W)
 400  	z = z.make(n + 1)
 401  	if s %= _W; s == 0 {
 402  		copy(z[n-m:n], x)
 403  		z[n] = 0
 404  	} else {
 405  		z[n] = lshVU(z[n-m:n], x, s)
 406  	}
 407  	clear(z[0 : n-m])
 408  
 409  	return z.norm()
 410  }
 411  
 412  // z = x >> s
 413  func (z nat) rsh(x nat, s uint) nat {
 414  	if s == 0 {
 415  		if same(z, x) {
 416  			return z
 417  		}
 418  		if !alias(z, x) {
 419  			return z.set(x)
 420  		}
 421  	}
 422  
 423  	m := len(x)
 424  	n := m - int(s/_W)
 425  	if n <= 0 {
 426  		return z[:0]
 427  	}
 428  	// n > 0
 429  
 430  	z = z.make(n)
 431  	if s %= _W; s == 0 {
 432  		copy(z, x[m-n:])
 433  	} else {
 434  		rshVU(z, x[m-n:], s)
 435  	}
 436  
 437  	return z.norm()
 438  }
 439  
 440  func (z nat) setBit(x nat, i uint, b uint) nat {
 441  	j := int(i / _W)
 442  	m := Word(1) << (i % _W)
 443  	n := len(x)
 444  	switch b {
 445  	case 0:
 446  		z = z.make(n)
 447  		copy(z, x)
 448  		if j >= n {
 449  			// no need to grow
 450  			return z
 451  		}
 452  		z[j] &^= m
 453  		return z.norm()
 454  	case 1:
 455  		if j >= n {
 456  			z = z.make(j + 1)
 457  			clear(z[n:])
 458  		} else {
 459  			z = z.make(n)
 460  		}
 461  		copy(z, x)
 462  		z[j] |= m
 463  		// no need to normalize
 464  		return z
 465  	}
 466  	panic("set bit is not 0 or 1")
 467  }
 468  
 469  // bit returns the value of the i'th bit, with lsb == bit 0.
 470  func (x nat) bit(i uint) uint {
 471  	j := i / _W
 472  	if j >= uint(len(x)) {
 473  		return 0
 474  	}
 475  	// 0 <= j < len(x)
 476  	return uint(x[j] >> (i % _W) & 1)
 477  }
 478  
 479  // sticky returns 1 if there's a 1 bit within the
 480  // i least significant bits, otherwise it returns 0.
 481  func (x nat) sticky(i uint) uint {
 482  	j := i / _W
 483  	if j >= uint(len(x)) {
 484  		if len(x) == 0 {
 485  			return 0
 486  		}
 487  		return 1
 488  	}
 489  	// 0 <= j < len(x)
 490  	for _, x := range x[:j] {
 491  		if x != 0 {
 492  			return 1
 493  		}
 494  	}
 495  	if x[j]<<(_W-i%_W) != 0 {
 496  		return 1
 497  	}
 498  	return 0
 499  }
 500  
 501  func (z nat) and(x, y nat) nat {
 502  	m := len(x)
 503  	n := len(y)
 504  	if m > n {
 505  		m = n
 506  	}
 507  	// m <= n
 508  
 509  	z = z.make(m)
 510  	for i := 0; i < m; i++ {
 511  		z[i] = x[i] & y[i]
 512  	}
 513  
 514  	return z.norm()
 515  }
 516  
 517  // trunc returns z = x mod 2ⁿ.
 518  func (z nat) trunc(x nat, n uint) nat {
 519  	w := (n + _W - 1) / _W
 520  	if uint(len(x)) < w {
 521  		return z.set(x)
 522  	}
 523  	z = z.make(int(w))
 524  	copy(z, x)
 525  	if n%_W != 0 {
 526  		z[len(z)-1] &= 1<<(n%_W) - 1
 527  	}
 528  	return z.norm()
 529  }
 530  
 531  func (z nat) andNot(x, y nat) nat {
 532  	m := len(x)
 533  	n := len(y)
 534  	if n > m {
 535  		n = m
 536  	}
 537  	// m >= n
 538  
 539  	z = z.make(m)
 540  	for i := 0; i < n; i++ {
 541  		z[i] = x[i] &^ y[i]
 542  	}
 543  	copy(z[n:m], x[n:m])
 544  
 545  	return z.norm()
 546  }
 547  
 548  func (z nat) or(x, y nat) nat {
 549  	m := len(x)
 550  	n := len(y)
 551  	s := x
 552  	if m < n {
 553  		n, m = m, n
 554  		s = y
 555  	}
 556  	// m >= n
 557  
 558  	z = z.make(m)
 559  	for i := 0; i < n; i++ {
 560  		z[i] = x[i] | y[i]
 561  	}
 562  	copy(z[n:m], s[n:m])
 563  
 564  	return z.norm()
 565  }
 566  
 567  func (z nat) xor(x, y nat) nat {
 568  	m := len(x)
 569  	n := len(y)
 570  	s := x
 571  	if m < n {
 572  		n, m = m, n
 573  		s = y
 574  	}
 575  	// m >= n
 576  
 577  	z = z.make(m)
 578  	for i := 0; i < n; i++ {
 579  		z[i] = x[i] ^ y[i]
 580  	}
 581  	copy(z[n:m], s[n:m])
 582  
 583  	return z.norm()
 584  }
 585  
 586  // random creates a random integer in [0..limit), using the space in z if
 587  // possible. n is the bit length of limit.
 588  func (z nat) random(rand *rand.Rand, limit nat, n int) nat {
 589  	if alias(z, limit) {
 590  		z = nil // z is an alias for limit - cannot reuse
 591  	}
 592  	z = z.make(len(limit))
 593  
 594  	bitLengthOfMSW := uint(n % _W)
 595  	if bitLengthOfMSW == 0 {
 596  		bitLengthOfMSW = _W
 597  	}
 598  	mask := Word((1 << bitLengthOfMSW) - 1)
 599  
 600  	for {
 601  		switch _W {
 602  		case 32:
 603  			for i := range z {
 604  				z[i] = Word(rand.Uint32())
 605  			}
 606  		case 64:
 607  			for i := range z {
 608  				z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32
 609  			}
 610  		default:
 611  			panic("unknown word size")
 612  		}
 613  		z[len(limit)-1] &= mask
 614  		if z.cmp(limit) < 0 {
 615  			break
 616  		}
 617  	}
 618  
 619  	return z.norm()
 620  }
 621  
 622  // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m;
 623  // otherwise it sets z to x**y. The result is the value of z.
 624  // The caller may pass stk == nil to request that expNN obtain and release one itself.
 625  func (z nat) expNN(stk *stack, x, y, m nat, slow bool) nat {
 626  	if alias(z, x) || alias(z, y) {
 627  		// We cannot allow in-place modification of x or y.
 628  		z = nil
 629  	}
 630  
 631  	// x**y mod 1 == 0
 632  	if len(m) == 1 && m[0] == 1 {
 633  		return z.setWord(0)
 634  	}
 635  	// m == 0 || m > 1
 636  
 637  	// x**0 == 1
 638  	if len(y) == 0 {
 639  		return z.setWord(1)
 640  	}
 641  	// y > 0
 642  
 643  	// 0**y = 0
 644  	if len(x) == 0 {
 645  		return z.setWord(0)
 646  	}
 647  	// x > 0
 648  
 649  	// 1**y = 1
 650  	if len(x) == 1 && x[0] == 1 {
 651  		return z.setWord(1)
 652  	}
 653  	// x > 1
 654  
 655  	// x**1 == x
 656  	if len(y) == 1 && y[0] == 1 && len(m) == 0 {
 657  		return z.set(x)
 658  	}
 659  	if stk == nil {
 660  		stk = getStack()
 661  		defer stk.free()
 662  	}
 663  	if len(y) == 1 && y[0] == 1 { // len(m) > 0
 664  		return z.rem(stk, x, m)
 665  	}
 666  
 667  	// y > 1
 668  
 669  	if len(m) != 0 {
 670  		// We likely end up being as long as the modulus.
 671  		z = z.make(len(m))
 672  
 673  		// If the exponent is large, we use the Montgomery method for odd values,
 674  		// and a 4-bit, windowed exponentiation for powers of two,
 675  		// and a CRT-decomposed Montgomery method for the remaining values
 676  		// (even values times non-trivial odd values, which decompose into one
 677  		// instance of each of the first two cases).
 678  		if len(y) > 1 && !slow {
 679  			if m[0]&1 == 1 {
 680  				return z.expNNMontgomery(stk, x, y, m)
 681  			}
 682  			if logM, ok := m.isPow2(); ok {
 683  				return z.expNNWindowed(stk, x, y, logM)
 684  			}
 685  			return z.expNNMontgomeryEven(stk, x, y, m)
 686  		}
 687  	}
 688  
 689  	z = z.set(x)
 690  	v := y[len(y)-1] // v > 0 because y is normalized and y > 0
 691  	shift := nlz(v) + 1
 692  	v <<= shift
 693  	var q nat
 694  
 695  	const mask = 1 << (_W - 1)
 696  
 697  	// We walk through the bits of the exponent one by one. Each time we
 698  	// see a bit, we square, thus doubling the power. If the bit is a one,
 699  	// we also multiply by x, thus adding one to the power.
 700  
 701  	w := _W - int(shift)
 702  	// zz and r are used to avoid allocating in mul and div as
 703  	// otherwise the arguments would alias.
 704  	var zz, r nat
 705  	for j := 0; j < w; j++ {
 706  		zz = zz.sqr(stk, z)
 707  		zz, z = z, zz
 708  
 709  		if v&mask != 0 {
 710  			zz = zz.mul(stk, z, x)
 711  			zz, z = z, zz
 712  		}
 713  
 714  		if len(m) != 0 {
 715  			zz, r = zz.div(stk, r, z, m)
 716  			zz, r, q, z = q, z, zz, r
 717  		}
 718  
 719  		v <<= 1
 720  	}
 721  
 722  	for i := len(y) - 2; i >= 0; i-- {
 723  		v = y[i]
 724  
 725  		for j := 0; j < _W; j++ {
 726  			zz = zz.sqr(stk, z)
 727  			zz, z = z, zz
 728  
 729  			if v&mask != 0 {
 730  				zz = zz.mul(stk, z, x)
 731  				zz, z = z, zz
 732  			}
 733  
 734  			if len(m) != 0 {
 735  				zz, r = zz.div(stk, r, z, m)
 736  				zz, r, q, z = q, z, zz, r
 737  			}
 738  
 739  			v <<= 1
 740  		}
 741  	}
 742  
 743  	return z.norm()
 744  }
 745  
 746  // expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd.
 747  // It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2
 748  // and then uses the Chinese Remainder Theorem to combine the results.
 749  // The recursive call using m1 will use expNNWindowed,
 750  // while the recursive call using m2 will use expNNMontgomery.
 751  // For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”,
 752  // IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994.
 753  // http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf
 754  func (z nat) expNNMontgomeryEven(stk *stack, x, y, m nat) nat {
 755  	// Split m = m₁ × m₂ where m₁ = 2ⁿ
 756  	n := m.trailingZeroBits()
 757  	m1 := nat(nil).lsh(natOne, n)
 758  	m2 := nat(nil).rsh(m, n)
 759  
 760  	// We want z = x**y mod m.
 761  	// z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1
 762  	// z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2
 763  	// (We are using the math/big convention for names here,
 764  	// where the computation is z = x**y mod m, so its parts are z1 and z2.
 765  	// The paper is computing x = a**e mod n; it refers to these as x2 and z1.)
 766  	z1 := nat(nil).expNN(stk, x, y, m1, false)
 767  	z2 := nat(nil).expNN(stk, x, y, m2, false)
 768  
 769  	// Reconstruct z from z₁, z₂ using CRT, using algorithm from paper,
 770  	// which uses only a single modInverse (and an easy one at that).
 771  	//	p = (z₁ - z₂) × m₂⁻¹ (mod m₁)
 772  	//	z = z₂ + p × m₂
 773  	// The final addition is in range because:
 774  	//	z = z₂ + p × m₂
 775  	//	  ≤ z₂ + (m₁-1) × m₂
 776  	//	  < m₂ + (m₁-1) × m₂
 777  	//	  = m₁ × m₂
 778  	//	  = m.
 779  	z = z.set(z2)
 780  
 781  	// Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1.
 782  	z1 = z1.subMod2N(z1, z2, n)
 783  
 784  	// Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]).
 785  	m2inv := nat(nil).modInverse(m2, m1)
 786  	z2 = z2.mul(stk, z1, m2inv)
 787  	z2 = z2.trunc(z2, n)
 788  
 789  	// Reuse z1 for p * m2.
 790  	z = z.add(z, z1.mul(stk, z2, m2))
 791  
 792  	return z
 793  }
 794  
 795  // expNNWindowed calculates x**y mod m using a fixed, 4-bit window,
 796  // where m = 2**logM.
 797  func (z nat) expNNWindowed(stk *stack, x, y nat, logM uint) nat {
 798  	if len(y) <= 1 {
 799  		panic("big: misuse of expNNWindowed")
 800  	}
 801  	if x[0]&1 == 0 {
 802  		// len(y) > 1, so y  > logM.
 803  		// x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM.
 804  		return z.setWord(0)
 805  	}
 806  	if logM == 1 {
 807  		return z.setWord(1)
 808  	}
 809  
 810  	// zz is used to avoid allocating in mul as otherwise
 811  	// the arguments would alias.
 812  	defer stk.restore(stk.save())
 813  	w := int((logM + _W - 1) / _W)
 814  	zz := stk.nat(w)
 815  
 816  	const n = 4
 817  	// powers[i] contains x^i.
 818  	var powers [1 << n]nat
 819  	for i := range powers {
 820  		powers[i] = stk.nat(w)
 821  	}
 822  	powers[0] = powers[0].set(natOne)
 823  	powers[1] = powers[1].trunc(x, logM)
 824  	for i := 2; i < 1<<n; i += 2 {
 825  		p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1]
 826  		*p = p.sqr(stk, *p2)
 827  		*p = p.trunc(*p, logM)
 828  		*p1 = p1.mul(stk, *p, x)
 829  		*p1 = p1.trunc(*p1, logM)
 830  	}
 831  
 832  	// Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1,
 833  	// so we can compute x**(y mod 2**(logM-1)) instead of x**y.
 834  	// That is, we can throw away all but the bottom logM-1 bits of y.
 835  	// Instead of allocating a new y, we start reading y at the right word
 836  	// and truncate it appropriately at the start of the loop.
 837  	i := len(y) - 1
 838  	mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word.
 839  	mmask := ^Word(0)
 840  	if mbits := (logM - 1) & (_W - 1); mbits != 0 {
 841  		mmask = (1 << mbits) - 1
 842  	}
 843  	if i > mtop {
 844  		i = mtop
 845  	}
 846  	advance := false
 847  	z = z.setWord(1)
 848  	for ; i >= 0; i-- {
 849  		yi := y[i]
 850  		if i == mtop {
 851  			yi &= mmask
 852  		}
 853  		for j := 0; j < _W; j += n {
 854  			if advance {
 855  				// Account for use of 4 bits in previous iteration.
 856  				// Unrolled loop for significant performance
 857  				// gain. Use go test -bench=".*" in crypto/rsa
 858  				// to check performance before making changes.
 859  				zz = zz.sqr(stk, z)
 860  				zz, z = z, zz
 861  				z = z.trunc(z, logM)
 862  
 863  				zz = zz.sqr(stk, z)
 864  				zz, z = z, zz
 865  				z = z.trunc(z, logM)
 866  
 867  				zz = zz.sqr(stk, z)
 868  				zz, z = z, zz
 869  				z = z.trunc(z, logM)
 870  
 871  				zz = zz.sqr(stk, z)
 872  				zz, z = z, zz
 873  				z = z.trunc(z, logM)
 874  			}
 875  
 876  			zz = zz.mul(stk, z, powers[yi>>(_W-n)])
 877  			zz, z = z, zz
 878  			z = z.trunc(z, logM)
 879  
 880  			yi <<= n
 881  			advance = true
 882  		}
 883  	}
 884  
 885  	return z.norm()
 886  }
 887  
 888  // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window.
 889  // Uses Montgomery representation.
 890  func (z nat) expNNMontgomery(stk *stack, x, y, m nat) nat {
 891  	numWords := len(m)
 892  
 893  	// We want the lengths of x and m to be equal.
 894  	// It is OK if x >= m as long as len(x) == len(m).
 895  	if len(x) > numWords {
 896  		_, x = nat(nil).div(stk, nil, x, m)
 897  		// Note: now len(x) <= numWords, not guaranteed ==.
 898  	}
 899  	if len(x) < numWords {
 900  		rr := make(nat, numWords)
 901  		copy(rr, x)
 902  		x = rr
 903  	}
 904  
 905  	// Ideally the precomputations would be performed outside, and reused
 906  	// k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson
 907  	// Iteration for Multiplicative Inverses Modulo Prime Powers".
 908  	k0 := 2 - m[0]
 909  	t := m[0] - 1
 910  	for i := 1; i < _W; i <<= 1 {
 911  		t *= t
 912  		k0 *= (t + 1)
 913  	}
 914  	k0 = -k0
 915  
 916  	// RR = 2**(2*_W*len(m)) mod m
 917  	RR := nat(nil).setWord(1)
 918  	zz := nat(nil).lsh(RR, uint(2*numWords*_W))
 919  	_, RR = nat(nil).div(stk, RR, zz, m)
 920  	if len(RR) < numWords {
 921  		zz = zz.make(numWords)
 922  		copy(zz, RR)
 923  		RR = zz
 924  	}
 925  	// one = 1, with equal length to that of m
 926  	one := make(nat, numWords)
 927  	one[0] = 1
 928  
 929  	const n = 4
 930  	// powers[i] contains x^i
 931  	var powers [1 << n]nat
 932  	powers[0] = powers[0].montgomery(one, RR, m, k0, numWords)
 933  	powers[1] = powers[1].montgomery(x, RR, m, k0, numWords)
 934  	for i := 2; i < 1<<n; i++ {
 935  		powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords)
 936  	}
 937  
 938  	// initialize z = 1 (Montgomery 1)
 939  	z = z.make(numWords)
 940  	copy(z, powers[0])
 941  
 942  	zz = zz.make(numWords)
 943  
 944  	// same windowed exponent, but with Montgomery multiplications
 945  	for i := len(y) - 1; i >= 0; i-- {
 946  		yi := y[i]
 947  		for j := 0; j < _W; j += n {
 948  			if i != len(y)-1 || j != 0 {
 949  				zz = zz.montgomery(z, z, m, k0, numWords)
 950  				z = z.montgomery(zz, zz, m, k0, numWords)
 951  				zz = zz.montgomery(z, z, m, k0, numWords)
 952  				z = z.montgomery(zz, zz, m, k0, numWords)
 953  			}
 954  			zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords)
 955  			z, zz = zz, z
 956  			yi <<= n
 957  		}
 958  	}
 959  	// convert to regular number
 960  	zz = zz.montgomery(z, one, m, k0, numWords)
 961  
 962  	// One last reduction, just in case.
 963  	// See golang.org/issue/13907.
 964  	if zz.cmp(m) >= 0 {
 965  		// Common case is m has high bit set; in that case,
 966  		// since zz is the same length as m, there can be just
 967  		// one multiple of m to remove. Just subtract.
 968  		// We think that the subtract should be sufficient in general,
 969  		// so do that unconditionally, but double-check,
 970  		// in case our beliefs are wrong.
 971  		// The div is not expected to be reached.
 972  		zz = zz.sub(zz, m)
 973  		if zz.cmp(m) >= 0 {
 974  			_, zz = nat(nil).div(stk, nil, zz, m)
 975  		}
 976  	}
 977  
 978  	return zz.norm()
 979  }
 980  
 981  // bytes writes the value of z into buf using big-endian encoding.
 982  // The value of z is encoded in the slice buf[i:]. If the value of z
 983  // cannot be represented in buf, bytes panics. The number i of unused
 984  // bytes at the beginning of buf is returned as result.
 985  func (z nat) bytes(buf []byte) (i int) {
 986  	// This function is used in cryptographic operations. It must not leak
 987  	// anything but the Int's sign and bit size through side-channels. Any
 988  	// changes must be reviewed by a security expert.
 989  	i = len(buf)
 990  	for _, d := range z {
 991  		for j := 0; j < _S; j++ {
 992  			i--
 993  			if i >= 0 {
 994  				buf[i] = byte(d)
 995  			} else if byte(d) != 0 {
 996  				panic("math/big: buffer too small to fit value")
 997  			}
 998  			d >>= 8
 999  		}
1000  	}
1001  
1002  	if i < 0 {
1003  		i = 0
1004  	}
1005  	for i < len(buf) && buf[i] == 0 {
1006  		i++
1007  	}
1008  
1009  	return
1010  }
1011  
1012  // bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value.
1013  func bigEndianWord(buf []byte) Word {
1014  	if _W == 64 {
1015  		return Word(byteorder.BEUint64(buf))
1016  	}
1017  	return Word(byteorder.BEUint32(buf))
1018  }
1019  
1020  // setBytes interprets buf as the bytes of a big-endian unsigned
1021  // integer, sets z to that value, and returns z.
1022  func (z nat) setBytes(buf []byte) nat {
1023  	z = z.make((len(buf) + _S - 1) / _S)
1024  
1025  	i := len(buf)
1026  	for k := 0; i >= _S; k++ {
1027  		z[k] = bigEndianWord(buf[i-_S : i])
1028  		i -= _S
1029  	}
1030  	if i > 0 {
1031  		var d Word
1032  		for s := uint(0); i > 0; s += 8 {
1033  			d |= Word(buf[i-1]) << s
1034  			i--
1035  		}
1036  		z[len(z)-1] = d
1037  	}
1038  
1039  	return z.norm()
1040  }
1041  
1042  // sqrt sets z = ⌊√x⌋
1043  // The caller may pass stk == nil to request that sqrt obtain and release one itself.
1044  func (z nat) sqrt(stk *stack, x nat) nat {
1045  	if x.cmp(natOne) <= 0 {
1046  		return z.set(x)
1047  	}
1048  	if alias(z, x) {
1049  		z = nil
1050  	}
1051  
1052  	if stk == nil {
1053  		stk = getStack()
1054  		defer stk.free()
1055  	}
1056  
1057  	// Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller.
1058  	// See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt).
1059  	// https://members.loria.fr/PZimmermann/mca/pub226.html
1060  	// If x is one less than a perfect square, the sequence oscillates between the correct z and z+1;
1061  	// otherwise it converges to the correct z and stays there.
1062  	var z1, z2 nat
1063  	z1 = z
1064  	z1 = z1.setUint64(1)
1065  	z1 = z1.lsh(z1, uint(x.bitLen()+1)/2) // must be ≥ √x
1066  	for n := 0; ; n++ {
1067  		z2, _ = z2.div(stk, nil, x, z1)
1068  		z2 = z2.add(z2, z1)
1069  		z2 = z2.rsh(z2, 1)
1070  		if z2.cmp(z1) >= 0 {
1071  			// z1 is answer.
1072  			// Figure out whether z1 or z2 is currently aliased to z by looking at loop count.
1073  			if n&1 == 0 {
1074  				return z1
1075  			}
1076  			return z.set(z1)
1077  		}
1078  		z1, z2 = z2, z1
1079  	}
1080  }
1081  
1082  // subMod2N returns z = (x - y) mod 2ⁿ.
1083  func (z nat) subMod2N(x, y nat, n uint) nat {
1084  	if uint(x.bitLen()) > n {
1085  		if alias(z, x) {
1086  			// ok to overwrite x in place
1087  			x = x.trunc(x, n)
1088  		} else {
1089  			x = nat(nil).trunc(x, n)
1090  		}
1091  	}
1092  	if uint(y.bitLen()) > n {
1093  		if alias(z, y) {
1094  			// ok to overwrite y in place
1095  			y = y.trunc(y, n)
1096  		} else {
1097  			y = nat(nil).trunc(y, n)
1098  		}
1099  	}
1100  	if x.cmp(y) >= 0 {
1101  		return z.sub(x, y)
1102  	}
1103  	// x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x).
1104  	z = z.sub(y, x)
1105  	for uint(len(z))*_W < n {
1106  		z = append(z, 0)
1107  	}
1108  	for i := range z {
1109  		z[i] = ^z[i]
1110  	}
1111  	z = z.trunc(z, n)
1112  	return z.add(z, natOne)
1113  }
1114