mul.mx raw

   1  // Copyright 2025 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package asmgen
   6  
   7  // mulAddVWW generates mulAddVWW, which does z, c = x*m + a.
   8  func mulAddVWW(a *Asm) {
   9  	f := a.Func("func mulAddVWW(z, x []Word, m, a Word) (c Word)")
  10  
  11  	if a.AltCarry().Valid() {
  12  		addMulVirtualCarry(f, 0)
  13  		return
  14  	}
  15  	addMul(f, "", "x", 0)
  16  }
  17  
  18  // addMulVVWW generates addMulVVWW which does z, c = x + y*m + a.
  19  // (A more pedantic name would be addMulAddVVWW.)
  20  func addMulVVWW(a *Asm) {
  21  	f := a.Func("func addMulVVWW(z, x, y []Word, m, a Word) (c Word)")
  22  
  23  	// If the architecture has virtual carries, emit that version unconditionally.
  24  	if a.AltCarry().Valid() {
  25  		addMulVirtualCarry(f, 1)
  26  		return
  27  	}
  28  
  29  	// If the architecture optionally has two carries, test and emit both versions.
  30  	if a.JmpEnable(OptionAltCarry, "altcarry") {
  31  		regs := a.RegsUsed()
  32  		addMul(f, "x", "y", 1)
  33  		a.Label("altcarry")
  34  		a.SetOption(OptionAltCarry, true)
  35  		a.SetRegsUsed(regs)
  36  		addMulAlt(f)
  37  		a.SetOption(OptionAltCarry, false)
  38  		return
  39  	}
  40  
  41  	// Otherwise emit the one-carry form.
  42  	addMul(f, "x", "y", 1)
  43  }
  44  
  45  // Computing z = addsrc + m*mulsrc + a, we need:
  46  //
  47  //	for i := range z {
  48  //		lo, hi := m * mulsrc[i]
  49  //		lo, carry = bits.Add(lo, a, 0)
  50  //		lo, carryAlt = bits.Add(lo, addsrc[i], 0)
  51  //		z[i] = lo
  52  //		a = hi + carry + carryAlt  // cannot overflow
  53  //	}
  54  //
  55  // The final addition cannot overflow because after processing N words,
  56  // the maximum possible value is (for a 64-bit system):
  57  //
  58  //	  (2**64N - 1) + (2**64 - 1)*(2**64N - 1) + (2**64 - 1)
  59  //	= (2**64)*(2**64N - 1) + (2**64 - 1)
  60  //	= 2**64(N+1) - 1,
  61  //
  62  // which fits in N+1 words (the high order one being the new value of a).
  63  //
  64  // (For example, with 3 decimal words, 999 + 9*999 + 9 = 999*10 + 9 = 9999.)
  65  //
  66  // If we unroll the loop a bit, then we can chain the carries in two passes.
  67  // Consider:
  68  //
  69  //	lo0, hi0 := m * mulsrc[i]
  70  //	lo0, carry = bits.Add(lo0, a, 0)
  71  //	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
  72  //	z[i] = lo0
  73  //	a = hi + carry + carryAlt // cannot overflow
  74  //
  75  //	lo1, hi1 := m * mulsrc[i]
  76  //	lo1, carry = bits.Add(lo1, a, 0)
  77  //	lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
  78  //	z[i] = lo1
  79  //	a = hi + carry + carryAlt // cannot overflow
  80  //
  81  //	lo2, hi2 := m * mulsrc[i]
  82  //	lo2, carry = bits.Add(lo2, a, 0)
  83  //	lo2, carryAlt = bits.Add(lo2, addsrc[i], 0)
  84  //	z[i] = lo2
  85  //	a = hi + carry + carryAlt // cannot overflow
  86  //
  87  //	lo3, hi3 := m * mulsrc[i]
  88  //	lo3, carry = bits.Add(lo3, a, 0)
  89  //	lo3, carryAlt = bits.Add(lo3, addsrc[i], 0)
  90  //	z[i] = lo3
  91  //	a = hi + carry + carryAlt // cannot overflow
  92  //
  93  // There are three ways we can optimize this sequence.
  94  //
  95  // (1) Reordering, we can chain carries so that we can use one hardware carry flag
  96  // but amortize the cost of saving and restoring it across multiple instructions:
  97  //
  98  //	// multiply
  99  //	lo0, hi0 := m * mulsrc[i]
 100  //	lo1, hi1 := m * mulsrc[i+1]
 101  //	lo2, hi2 := m * mulsrc[i+2]
 102  //	lo3, hi3 := m * mulsrc[i+3]
 103  //
 104  //	lo0, carry = bits.Add(lo0, a, 0)
 105  //	lo1, carry = bits.Add(lo1, hi0, carry)
 106  //	lo2, carry = bits.Add(lo2, hi1, carry)
 107  //	lo3, carry = bits.Add(lo3, hi2, carry)
 108  //	a = hi3 + carry // cannot overflow
 109  //
 110  //	// add
 111  //	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
 112  //	lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
 113  //	lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
 114  //	lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
 115  //	a = a + carryAlt // cannot overflow
 116  //
 117  //	z[i] = lo0
 118  //	z[i+1] = lo1
 119  //	z[i+2] = lo2
 120  //	z[i+3] = lo3
 121  //
 122  // addMul takes this approach, using the hardware carry flag
 123  // first for carry and then for carryAlt.
 124  //
 125  // (2) addMulAlt assumes there are two hardware carry flags available.
 126  // It dedicates one each to carry and carryAlt, so that a multi-block
 127  // unrolling can keep the flags in hardware across all the blocks.
 128  // So even if the block size is 1, the code can do:
 129  //
 130  //	// multiply and add
 131  //	lo0, hi0 := m * mulsrc[i]
 132  //	lo0, carry = bits.Add(lo0, a, 0)
 133  //	lo0, carryAlt = bits.Add(lo0, addsrc[i], 0)
 134  //	z[i] = lo0
 135  //
 136  //	lo1, hi1 := m * mulsrc[i+1]
 137  //	lo1, carry = bits.Add(lo1, hi0, carry)
 138  //	lo1, carryAlt = bits.Add(lo1, addsrc[i+1], carryAlt)
 139  //	z[i+1] = lo1
 140  //
 141  //	lo2, hi2 := m * mulsrc[i+2]
 142  //	lo2, carry = bits.Add(lo2, hi1, carry)
 143  //	lo2, carryAlt = bits.Add(lo2, addsrc[i+2], carryAlt)
 144  //	z[i+2] = lo2
 145  //
 146  //	lo3, hi3 := m * mulsrc[i+3]
 147  //	lo3, carry = bits.Add(lo3, hi2, carry)
 148  //	lo3, carryAlt = bits.Add(lo3, addrsc[i+3], carryAlt)
 149  //	z[i+3] = lo2
 150  //
 151  //	a = hi3 + carry + carryAlt // cannot overflow
 152  //
 153  // (3) addMulVirtualCarry optimizes for systems with explicitly computed carry bits
 154  // (loong64, mips, riscv64), cutting the number of actual instructions almost by half.
 155  // Look again at the original word-at-a-time version:
 156  //
 157  //	lo1, hi1 := m * mulsrc[i]
 158  //	lo1, carry = bits.Add(lo1, a, 0)
 159  //	lo1, carryAlt = bits.Add(lo1, addsrc[i], 0)
 160  //	z[i] = lo1
 161  //	a = hi + carry + carryAlt // cannot overflow
 162  //
 163  // Although it uses four adds per word, those are cheap adds: the two bits.Add adds
 164  // use two instructions each (ADD+SLTU) and the final + adds only use one ADD each,
 165  // for a total of 6 instructions per word. In contrast, the middle stanzas in (2) use
 166  // only two “adds” per word, but these are SetCarry|UseCarry adds, which compile to
 167  // five instruction each, for a total of 10 instructions per word. So the word-at-a-time
 168  // loop is actually better. And we can reorder things slightly to use only a single carry bit:
 169  //
 170  //	lo1, hi1 := m * mulsrc[i]
 171  //	lo1, carry = bits.Add(lo1, a, 0)
 172  //	a = hi + carry
 173  //	lo1, carry = bits.Add(lo1, addsrc[i], 0)
 174  //	a = a + carry
 175  //	z[i] = lo1
 176  func addMul(f *Func, addsrc, mulsrc []byte, mulIndex int) {
 177  	a := f.Asm
 178  	mh := HintNone
 179  	if a.Arch == Arch386 && addsrc != "" {
 180  		mh = HintMemOK // too few registers otherwise
 181  	}
 182  	m := f.ArgHint("m", mh)
 183  	c := f.Arg("a")
 184  	n := f.Arg("z_len")
 185  
 186  	p := f.Pipe()
 187  	if addsrc != "" {
 188  		p.SetHint(addsrc, HintMemOK)
 189  	}
 190  	p.SetHint(mulsrc, HintMulSrc)
 191  	unroll := []int{1, 4}
 192  	switch a.Arch {
 193  	case Arch386:
 194  		unroll = []int{1} // too few registers
 195  	case ArchARM:
 196  		p.SetMaxColumns(2) // too few registers (but more than 386)
 197  	case ArchARM64:
 198  		unroll = []int{1, 8} // 5% speedup on c4as16
 199  	}
 200  
 201  	// See the large comment above for an explanation of the code being generated.
 202  	// This is optimization strategy 1.
 203  	p.Start(n, unroll...)
 204  	p.Loop(func(in, out [][]Reg) {
 205  		a.Comment("multiply")
 206  		prev := c
 207  		flag := SetCarry
 208  		for i, x := range in[mulIndex] {
 209  			hi := a.RegHint(HintMulHi)
 210  			a.MulWide(m, x, x, hi)
 211  			a.Add(prev, x, x, flag)
 212  			flag = UseCarry | SetCarry
 213  			if prev != c {
 214  				a.Free(prev)
 215  			}
 216  			out[0][i] = x
 217  			prev = hi
 218  		}
 219  		a.Add(a.Imm(0), prev, c, UseCarry|SmashCarry)
 220  		if addsrc != "" {
 221  			a.Comment("add")
 222  			flag := SetCarry
 223  			for i, x := range in[0] {
 224  				a.Add(x, out[0][i], out[0][i], flag)
 225  				flag = UseCarry | SetCarry
 226  			}
 227  			a.Add(a.Imm(0), c, c, UseCarry|SmashCarry)
 228  		}
 229  		p.StoreN(out)
 230  	})
 231  
 232  	f.StoreArg(c, "c")
 233  	a.Ret()
 234  }
 235  
 236  func addMulAlt(f *Func) {
 237  	a := f.Asm
 238  	m := f.ArgHint("m", HintMulSrc)
 239  	c := f.Arg("a")
 240  	n := f.Arg("z_len")
 241  
 242  	// On amd64, we need a non-immediate for the AtUnrollEnd adds.
 243  	r0 := a.ZR()
 244  	if !r0.Valid() {
 245  		r0 = a.Reg()
 246  		a.Mov(a.Imm(0), r0)
 247  	}
 248  
 249  	p := f.Pipe()
 250  	p.SetLabel("alt")
 251  	p.SetHint("x", HintMemOK)
 252  	p.SetHint("y", HintMemOK)
 253  	if a.Arch == ArchAMD64 {
 254  		p.SetMaxColumns(2)
 255  	}
 256  
 257  	// See the large comment above for an explanation of the code being generated.
 258  	// This is optimization strategy (2).
 259  	var hi Reg
 260  	prev := c
 261  	p.Start(n, 1, 8)
 262  	p.AtUnrollStart(func() {
 263  		a.Comment("multiply and add")
 264  		a.ClearCarry(AddCarry | AltCarry)
 265  		a.ClearCarry(AddCarry)
 266  		hi = a.Reg()
 267  	})
 268  	p.AtUnrollEnd(func() {
 269  		a.Add(r0, prev, c, UseCarry|SmashCarry)
 270  		a.Add(r0, c, c, UseCarry|SmashCarry|AltCarry)
 271  		prev = c
 272  	})
 273  	p.Loop(func(in, out [][]Reg) {
 274  		for i, y := range in[1] {
 275  			x := in[0][i]
 276  			lo := y
 277  			if lo.IsMem() {
 278  				lo = a.Reg()
 279  			}
 280  			a.MulWide(m, y, lo, hi)
 281  			a.Add(prev, lo, lo, UseCarry|SetCarry)
 282  			a.Add(x, lo, lo, UseCarry|SetCarry|AltCarry)
 283  			out[0][i] = lo
 284  			prev, hi = hi, prev
 285  		}
 286  		p.StoreN(out)
 287  	})
 288  
 289  	f.StoreArg(c, "c")
 290  	a.Ret()
 291  }
 292  
 293  func addMulVirtualCarry(f *Func, mulIndex int) {
 294  	a := f.Asm
 295  	m := f.Arg("m")
 296  	c := f.Arg("a")
 297  	n := f.Arg("z_len")
 298  
 299  	// See the large comment above for an explanation of the code being generated.
 300  	// This is optimization strategy (3).
 301  	p := f.Pipe()
 302  	p.Start(n, 1, 4)
 303  	p.Loop(func(in, out [][]Reg) {
 304  		a.Comment("synthetic carry, one column at a time")
 305  		lo, hi := a.Reg(), a.Reg()
 306  		for i, x := range in[mulIndex] {
 307  			a.MulWide(m, x, lo, hi)
 308  			if mulIndex == 1 {
 309  				a.Add(in[0][i], lo, lo, SetCarry)
 310  				a.Add(a.Imm(0), hi, hi, UseCarry|SmashCarry)
 311  			}
 312  			a.Add(c, lo, x, SetCarry)
 313  			a.Add(a.Imm(0), hi, c, UseCarry|SmashCarry)
 314  			out[0][i] = x
 315  		}
 316  		p.StoreN(out)
 317  	})
 318  	f.StoreArg(c, "c")
 319  	a.Ret()
 320  }
 321