1 // Copyright 2009 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 // Multiplication.
6 7 package big
8 9 // Operands that are shorter than karatsubaThreshold are multiplied using
10 // "grade school" multiplication; for longer operands the Karatsuba algorithm
11 // is used.
12 var karatsubaThreshold = 40 // see calibrate_test.go
13 14 // mul sets z = x*y, using stk for temporary storage.
15 // The caller may pass stk == nil to request that mul obtain and release one itself.
16 func (z nat) mul(stk *stack, x, y nat) nat {
17 m := len(x)
18 n := len(y)
19 20 switch {
21 case m < n:
22 return z.mul(stk, y, x)
23 case m == 0 || n == 0:
24 return z[:0]
25 case n == 1:
26 return z.mulAddWW(x, y[0], 0)
27 }
28 // m >= n > 1
29 30 // determine if z can be reused
31 if alias(z, x) || alias(z, y) {
32 z = nil // z is an alias for x or y - cannot reuse
33 }
34 z = z.make(m + n)
35 36 // use basic multiplication if the numbers are small
37 if n < karatsubaThreshold {
38 basicMul(z, x, y)
39 return z.norm()
40 }
41 42 if stk == nil {
43 stk = getStack()
44 defer stk.free()
45 }
46 47 // Let x = x1:x0 where x0 is the same length as y.
48 // Compute z = x0*y and then add in x1*y in sections
49 // if needed.
50 karatsuba(stk, z[:2*n], x[:n], y)
51 52 if n < m {
53 clear(z[2*n:])
54 defer stk.restore(stk.save())
55 t := stk.nat(2 * n)
56 for i := n; i < m; i += n {
57 t = t.mul(stk, x[i:min(i+n, len(x))], y)
58 addTo(z[i:], t)
59 }
60 }
61 62 return z.norm()
63 }
64 65 // Operands that are shorter than basicSqrThreshold are squared using
66 // "grade school" multiplication; for operands longer than karatsubaSqrThreshold
67 // we use the Karatsuba algorithm optimized for x == y.
68 var basicSqrThreshold = 12 // see calibrate_test.go
69 var karatsubaSqrThreshold = 80 // see calibrate_test.go
70 71 // sqr sets z = x*x, using stk for temporary storage.
72 // The caller may pass stk == nil to request that sqr obtain and release one itself.
73 func (z nat) sqr(stk *stack, x nat) nat {
74 n := len(x)
75 switch {
76 case n == 0:
77 return z[:0]
78 case n == 1:
79 d := x[0]
80 z = z.make(2)
81 z[1], z[0] = mulWW(d, d)
82 return z.norm()
83 }
84 85 if alias(z, x) {
86 z = nil // z is an alias for x - cannot reuse
87 }
88 z = z.make(2 * n)
89 90 if n < basicSqrThreshold && n < karatsubaSqrThreshold {
91 basicMul(z, x, x)
92 return z.norm()
93 }
94 95 if stk == nil {
96 stk = getStack()
97 defer stk.free()
98 }
99 100 if n < karatsubaSqrThreshold {
101 basicSqr(stk, z, x)
102 return z.norm()
103 }
104 105 karatsubaSqr(stk, z, x)
106 return z.norm()
107 }
108 109 // basicSqr sets z = x*x and is asymptotically faster than basicMul
110 // by about a factor of 2, but slower for small arguments due to overhead.
111 // Requirements: len(x) > 0, len(z) == 2*len(x)
112 // The (non-normalized) result is placed in z.
113 func basicSqr(stk *stack, z, x nat) {
114 n := len(x)
115 if n < basicSqrThreshold {
116 basicMul(z, x, x)
117 return
118 }
119 120 defer stk.restore(stk.save())
121 t := stk.nat(2 * n)
122 clear(t)
123 z[1], z[0] = mulWW(x[0], x[0]) // the initial square
124 for i := 1; i < n; i++ {
125 d := x[i]
126 // z collects the squares x[i] * x[i]
127 z[2*i+1], z[2*i] = mulWW(d, d)
128 // t collects the products x[i] * x[j] where j < i
129 t[2*i] = addMulVVWW(t[i:2*i], t[i:2*i], x[0:i], d, 0)
130 }
131 t[2*n-1] = lshVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products
132 addVV(z, z, t) // combine the result
133 }
134 135 // mulAddWW returns z = x*y + r.
136 func (z nat) mulAddWW(x nat, y, r Word) nat {
137 m := len(x)
138 if m == 0 || y == 0 {
139 return z.setWord(r) // result is r
140 }
141 // m > 0
142 143 z = z.make(m + 1)
144 z[m] = mulAddVWW(z[0:m], x, y, r)
145 146 return z.norm()
147 }
148 149 // basicMul multiplies x and y and leaves the result in z.
150 // The (non-normalized) result is placed in z[0 : len(x) + len(y)].
151 func basicMul(z, x, y nat) {
152 clear(z[0 : len(x)+len(y)]) // initialize z
153 for i, d := range y {
154 if d != 0 {
155 z[len(x)+i] = addMulVVWW(z[i:i+len(x)], z[i:i+len(x)], x, d, 0)
156 }
157 }
158 }
159 160 // karatsuba multiplies x and y,
161 // writing the (non-normalized) result to z.
162 // x and y must have the same length n,
163 // and z must have length twice that.
164 func karatsuba(stk *stack, z, x, y nat) {
165 n := len(y)
166 if len(x) != n || len(z) != 2*n {
167 panic("bad karatsuba length")
168 }
169 170 // Fall back to basic algorithm if small enough.
171 if n < karatsubaThreshold || n < 2 {
172 basicMul(z, x, y)
173 return
174 }
175 176 // Let the notation x1:x0 denote the nat (x1<<N)+x0 for some N,
177 // and similarly z2:z1:z0 = (z2<<2N)+(z1<<N)+z0.
178 //
179 // (Note that z0, z1, z2 might be ≥ 2**N, in which case the high
180 // bits of, say, z0 are being added to the low bits of z1 in this notation.)
181 //
182 // Karatsuba multiplication is based on the observation that
183 //
184 // x1:x0 * y1:y0 = x1*y1:(x0*y1+y0*x1):x0*y0
185 // = x1*y1:((x0-x1)*(y1-y0)+x1*y1+x0*y0):x0*y0
186 //
187 // The second form uses only three half-width multiplications
188 // instead of the four that the straightforward first form does.
189 //
190 // We call the three pieces z0, z1, z2:
191 //
192 // z0 = x0*y0
193 // z2 = x1*y1
194 // z1 = (x0-x1)*(y1-y0) + z0 + z2
195 196 n2 := (n + 1) / 2
197 x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
198 y0, y1 := &Int{abs: y[:n2].norm()}, &Int{abs: y[n2:].norm()}
199 z0 := &Int{abs: z[0 : 2*n2]}
200 z2 := &Int{abs: z[2*n2:]}
201 202 // Allocate temporary storage for z1; repurpose z0 to hold tx and ty.
203 defer stk.restore(stk.save())
204 z1 := &Int{abs: stk.nat(2*n2 + 1)}
205 tx := &Int{abs: z[0:n2]}
206 ty := &Int{abs: z[n2 : 2*n2]}
207 208 tx.Sub(x0, x1)
209 ty.Sub(y1, y0)
210 z1.mul(stk, tx, ty)
211 212 clear(z)
213 z0.mul(stk, x0, y0)
214 z2.mul(stk, x1, y1)
215 z1.Add(z1, z0)
216 z1.Add(z1, z2)
217 addTo(z[n2:], z1.abs)
218 219 // Debug mode: double-check answer and print trace on failure.
220 const debug = false
221 if debug {
222 zz := make(nat, len(z))
223 basicMul(zz, x, y)
224 if z.cmp(zz) != 0 {
225 // All the temps were aliased to z and gone. Recompute.
226 z0 = &Int{}
227 z0.mul(stk, x0, y0)
228 tx = (&Int{}).Sub(x1, x0)
229 ty = (&Int{}).Sub(y0, y1)
230 z2 = &Int{}
231 z2.mul(stk, x1, y1)
232 print("karatsuba wrong\n")
233 trace("x ", &Int{abs: x})
234 trace("y ", &Int{abs: y})
235 trace("z ", &Int{abs: z})
236 trace("zz", &Int{abs: zz})
237 trace("x0", x0)
238 trace("x1", x1)
239 trace("y0", y0)
240 trace("y1", y1)
241 trace("tx", tx)
242 trace("ty", ty)
243 trace("z0", z0)
244 trace("z1", z1)
245 trace("z2", z2)
246 panic("karatsuba")
247 }
248 }
249 250 }
251 252 // karatsubaSqr squares x,
253 // writing the (non-normalized) result to z.
254 // z must have length 2*len(x).
255 // It is analogous to [karatsuba] but can run faster
256 // knowing both multiplicands are the same value.
257 func karatsubaSqr(stk *stack, z, x nat) {
258 n := len(x)
259 if len(z) != 2*n {
260 panic("bad karatsubaSqr length")
261 }
262 263 if n < karatsubaSqrThreshold || n < 2 {
264 basicSqr(stk, z, x)
265 return
266 }
267 268 // Recall that for karatsuba we want to compute:
269 //
270 // x1:x0 * y1:y0 = x1y1:(x0y1+y0x1):x0y0
271 // = x1y1:((x0-x1)*(y1-y0)+x1y1+x0y0):x0y0
272 // = z2:z1:z0
273 // where:
274 //
275 // z0 = x0y0
276 // z2 = x1y1
277 // z1 = (x0-x1)*(y1-y0) + z0 + z2
278 //
279 // When x = y, these simplify to:
280 //
281 // z0 = x0²
282 // z2 = x1²
283 // z1 = z0 + z2 - (x0-x1)²
284 285 n2 := (n + 1) / 2
286 x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()}
287 z0 := &Int{abs: z[0 : 2*n2]}
288 z2 := &Int{abs: z[2*n2:]}
289 290 // Allocate temporary storage for z1; repurpose z0 to hold tx.
291 defer stk.restore(stk.save())
292 z1 := &Int{abs: stk.nat(2*n2 + 1)}
293 tx := &Int{abs: z[0:n2]}
294 295 tx.Sub(x0, x1)
296 z1.abs = z1.abs.sqr(stk, tx.abs)
297 z1.neg = true
298 299 clear(z)
300 z0.abs = z0.abs.sqr(stk, x0.abs)
301 z2.abs = z2.abs.sqr(stk, x1.abs)
302 z1.Add(z1, z0)
303 z1.Add(z1, z2)
304 addTo(z[n2:], z1.abs)
305 306 // Debug mode: double-check answer and print trace on failure.
307 const debug = false
308 if debug {
309 zz := make(nat, len(z))
310 basicSqr(stk, zz, x)
311 if z.cmp(zz) != 0 {
312 // All the temps were aliased to z and gone. Recompute.
313 tx = (&Int{}).Sub(x0, x1)
314 z0 = (&Int{}).Mul(x0, x0)
315 z2 = (&Int{}).Mul(x1, x1)
316 z1 = (&Int{}).Mul(tx, tx)
317 z1.Neg(z1)
318 z1.Add(z1, z0)
319 z1.Add(z1, z2)
320 print("karatsubaSqr wrong\n")
321 trace("x ", &Int{abs: x})
322 trace("z ", &Int{abs: z})
323 trace("zz", &Int{abs: zz})
324 trace("x0", x0)
325 trace("x1", x1)
326 trace("z0", z0)
327 trace("z1", z1)
328 trace("z2", z2)
329 panic("karatsubaSqr")
330 }
331 }
332 }
333 334 // ifmt returns the debug formatting of the Int x: 0xHEX.
335 func ifmt(x *Int) []byte {
336 neg, s, t := "", x.Text(16), ""
337 if s == "" { // happens for denormalized zero
338 s = "0x0"
339 }
340 if s[0] == '-' {
341 neg, s = "-", s[1:]
342 }
343 344 // Add _ between words.
345 const D = _W / 4 // digits per chunk
346 for len(s) > D {
347 s, t = s[:len(s)-D], s[len(s)-D:]|"_"|t
348 }
349 return neg | s | t
350 }
351 352 // trace prints a single debug value.
353 func trace(name []byte, x *Int) {
354 print(name, "=", ifmt(x), "\n")
355 }
356