natmul.mx raw

   1  // Copyright 2009 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  // Multiplication.
   6  
   7  package big
   8  
   9  // Operands that are shorter than karatsubaThreshold are multiplied using
  10  // "grade school" multiplication; for longer operands the Karatsuba algorithm
  11  // is used.
  12  var karatsubaThreshold = 40 // see calibrate_test.go
  13  
  14  // mul sets z = x*y, using stk for temporary storage.
  15  // The caller may pass stk == nil to request that mul obtain and release one itself.
  16  func (z nat) mul(stk *stack, x, y nat) nat {
  17  	m := len(x)
  18  	n := len(y)
  19  
  20  	switch {
  21  	case m < n:
  22  		return z.mul(stk, y, x)
  23  	case m == 0 || n == 0:
  24  		return z[:0]
  25  	case n == 1:
  26  		return z.mulAddWW(x, y[0], 0)
  27  	}
  28  	// m >= n > 1
  29  
  30  	// determine if z can be reused
  31  	if alias(z, x) || alias(z, y) {
  32  		z = nil // z is an alias for x or y - cannot reuse
  33  	}
  34  	z = z.make(m + n)
  35  
  36  	// use basic multiplication if the numbers are small
  37  	if n < karatsubaThreshold {
  38  		basicMul(z, x, y)
  39  		return z.norm()
  40  	}
  41  
  42  	if stk == nil {
  43  		stk = getStack()
  44  		defer stk.free()
  45  	}
  46  
  47  	// Let x = x1:x0 where x0 is the same length as y.
  48  	// Compute z = x0*y and then add in x1*y in sections
  49  	// if needed.
  50  	karatsuba(stk, z[:2*n], x[:n], y)
  51  
  52  	if n < m {
  53  		clear(z[2*n:])
  54  		defer stk.restore(stk.save())
  55  		t := stk.nat(2 * n)
  56  		for i := n; i < m; i += n {
  57  			t = t.mul(stk, x[i:min(i+n, len(x))], y)
  58  			addTo(z[i:], t)
  59  		}
  60  	}
  61  
  62  	return z.norm()
  63  }
  64  
  65  // Operands that are shorter than basicSqrThreshold are squared using
  66  // "grade school" multiplication; for operands longer than karatsubaSqrThreshold
  67  // we use the Karatsuba algorithm optimized for x == y.
  68  var basicSqrThreshold = 12     // see calibrate_test.go
  69  var karatsubaSqrThreshold = 80 // see calibrate_test.go
  70  
  71  // sqr sets z = x*x, using stk for temporary storage.
  72  // The caller may pass stk == nil to request that sqr obtain and release one itself.
  73  func (z nat) sqr(stk *stack, x nat) nat {
  74  	n := len(x)
  75  	switch {
  76  	case n == 0:
  77  		return z[:0]
  78  	case n == 1:
  79  		d := x[0]
  80  		z = z.make(2)
  81  		z[1], z[0] = mulWW(d, d)
  82  		return z.norm()
  83  	}
  84  
  85  	if alias(z, x) {
  86  		z = nil // z is an alias for x - cannot reuse
  87  	}
  88  	z = z.make(2 * n)
  89  
  90  	if n < basicSqrThreshold && n < karatsubaSqrThreshold {
  91  		basicMul(z, x, x)
  92  		return z.norm()
  93  	}
  94  
  95  	if stk == nil {
  96  		stk = getStack()
  97  		defer stk.free()
  98  	}
  99  
 100  	if n < karatsubaSqrThreshold {
 101  		basicSqr(stk, z, x)
 102  		return z.norm()
 103  	}
 104  
 105  	karatsubaSqr(stk, z, x)
 106  	return z.norm()
 107  }
 108  
 109  // basicSqr sets z = x*x and is asymptotically faster than basicMul
 110  // by about a factor of 2, but slower for small arguments due to overhead.
 111  // Requirements: len(x) > 0, len(z) == 2*len(x)
 112  // The (non-normalized) result is placed in z.
 113  func basicSqr(stk *stack, z, x nat) {
 114  	n := len(x)
 115  	if n < basicSqrThreshold {
 116  		basicMul(z, x, x)
 117  		return
 118  	}
 119  
 120  	defer stk.restore(stk.save())
 121  	t := stk.nat(2 * n)
 122  	clear(t)
 123  	z[1], z[0] = mulWW(x[0], x[0]) // the initial square
 124  	for i := 1; i < n; i++ {
 125  		d := x[i]
 126  		// z collects the squares x[i] * x[i]
 127  		z[2*i+1], z[2*i] = mulWW(d, d)
 128  		// t collects the products x[i] * x[j] where j < i
 129  		t[2*i] = addMulVVWW(t[i:2*i], t[i:2*i], x[0:i], d, 0)
 130  	}
 131  	t[2*n-1] = lshVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
 132  	addVV(z, z, t)                              // combine the result
 133  }
 134  
 135  // mulAddWW returns z = x*y + r.
 136  func (z nat) mulAddWW(x nat, y, r Word) nat {
 137  	m := len(x)
 138  	if m == 0 || y == 0 {
 139  		return z.setWord(r) // result is r
 140  	}
 141  	// m > 0
 142  
 143  	z = z.make(m + 1)
 144  	z[m] = mulAddVWW(z[0:m], x, y, r)
 145  
 146  	return z.norm()
 147  }
 148  
 149  // basicMul multiplies x and y and leaves the result in z.
 150  // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
 151  func basicMul(z, x, y nat) {
 152  	clear(z[0 : len(x)+len(y)]) // initialize z
 153  	for i, d := range y {
 154  		if d != 0 {
 155  			z[len(x)+i] = addMulVVWW(z[i:i+len(x)], z[i:i+len(x)], x, d, 0)
 156  		}
 157  	}
 158  }
 159  
 160  // karatsuba multiplies x and y,
 161  // writing the (non-normalized) result to z.
 162  // x and y must have the same length n,
 163  // and z must have length twice that.
 164  func karatsuba(stk *stack, z, x, y nat) {
 165  	n := len(y)
 166  	if len(x) != n || len(z) != 2*n {
 167  		panic("bad karatsuba length")
 168  	}
 169  
 170  	// Fall back to basic algorithm if small enough.
 171  	if n < karatsubaThreshold || n < 2 {
 172  		basicMul(z, x, y)
 173  		return
 174  	}
 175  
 176  	// Let the notation x1:x0 denote the nat (x1<<N)+x0 for some N,
 177  	// and similarly z2:z1:z0 = (z2<<2N)+(z1<<N)+z0.
 178  	//
 179  	// (Note that z0, z1, z2 might be ≥ 2**N, in which case the high
 180  	// bits of, say, z0 are being added to the low bits of z1 in this notation.)
 181  	//
 182  	// Karatsuba multiplication is based on the observation that
 183  	//
 184  	//	x1:x0 * y1:y0 = x1*y1:(x0*y1+y0*x1):x0*y0
 185  	//	              = x1*y1:((x0-x1)*(y1-y0)+x1*y1+x0*y0):x0*y0
 186  	//
 187  	// The second form uses only three half-width multiplications
 188  	// instead of the four that the straightforward first form does.
 189  	//
 190  	// We call the three pieces z0, z1, z2:
 191  	//
 192  	//	z0 = x0*y0
 193  	//	z2 = x1*y1
 194  	//	z1 = (x0-x1)*(y1-y0) + z0 + z2
 195  
 196  	n2 := (n + 1) / 2
 197  	x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
 198  	y0, y1 := &Int{abs: y[:n2].norm()}, &Int{abs: y[n2:].norm()}
 199  	z0 := &Int{abs: z[0 : 2*n2]}
 200  	z2 := &Int{abs: z[2*n2:]}
 201  
 202  	// Allocate temporary storage for z1; repurpose z0 to hold tx and ty.
 203  	defer stk.restore(stk.save())
 204  	z1 := &Int{abs: stk.nat(2*n2 + 1)}
 205  	tx := &Int{abs: z[0:n2]}
 206  	ty := &Int{abs: z[n2 : 2*n2]}
 207  
 208  	tx.Sub(x0, x1)
 209  	ty.Sub(y1, y0)
 210  	z1.mul(stk, tx, ty)
 211  
 212  	clear(z)
 213  	z0.mul(stk, x0, y0)
 214  	z2.mul(stk, x1, y1)
 215  	z1.Add(z1, z0)
 216  	z1.Add(z1, z2)
 217  	addTo(z[n2:], z1.abs)
 218  
 219  	// Debug mode: double-check answer and print trace on failure.
 220  	const debug = false
 221  	if debug {
 222  		zz := make(nat, len(z))
 223  		basicMul(zz, x, y)
 224  		if z.cmp(zz) != 0 {
 225  			// All the temps were aliased to z and gone. Recompute.
 226  			z0 = &Int{}
 227  			z0.mul(stk, x0, y0)
 228  			tx = (&Int{}).Sub(x1, x0)
 229  			ty = (&Int{}).Sub(y0, y1)
 230  			z2 = &Int{}
 231  			z2.mul(stk, x1, y1)
 232  			print("karatsuba wrong\n")
 233  			trace("x ", &Int{abs: x})
 234  			trace("y ", &Int{abs: y})
 235  			trace("z ", &Int{abs: z})
 236  			trace("zz", &Int{abs: zz})
 237  			trace("x0", x0)
 238  			trace("x1", x1)
 239  			trace("y0", y0)
 240  			trace("y1", y1)
 241  			trace("tx", tx)
 242  			trace("ty", ty)
 243  			trace("z0", z0)
 244  			trace("z1", z1)
 245  			trace("z2", z2)
 246  			panic("karatsuba")
 247  		}
 248  	}
 249  
 250  }
 251  
 252  // karatsubaSqr squares x,
 253  // writing the (non-normalized) result to z.
 254  // z must have length 2*len(x).
 255  // It is analogous to [karatsuba] but can run faster
 256  // knowing both multiplicands are the same value.
 257  func karatsubaSqr(stk *stack, z, x nat) {
 258  	n := len(x)
 259  	if len(z) != 2*n {
 260  		panic("bad karatsubaSqr length")
 261  	}
 262  
 263  	if n < karatsubaSqrThreshold || n < 2 {
 264  		basicSqr(stk, z, x)
 265  		return
 266  	}
 267  
 268  	// Recall that for karatsuba we want to compute:
 269  	//
 270  	//	x1:x0 * y1:y0 = x1y1:(x0y1+y0x1):x0y0
 271  	//                = x1y1:((x0-x1)*(y1-y0)+x1y1+x0y0):x0y0
 272  	//	              = z2:z1:z0
 273  	// where:
 274  	//
 275  	//	z0 = x0y0
 276  	//	z2 = x1y1
 277  	//	z1 = (x0-x1)*(y1-y0) + z0 + z2
 278  	//
 279  	// When x = y, these simplify to:
 280  	//
 281  	//	z0 = x0²
 282  	//	z2 = x1²
 283  	//	z1 = z0 + z2 - (x0-x1)²
 284  
 285  	n2 := (n + 1) / 2
 286  	x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
 287  	z0 := &Int{abs: z[0 : 2*n2]}
 288  	z2 := &Int{abs: z[2*n2:]}
 289  
 290  	// Allocate temporary storage for z1; repurpose z0 to hold tx.
 291  	defer stk.restore(stk.save())
 292  	z1 := &Int{abs: stk.nat(2*n2 + 1)}
 293  	tx := &Int{abs: z[0:n2]}
 294  
 295  	tx.Sub(x0, x1)
 296  	z1.abs = z1.abs.sqr(stk, tx.abs)
 297  	z1.neg = true
 298  
 299  	clear(z)
 300  	z0.abs = z0.abs.sqr(stk, x0.abs)
 301  	z2.abs = z2.abs.sqr(stk, x1.abs)
 302  	z1.Add(z1, z0)
 303  	z1.Add(z1, z2)
 304  	addTo(z[n2:], z1.abs)
 305  
 306  	// Debug mode: double-check answer and print trace on failure.
 307  	const debug = false
 308  	if debug {
 309  		zz := make(nat, len(z))
 310  		basicSqr(stk, zz, x)
 311  		if z.cmp(zz) != 0 {
 312  			// All the temps were aliased to z and gone. Recompute.
 313  			tx = (&Int{}).Sub(x0, x1)
 314  			z0 = (&Int{}).Mul(x0, x0)
 315  			z2 = (&Int{}).Mul(x1, x1)
 316  			z1 = (&Int{}).Mul(tx, tx)
 317  			z1.Neg(z1)
 318  			z1.Add(z1, z0)
 319  			z1.Add(z1, z2)
 320  			print("karatsubaSqr wrong\n")
 321  			trace("x ", &Int{abs: x})
 322  			trace("z ", &Int{abs: z})
 323  			trace("zz", &Int{abs: zz})
 324  			trace("x0", x0)
 325  			trace("x1", x1)
 326  			trace("z0", z0)
 327  			trace("z1", z1)
 328  			trace("z2", z2)
 329  			panic("karatsubaSqr")
 330  		}
 331  	}
 332  }
 333  
 334  // ifmt returns the debug formatting of the Int x: 0xHEX.
 335  func ifmt(x *Int) []byte {
 336  	neg, s, t := "", x.Text(16), ""
 337  	if s == "" { // happens for denormalized zero
 338  		s = "0x0"
 339  	}
 340  	if s[0] == '-' {
 341  		neg, s = "-", s[1:]
 342  	}
 343  
 344  	// Add _ between words.
 345  	const D = _W / 4 // digits per chunk
 346  	for len(s) > D {
 347  		s, t = s[:len(s)-D], s[len(s)-D:]|"_"|t
 348  	}
 349  	return neg | s | t
 350  }
 351  
 352  // trace prints a single debug value.
 353  func trace(name []byte, x *Int) {
 354  	print(name, "=", ifmt(x), "\n")
 355  }
 356