1 // Copyright 2021 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 package bigmod
6 7 import (
8 _ "crypto/internal/fips140/check"
9 "crypto/internal/fips140deps/byteorder"
10 "errors"
11 "math/bits"
12 )
13 14 const (
15 // _W is the size in bits of our limbs.
16 _W = bits.UintSize
17 // _S is the size in bytes of our limbs.
18 _S = _W / 8
19 )
20 21 // Note: These functions make many loops over all the words in a Nat.
22 // These loops used to be in assembly, invisible to -race, -asan, and -msan,
23 // but now they are in Go and incur significant overhead in those modes.
24 // To bring the old performance back, we mark all functions that loop
25 // over Nat words with //go:norace. Because //go:norace does not
26 // propagate across inlining, we must also mark functions that inline
27 // //go:norace functions - specifically, those that inline add, addMulVVW,
28 // assign, cmpGeq, rshift1, and sub.
29 30 // choice represents a constant-time boolean. The value of choice is always
31 // either 1 or 0. We use an int instead of bool in order to make decisions in
32 // constant time by turning it into a mask.
33 type choice uint
34 35 func not(c choice) choice { return 1 ^ c }
36 37 const yes = choice(1)
38 const no = choice(0)
39 40 // ctMask is all 1s if on is yes, and all 0s otherwise.
41 func ctMask(on choice) uint { return -uint(on) }
42 43 // ctEq returns 1 if x == y, and 0 otherwise. The execution time of this
44 // function does not depend on its inputs.
45 func ctEq(x, y uint) choice {
46 // If x != y, then either x - y or y - x will generate a carry.
47 _, c1 := bits.Sub(x, y, 0)
48 _, c2 := bits.Sub(y, x, 0)
49 return not(choice(c1 | c2))
50 }
51 52 // Nat represents an arbitrary natural number
53 //
54 // Each Nat has an announced length, which is the number of limbs it has stored.
55 // Operations on this number are allowed to leak this length, but will not leak
56 // any information about the values contained in those limbs.
57 type Nat struct {
58 // limbs is little-endian in base 2^W with W = bits.UintSize.
59 limbs []uint
60 }
61 62 // preallocTarget is the size in bits of the numbers used to implement the most
63 // common and most performant RSA key size. It's also enough to cover some of
64 // the operations of key sizes up to 4096.
65 const preallocTarget = 2048
66 const preallocLimbs = (preallocTarget + _W - 1) / _W
67 68 // NewNat returns a new nat with a size of zero, just like new(Nat), but with
69 // the preallocated capacity to hold a number of up to preallocTarget bits.
70 // NewNat inlines, so the allocation can live on the stack.
71 func NewNat() *Nat {
72 limbs := []uint{:0:preallocLimbs}
73 return &Nat{limbs}
74 }
75 76 // expand expands x to n limbs, leaving its value unchanged.
77 func (x *Nat) expand(n int) *Nat {
78 if len(x.limbs) > n {
79 panic("bigmod: internal error: shrinking nat")
80 }
81 if cap(x.limbs) < n {
82 newLimbs := []uint{:n}
83 copy(newLimbs, x.limbs)
84 x.limbs = newLimbs
85 return x
86 }
87 extraLimbs := x.limbs[len(x.limbs):n]
88 clear(extraLimbs)
89 x.limbs = x.limbs[:n]
90 return x
91 }
92 93 // reset returns a zero nat of n limbs, reusing x's storage if n <= cap(x.limbs).
94 func (x *Nat) reset(n int) *Nat {
95 if cap(x.limbs) < n {
96 x.limbs = []uint{:n}
97 return x
98 }
99 // Clear both the returned limbs and the previously used ones.
100 clear(x.limbs[:max(n, len(x.limbs))])
101 x.limbs = x.limbs[:n]
102 return x
103 }
104 105 // resetToBytes assigns x = b, where b is a slice of big-endian bytes, resizing
106 // n to the appropriate size.
107 //
108 // The announced length of x is set based on the actual bit size of the input,
109 // ignoring leading zeroes.
110 func (x *Nat) resetToBytes(b []byte) *Nat {
111 x.reset((len(b) + _S - 1) / _S)
112 if err := x.setBytes(b); err != nil {
113 panic("bigmod: internal error: bad arithmetic")
114 }
115 return x.trim()
116 }
117 118 // trim reduces the size of x to match its value.
119 func (x *Nat) trim() *Nat {
120 // Trim most significant (trailing in little-endian) zero limbs.
121 // We assume comparison with zero (but not the branch) is constant time.
122 for i := len(x.limbs) - 1; i >= 0; i-- {
123 if x.limbs[i] != 0 {
124 break
125 }
126 x.limbs = x.limbs[:i]
127 }
128 return x
129 }
130 131 // set assigns x = y, optionally resizing x to the appropriate size.
132 func (x *Nat) set(y *Nat) *Nat {
133 x.reset(len(y.limbs))
134 copy(x.limbs, y.limbs)
135 return x
136 }
137 138 // Bits returns x as a little-endian slice of uint. The length of the slice
139 // matches the announced length of x. The result and x share the same underlying
140 // array.
141 func (x *Nat) Bits() []uint {
142 return x.limbs
143 }
144 145 // Bytes returns x as a zero-extended big-endian byte slice. The size of the
146 // slice will match the size of m.
147 //
148 // x must have the same size as m and it must be less than or equal to m.
149 func (x *Nat) Bytes(m *Modulus) []byte {
150 i := m.Size()
151 bytes := []byte{:i}
152 for _, limb := range x.limbs {
153 for j := 0; j < _S; j++ {
154 i--
155 if i < 0 {
156 if limb == 0 {
157 break
158 }
159 panic("bigmod: modulus is smaller than nat")
160 }
161 bytes[i] = byte(limb)
162 limb >>= 8
163 }
164 }
165 return bytes
166 }
167 168 // SetBytes assigns x = b, where b is a slice of big-endian bytes.
169 // SetBytes returns an error if b >= m.
170 //
171 // The output will be resized to the size of m and overwritten.
172 //
173 //go:norace
174 func (x *Nat) SetBytes(b []byte, m *Modulus) (*Nat, error) {
175 x.resetFor(m)
176 if err := x.setBytes(b); err != nil {
177 return nil, err
178 }
179 if x.cmpGeq(m.nat) == yes {
180 return nil, errors.New("input overflows the modulus")
181 }
182 return x, nil
183 }
184 185 // SetOverflowingBytes assigns x = b, where b is a slice of big-endian bytes.
186 // SetOverflowingBytes returns an error if b has a longer bit length than m, but
187 // reduces overflowing values up to 2^⌈log2(m)⌉ - 1.
188 //
189 // The output will be resized to the size of m and overwritten.
190 func (x *Nat) SetOverflowingBytes(b []byte, m *Modulus) (*Nat, error) {
191 x.resetFor(m)
192 if err := x.setBytes(b); err != nil {
193 return nil, err
194 }
195 // setBytes would have returned an error if the input overflowed the limb
196 // size of the modulus, so now we only need to check if the most significant
197 // limb of x has more bits than the most significant limb of the modulus.
198 if bitLen(x.limbs[len(x.limbs)-1]) > bitLen(m.nat.limbs[len(m.nat.limbs)-1]) {
199 return nil, errors.New("input overflows the modulus size")
200 }
201 x.maybeSubtractModulus(no, m)
202 return x, nil
203 }
204 205 // bigEndianUint returns the contents of buf interpreted as a
206 // big-endian encoded uint value.
207 func bigEndianUint(buf []byte) uint {
208 if _W == 64 {
209 return uint(byteorder.BEUint64(buf))
210 }
211 return uint(byteorder.BEUint32(buf))
212 }
213 214 func (x *Nat) setBytes(b []byte) error {
215 i, k := len(b), 0
216 for k < len(x.limbs) && i >= _S {
217 x.limbs[k] = bigEndianUint(b[i-_S : i])
218 i -= _S
219 k++
220 }
221 for s := 0; s < _W && k < len(x.limbs) && i > 0; s += 8 {
222 x.limbs[k] |= uint(b[i-1]) << s
223 i--
224 }
225 if i > 0 {
226 return errors.New("input overflows the modulus size")
227 }
228 return nil
229 }
230 231 // SetUint assigns x = y.
232 //
233 // The output will be resized to a single limb and overwritten.
234 func (x *Nat) SetUint(y uint) *Nat {
235 x.reset(1)
236 x.limbs[0] = y
237 return x
238 }
239 240 // Equal returns 1 if x == y, and 0 otherwise.
241 //
242 // Both operands must have the same announced length.
243 //
244 //go:norace
245 func (x *Nat) Equal(y *Nat) choice {
246 // Eliminate bounds checks in the loop.
247 size := len(x.limbs)
248 xLimbs := x.limbs[:size]
249 yLimbs := y.limbs[:size]
250 251 equal := yes
252 for i := 0; i < size; i++ {
253 equal &= ctEq(xLimbs[i], yLimbs[i])
254 }
255 return equal
256 }
257 258 // IsZero returns 1 if x == 0, and 0 otherwise.
259 //
260 //go:norace
261 func (x *Nat) IsZero() choice {
262 // Eliminate bounds checks in the loop.
263 size := len(x.limbs)
264 xLimbs := x.limbs[:size]
265 266 zero := yes
267 for i := 0; i < size; i++ {
268 zero &= ctEq(xLimbs[i], 0)
269 }
270 return zero
271 }
272 273 // IsOne returns 1 if x == 1, and 0 otherwise.
274 //
275 //go:norace
276 func (x *Nat) IsOne() choice {
277 // Eliminate bounds checks in the loop.
278 size := len(x.limbs)
279 xLimbs := x.limbs[:size]
280 281 if len(xLimbs) == 0 {
282 return no
283 }
284 285 one := ctEq(xLimbs[0], 1)
286 for i := 1; i < size; i++ {
287 one &= ctEq(xLimbs[i], 0)
288 }
289 return one
290 }
291 292 // IsMinusOne returns 1 if x == -1 mod m, and 0 otherwise.
293 //
294 // The length of x must be the same as the modulus. x must already be reduced
295 // modulo m.
296 //
297 //go:norace
298 func (x *Nat) IsMinusOne(m *Modulus) choice {
299 minusOne := m.Nat()
300 minusOne.SubOne(m)
301 return x.Equal(minusOne)
302 }
303 304 // IsOdd returns 1 if x is odd, and 0 otherwise.
305 func (x *Nat) IsOdd() choice {
306 if len(x.limbs) == 0 {
307 return no
308 }
309 return choice(x.limbs[0] & 1)
310 }
311 312 // TrailingZeroBitsVarTime returns the number of trailing zero bits in x.
313 func (x *Nat) TrailingZeroBitsVarTime() uint {
314 var t uint
315 limbs := x.limbs
316 for _, l := range limbs {
317 if l == 0 {
318 t += _W
319 continue
320 }
321 t += uint(bits.TrailingZeros(l))
322 break
323 }
324 return t
325 }
326 327 // cmpGeq returns 1 if x >= y, and 0 otherwise.
328 //
329 // Both operands must have the same announced length.
330 //
331 //go:norace
332 func (x *Nat) cmpGeq(y *Nat) choice {
333 // Eliminate bounds checks in the loop.
334 size := len(x.limbs)
335 xLimbs := x.limbs[:size]
336 yLimbs := y.limbs[:size]
337 338 var c uint
339 for i := 0; i < size; i++ {
340 _, c = bits.Sub(xLimbs[i], yLimbs[i], c)
341 }
342 // If there was a carry, then subtracting y underflowed, so
343 // x is not greater than or equal to y.
344 return not(choice(c))
345 }
346 347 // assign sets x <- y if on == 1, and does nothing otherwise.
348 //
349 // Both operands must have the same announced length.
350 //
351 //go:norace
352 func (x *Nat) assign(on choice, y *Nat) *Nat {
353 // Eliminate bounds checks in the loop.
354 size := len(x.limbs)
355 xLimbs := x.limbs[:size]
356 yLimbs := y.limbs[:size]
357 358 mask := ctMask(on)
359 for i := 0; i < size; i++ {
360 xLimbs[i] ^= mask & (xLimbs[i] ^ yLimbs[i])
361 }
362 return x
363 }
364 365 // add computes x += y and returns the carry.
366 //
367 // Both operands must have the same announced length.
368 //
369 //go:norace
370 func (x *Nat) add(y *Nat) (c uint) {
371 // Eliminate bounds checks in the loop.
372 size := len(x.limbs)
373 xLimbs := x.limbs[:size]
374 yLimbs := y.limbs[:size]
375 376 for i := 0; i < size; i++ {
377 xLimbs[i], c = bits.Add(xLimbs[i], yLimbs[i], c)
378 }
379 return
380 }
381 382 // sub computes x -= y. It returns the borrow of the subtraction.
383 //
384 // Both operands must have the same announced length.
385 //
386 //go:norace
387 func (x *Nat) sub(y *Nat) (c uint) {
388 // Eliminate bounds checks in the loop.
389 size := len(x.limbs)
390 xLimbs := x.limbs[:size]
391 yLimbs := y.limbs[:size]
392 393 for i := 0; i < size; i++ {
394 xLimbs[i], c = bits.Sub(xLimbs[i], yLimbs[i], c)
395 }
396 return
397 }
398 399 // ShiftRightVarTime sets x = x >> n.
400 //
401 // The announced length of x is unchanged.
402 //
403 //go:norace
404 func (x *Nat) ShiftRightVarTime(n uint) *Nat {
405 // Eliminate bounds checks in the loop.
406 size := len(x.limbs)
407 xLimbs := x.limbs[:size]
408 409 shift := int(n % _W)
410 shiftLimbs := int(n / _W)
411 412 var shiftedLimbs []uint
413 if shiftLimbs < size {
414 shiftedLimbs = xLimbs[shiftLimbs:]
415 }
416 417 for i := range xLimbs {
418 if i >= len(shiftedLimbs) {
419 xLimbs[i] = 0
420 continue
421 }
422 423 xLimbs[i] = shiftedLimbs[i] >> shift
424 if i+1 < len(shiftedLimbs) {
425 xLimbs[i] |= shiftedLimbs[i+1] << (_W - shift)
426 }
427 }
428 429 return x
430 }
431 432 // BitLenVarTime returns the actual size of x in bits.
433 //
434 // The actual size of x (but nothing more) leaks through timing side-channels.
435 // Note that this is ordinarily secret, as opposed to the announced size of x.
436 func (x *Nat) BitLenVarTime() int {
437 // Eliminate bounds checks in the loop.
438 size := len(x.limbs)
439 xLimbs := x.limbs[:size]
440 441 for i := size - 1; i >= 0; i-- {
442 if xLimbs[i] != 0 {
443 return i*_W + bitLen(xLimbs[i])
444 }
445 }
446 return 0
447 }
448 449 // bitLen is a version of bits.Len that only leaks the bit length of n, but not
450 // its value. bits.Len and bits.LeadingZeros use a lookup table for the
451 // low-order bits on some architectures.
452 func bitLen(n uint) int {
453 len := 0
454 // We assume, here and elsewhere, that comparison to zero is constant time
455 // with respect to different non-zero values.
456 for n != 0 {
457 len++
458 n >>= 1
459 }
460 return len
461 }
462 463 // Modulus is used for modular arithmetic, precomputing relevant constants.
464 //
465 // A Modulus can leak the exact number of bits needed to store its value
466 // and is stored without padding. Its actual value is still kept secret.
467 type Modulus struct {
468 // The underlying natural number for this modulus.
469 //
470 // This will be stored without any padding, and shouldn't alias with any
471 // other natural number being used.
472 nat *Nat
473 474 // If m is even, the following fields are not set.
475 odd bool
476 m0inv uint // -nat.limbs[0]⁻¹ mod _W
477 rr *Nat // R*R for montgomeryRepresentation
478 }
479 480 // rr returns R*R with R = 2^(_W * n) and n = len(m.nat.limbs).
481 func rr(m *Modulus) *Nat {
482 rr := NewNat().ExpandFor(m)
483 n := uint(len(rr.limbs))
484 mLen := uint(m.BitLen())
485 logR := _W * n
486 487 // We start by computing R = 2^(_W * n) mod m. We can get pretty close, to
488 // 2^⌊log₂m⌋, by setting the highest bit we can without having to reduce.
489 rr.limbs[n-1] = 1 << ((mLen - 1) % _W)
490 // Then we double until we reach 2^(_W * n).
491 for i := mLen - 1; i < logR; i++ {
492 rr.Add(rr, m)
493 }
494 495 // Next we need to get from R to 2^(_W * n) R mod m (aka from one to R in
496 // the Montgomery domain, meaning we can use Montgomery multiplication now).
497 // We could do that by doubling _W * n times, or with a square-and-double
498 // chain log2(_W * n) long. Turns out the fastest thing is to start out with
499 // doublings, and switch to square-and-double once the exponent is large
500 // enough to justify the cost of the multiplications.
501 502 // The threshold is selected experimentally as a linear function of n.
503 threshold := n / 4
504 505 // We calculate how many of the most-significant bits of the exponent we can
506 // compute before crossing the threshold, and we do it with doublings.
507 i := bits.UintSize
508 for logR>>i <= threshold {
509 i--
510 }
511 for k := uint(0); k < logR>>i; k++ {
512 rr.Add(rr, m)
513 }
514 515 // Then we process the remaining bits of the exponent with a
516 // square-and-double chain.
517 for i > 0 {
518 rr.montgomeryMul(rr, rr, m)
519 i--
520 if logR>>i&1 != 0 {
521 rr.Add(rr, m)
522 }
523 }
524 525 return rr
526 }
527 528 // minusInverseModW computes -x⁻¹ mod _W with x odd.
529 //
530 // This operation is used to precompute a constant involved in Montgomery
531 // multiplication.
532 func minusInverseModW(x uint) uint {
533 // Every iteration of this loop doubles the least-significant bits of
534 // correct inverse in y. The first three bits are already correct (1⁻¹ = 1,
535 // 3⁻¹ = 3, 5⁻¹ = 5, and 7⁻¹ = 7 mod 8), so doubling five times is enough
536 // for 64 bits (and wastes only one iteration for 32 bits).
537 //
538 // See https://crypto.stackexchange.com/a/47496.
539 y := x
540 for i := 0; i < 5; i++ {
541 y = y * (2 - x*y)
542 }
543 return -y
544 }
545 546 // NewModulus creates a new Modulus from a slice of big-endian bytes. The
547 // modulus must be greater than one.
548 //
549 // The number of significant bits and whether the modulus is even is leaked
550 // through timing side-channels.
551 func NewModulus(b []byte) (*Modulus, error) {
552 n := NewNat().resetToBytes(b)
553 return newModulus(n)
554 }
555 556 // NewModulusProduct creates a new Modulus from the product of two numbers
557 // represented as big-endian byte slices. The result must be greater than one.
558 //
559 //go:norace
560 func NewModulusProduct(a, b []byte) (*Modulus, error) {
561 x := NewNat().resetToBytes(a)
562 y := NewNat().resetToBytes(b)
563 n := NewNat().reset(len(x.limbs) + len(y.limbs))
564 for i := range y.limbs {
565 n.limbs[i+len(x.limbs)] = addMulVVW(n.limbs[i:i+len(x.limbs)], x.limbs, y.limbs[i])
566 }
567 return newModulus(n.trim())
568 }
569 570 func newModulus(n *Nat) (*Modulus, error) {
571 m := &Modulus{nat: n}
572 if m.nat.IsZero() == yes || m.nat.IsOne() == yes {
573 return nil, errors.New("modulus must be > 1")
574 }
575 if m.nat.IsOdd() == 1 {
576 m.odd = true
577 m.m0inv = minusInverseModW(m.nat.limbs[0])
578 m.rr = rr(m)
579 }
580 return m, nil
581 }
582 583 // Size returns the size of m in bytes.
584 func (m *Modulus) Size() int {
585 return (m.BitLen() + 7) / 8
586 }
587 588 // BitLen returns the size of m in bits.
589 func (m *Modulus) BitLen() int {
590 return m.nat.BitLenVarTime()
591 }
592 593 // Nat returns m as a Nat.
594 func (m *Modulus) Nat() *Nat {
595 // Make a copy so that the caller can't modify m.nat or alias it with
596 // another Nat in a modulus operation.
597 n := NewNat()
598 n.set(m.nat)
599 return n
600 }
601 602 // shiftIn calculates x = x << _W + y mod m.
603 //
604 // This assumes that x is already reduced mod m.
605 //
606 //go:norace
607 func (x *Nat) shiftIn(y uint, m *Modulus) *Nat {
608 d := NewNat().resetFor(m)
609 610 // Eliminate bounds checks in the loop.
611 size := len(m.nat.limbs)
612 xLimbs := x.limbs[:size]
613 dLimbs := d.limbs[:size]
614 mLimbs := m.nat.limbs[:size]
615 616 // Each iteration of this loop computes x = 2x + b mod m, where b is a bit
617 // from y. Effectively, it left-shifts x and adds y one bit at a time,
618 // reducing it every time.
619 //
620 // To do the reduction, each iteration computes both 2x + b and 2x + b - m.
621 // The next iteration (and finally the return line) will use either result
622 // based on whether 2x + b overflows m.
623 needSubtraction := no
624 for i := _W - 1; i >= 0; i-- {
625 carry := (y >> i) & 1
626 var borrow uint
627 mask := ctMask(needSubtraction)
628 for i := 0; i < size; i++ {
629 l := xLimbs[i] ^ (mask & (xLimbs[i] ^ dLimbs[i]))
630 xLimbs[i], carry = bits.Add(l, l, carry)
631 dLimbs[i], borrow = bits.Sub(xLimbs[i], mLimbs[i], borrow)
632 }
633 // Like in maybeSubtractModulus, we need the subtraction if either it
634 // didn't underflow (meaning 2x + b > m) or if computing 2x + b
635 // overflowed (meaning 2x + b > 2^_W*n > m).
636 needSubtraction = not(choice(borrow)) | choice(carry)
637 }
638 return x.assign(needSubtraction, d)
639 }
640 641 // Mod calculates out = x mod m.
642 //
643 // This works regardless how large the value of x is.
644 //
645 // The output will be resized to the size of m and overwritten.
646 //
647 //go:norace
648 func (out *Nat) Mod(x *Nat, m *Modulus) *Nat {
649 out.resetFor(m)
650 // Working our way from the most significant to the least significant limb,
651 // we can insert each limb at the least significant position, shifting all
652 // previous limbs left by _W. This way each limb will get shifted by the
653 // correct number of bits. We can insert at least N - 1 limbs without
654 // overflowing m. After that, we need to reduce every time we shift.
655 i := len(x.limbs) - 1
656 // For the first N - 1 limbs we can skip the actual shifting and position
657 // them at the shifted position, which starts at min(N - 2, i).
658 start := len(m.nat.limbs) - 2
659 if i < start {
660 start = i
661 }
662 for j := start; j >= 0; j-- {
663 out.limbs[j] = x.limbs[i]
664 i--
665 }
666 // We shift in the remaining limbs, reducing modulo m each time.
667 for i >= 0 {
668 out.shiftIn(x.limbs[i], m)
669 i--
670 }
671 return out
672 }
673 674 // ExpandFor ensures x has the right size to work with operations modulo m.
675 //
676 // The announced size of x must be smaller than or equal to that of m.
677 func (x *Nat) ExpandFor(m *Modulus) *Nat {
678 return x.expand(len(m.nat.limbs))
679 }
680 681 // resetFor ensures out has the right size to work with operations modulo m.
682 //
683 // out is zeroed and may start at any size.
684 func (out *Nat) resetFor(m *Modulus) *Nat {
685 return out.reset(len(m.nat.limbs))
686 }
687 688 // maybeSubtractModulus computes x -= m if and only if x >= m or if "always" is yes.
689 //
690 // It can be used to reduce modulo m a value up to 2m - 1, which is a common
691 // range for results computed by higher level operations.
692 //
693 // always is usually a carry that indicates that the operation that produced x
694 // overflowed its size, meaning abstractly x > 2^_W*n > m even if x < m.
695 //
696 // x and m operands must have the same announced length.
697 //
698 //go:norace
699 func (x *Nat) maybeSubtractModulus(always choice, m *Modulus) {
700 t := NewNat().set(x)
701 underflow := t.sub(m.nat)
702 // We keep the result if x - m didn't underflow (meaning x >= m)
703 // or if always was set.
704 keep := not(choice(underflow)) | choice(always)
705 x.assign(keep, t)
706 }
707 708 // Sub computes x = x - y mod m.
709 //
710 // The length of both operands must be the same as the modulus. Both operands
711 // must already be reduced modulo m.
712 //
713 //go:norace
714 func (x *Nat) Sub(y *Nat, m *Modulus) *Nat {
715 underflow := x.sub(y)
716 // If the subtraction underflowed, add m.
717 t := NewNat().set(x)
718 t.add(m.nat)
719 x.assign(choice(underflow), t)
720 return x
721 }
722 723 // SubOne computes x = x - 1 mod m.
724 //
725 // The length of x must be the same as the modulus.
726 func (x *Nat) SubOne(m *Modulus) *Nat {
727 one := NewNat().ExpandFor(m)
728 one.limbs[0] = 1
729 // Sub asks for x to be reduced modulo m, while SubOne doesn't, but when
730 // y = 1, it works, and this is an internal use.
731 return x.Sub(one, m)
732 }
733 734 // Add computes x = x + y mod m.
735 //
736 // The length of both operands must be the same as the modulus. Both operands
737 // must already be reduced modulo m.
738 //
739 //go:norace
740 func (x *Nat) Add(y *Nat, m *Modulus) *Nat {
741 overflow := x.add(y)
742 x.maybeSubtractModulus(choice(overflow), m)
743 return x
744 }
745 746 // montgomeryRepresentation calculates x = x * R mod m, with R = 2^(_W * n) and
747 // n = len(m.nat.limbs).
748 //
749 // Faster Montgomery multiplication replaces standard modular multiplication for
750 // numbers in this representation.
751 //
752 // This assumes that x is already reduced mod m.
753 func (x *Nat) montgomeryRepresentation(m *Modulus) *Nat {
754 // A Montgomery multiplication (which computes a * b / R) by R * R works out
755 // to a multiplication by R, which takes the value out of the Montgomery domain.
756 return x.montgomeryMul(x, m.rr, m)
757 }
758 759 // montgomeryReduction calculates x = x / R mod m, with R = 2^(_W * n) and
760 // n = len(m.nat.limbs).
761 //
762 // This assumes that x is already reduced mod m.
763 func (x *Nat) montgomeryReduction(m *Modulus) *Nat {
764 // By Montgomery multiplying with 1 not in Montgomery representation, we
765 // convert out back from Montgomery representation, because it works out to
766 // dividing by R.
767 one := NewNat().ExpandFor(m)
768 one.limbs[0] = 1
769 return x.montgomeryMul(x, one, m)
770 }
771 772 // montgomeryMul calculates x = a * b / R mod m, with R = 2^(_W * n) and
773 // n = len(m.nat.limbs), also known as a Montgomery multiplication.
774 //
775 // All inputs should be the same length and already reduced modulo m.
776 // x will be resized to the size of m and overwritten.
777 //
778 //go:norace
779 func (x *Nat) montgomeryMul(a *Nat, b *Nat, m *Modulus) *Nat {
780 n := len(m.nat.limbs)
781 mLimbs := m.nat.limbs[:n]
782 aLimbs := a.limbs[:n]
783 bLimbs := b.limbs[:n]
784 785 switch n {
786 default:
787 // Attempt to use a stack-allocated backing array.
788 T := []uint{:0:preallocLimbs*2}
789 if cap(T) < n*2 {
790 T = []uint{:0:n*2}
791 }
792 T = T[:n*2]
793 794 // This loop implements Word-by-Word Montgomery Multiplication, as
795 // described in Algorithm 4 (Fig. 3) of "Efficient Software
796 // Implementations of Modular Exponentiation" by Shay Gueron
797 // [https://eprint.iacr.org/2011/239.pdf].
798 var c uint
799 for i := 0; i < n; i++ {
800 _ = T[n+i] // bounds check elimination hint
801 802 // Step 1 (T = a × b) is computed as a large pen-and-paper column
803 // multiplication of two numbers with n base-2^_W digits. If we just
804 // wanted to produce 2n-wide T, we would do
805 //
806 // for i := 0; i < n; i++ {
807 // d := bLimbs[i]
808 // T[n+i] = addMulVVW(T[i:n+i], aLimbs, d)
809 // }
810 //
811 // where d is a digit of the multiplier, T[i:n+i] is the shifted
812 // position of the product of that digit, and T[n+i] is the final carry.
813 // Note that T[i] isn't modified after processing the i-th digit.
814 //
815 // Instead of running two loops, one for Step 1 and one for Steps 2–6,
816 // the result of Step 1 is computed during the next loop. This is
817 // possible because each iteration only uses T[i] in Step 2 and then
818 // discards it in Step 6.
819 d := bLimbs[i]
820 c1 := addMulVVW(T[i:n+i], aLimbs, d)
821 822 // Step 6 is replaced by shifting the virtual window we operate
823 // over: T of the algorithm is T[i:] for us. That means that T1 in
824 // Step 2 (T mod 2^_W) is simply T[i]. k0 in Step 3 is our m0inv.
825 Y := T[i] * m.m0inv
826 827 // Step 4 and 5 add Y × m to T, which as mentioned above is stored
828 // at T[i:]. The two carries (from a × d and Y × m) are added up in
829 // the next word T[n+i], and the carry bit from that addition is
830 // brought forward to the next iteration.
831 c2 := addMulVVW(T[i:n+i], mLimbs, Y)
832 T[n+i], c = bits.Add(c1, c2, c)
833 }
834 835 // Finally for Step 7 we copy the final T window into x, and subtract m
836 // if necessary (which as explained in maybeSubtractModulus can be the
837 // case both if x >= m, or if x overflowed).
838 //
839 // The paper suggests in Section 4 that we can do an "Almost Montgomery
840 // Multiplication" by subtracting only in the overflow case, but the
841 // cost is very similar since the constant time subtraction tells us if
842 // x >= m as a side effect, and taking care of the broken invariant is
843 // highly undesirable (see https://go.dev/issue/13907).
844 copy(x.reset(n).limbs, T[n:])
845 x.maybeSubtractModulus(choice(c), m)
846 847 // The following specialized cases follow the exact same algorithm, but
848 // optimized for the sizes most used in RSA. addMulVVW is implemented in
849 // assembly with loop unrolling depending on the architecture and bounds
850 // checks are removed by the compiler thanks to the constant size.
851 case 1024 / _W:
852 const n = 1024 / _W // compiler hint
853 T := []uint{:n*2}
854 var c uint
855 for i := 0; i < n; i++ {
856 d := bLimbs[i]
857 c1 := addMulVVW1024(&T[i], &aLimbs[0], d)
858 Y := T[i] * m.m0inv
859 c2 := addMulVVW1024(&T[i], &mLimbs[0], Y)
860 T[n+i], c = bits.Add(c1, c2, c)
861 }
862 copy(x.reset(n).limbs, T[n:])
863 x.maybeSubtractModulus(choice(c), m)
864 865 case 1536 / _W:
866 const n = 1536 / _W // compiler hint
867 T := []uint{:n*2}
868 var c uint
869 for i := 0; i < n; i++ {
870 d := bLimbs[i]
871 c1 := addMulVVW1536(&T[i], &aLimbs[0], d)
872 Y := T[i] * m.m0inv
873 c2 := addMulVVW1536(&T[i], &mLimbs[0], Y)
874 T[n+i], c = bits.Add(c1, c2, c)
875 }
876 copy(x.reset(n).limbs, T[n:])
877 x.maybeSubtractModulus(choice(c), m)
878 879 case 2048 / _W:
880 const n = 2048 / _W // compiler hint
881 T := []uint{:n*2}
882 var c uint
883 for i := 0; i < n; i++ {
884 d := bLimbs[i]
885 c1 := addMulVVW2048(&T[i], &aLimbs[0], d)
886 Y := T[i] * m.m0inv
887 c2 := addMulVVW2048(&T[i], &mLimbs[0], Y)
888 T[n+i], c = bits.Add(c1, c2, c)
889 }
890 copy(x.reset(n).limbs, T[n:])
891 x.maybeSubtractModulus(choice(c), m)
892 }
893 894 return x
895 }
896 897 // addMulVVW multiplies the multi-word value x by the single-word value y,
898 // adding the result to the multi-word value z and returning the final carry.
899 // It can be thought of as one row of a pen-and-paper column multiplication.
900 //
901 //go:norace
902 func addMulVVW(z, x []uint, y uint) (carry uint) {
903 _ = x[len(z)-1] // bounds check elimination hint
904 for i := range z {
905 hi, lo := bits.Mul(x[i], y)
906 lo, c := bits.Add(lo, z[i], 0)
907 // We use bits.Add with zero to get an add-with-carry instruction that
908 // absorbs the carry from the previous bits.Add.
909 hi, _ = bits.Add(hi, 0, c)
910 lo, c = bits.Add(lo, carry, 0)
911 hi, _ = bits.Add(hi, 0, c)
912 carry = hi
913 z[i] = lo
914 }
915 return carry
916 }
917 918 // Mul calculates x = x * y mod m.
919 //
920 // The length of both operands must be the same as the modulus. Both operands
921 // must already be reduced modulo m.
922 //
923 //go:norace
924 func (x *Nat) Mul(y *Nat, m *Modulus) *Nat {
925 if m.odd {
926 // A Montgomery multiplication by a value out of the Montgomery domain
927 // takes the result out of Montgomery representation.
928 xR := NewNat().set(x).montgomeryRepresentation(m) // xR = x * R mod m
929 return x.montgomeryMul(xR, y, m) // x = xR * y / R mod m
930 }
931 932 n := len(m.nat.limbs)
933 xLimbs := x.limbs[:n]
934 yLimbs := y.limbs[:n]
935 936 switch n {
937 default:
938 // Attempt to use a stack-allocated backing array.
939 T := []uint{:0:preallocLimbs*2}
940 if cap(T) < n*2 {
941 T = []uint{:0:n*2}
942 }
943 T = T[:n*2]
944 945 // T = x * y
946 for i := 0; i < n; i++ {
947 T[n+i] = addMulVVW(T[i:n+i], xLimbs, yLimbs[i])
948 }
949 950 // x = T mod m
951 return x.Mod(&Nat{limbs: T}, m)
952 953 // The following specialized cases follow the exact same algorithm, but
954 // optimized for the sizes most used in RSA. See montgomeryMul for details.
955 case 1024 / _W:
956 const n = 1024 / _W // compiler hint
957 T := []uint{:n*2}
958 for i := 0; i < n; i++ {
959 T[n+i] = addMulVVW1024(&T[i], &xLimbs[0], yLimbs[i])
960 }
961 return x.Mod(&Nat{limbs: T}, m)
962 case 1536 / _W:
963 const n = 1536 / _W // compiler hint
964 T := []uint{:n*2}
965 for i := 0; i < n; i++ {
966 T[n+i] = addMulVVW1536(&T[i], &xLimbs[0], yLimbs[i])
967 }
968 return x.Mod(&Nat{limbs: T}, m)
969 case 2048 / _W:
970 const n = 2048 / _W // compiler hint
971 T := []uint{:n*2}
972 for i := 0; i < n; i++ {
973 T[n+i] = addMulVVW2048(&T[i], &xLimbs[0], yLimbs[i])
974 }
975 return x.Mod(&Nat{limbs: T}, m)
976 }
977 }
978 979 // Exp calculates out = x^e mod m.
980 //
981 // The exponent e is represented in big-endian order. The output will be resized
982 // to the size of m and overwritten. x must already be reduced modulo m.
983 //
984 // m must be odd, or Exp will panic.
985 //
986 //go:norace
987 func (out *Nat) Exp(x *Nat, e []byte, m *Modulus) *Nat {
988 if !m.odd {
989 panic("bigmod: modulus for Exp must be odd")
990 }
991 992 // We use a 4 bit window. For our RSA workload, 4 bit windows are faster
993 // than 2 bit windows, but use an extra 12 nats worth of scratch space.
994 // Using bit sizes that don't divide 8 are more complex to implement, but
995 // are likely to be more efficient if necessary.
996 997 table := [(1 << 4) - 1]*Nat{ // table[i] = x ^ (i+1)
998 // newNat calls are unrolled so they are allocated on the stack.
999 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
1000 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
1001 NewNat(), NewNat(), NewNat(), NewNat(), NewNat(),
1002 }
1003 table[0].set(x).montgomeryRepresentation(m)
1004 for i := 1; i < len(table); i++ {
1005 table[i].montgomeryMul(table[i-1], table[0], m)
1006 }
1007 1008 out.resetFor(m)
1009 out.limbs[0] = 1
1010 out.montgomeryRepresentation(m)
1011 tmp := NewNat().ExpandFor(m)
1012 for _, b := range e {
1013 for _, j := range []int{4, 0} {
1014 // Square four times. Optimization note: this can be implemented
1015 // more efficiently than with generic Montgomery multiplication.
1016 out.montgomeryMul(out, out, m)
1017 out.montgomeryMul(out, out, m)
1018 out.montgomeryMul(out, out, m)
1019 out.montgomeryMul(out, out, m)
1020 1021 // Select x^k in constant time from the table.
1022 k := uint((b >> j) & 0b1111)
1023 for i := range table {
1024 tmp.assign(ctEq(k, uint(i+1)), table[i])
1025 }
1026 1027 // Multiply by x^k, discarding the result if k = 0.
1028 tmp.montgomeryMul(out, tmp, m)
1029 out.assign(not(ctEq(k, 0)), tmp)
1030 }
1031 }
1032 1033 return out.montgomeryReduction(m)
1034 }
1035 1036 // ExpShortVarTime calculates out = x^e mod m.
1037 //
1038 // The output will be resized to the size of m and overwritten. x must already
1039 // be reduced modulo m. This leaks the exponent through timing side-channels.
1040 //
1041 // m must be odd, or ExpShortVarTime will panic.
1042 func (out *Nat) ExpShortVarTime(x *Nat, e uint, m *Modulus) *Nat {
1043 if !m.odd {
1044 panic("bigmod: modulus for ExpShortVarTime must be odd")
1045 }
1046 // For short exponents, precomputing a table and using a window like in Exp
1047 // doesn't pay off. Instead, we do a simple conditional square-and-multiply
1048 // chain, skipping the initial run of zeroes.
1049 xR := NewNat().set(x).montgomeryRepresentation(m)
1050 out.set(xR)
1051 for i := bits.UintSize - bits.Len(e) + 1; i < bits.UintSize; i++ {
1052 out.montgomeryMul(out, out, m)
1053 if k := (e >> (bits.UintSize - i - 1)) & 1; k != 0 {
1054 out.montgomeryMul(out, xR, m)
1055 }
1056 }
1057 return out.montgomeryReduction(m)
1058 }
1059 1060 // InverseVarTime calculates x = a⁻¹ mod m and returns (x, true) if a is
1061 // invertible. Otherwise, InverseVarTime returns (x, false) and x is not
1062 // modified.
1063 //
1064 // a must be reduced modulo m, but doesn't need to have the same size. The
1065 // output will be resized to the size of m and overwritten.
1066 //
1067 //go:norace
1068 func (x *Nat) InverseVarTime(a *Nat, m *Modulus) (*Nat, bool) {
1069 u, A, err := extendedGCD(a, m.nat)
1070 if err != nil {
1071 return x, false
1072 }
1073 if u.IsOne() == no {
1074 return x, false
1075 }
1076 return x.set(A), true
1077 }
1078 1079 // GCDVarTime calculates x = GCD(a, b) where at least one of a or b is odd, and
1080 // both are non-zero. If GCDVarTime returns an error, x is not modified.
1081 //
1082 // The output will be resized to the size of the larger of a and b.
1083 func (x *Nat) GCDVarTime(a, b *Nat) (*Nat, error) {
1084 u, _, err := extendedGCD(a, b)
1085 if err != nil {
1086 return nil, err
1087 }
1088 return x.set(u), nil
1089 }
1090 1091 // extendedGCD computes u and A such that a = GCD(a, m) and u = A*a - B*m.
1092 //
1093 // u will have the size of the larger of a and m, and A will have the size of m.
1094 //
1095 // It is an error if either a or m is zero, or if they are both even.
1096 func extendedGCD(a, m *Nat) (u, A *Nat, err error) {
1097 // This is the extended binary GCD algorithm described in the Handbook of
1098 // Applied Cryptography, Algorithm 14.61, adapted by BoringSSL to bound
1099 // coefficients and avoid negative numbers. For more details and proof of
1100 // correctness, see https://github.com/mit-plv/fiat-crypto/pull/333/files.
1101 //
1102 // Following the proof linked in the PR above, the changes are:
1103 //
1104 // 1. Negate [B] and [C] so they are positive. The invariant now involves a
1105 // subtraction.
1106 // 2. If step 2 (both [x] and [y] are even) runs, abort immediately. This
1107 // case needs to be handled by the caller.
1108 // 3. Subtract copies of [x] and [y] as needed in step 6 (both [u] and [v]
1109 // are odd) so coefficients stay in bounds.
1110 // 4. Replace the [u >= v] check with [u > v]. This changes the end
1111 // condition to [v = 0] rather than [u = 0]. This saves an extra
1112 // subtraction due to which coefficients were negated.
1113 // 5. Rename x and y to a and n, to capture that one is a modulus.
1114 // 6. Rearrange steps 4 through 6 slightly. Merge the loops in steps 4 and
1115 // 5 into the main loop (step 7's goto), and move step 6 to the start of
1116 // the loop iteration, ensuring each loop iteration halves at least one
1117 // value.
1118 //
1119 // Note this algorithm does not handle either input being zero.
1120 1121 if a.IsZero() == yes || m.IsZero() == yes {
1122 return nil, nil, errors.New("extendedGCD: a or m is zero")
1123 }
1124 if a.IsOdd() == no && m.IsOdd() == no {
1125 return nil, nil, errors.New("extendedGCD: both a and m are even")
1126 }
1127 1128 size := max(len(a.limbs), len(m.limbs))
1129 u = NewNat().set(a).expand(size)
1130 v := NewNat().set(m).expand(size)
1131 1132 A = NewNat().reset(len(m.limbs))
1133 A.limbs[0] = 1
1134 B := NewNat().reset(len(a.limbs))
1135 C := NewNat().reset(len(m.limbs))
1136 D := NewNat().reset(len(a.limbs))
1137 D.limbs[0] = 1
1138 1139 // Before and after each loop iteration, the following hold:
1140 //
1141 // u = A*a - B*m
1142 // v = D*m - C*a
1143 // 0 < u <= a
1144 // 0 <= v <= m
1145 // 0 <= A < m
1146 // 0 <= B <= a
1147 // 0 <= C < m
1148 // 0 <= D <= a
1149 //
1150 // After each loop iteration, u and v only get smaller, and at least one of
1151 // them shrinks by at least a factor of two.
1152 for {
1153 // If both u and v are odd, subtract the smaller from the larger.
1154 // If u = v, we need to subtract from v to hit the modified exit condition.
1155 if u.IsOdd() == yes && v.IsOdd() == yes {
1156 if v.cmpGeq(u) == no {
1157 u.sub(v)
1158 A.Add(C, &Modulus{nat: m})
1159 B.Add(D, &Modulus{nat: a})
1160 } else {
1161 v.sub(u)
1162 C.Add(A, &Modulus{nat: m})
1163 D.Add(B, &Modulus{nat: a})
1164 }
1165 }
1166 1167 // Exactly one of u and v is now even.
1168 if u.IsOdd() == v.IsOdd() {
1169 panic("bigmod: internal error: u and v are not in the expected state")
1170 }
1171 1172 // Halve the even one and adjust the corresponding coefficient.
1173 if u.IsOdd() == no {
1174 rshift1(u, 0)
1175 if A.IsOdd() == yes || B.IsOdd() == yes {
1176 rshift1(A, A.add(m))
1177 rshift1(B, B.add(a))
1178 } else {
1179 rshift1(A, 0)
1180 rshift1(B, 0)
1181 }
1182 } else { // v.IsOdd() == no
1183 rshift1(v, 0)
1184 if C.IsOdd() == yes || D.IsOdd() == yes {
1185 rshift1(C, C.add(m))
1186 rshift1(D, D.add(a))
1187 } else {
1188 rshift1(C, 0)
1189 rshift1(D, 0)
1190 }
1191 }
1192 1193 if v.IsZero() == yes {
1194 return u, A, nil
1195 }
1196 }
1197 }
1198 1199 //go:norace
1200 func rshift1(a *Nat, carry uint) {
1201 size := len(a.limbs)
1202 aLimbs := a.limbs[:size]
1203 1204 for i := range size {
1205 aLimbs[i] >>= 1
1206 if i+1 < size {
1207 aLimbs[i] |= aLimbs[i+1] << (_W - 1)
1208 } else {
1209 aLimbs[i] |= carry << (_W - 1)
1210 }
1211 }
1212 }
1213 1214 // DivShortVarTime calculates x = x / y and returns the remainder.
1215 //
1216 // It panics if y is zero.
1217 //
1218 //go:norace
1219 func (x *Nat) DivShortVarTime(y uint) uint {
1220 if y == 0 {
1221 panic("bigmod: division by zero")
1222 }
1223 1224 var r uint
1225 for i := len(x.limbs) - 1; i >= 0; i-- {
1226 x.limbs[i], r = bits.Div(r, x.limbs[i], y)
1227 }
1228 return r
1229 }
1230