rat.mx raw

   1  // Copyright 2010 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // This file implements multi-precision rational numbers.
   6  
   7  package big
   8  
   9  import (
  10  	"fmt"
  11  	"math"
  12  )
  13  
  14  // A Rat represents a quotient a/b of arbitrary precision.
  15  // The zero value for a Rat represents the value 0.
  16  //
  17  // Operations always take pointer arguments (*Rat) rather
  18  // than Rat values, and each unique Rat value requires
  19  // its own unique *Rat pointer. To "copy" a Rat value,
  20  // an existing (or newly allocated) Rat must be set to
  21  // a new value using the [Rat.Set] method; shallow copies
  22  // of Rats are not supported and may lead to errors.
  23  type Rat struct {
  24  	// To make zero values for Rat work w/o initialization,
  25  	// a zero value of b (len(b) == 0) acts like b == 1. At
  26  	// the earliest opportunity (when an assignment to the Rat
  27  	// is made), such uninitialized denominators are set to 1.
  28  	// a.neg determines the sign of the Rat, b.neg is ignored.
  29  	a, b Int
  30  }
  31  
  32  // NewRat creates a new [Rat] with numerator a and denominator b.
  33  func NewRat(a, b int64) *Rat {
  34  	return (&Rat{}).SetFrac64(a, b)
  35  }
  36  
  37  // SetFloat64 sets z to exactly f and returns z.
  38  // If f is not finite, SetFloat returns nil.
  39  func (z *Rat) SetFloat64(f float64) *Rat {
  40  	const expMask = 1<<11 - 1
  41  	bits := math.Float64bits(f)
  42  	mantissa := bits & (1<<52 - 1)
  43  	exp := int((bits >> 52) & expMask)
  44  	switch exp {
  45  	case expMask: // non-finite
  46  		return nil
  47  	case 0: // denormal
  48  		exp -= 1022
  49  	default: // normal
  50  		mantissa |= 1 << 52
  51  		exp -= 1023
  52  	}
  53  
  54  	shift := 52 - exp
  55  
  56  	// Optimization (?): partially pre-normalise.
  57  	for mantissa&1 == 0 && shift > 0 {
  58  		mantissa >>= 1
  59  		shift--
  60  	}
  61  
  62  	z.a.SetUint64(mantissa)
  63  	z.a.neg = f < 0
  64  	z.b.Set(intOne)
  65  	if shift > 0 {
  66  		z.b.Lsh(&z.b, uint(shift))
  67  	} else {
  68  		z.a.Lsh(&z.a, uint(-shift))
  69  	}
  70  	return z.norm()
  71  }
  72  
  73  // quotToFloat32 returns the non-negative float32 value
  74  // nearest to the quotient a/b, using round-to-even in
  75  // halfway cases. It does not mutate its arguments.
  76  // Preconditions: b is non-zero; a and b have no common factors.
  77  func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) {
  78  	const (
  79  		// float size in bits
  80  		Fsize = 32
  81  
  82  		// mantissa
  83  		Msize  = 23
  84  		Msize1 = Msize + 1 // incl. implicit 1
  85  		Msize2 = Msize1 + 1
  86  
  87  		// exponent
  88  		Esize = Fsize - Msize1
  89  		Ebias = 1<<(Esize-1) - 1
  90  		Emin  = 1 - Ebias
  91  		Emax  = Ebias
  92  	)
  93  
  94  	// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
  95  	alen := a.bitLen()
  96  	if alen == 0 {
  97  		return 0, true
  98  	}
  99  	blen := b.bitLen()
 100  	if blen == 0 {
 101  		panic("division by zero")
 102  	}
 103  
 104  	// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
 105  	// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
 106  	// This is 2 or 3 more than the float32 mantissa field width of Msize:
 107  	// - the optional extra bit is shifted away in step 3 below.
 108  	// - the high-order 1 is omitted in "normal" representation;
 109  	// - the low-order 1 will be used during rounding then discarded.
 110  	exp := alen - blen
 111  	var a2, b2 nat
 112  	a2 = a2.set(a)
 113  	b2 = b2.set(b)
 114  	if shift := Msize2 - exp; shift > 0 {
 115  		a2 = a2.lsh(a2, uint(shift))
 116  	} else if shift < 0 {
 117  		b2 = b2.lsh(b2, uint(-shift))
 118  	}
 119  
 120  	// 2. Compute quotient and remainder (q, r).  NB: due to the
 121  	// extra shift, the low-order bit of q is logically the
 122  	// high-order bit of r.
 123  	var q nat
 124  	q, r := q.div(stk, a2, a2, b2) // (recycle a2)
 125  	mantissa := low32(q)
 126  	haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
 127  
 128  	// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
 129  	// (in effect---we accomplish this incrementally).
 130  	if mantissa>>Msize2 == 1 {
 131  		if mantissa&1 == 1 {
 132  			haveRem = true
 133  		}
 134  		mantissa >>= 1
 135  		exp++
 136  	}
 137  	if mantissa>>Msize1 != 1 {
 138  		panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
 139  	}
 140  
 141  	// 4. Rounding.
 142  	if Emin-Msize <= exp && exp <= Emin {
 143  		// Denormal case; lose 'shift' bits of precision.
 144  		shift := uint(Emin - (exp - 1)) // [1..Esize1)
 145  		lostbits := mantissa & (1<<shift - 1)
 146  		haveRem = haveRem || lostbits != 0
 147  		mantissa >>= shift
 148  		exp = 2 - Ebias // == exp + shift
 149  	}
 150  	// Round q using round-half-to-even.
 151  	exact = !haveRem
 152  	if mantissa&1 != 0 {
 153  		exact = false
 154  		if haveRem || mantissa&2 != 0 {
 155  			if mantissa++; mantissa >= 1<<Msize2 {
 156  				// Complete rollover 11...1 => 100...0, so shift is safe
 157  				mantissa >>= 1
 158  				exp++
 159  			}
 160  		}
 161  	}
 162  	mantissa >>= 1 // discard rounding bit.  Mantissa now scaled by 1<<Msize1.
 163  
 164  	f = float32(math.Ldexp(float64(mantissa), exp-Msize1))
 165  	if math.IsInf(float64(f), 0) {
 166  		exact = false
 167  	}
 168  	return
 169  }
 170  
 171  // quotToFloat64 returns the non-negative float64 value
 172  // nearest to the quotient a/b, using round-to-even in
 173  // halfway cases. It does not mutate its arguments.
 174  // Preconditions: b is non-zero; a and b have no common factors.
 175  func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) {
 176  	const (
 177  		// float size in bits
 178  		Fsize = 64
 179  
 180  		// mantissa
 181  		Msize  = 52
 182  		Msize1 = Msize + 1 // incl. implicit 1
 183  		Msize2 = Msize1 + 1
 184  
 185  		// exponent
 186  		Esize = Fsize - Msize1
 187  		Ebias = 1<<(Esize-1) - 1
 188  		Emin  = 1 - Ebias
 189  		Emax  = Ebias
 190  	)
 191  
 192  	// TODO(adonovan): specialize common degenerate cases: 1.0, integers.
 193  	alen := a.bitLen()
 194  	if alen == 0 {
 195  		return 0, true
 196  	}
 197  	blen := b.bitLen()
 198  	if blen == 0 {
 199  		panic("division by zero")
 200  	}
 201  
 202  	// 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1)
 203  	// (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B).
 204  	// This is 2 or 3 more than the float64 mantissa field width of Msize:
 205  	// - the optional extra bit is shifted away in step 3 below.
 206  	// - the high-order 1 is omitted in "normal" representation;
 207  	// - the low-order 1 will be used during rounding then discarded.
 208  	exp := alen - blen
 209  	var a2, b2 nat
 210  	a2 = a2.set(a)
 211  	b2 = b2.set(b)
 212  	if shift := Msize2 - exp; shift > 0 {
 213  		a2 = a2.lsh(a2, uint(shift))
 214  	} else if shift < 0 {
 215  		b2 = b2.lsh(b2, uint(-shift))
 216  	}
 217  
 218  	// 2. Compute quotient and remainder (q, r).  NB: due to the
 219  	// extra shift, the low-order bit of q is logically the
 220  	// high-order bit of r.
 221  	var q nat
 222  	q, r := q.div(stk, a2, a2, b2) // (recycle a2)
 223  	mantissa := low64(q)
 224  	haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half
 225  
 226  	// 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1
 227  	// (in effect---we accomplish this incrementally).
 228  	if mantissa>>Msize2 == 1 {
 229  		if mantissa&1 == 1 {
 230  			haveRem = true
 231  		}
 232  		mantissa >>= 1
 233  		exp++
 234  	}
 235  	if mantissa>>Msize1 != 1 {
 236  		panic(fmt.Sprintf("expected exactly %d bits of result", Msize2))
 237  	}
 238  
 239  	// 4. Rounding.
 240  	if Emin-Msize <= exp && exp <= Emin {
 241  		// Denormal case; lose 'shift' bits of precision.
 242  		shift := uint(Emin - (exp - 1)) // [1..Esize1)
 243  		lostbits := mantissa & (1<<shift - 1)
 244  		haveRem = haveRem || lostbits != 0
 245  		mantissa >>= shift
 246  		exp = 2 - Ebias // == exp + shift
 247  	}
 248  	// Round q using round-half-to-even.
 249  	exact = !haveRem
 250  	if mantissa&1 != 0 {
 251  		exact = false
 252  		if haveRem || mantissa&2 != 0 {
 253  			if mantissa++; mantissa >= 1<<Msize2 {
 254  				// Complete rollover 11...1 => 100...0, so shift is safe
 255  				mantissa >>= 1
 256  				exp++
 257  			}
 258  		}
 259  	}
 260  	mantissa >>= 1 // discard rounding bit.  Mantissa now scaled by 1<<Msize1.
 261  
 262  	f = math.Ldexp(float64(mantissa), exp-Msize1)
 263  	if math.IsInf(f, 0) {
 264  		exact = false
 265  	}
 266  	return
 267  }
 268  
 269  // Float32 returns the nearest float32 value for x and a bool indicating
 270  // whether f represents x exactly. If the magnitude of x is too large to
 271  // be represented by a float32, f is an infinity and exact is false.
 272  // The sign of f always matches the sign of x, even if f == 0.
 273  func (x *Rat) Float32() (f float32, exact bool) {
 274  	b := x.b.abs
 275  	if len(b) == 0 {
 276  		b = natOne
 277  	}
 278  	stk := getStack()
 279  	defer stk.free()
 280  	f, exact = quotToFloat32(stk, x.a.abs, b)
 281  	if x.a.neg {
 282  		f = -f
 283  	}
 284  	return
 285  }
 286  
 287  // Float64 returns the nearest float64 value for x and a bool indicating
 288  // whether f represents x exactly. If the magnitude of x is too large to
 289  // be represented by a float64, f is an infinity and exact is false.
 290  // The sign of f always matches the sign of x, even if f == 0.
 291  func (x *Rat) Float64() (f float64, exact bool) {
 292  	b := x.b.abs
 293  	if len(b) == 0 {
 294  		b = natOne
 295  	}
 296  	stk := getStack()
 297  	defer stk.free()
 298  	f, exact = quotToFloat64(stk, x.a.abs, b)
 299  	if x.a.neg {
 300  		f = -f
 301  	}
 302  	return
 303  }
 304  
 305  // SetFrac sets z to a/b and returns z.
 306  // If b == 0, SetFrac panics.
 307  func (z *Rat) SetFrac(a, b *Int) *Rat {
 308  	z.a.neg = a.neg != b.neg
 309  	babs := b.abs
 310  	if len(babs) == 0 {
 311  		panic("division by zero")
 312  	}
 313  	if &z.a == b || alias(z.a.abs, babs) {
 314  		babs = nat(nil).set(babs) // make a copy
 315  	}
 316  	z.a.abs = z.a.abs.set(a.abs)
 317  	z.b.abs = z.b.abs.set(babs)
 318  	return z.norm()
 319  }
 320  
 321  // SetFrac64 sets z to a/b and returns z.
 322  // If b == 0, SetFrac64 panics.
 323  func (z *Rat) SetFrac64(a, b int64) *Rat {
 324  	if b == 0 {
 325  		panic("division by zero")
 326  	}
 327  	z.a.SetInt64(a)
 328  	if b < 0 {
 329  		b = -b
 330  		z.a.neg = !z.a.neg
 331  	}
 332  	z.b.abs = z.b.abs.setUint64(uint64(b))
 333  	return z.norm()
 334  }
 335  
 336  // SetInt sets z to x (by making a copy of x) and returns z.
 337  func (z *Rat) SetInt(x *Int) *Rat {
 338  	z.a.Set(x)
 339  	z.b.abs = z.b.abs.setWord(1)
 340  	return z
 341  }
 342  
 343  // SetInt64 sets z to x and returns z.
 344  func (z *Rat) SetInt64(x int64) *Rat {
 345  	z.a.SetInt64(x)
 346  	z.b.abs = z.b.abs.setWord(1)
 347  	return z
 348  }
 349  
 350  // SetUint64 sets z to x and returns z.
 351  func (z *Rat) SetUint64(x uint64) *Rat {
 352  	z.a.SetUint64(x)
 353  	z.b.abs = z.b.abs.setWord(1)
 354  	return z
 355  }
 356  
 357  // Set sets z to x (by making a copy of x) and returns z.
 358  func (z *Rat) Set(x *Rat) *Rat {
 359  	if z != x {
 360  		z.a.Set(&x.a)
 361  		z.b.Set(&x.b)
 362  	}
 363  	if len(z.b.abs) == 0 {
 364  		z.b.abs = z.b.abs.setWord(1)
 365  	}
 366  	return z
 367  }
 368  
 369  // Abs sets z to |x| (the absolute value of x) and returns z.
 370  func (z *Rat) Abs(x *Rat) *Rat {
 371  	z.Set(x)
 372  	z.a.neg = false
 373  	return z
 374  }
 375  
 376  // Neg sets z to -x and returns z.
 377  func (z *Rat) Neg(x *Rat) *Rat {
 378  	z.Set(x)
 379  	z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign
 380  	return z
 381  }
 382  
 383  // Inv sets z to 1/x and returns z.
 384  // If x == 0, Inv panics.
 385  func (z *Rat) Inv(x *Rat) *Rat {
 386  	if len(x.a.abs) == 0 {
 387  		panic("division by zero")
 388  	}
 389  	z.Set(x)
 390  	z.a.abs, z.b.abs = z.b.abs, z.a.abs
 391  	return z
 392  }
 393  
 394  // Sign returns:
 395  //   - -1 if x < 0;
 396  //   - 0 if x == 0;
 397  //   - +1 if x > 0.
 398  func (x *Rat) Sign() int {
 399  	return x.a.Sign()
 400  }
 401  
 402  // IsInt reports whether the denominator of x is 1.
 403  func (x *Rat) IsInt() bool {
 404  	return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0
 405  }
 406  
 407  // Num returns the numerator of x; it may be <= 0.
 408  // The result is a reference to x's numerator; it
 409  // may change if a new value is assigned to x, and vice versa.
 410  // The sign of the numerator corresponds to the sign of x.
 411  func (x *Rat) Num() *Int {
 412  	return &x.a
 413  }
 414  
 415  // Denom returns the denominator of x; it is always > 0.
 416  // The result is a reference to x's denominator, unless
 417  // x is an uninitialized (zero value) [Rat], in which case
 418  // the result is a new [Int] of value 1. (To initialize x,
 419  // any operation that sets x will do, including x.Set(x).)
 420  // If the result is a reference to x's denominator it
 421  // may change if a new value is assigned to x, and vice versa.
 422  func (x *Rat) Denom() *Int {
 423  	// Note that x.b.neg is guaranteed false.
 424  	if len(x.b.abs) == 0 {
 425  		// Note: If this proves problematic, we could
 426  		//       panic instead and require the Rat to
 427  		//       be explicitly initialized.
 428  		return &Int{abs: nat{1}}
 429  	}
 430  	return &x.b
 431  }
 432  
 433  func (z *Rat) norm() *Rat {
 434  	switch {
 435  	case len(z.a.abs) == 0:
 436  		// z == 0; normalize sign and denominator
 437  		z.a.neg = false
 438  		z.b.abs = z.b.abs.setWord(1)
 439  	case len(z.b.abs) == 0:
 440  		// z is integer; normalize denominator
 441  		z.b.abs = z.b.abs.setWord(1)
 442  	default:
 443  		// z is fraction; normalize numerator and denominator
 444  		stk := getStack()
 445  		defer stk.free()
 446  		neg := z.a.neg
 447  		z.a.neg = false
 448  		z.b.neg = false
 449  		if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 {
 450  			z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs)
 451  			z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs)
 452  		}
 453  		z.a.neg = neg
 454  	}
 455  	return z
 456  }
 457  
 458  // mulDenom sets z to the denominator product x*y (by taking into
 459  // account that 0 values for x or y must be interpreted as 1) and
 460  // returns z.
 461  func mulDenom(stk *stack, z, x, y nat) nat {
 462  	switch {
 463  	case len(x) == 0 && len(y) == 0:
 464  		return z.setWord(1)
 465  	case len(x) == 0:
 466  		return z.set(y)
 467  	case len(y) == 0:
 468  		return z.set(x)
 469  	}
 470  	return z.mul(stk, x, y)
 471  }
 472  
 473  // scaleDenom sets z to the product x*f.
 474  // If f == 0 (zero value of denominator), z is set to (a copy of) x.
 475  func (z *Int) scaleDenom(stk *stack, x *Int, f nat) {
 476  	if len(f) == 0 {
 477  		z.Set(x)
 478  		return
 479  	}
 480  	z.abs = z.abs.mul(stk, x.abs, f)
 481  	z.neg = x.neg
 482  }
 483  
 484  // Cmp compares x and y and returns:
 485  //   - -1 if x < y;
 486  //   - 0 if x == y;
 487  //   - +1 if x > y.
 488  func (x *Rat) Cmp(y *Rat) int {
 489  	var a, b Int
 490  	stk := getStack()
 491  	defer stk.free()
 492  	a.scaleDenom(stk, &x.a, y.b.abs)
 493  	b.scaleDenom(stk, &y.a, x.b.abs)
 494  	return a.Cmp(&b)
 495  }
 496  
 497  // Add sets z to the sum x+y and returns z.
 498  func (z *Rat) Add(x, y *Rat) *Rat {
 499  	stk := getStack()
 500  	defer stk.free()
 501  
 502  	var a1, a2 Int
 503  	a1.scaleDenom(stk, &x.a, y.b.abs)
 504  	a2.scaleDenom(stk, &y.a, x.b.abs)
 505  	z.a.Add(&a1, &a2)
 506  	z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
 507  	return z.norm()
 508  }
 509  
 510  // Sub sets z to the difference x-y and returns z.
 511  func (z *Rat) Sub(x, y *Rat) *Rat {
 512  	stk := getStack()
 513  	defer stk.free()
 514  
 515  	var a1, a2 Int
 516  	a1.scaleDenom(stk, &x.a, y.b.abs)
 517  	a2.scaleDenom(stk, &y.a, x.b.abs)
 518  	z.a.Sub(&a1, &a2)
 519  	z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
 520  	return z.norm()
 521  }
 522  
 523  // Mul sets z to the product x*y and returns z.
 524  func (z *Rat) Mul(x, y *Rat) *Rat {
 525  	stk := getStack()
 526  	defer stk.free()
 527  
 528  	if x == y {
 529  		// a squared Rat is positive and can't be reduced (no need to call norm())
 530  		z.a.neg = false
 531  		z.a.abs = z.a.abs.sqr(stk, x.a.abs)
 532  		if len(x.b.abs) == 0 {
 533  			z.b.abs = z.b.abs.setWord(1)
 534  		} else {
 535  			z.b.abs = z.b.abs.sqr(stk, x.b.abs)
 536  		}
 537  		return z
 538  	}
 539  
 540  	z.a.mul(stk, &x.a, &y.a)
 541  	z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs)
 542  	return z.norm()
 543  }
 544  
 545  // Quo sets z to the quotient x/y and returns z.
 546  // If y == 0, Quo panics.
 547  func (z *Rat) Quo(x, y *Rat) *Rat {
 548  	stk := getStack()
 549  	defer stk.free()
 550  
 551  	if len(y.a.abs) == 0 {
 552  		panic("division by zero")
 553  	}
 554  	var a, b Int
 555  	a.scaleDenom(stk, &x.a, y.b.abs)
 556  	b.scaleDenom(stk, &y.a, x.b.abs)
 557  	z.a.abs = a.abs
 558  	z.b.abs = b.abs
 559  	z.a.neg = a.neg != b.neg
 560  	return z.norm()
 561  }
 562