1 // Copyright 2010 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 // This file implements multi-precision rational numbers.
6 7 package big
8 9 import (
10 "fmt"
11 "math"
12 )
13 14 // A Rat represents a quotient a/b of arbitrary precision.
15 // The zero value for a Rat represents the value 0.
16 //
17 // Operations always take pointer arguments (*Rat) rather
18 // than Rat values, and each unique Rat value requires
19 // its own unique *Rat pointer. To "copy" a Rat value,
20 // an existing (or newly allocated) Rat must be set to
21 // a new value using the [Rat.Set] method; shallow copies
22 // of Rats are not supported and may lead to errors.
23 type Rat struct {
24 // To make zero values for Rat work w/o initialization,
25 // a zero value of b (len(b) == 0) acts like b == 1. At
26 // the earliest opportunity (when an assignment to the Rat
27 // is made), such uninitialized denominators are set to 1.
28 // a.neg determines the sign of the Rat, b.neg is ignored.
29 a, b Int
30 }
31 32 // NewRat creates a new [Rat] with numerator a and denominator b.
33 func NewRat(a, b int64) *Rat {
34 return (&Rat{}).SetFrac64(a, b)
35 }
36 37 // SetFloat64 sets z to exactly f and returns z.
38 // If f is not finite, SetFloat returns nil.
39 func (z *Rat) SetFloat64(f float64) *Rat {
40 const expMask = 1<<11 - 1
41 bits := math.Float64bits(f)
42 mantissa := bits & (1<<52 - 1)
43 exp := int((bits >> 52) & expMask)
44 switch exp {
45 case expMask: // non-finite
46 return nil
47 case 0: // denormal
48 exp -= 1022
49 default: // normal
50 mantissa |= 1 << 52
51 exp -= 1023
52 }
53 54 shift := 52 - exp
55 56 // Optimization (?): partially pre-normalise.
57 for mantissa&1 == 0 && shift > 0 {
58 mantissa >>= 1
59 shift--
60 }
61 62 z.a.SetUint64(mantissa)
63 z.a.neg = f < 0
64 z.b.Set(intOne)
65 if shift > 0 {
66 z.b.Lsh(&z.b, uint(shift))
67 } else {
68 z.a.Lsh(&z.a, uint(-shift))
69 }
70 return z.norm()
71 }
72 73 // quotToFloat32 returns the non-negative float32 value
74 // nearest to the quotient a/b, using round-to-even in
75 // halfway cases. It does not mutate its arguments.
76 // Preconditions: b is non-zero; a and b have no common factors.
77 func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) {
78 const (
79 // float size in bits
80 Fsize = 32
81 82 // mantissa
83 Msize = 23
84 Msize1 = Msize + 1 // incl. implicit 1
85 Msize2 = Msize1 + 1
86 87 // exponent
88 Esize = Fsize - Msize1
89 Ebias = 1<<(Esize-1) - 1
90 Emin = 1 - Ebias
91 Emax = Ebias
92 )
93 94 // TODO(adonovan): specialize common degenerate cases: 1.0, integers.
95 alen := a.bitLen()
96 if alen == 0 {
97 return 0, true
98 }
99 blen := b.bitLen()
100 if blen == 0 {
101 panic("division by zero")
102 }
103 104 // 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
105 // (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
106 // This is 2 or 3 more than the float32 mantissa field width of Msize:
107 // - the optional extra bit is shifted away in step 3 below.
108 // - the high-order 1 is omitted in "normal" representation;
109 // - the low-order 1 will be used during rounding then discarded.
110 exp := alen - blen
111 var a2, b2 nat
112 a2 = a2.set(a)
113 b2 = b2.set(b)
114 if shift := Msize2 - exp; shift > 0 {
115 a2 = a2.lsh(a2, uint(shift))
116 } else if shift < 0 {
117 b2 = b2.lsh(b2, uint(-shift))
118 }
119 120 // 2. Compute quotient and remainder (q, r). NB: due to the
121 // extra shift, the low-order bit of q is logically the
122 // high-order bit of r.
123 var q nat
124 q, r := q.div(stk, a2, a2, b2) // (recycle a2)
125 mantissa := low32(q)
126 haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
127 128 // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
129 // (in effect---we accomplish this incrementally).
130 if mantissa>>Msize2 == 1 {
131 if mantissa&1 == 1 {
132 haveRem = true
133 }
134 mantissa >>= 1
135 exp++
136 }
137 if mantissa>>Msize1 != 1 {
138 panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
139 }
140 141 // 4. Rounding.
142 if Emin-Msize <= exp && exp <= Emin {
143 // Denormal case; lose 'shift' bits of precision.
144 shift := uint(Emin - (exp - 1)) // [1..Esize1)
145 lostbits := mantissa & (1<<shift - 1)
146 haveRem = haveRem || lostbits != 0
147 mantissa >>= shift
148 exp = 2 - Ebias // == exp + shift
149 }
150 // Round q using round-half-to-even.
151 exact = !haveRem
152 if mantissa&1 != 0 {
153 exact = false
154 if haveRem || mantissa&2 != 0 {
155 if mantissa++; mantissa >= 1<<Msize2 {
156 // Complete rollover 11...1 => 100...0, so shift is safe
157 mantissa >>= 1
158 exp++
159 }
160 }
161 }
162 mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1.
163 164 f = float32(math.Ldexp(float64(mantissa), exp-Msize1))
165 if math.IsInf(float64(f), 0) {
166 exact = false
167 }
168 return
169 }
170 171 // quotToFloat64 returns the non-negative float64 value
172 // nearest to the quotient a/b, using round-to-even in
173 // halfway cases. It does not mutate its arguments.
174 // Preconditions: b is non-zero; a and b have no common factors.
175 func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) {
176 const (
177 // float size in bits
178 Fsize = 64
179 180 // mantissa
181 Msize = 52
182 Msize1 = Msize + 1 // incl. implicit 1
183 Msize2 = Msize1 + 1
184 185 // exponent
186 Esize = Fsize - Msize1
187 Ebias = 1<<(Esize-1) - 1
188 Emin = 1 - Ebias
189 Emax = Ebias
190 )
191 192 // TODO(adonovan): specialize common degenerate cases: 1.0, integers.
193 alen := a.bitLen()
194 if alen == 0 {
195 return 0, true
196 }
197 blen := b.bitLen()
198 if blen == 0 {
199 panic("division by zero")
200 }
201 202 // 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
203 // (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
204 // This is 2 or 3 more than the float64 mantissa field width of Msize:
205 // - the optional extra bit is shifted away in step 3 below.
206 // - the high-order 1 is omitted in "normal" representation;
207 // - the low-order 1 will be used during rounding then discarded.
208 exp := alen - blen
209 var a2, b2 nat
210 a2 = a2.set(a)
211 b2 = b2.set(b)
212 if shift := Msize2 - exp; shift > 0 {
213 a2 = a2.lsh(a2, uint(shift))
214 } else if shift < 0 {
215 b2 = b2.lsh(b2, uint(-shift))
216 }
217 218 // 2. Compute quotient and remainder (q, r). NB: due to the
219 // extra shift, the low-order bit of q is logically the
220 // high-order bit of r.
221 var q nat
222 q, r := q.div(stk, a2, a2, b2) // (recycle a2)
223 mantissa := low64(q)
224 haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
225 226 // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
227 // (in effect---we accomplish this incrementally).
228 if mantissa>>Msize2 == 1 {
229 if mantissa&1 == 1 {
230 haveRem = true
231 }
232 mantissa >>= 1
233 exp++
234 }
235 if mantissa>>Msize1 != 1 {
236 panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
237 }
238 239 // 4. Rounding.
240 if Emin-Msize <= exp && exp <= Emin {
241 // Denormal case; lose 'shift' bits of precision.
242 shift := uint(Emin - (exp - 1)) // [1..Esize1)
243 lostbits := mantissa & (1<<shift - 1)
244 haveRem = haveRem || lostbits != 0
245 mantissa >>= shift
246 exp = 2 - Ebias // == exp + shift
247 }
248 // Round q using round-half-to-even.
249 exact = !haveRem
250 if mantissa&1 != 0 {
251 exact = false
252 if haveRem || mantissa&2 != 0 {
253 if mantissa++; mantissa >= 1<<Msize2 {
254 // Complete rollover 11...1 => 100...0, so shift is safe
255 mantissa >>= 1
256 exp++
257 }
258 }
259 }
260 mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1.
261 262 f = math.Ldexp(float64(mantissa), exp-Msize1)
263 if math.IsInf(f, 0) {
264 exact = false
265 }
266 return
267 }
268 269 // Float32 returns the nearest float32 value for x and a bool indicating
270 // whether f represents x exactly. If the magnitude of x is too large to
271 // be represented by a float32, f is an infinity and exact is false.
272 // The sign of f always matches the sign of x, even if f == 0.
273 func (x *Rat) Float32() (f float32, exact bool) {
274 b := x.b.abs
275 if len(b) == 0 {
276 b = natOne
277 }
278 stk := getStack()
279 defer stk.free()
280 f, exact = quotToFloat32(stk, x.a.abs, b)
281 if x.a.neg {
282 f = -f
283 }
284 return
285 }
286 287 // Float64 returns the nearest float64 value for x and a bool indicating
288 // whether f represents x exactly. If the magnitude of x is too large to
289 // be represented by a float64, f is an infinity and exact is false.
290 // The sign of f always matches the sign of x, even if f == 0.
291 func (x *Rat) Float64() (f float64, exact bool) {
292 b := x.b.abs
293 if len(b) == 0 {
294 b = natOne
295 }
296 stk := getStack()
297 defer stk.free()
298 f, exact = quotToFloat64(stk, x.a.abs, b)
299 if x.a.neg {
300 f = -f
301 }
302 return
303 }
304 305 // SetFrac sets z to a/b and returns z.
306 // If b == 0, SetFrac panics.
307 func (z *Rat) SetFrac(a, b *Int) *Rat {
308 z.a.neg = a.neg != b.neg
309 babs := b.abs
310 if len(babs) == 0 {
311 panic("division by zero")
312 }
313 if &z.a == b || alias(z.a.abs, babs) {
314 babs = nat(nil).set(babs) // make a copy
315 }
316 z.a.abs = z.a.abs.set(a.abs)
317 z.b.abs = z.b.abs.set(babs)
318 return z.norm()
319 }
320 321 // SetFrac64 sets z to a/b and returns z.
322 // If b == 0, SetFrac64 panics.
323 func (z *Rat) SetFrac64(a, b int64) *Rat {
324 if b == 0 {
325 panic("division by zero")
326 }
327 z.a.SetInt64(a)
328 if b < 0 {
329 b = -b
330 z.a.neg = !z.a.neg
331 }
332 z.b.abs = z.b.abs.setUint64(uint64(b))
333 return z.norm()
334 }
335 336 // SetInt sets z to x (by making a copy of x) and returns z.
337 func (z *Rat) SetInt(x *Int) *Rat {
338 z.a.Set(x)
339 z.b.abs = z.b.abs.setWord(1)
340 return z
341 }
342 343 // SetInt64 sets z to x and returns z.
344 func (z *Rat) SetInt64(x int64) *Rat {
345 z.a.SetInt64(x)
346 z.b.abs = z.b.abs.setWord(1)
347 return z
348 }
349 350 // SetUint64 sets z to x and returns z.
351 func (z *Rat) SetUint64(x uint64) *Rat {
352 z.a.SetUint64(x)
353 z.b.abs = z.b.abs.setWord(1)
354 return z
355 }
356 357 // Set sets z to x (by making a copy of x) and returns z.
358 func (z *Rat) Set(x *Rat) *Rat {
359 if z != x {
360 z.a.Set(&x.a)
361 z.b.Set(&x.b)
362 }
363 if len(z.b.abs) == 0 {
364 z.b.abs = z.b.abs.setWord(1)
365 }
366 return z
367 }
368 369 // Abs sets z to |x| (the absolute value of x) and returns z.
370 func (z *Rat) Abs(x *Rat) *Rat {
371 z.Set(x)
372 z.a.neg = false
373 return z
374 }
375 376 // Neg sets z to -x and returns z.
377 func (z *Rat) Neg(x *Rat) *Rat {
378 z.Set(x)
379 z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign
380 return z
381 }
382 383 // Inv sets z to 1/x and returns z.
384 // If x == 0, Inv panics.
385 func (z *Rat) Inv(x *Rat) *Rat {
386 if len(x.a.abs) == 0 {
387 panic("division by zero")
388 }
389 z.Set(x)
390 z.a.abs, z.b.abs = z.b.abs, z.a.abs
391 return z
392 }
393 394 // Sign returns:
395 // - -1 if x < 0;
396 // - 0 if x == 0;
397 // - +1 if x > 0.
398 func (x *Rat) Sign() int {
399 return x.a.Sign()
400 }
401 402 // IsInt reports whether the denominator of x is 1.
403 func (x *Rat) IsInt() bool {
404 return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0
405 }
406 407 // Num returns the numerator of x; it may be <= 0.
408 // The result is a reference to x's numerator; it
409 // may change if a new value is assigned to x, and vice versa.
410 // The sign of the numerator corresponds to the sign of x.
411 func (x *Rat) Num() *Int {
412 return &x.a
413 }
414 415 // Denom returns the denominator of x; it is always > 0.
416 // The result is a reference to x's denominator, unless
417 // x is an uninitialized (zero value) [Rat], in which case
418 // the result is a new [Int] of value 1. (To initialize x,
419 // any operation that sets x will do, including x.Set(x).)
420 // If the result is a reference to x's denominator it
421 // may change if a new value is assigned to x, and vice versa.
422 func (x *Rat) Denom() *Int {
423 // Note that x.b.neg is guaranteed false.
424 if len(x.b.abs) == 0 {
425 // Note: If this proves problematic, we could
426 // panic instead and require the Rat to
427 // be explicitly initialized.
428 return &Int{abs: nat{1}}
429 }
430 return &x.b
431 }
432 433 func (z *Rat) norm() *Rat {
434 switch {
435 case len(z.a.abs) == 0:
436 // z == 0; normalize sign and denominator
437 z.a.neg = false
438 z.b.abs = z.b.abs.setWord(1)
439 case len(z.b.abs) == 0:
440 // z is integer; normalize denominator
441 z.b.abs = z.b.abs.setWord(1)
442 default:
443 // z is fraction; normalize numerator and denominator
444 stk := getStack()
445 defer stk.free()
446 neg := z.a.neg
447 z.a.neg = false
448 z.b.neg = false
449 if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 {
450 z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs)
451 z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs)
452 }
453 z.a.neg = neg
454 }
455 return z
456 }
457 458 // mulDenom sets z to the denominator product x*y (by taking into
459 // account that 0 values for x or y must be interpreted as 1) and
460 // returns z.
461 func mulDenom(stk *stack, z, x, y nat) nat {
462 switch {
463 case len(x) == 0 && len(y) == 0:
464 return z.setWord(1)
465 case len(x) == 0:
466 return z.set(y)
467 case len(y) == 0:
468 return z.set(x)
469 }
470 return z.mul(stk, x, y)
471 }
472 473 // scaleDenom sets z to the product x*f.
474 // If f == 0 (zero value of denominator), z is set to (a copy of) x.
475 func (z *Int) scaleDenom(stk *stack, x *Int, f nat) {
476 if len(f) == 0 {
477 z.Set(x)
478 return
479 }
480 z.abs = z.abs.mul(stk, x.abs, f)
481 z.neg = x.neg
482 }
483 484 // Cmp compares x and y and returns:
485 // - -1 if x < y;
486 // - 0 if x == y;
487 // - +1 if x > y.
488 func (x *Rat) Cmp(y *Rat) int {
489 var a, b Int
490 stk := getStack()
491 defer stk.free()
492 a.scaleDenom(stk, &x.a, y.b.abs)
493 b.scaleDenom(stk, &y.a, x.b.abs)
494 return a.Cmp(&b)
495 }
496 497 // Add sets z to the sum x+y and returns z.
498 func (z *Rat) Add(x, y *Rat) *Rat {
499 stk := getStack()
500 defer stk.free()
501 502 var a1, a2 Int
503 a1.scaleDenom(stk, &x.a, y.b.abs)
504 a2.scaleDenom(stk, &y.a, x.b.abs)
505 z.a.Add(&a1, &a2)
506 z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
507 return z.norm()
508 }
509 510 // Sub sets z to the difference x-y and returns z.
511 func (z *Rat) Sub(x, y *Rat) *Rat {
512 stk := getStack()
513 defer stk.free()
514 515 var a1, a2 Int
516 a1.scaleDenom(stk, &x.a, y.b.abs)
517 a2.scaleDenom(stk, &y.a, x.b.abs)
518 z.a.Sub(&a1, &a2)
519 z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
520 return z.norm()
521 }
522 523 // Mul sets z to the product x*y and returns z.
524 func (z *Rat) Mul(x, y *Rat) *Rat {
525 stk := getStack()
526 defer stk.free()
527 528 if x == y {
529 // a squared Rat is positive and can't be reduced (no need to call norm())
530 z.a.neg = false
531 z.a.abs = z.a.abs.sqr(stk, x.a.abs)
532 if len(x.b.abs) == 0 {
533 z.b.abs = z.b.abs.setWord(1)
534 } else {
535 z.b.abs = z.b.abs.sqr(stk, x.b.abs)
536 }
537 return z
538 }
539 540 z.a.mul(stk, &x.a, &y.a)
541 z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
542 return z.norm()
543 }
544 545 // Quo sets z to the quotient x/y and returns z.
546 // If y == 0, Quo panics.
547 func (z *Rat) Quo(x, y *Rat) *Rat {
548 stk := getStack()
549 defer stk.free()
550 551 if len(y.a.abs) == 0 {
552 panic("division by zero")
553 }
554 var a, b Int
555 a.scaleDenom(stk, &x.a, y.b.abs)
556 b.scaleDenom(stk, &y.a, x.b.abs)
557 z.a.abs = a.abs
558 z.b.abs = b.abs
559 z.a.neg = a.neg != b.neg
560 return z.norm()
561 }
562