1 // Copyright 2025 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 package asmgen
6 7 // mulAddVWW generates mulAddVWW, which does z, c = x*m + a.
8 func mulAddVWW(a *Asm) {
9 f := a.Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)")
10 11 if a.AltCarry().Valid() {
12 addMulVirtualCarry(f, 0)
13 return
14 }
15 addMul(f, "", "x", 0)
16 }
17 18 // addMulVVWW generates addMulVVWW which does z, c = x + y*m + a.
19 // (A more pedantic name would be addMulAddVVWW.)
20 func addMulVVWW(a *Asm) {
21 f := a.Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)")
22 23 // If the architecture has virtual carries, emit that version unconditionally.
24 if a.AltCarry().Valid() {
25 addMulVirtualCarry(f, 1)
26 return
27 }
28 29 // If the architecture optionally has two carries, test and emit both versions.
30 if a.JmpEnable(OptionAltCarry, "altcarry") {
31 regs := a.RegsUsed()
32 addMul(f, "x", "y", 1)
33 a.Label("altcarry")
34 a.SetOption(OptionAltCarry, true)
35 a.SetRegsUsed(regs)
36 addMulAlt(f)
37 a.SetOption(OptionAltCarry, false)
38 return
39 }
40 41 // Otherwise emit the one-carry form.
42 addMul(f, "x", "y", 1)
43 }
44 45 // Computing z = addsrc + m*mulsrc + a, we need:
46 //
47 // for i := range z {
48 // lo, hi := m * mulsrc[i]
49 // lo, carry = bits.Add(lo, a, 0)
50 // lo, carryAlt = bits.Add(lo, addsrc[i], 0)
51 // z[i] = lo
52 // a = hi + carry + carryAlt // cannot overflow
53 // }
54 //
55 // The final addition cannot overflow because after processing N words,
56 // the maximum possible value is (for a 64-bit system):
57 //
58 // (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1)
59 // = (2**64)*(2**64N - 1) + (2**64 - 1)
60 // = 2**64(N+1) - 1,
61 //
62 // which fits in N+1 words (the high order one being the new value of a).
63 //
64 // (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.)
65 //
66 // If we unroll the loop a bit, then we can chain the carries in two passes.
67 // Consider:
68 //
69 // lo0, hi0 := m * mulsrc[i]
70 // lo0, carry = bits.Add(lo0, a, 0)
71 // lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
72 // z[i] = lo0
73 // a = hi + carry + carryAlt // cannot overflow
74 //
75 // lo1, hi1 := m * mulsrc[i]
76 // lo1, carry = bits.Add(lo1, a, 0)
77 // lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
78 // z[i] = lo1
79 // a = hi + carry + carryAlt // cannot overflow
80 //
81 // lo2, hi2 := m * mulsrc[i]
82 // lo2, carry = bits.Add(lo2, a, 0)
83 // lo2, carryAlt = bits.Add(lo2, addsrc[i], 0)
84 // z[i] = lo2
85 // a = hi + carry + carryAlt // cannot overflow
86 //
87 // lo3, hi3 := m * mulsrc[i]
88 // lo3, carry = bits.Add(lo3, a, 0)
89 // lo3, carryAlt = bits.Add(lo3, addsrc[i], 0)
90 // z[i] = lo3
91 // a = hi + carry + carryAlt // cannot overflow
92 //
93 // There are three ways we can optimize this sequence.
94 //
95 // (1) Reordering, we can chain carries so that we can use one hardware carry flag
96 // but amortize the cost of saving and restoring it across multiple instructions:
97 //
98 // // multiply
99 // lo0, hi0 := m * mulsrc[i]
100 // lo1, hi1 := m * mulsrc[i+1]
101 // lo2, hi2 := m * mulsrc[i+2]
102 // lo3, hi3 := m * mulsrc[i+3]
103 //
104 // lo0, carry = bits.Add(lo0, a, 0)
105 // lo1, carry = bits.Add(lo1, hi0, carry)
106 // lo2, carry = bits.Add(lo2, hi1, carry)
107 // lo3, carry = bits.Add(lo3, hi2, carry)
108 // a = hi3 + carry // cannot overflow
109 //
110 // // add
111 // lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
112 // lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
113 // lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
114 // lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
115 // a = a + carryAlt // cannot overflow
116 //
117 // z[i] = lo0
118 // z[i+1] = lo1
119 // z[i+2] = lo2
120 // z[i+3] = lo3
121 //
122 // addMul takes this approach, using the hardware carry flag
123 // first for carry and then for carryAlt.
124 //
125 // (2) addMulAlt assumes there are two hardware carry flags available.
126 // It dedicates one each to carry and carryAlt, so that a multi-block
127 // unrolling can keep the flags in hardware across all the blocks.
128 // So even if the block size is 1, the code can do:
129 //
130 // // multiply and add
131 // lo0, hi0 := m * mulsrc[i]
132 // lo0, carry = bits.Add(lo0, a, 0)
133 // lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
134 // z[i] = lo0
135 //
136 // lo1, hi1 := m * mulsrc[i+1]
137 // lo1, carry = bits.Add(lo1, hi0, carry)
138 // lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
139 // z[i+1] = lo1
140 //
141 // lo2, hi2 := m * mulsrc[i+2]
142 // lo2, carry = bits.Add(lo2, hi1, carry)
143 // lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
144 // z[i+2] = lo2
145 //
146 // lo3, hi3 := m * mulsrc[i+3]
147 // lo3, carry = bits.Add(lo3, hi2, carry)
148 // lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
149 // z[i+3] = lo2
150 //
151 // a = hi3 + carry + carryAlt // cannot overflow
152 //
153 // (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits
154 // (loong64, mips, riscv64), cutting the number of actual instructions almost by half.
155 // Look again at the original word-at-a-time version:
156 //
157 // lo1, hi1 := m * mulsrc[i]
158 // lo1, carry = bits.Add(lo1, a, 0)
159 // lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
160 // z[i] = lo1
161 // a = hi + carry + carryAlt // cannot overflow
162 //
163 // Although it uses four adds per word, those are cheap adds: the two bits.Add adds
164 // use two instructions each (ADD+SLTU) and the final + adds only use one ADD each,
165 // for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use
166 // only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to
167 // five instruction each, for a total of 10 instructions per word. So the word-at-a-time
168 // loop is actually better. And we can reorder things slightly to use only a single carry bit:
169 //
170 // lo1, hi1 := m * mulsrc[i]
171 // lo1, carry = bits.Add(lo1, a, 0)
172 // a = hi + carry
173 // lo1, carry = bits.Add(lo1, addsrc[i], 0)
174 // a = a + carry
175 // z[i] = lo1
176 func addMul(f *Func, addsrc, mulsrc []byte, mulIndex int) {
177 a := f.Asm
178 mh := HintNone
179 if a.Arch == Arch386 && addsrc != "" {
180 mh = HintMemOK // too few registers otherwise
181 }
182 m := f.ArgHint("m", mh)
183 c := f.Arg("a")
184 n := f.Arg("z_len")
185 186 p := f.Pipe()
187 if addsrc != "" {
188 p.SetHint(addsrc, HintMemOK)
189 }
190 p.SetHint(mulsrc, HintMulSrc)
191 unroll := []int{1, 4}
192 switch a.Arch {
193 case Arch386:
194 unroll = []int{1} // too few registers
195 case ArchARM:
196 p.SetMaxColumns(2) // too few registers (but more than 386)
197 case ArchARM64:
198 unroll = []int{1, 8} // 5% speedup on c4as16
199 }
200 201 // See the large comment above for an explanation of the code being generated.
202 // This is optimization strategy 1.
203 p.Start(n, unroll...)
204 p.Loop(func(in, out [][]Reg) {
205 a.Comment("multiply")
206 prev := c
207 flag := SetCarry
208 for i, x := range in[mulIndex] {
209 hi := a.RegHint(HintMulHi)
210 a.MulWide(m, x, x, hi)
211 a.Add(prev, x, x, flag)
212 flag = UseCarry | SetCarry
213 if prev != c {
214 a.Free(prev)
215 }
216 out[0][i] = x
217 prev = hi
218 }
219 a.Add(a.Imm(0), prev, c, UseCarry|SmashCarry)
220 if addsrc != "" {
221 a.Comment("add")
222 flag := SetCarry
223 for i, x := range in[0] {
224 a.Add(x, out[0][i], out[0][i], flag)
225 flag = UseCarry | SetCarry
226 }
227 a.Add(a.Imm(0), c, c, UseCarry|SmashCarry)
228 }
229 p.StoreN(out)
230 })
231 232 f.StoreArg(c, "c")
233 a.Ret()
234 }
235 236 func addMulAlt(f *Func) {
237 a := f.Asm
238 m := f.ArgHint("m", HintMulSrc)
239 c := f.Arg("a")
240 n := f.Arg("z_len")
241 242 // On amd64, we need a non-immediate for the AtUnrollEnd adds.
243 r0 := a.ZR()
244 if !r0.Valid() {
245 r0 = a.Reg()
246 a.Mov(a.Imm(0), r0)
247 }
248 249 p := f.Pipe()
250 p.SetLabel("alt")
251 p.SetHint("x", HintMemOK)
252 p.SetHint("y", HintMemOK)
253 if a.Arch == ArchAMD64 {
254 p.SetMaxColumns(2)
255 }
256 257 // See the large comment above for an explanation of the code being generated.
258 // This is optimization strategy (2).
259 var hi Reg
260 prev := c
261 p.Start(n, 1, 8)
262 p.AtUnrollStart(func() {
263 a.Comment("multiply and add")
264 a.ClearCarry(AddCarry | AltCarry)
265 a.ClearCarry(AddCarry)
266 hi = a.Reg()
267 })
268 p.AtUnrollEnd(func() {
269 a.Add(r0, prev, c, UseCarry|SmashCarry)
270 a.Add(r0, c, c, UseCarry|SmashCarry|AltCarry)
271 prev = c
272 })
273 p.Loop(func(in, out [][]Reg) {
274 for i, y := range in[1] {
275 x := in[0][i]
276 lo := y
277 if lo.IsMem() {
278 lo = a.Reg()
279 }
280 a.MulWide(m, y, lo, hi)
281 a.Add(prev, lo, lo, UseCarry|SetCarry)
282 a.Add(x, lo, lo, UseCarry|SetCarry|AltCarry)
283 out[0][i] = lo
284 prev, hi = hi, prev
285 }
286 p.StoreN(out)
287 })
288 289 f.StoreArg(c, "c")
290 a.Ret()
291 }
292 293 func addMulVirtualCarry(f *Func, mulIndex int) {
294 a := f.Asm
295 m := f.Arg("m")
296 c := f.Arg("a")
297 n := f.Arg("z_len")
298 299 // See the large comment above for an explanation of the code being generated.
300 // This is optimization strategy (3).
301 p := f.Pipe()
302 p.Start(n, 1, 4)
303 p.Loop(func(in, out [][]Reg) {
304 a.Comment("synthetic carry, one column at a time")
305 lo, hi := a.Reg(), a.Reg()
306 for i, x := range in[mulIndex] {
307 a.MulWide(m, x, lo, hi)
308 if mulIndex == 1 {
309 a.Add(in[0][i], lo, lo, SetCarry)
310 a.Add(a.Imm(0), hi, hi, UseCarry|SmashCarry)
311 }
312 a.Add(c, lo, x, SetCarry)
313 a.Add(a.Imm(0), hi, c, UseCarry|SmashCarry)
314 out[0][i] = x
315 }
316 p.StoreN(out)
317 })
318 f.StoreArg(c, "c")
319 a.Ret()
320 }
321