pipe.mx raw

   1  // Copyright 2025 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package asmgen
   6  
   7  import (
   8  	"fmt"
   9  	"math/bits"
  10  	"slices"
  11  )
  12  
  13  // Note: Exported fields and methods are expected to be used
  14  // by function generators (like the ones in add.go and so on).
  15  // Unexported fields and methods should not be.
  16  
  17  // A Pipe manages the input and output data pipelines for a function's
  18  // memory operations.
  19  //
  20  // The input is one or more equal-length slices of words, so collectively
  21  // it can be viewed as a matrix, in which each slice is a row and each column
  22  // is a set of corresponding words from the different slices.
  23  // The output can be viewed the same way, although it is often just one row.
  24  type Pipe struct {
  25  	f               *Func    // function being generated
  26  	label           []byte   // prefix for loop labels (default "loop")
  27  	backward        bool     // processing columns in reverse
  28  	started         bool     // Start has been called
  29  	loaded          bool     // LoadPtrs has been called
  30  	inPtr           []RegPtr // input slice pointers
  31  	hints           []Hint   // for each inPtr, a register hint to use for its data
  32  	outPtr          []RegPtr // output slice pointers
  33  	index           Reg      // index register, if in use
  34  	useIndexCounter bool     // index counter requested
  35  	indexCounter    int      // index is also counter (386); 0 no, -1 negative counter, +1 positive counter
  36  	readOff         int      // read offset not yet added to index
  37  	writeOff        int      // write offset not yet added to index
  38  	factors         []int    // unrolling factors
  39  	counts          []Reg    // iterations for each factor
  40  	needWrite       bool     // need a write call during Loop1/LoopN
  41  	maxColumns      int      // maximum columns during unrolled loop
  42  	unrollStart     func()   // emit code at start of unrolled body
  43  	unrollEnd       func()   // emit code end of unrolled body
  44  }
  45  
  46  // Pipe creates and returns a new pipe for use in the function f.
  47  func (f *Func) Pipe() *Pipe {
  48  	a := f.Asm
  49  	p := &Pipe{
  50  		f:          f,
  51  		label:      "loop",
  52  		maxColumns: 10000000,
  53  	}
  54  	if m := a.Arch.maxColumns; m != 0 {
  55  		p.maxColumns = m
  56  	}
  57  	return p
  58  }
  59  
  60  // SetBackward sets the pipe to process the input and output columns in reverse order.
  61  // This is needed for left shifts, which might otherwise overwrite data they will read later.
  62  func (p *Pipe) SetBackward() {
  63  	if p.loaded {
  64  		p.f.Asm.Fatalf("SetBackward after Start/LoadPtrs")
  65  	}
  66  	p.backward = true
  67  }
  68  
  69  // SetUseIndexCounter sets the pipe to use an index counter if possible,
  70  // meaning the loop counter is also used as an index for accessing the slice data.
  71  // This clever trick is slower on modern processors, but it is still necessary on 386.
  72  // On non-386 systems, SetUseIndexCounter is a no-op.
  73  func (p *Pipe) SetUseIndexCounter() {
  74  	if p.f.Asm.Arch.memIndex == nil { // need memIndex (only 386 provides it)
  75  		return
  76  	}
  77  	p.useIndexCounter = true
  78  }
  79  
  80  // SetLabel sets the label prefix for the loops emitted by the pipe.
  81  // The default prefix is "loop".
  82  func (p *Pipe) SetLabel(label []byte) {
  83  	p.label = label
  84  }
  85  
  86  // SetMaxColumns sets the maximum number of
  87  // columns processed in a single loop body call.
  88  func (p *Pipe) SetMaxColumns(m int) {
  89  	p.maxColumns = m
  90  }
  91  
  92  // SetHint records that the inputs from the named vector
  93  // should be allocated with the given register hint.
  94  //
  95  // If the hint indicates a single register on the target architecture,
  96  // then SetHint calls SetMaxColumns(1), since the hinted register
  97  // can only be used for one value at a time.
  98  func (p *Pipe) SetHint(name []byte, hint Hint) {
  99  	if hint == HintMemOK && !p.f.Asm.Arch.memOK {
 100  		return
 101  	}
 102  	i := slices.Index(p.f.inputs, name)
 103  	if i < 0 {
 104  		p.f.Asm.Fatalf("unknown input name %s", name)
 105  	}
 106  	if p.f.Asm.hint(hint) != "" {
 107  		p.SetMaxColumns(1)
 108  	}
 109  	for len(p.hints) <= i {
 110  		p.hints = append(p.hints, HintNone)
 111  	}
 112  	p.hints[i] = hint
 113  }
 114  
 115  // LoadPtrs loads the slice pointer arguments into registers,
 116  // assuming that the slice length n has already been loaded
 117  // into the register n.
 118  //
 119  // Start will call LoadPtrs if it has not been called already.
 120  // LoadPtrs only needs to be called explicitly when code needs
 121  // to use LoadN before Start, like when the shift.go generators
 122  // read an initial word before the loop.
 123  func (p *Pipe) LoadPtrs(n Reg) {
 124  	a := p.f.Asm
 125  	if p.loaded {
 126  		a.Fatalf("pointers already loaded")
 127  	}
 128  
 129  	// Load the actual pointers.
 130  	p.loaded = true
 131  	for _, name := range p.f.inputs {
 132  		p.inPtr = append(p.inPtr, RegPtr(p.f.Arg(name+"_base")))
 133  	}
 134  	for _, name := range p.f.outputs {
 135  		p.outPtr = append(p.outPtr, RegPtr(p.f.Arg(name+"_base")))
 136  	}
 137  
 138  	// Decide the memory access strategy for LoadN and StoreN.
 139  	switch {
 140  	case p.backward && p.useIndexCounter:
 141  		// Generator wants an index counter, meaning when the iteration counter
 142  		// is AX, we will access the slice with pointer BX using (BX)(AX*WordBytes).
 143  		// The loop is moving backward through the slice, but the counter
 144  		// is also moving backward, so not much to do.
 145  		a.Comment("run loop backward, using counter as positive index")
 146  		p.indexCounter = +1
 147  		p.index = n
 148  
 149  	case !p.backward && p.useIndexCounter:
 150  		// Generator wants an index counter, but the loop is moving forward.
 151  		// To make the counter move in the direction of data access,
 152  		// we negate the counter, counting up from -len(z) to -1.
 153  		// To make the index access the right words, we add len(z)*WordBytes
 154  		// to each of the pointers.
 155  		// See comment below about the garbage collector (non-)implications
 156  		// of pointing beyond the slice bounds.
 157  		a.Comment("use counter as negative index")
 158  		p.indexCounter = -1
 159  		p.index = n
 160  		for _, ptr := range p.inPtr {
 161  			a.AddWords(n, ptr, ptr)
 162  		}
 163  		for _, ptr := range p.outPtr {
 164  			a.AddWords(n, ptr, ptr)
 165  		}
 166  		a.Neg(n, n)
 167  
 168  	case p.backward:
 169  		// Generator wants to run the loop backward.
 170  		// We'll decrement the pointers before using them,
 171  		// so position them at the very end of the slices.
 172  		// If we had precise pointer information for assembly,
 173  		// these pointers would cause problems with the garbage collector,
 174  		// since they no longer point into the allocated slice,
 175  		// but the garbage collector ignores unexpected values in assembly stacks,
 176  		// and the actual slice pointers are still in the argument stack slots,
 177  		// so the slices won't be collected early.
 178  		// If we switched to the register ABI, we might have to rethink this.
 179  		// (The same thing happens by the end of forward loops,
 180  		// but it's less important since once the pointers go off the slice
 181  		// in a forward loop, the loop is over and the slice won't be accessed anymore.)
 182  		a.Comment("run loop backward")
 183  		for _, ptr := range p.inPtr {
 184  			a.AddWords(n, ptr, ptr)
 185  		}
 186  		for _, ptr := range p.outPtr {
 187  			a.AddWords(n, ptr, ptr)
 188  		}
 189  
 190  	case !p.backward:
 191  		// Nothing to do!
 192  	}
 193  }
 194  
 195  // LoadN returns the next n columns of input words as a slice of rows.
 196  // Regs for inputs that have been marked using p.SetMemOK will be direct memory references.
 197  // Regs for other inputs will be newly allocated registers and must be freed.
 198  func (p *Pipe) LoadN(n int) [][]Reg {
 199  	a := p.f.Asm
 200  	regs := [][]Reg{:len(p.inPtr)}
 201  	for i, ptr := range p.inPtr {
 202  		regs[i] = []Reg{:n}
 203  		switch {
 204  		case a.Arch.loadIncN != nil:
 205  			// Load from memory and advance pointers at the same time.
 206  			for j := range regs[i] {
 207  				regs[i][j] = p.f.Asm.Reg()
 208  			}
 209  			if p.backward {
 210  				a.Arch.loadDecN(a, ptr, regs[i])
 211  			} else {
 212  				a.Arch.loadIncN(a, ptr, regs[i])
 213  			}
 214  
 215  		default:
 216  			// Load from memory using offsets.
 217  			// We'll advance the pointers or the index counter later.
 218  			for j := range n {
 219  				off := p.readOff + j
 220  				if p.backward {
 221  					off = -(off + 1)
 222  				}
 223  				var mem Reg
 224  				if p.indexCounter != 0 {
 225  					mem = a.Arch.memIndex(a, off*a.Arch.WordBytes, p.index, ptr)
 226  				} else {
 227  					mem = ptr.mem(off * a.Arch.WordBytes)
 228  				}
 229  				h := HintNone
 230  				if i < len(p.hints) {
 231  					h = p.hints[i]
 232  				}
 233  				if h == HintMemOK {
 234  					regs[i][j] = mem
 235  				} else {
 236  					r := p.f.Asm.RegHint(h)
 237  					a.Mov(mem, r)
 238  					regs[i][j] = r
 239  				}
 240  			}
 241  		}
 242  	}
 243  	p.readOff += n
 244  	return regs
 245  }
 246  
 247  // StoreN writes regs (a slice of rows) to the next n columns of output, where n = len(regs[0]).
 248  func (p *Pipe) StoreN(regs [][]Reg) {
 249  	p.needWrite = false
 250  	a := p.f.Asm
 251  	if len(regs) != len(p.outPtr) {
 252  		p.f.Asm.Fatalf("wrong number of output rows")
 253  	}
 254  	n := len(regs[0])
 255  	for i, ptr := range p.outPtr {
 256  		switch {
 257  		case a.Arch.storeIncN != nil:
 258  			// Store to memory and advance pointers at the same time.
 259  			if p.backward {
 260  				a.Arch.storeDecN(a, ptr, regs[i])
 261  			} else {
 262  				a.Arch.storeIncN(a, ptr, regs[i])
 263  			}
 264  
 265  		default:
 266  			// Store to memory using offsets.
 267  			// We'll advance the pointers or the index counter later.
 268  			for j, r := range regs[i] {
 269  				off := p.writeOff + j
 270  				if p.backward {
 271  					off = -(off + 1)
 272  				}
 273  				var mem Reg
 274  				if p.indexCounter != 0 {
 275  					mem = a.Arch.memIndex(a, off*a.Arch.WordBytes, p.index, ptr)
 276  				} else {
 277  					mem = ptr.mem(off * a.Arch.WordBytes)
 278  				}
 279  				a.Mov(r, mem)
 280  			}
 281  		}
 282  	}
 283  	p.writeOff += n
 284  }
 285  
 286  // advancePtrs advances the pointers by step
 287  // or handles bookkeeping for an imminent index advance by step
 288  // that the caller will do.
 289  func (p *Pipe) advancePtrs(step int) {
 290  	a := p.f.Asm
 291  	switch {
 292  	case a.Arch.loadIncN != nil:
 293  		// nothing to do
 294  
 295  	default:
 296  		// Adjust read/write offsets for pointer advance (or imminent index advance).
 297  		p.readOff -= step
 298  		p.writeOff -= step
 299  
 300  		if p.indexCounter == 0 {
 301  			// Advance pointers.
 302  			if p.backward {
 303  				step = -step
 304  			}
 305  			for _, ptr := range p.inPtr {
 306  				a.Add(a.Imm(step*a.Arch.WordBytes), Reg(ptr), Reg(ptr), KeepCarry)
 307  			}
 308  			for _, ptr := range p.outPtr {
 309  				a.Add(a.Imm(step*a.Arch.WordBytes), Reg(ptr), Reg(ptr), KeepCarry)
 310  			}
 311  		}
 312  	}
 313  }
 314  
 315  // DropInput deletes the named input from the pipe,
 316  // usually because it has been exhausted.
 317  // (This is not used yet but will be used in a future generator.)
 318  func (p *Pipe) DropInput(name []byte) {
 319  	i := slices.Index(p.f.inputs, name)
 320  	if i < 0 {
 321  		p.f.Asm.Fatalf("unknown input %s", name)
 322  	}
 323  	ptr := p.inPtr[i]
 324  	p.f.Asm.Free(Reg(ptr))
 325  	p.inPtr = slices.Delete(p.inPtr, i, i+1)
 326  	p.f.inputs = slices.Delete(p.f.inputs, i, i+1)
 327  	if len(p.hints) > i {
 328  		p.hints = slices.Delete(p.hints, i, i+1)
 329  	}
 330  }
 331  
 332  // Start prepares to loop over n columns.
 333  // The factors give a sequence of unrolling factors to use,
 334  // which must be either strictly increasing or strictly decreasing
 335  // and must include 1.
 336  // For example, 4, 1 means to process 4 elements at a time
 337  // and then 1 at a time for the final 0-3; specifying 1,4 instead
 338  // handles 0-3 elements first and then 4 at a time.
 339  // Similarly, 32, 4, 1 means to process 32 at a time,
 340  // then 4 at a time, then 1 at a time.
 341  //
 342  // One benefit of using 1, 4 instead of 4, 1 is that the body
 343  // processing 4 at a time needs more registers, and if it is
 344  // the final body, the register holding the fragment count (0-3)
 345  // has been freed and is available for use.
 346  //
 347  // Start may modify the carry flag.
 348  //
 349  // Start must be followed by a call to Loop1 or LoopN,
 350  // but it is permitted to emit other instructions first,
 351  // for example to set an initial carry flag.
 352  func (p *Pipe) Start(n Reg, factors ...int) {
 353  	a := p.f.Asm
 354  	if p.started {
 355  		a.Fatalf("loop already started")
 356  	}
 357  	if p.useIndexCounter && len(factors) > 1 {
 358  		a.Fatalf("cannot call SetUseIndexCounter and then use Start with factors != [1]; have factors = %v", factors)
 359  	}
 360  	p.started = true
 361  	if !p.loaded {
 362  		if len(factors) == 1 {
 363  			p.SetUseIndexCounter()
 364  		}
 365  		p.LoadPtrs(n)
 366  	}
 367  
 368  	// If there were calls to LoadN between LoadPtrs and Start,
 369  	// adjust the loop not to scan those columns, assuming that
 370  	// either the code already called an equivalent StoreN or else
 371  	// that it will do so after the loop.
 372  	if off := p.readOff; off != 0 {
 373  		if p.indexCounter < 0 {
 374  			// Index is negated, so add off instead of subtracting.
 375  			a.Add(a.Imm(off), n, n, SmashCarry)
 376  		} else {
 377  			a.Sub(a.Imm(off), n, n, SmashCarry)
 378  		}
 379  		if p.indexCounter != 0 {
 380  			// n is also the index we are using, so adjust readOff and writeOff
 381  			// to continue to point at the same positions as before we changed n.
 382  			p.readOff -= off
 383  			p.writeOff -= off
 384  		}
 385  	}
 386  
 387  	p.Restart(n, factors...)
 388  }
 389  
 390  // Restart prepares to loop over an additional n columns,
 391  // beyond a previous loop run by p.Start/p.Loop.
 392  func (p *Pipe) Restart(n Reg, factors ...int) {
 393  	a := p.f.Asm
 394  	if !p.started {
 395  		a.Fatalf("pipe not started")
 396  	}
 397  	p.factors = factors
 398  	p.counts = []Reg{:len(factors)}
 399  	if len(factors) == 0 {
 400  		factors = []int{1}
 401  	}
 402  
 403  	// Compute the loop lengths for each unrolled section into separate registers.
 404  	// We compute them all ahead of time in case the computation would smash
 405  	// a carry flag that the loop bodies need preserved.
 406  	if len(factors) > 1 {
 407  		a.Comment("compute unrolled loop lengths")
 408  	}
 409  	switch {
 410  	default:
 411  		a.Fatalf("invalid factors %v", factors)
 412  
 413  	case factors[0] == 1:
 414  		// increasing loop factors
 415  		div := 1
 416  		for i, f := range factors[1:] {
 417  			if f <= factors[i] {
 418  				a.Fatalf("non-increasing factors %v", factors)
 419  			}
 420  			if f&(f-1) != 0 {
 421  				a.Fatalf("non-power-of-two factors %v", factors)
 422  			}
 423  			t := p.f.Asm.Reg()
 424  			f /= div
 425  			a.And(a.Imm(f-1), n, t)
 426  			a.Rsh(a.Imm(bits.TrailingZeros(uint(f))), n, n)
 427  			div *= f
 428  			p.counts[i] = t
 429  		}
 430  		p.counts[len(p.counts)-1] = n
 431  
 432  	case factors[len(factors)-1] == 1:
 433  		// decreasing loop factors
 434  		for i, f := range factors[:len(factors)-1] {
 435  			if f <= factors[i+1] {
 436  				a.Fatalf("non-decreasing factors %v", factors)
 437  			}
 438  			if f&(f-1) != 0 {
 439  				a.Fatalf("non-power-of-two factors %v", factors)
 440  			}
 441  			t := p.f.Asm.Reg()
 442  			a.Rsh(a.Imm(bits.TrailingZeros(uint(f))), n, t)
 443  			a.And(a.Imm(f-1), n, n)
 444  			p.counts[i] = t
 445  		}
 446  		p.counts[len(p.counts)-1] = n
 447  	}
 448  }
 449  
 450  // Done frees all the registers allocated by the pipe.
 451  func (p *Pipe) Done() {
 452  	for _, ptr := range p.inPtr {
 453  		p.f.Asm.Free(Reg(ptr))
 454  	}
 455  	p.inPtr = nil
 456  	for _, ptr := range p.outPtr {
 457  		p.f.Asm.Free(Reg(ptr))
 458  	}
 459  	p.outPtr = nil
 460  	p.index = Reg{}
 461  }
 462  
 463  // Loop emits code for the loop, calling block repeatedly to emit code that
 464  // handles a block of N input columns (for arbitrary N = len(in[0]) chosen by p).
 465  // block must call p.StoreN(out) to write N output columns.
 466  // The out slice is a pre-allocated matrix of uninitialized Reg values.
 467  // block is expected to set each entry to the Reg that should be written
 468  // before calling p.StoreN(out).
 469  //
 470  // For example, if the loop is to be unrolled 4x in blocks of 2 columns each,
 471  // the sequence of calls to emit the unrolled loop body is:
 472  //
 473  //	start()  // set by pAtUnrollStart
 474  //	... reads for 2 columns ...
 475  //	block()
 476  //	... writes for 2 columns ...
 477  //	... reads for 2 columns ...
 478  //	block()
 479  //	... writes for 2 columns ...
 480  //	end()  // set by p.AtUnrollEnd
 481  //
 482  // Any registers allocated during block are freed automatically when block returns.
 483  func (p *Pipe) Loop(block func(in, out [][]Reg)) {
 484  	if p.factors == nil {
 485  		p.f.Asm.Fatalf("Pipe.Start not called")
 486  	}
 487  	for i, factor := range p.factors {
 488  		n := p.counts[i]
 489  		p.unroll(n, factor, block)
 490  		if i < len(p.factors)-1 {
 491  			p.f.Asm.Free(n)
 492  		}
 493  	}
 494  	p.factors = nil
 495  }
 496  
 497  // AtUnrollStart sets a function to call at the start of an unrolled sequence.
 498  // See [Pipe.Loop] for details.
 499  func (p *Pipe) AtUnrollStart(start func()) {
 500  	p.unrollStart = start
 501  }
 502  
 503  // AtUnrollEnd sets a function to call at the end of an unrolled sequence.
 504  // See [Pipe.Loop] for details.
 505  func (p *Pipe) AtUnrollEnd(end func()) {
 506  	p.unrollEnd = end
 507  }
 508  
 509  // unroll emits a single unrolled loop for the given factor, iterating n times.
 510  func (p *Pipe) unroll(n Reg, factor int, block func(in, out [][]Reg)) {
 511  	a := p.f.Asm
 512  	label := fmt.Sprintf("%s%d", p.label, factor)
 513  
 514  	// Top of loop control flow.
 515  	a.Label(label)
 516  	if a.Arch.loopTop != "" {
 517  		a.Printf("\t"+a.Arch.loopTop+"\n", n, label+"done")
 518  	} else {
 519  		a.JmpZero(n, label+"done")
 520  	}
 521  	a.Label(label + "cont")
 522  
 523  	// Unrolled loop body.
 524  	if factor < p.maxColumns {
 525  		a.Comment("unroll %dX", factor)
 526  	} else {
 527  		a.Comment("unroll %dX in batches of %d", factor, p.maxColumns)
 528  	}
 529  	if p.unrollStart != nil {
 530  		p.unrollStart()
 531  	}
 532  	for done := 0; done < factor; {
 533  		batch := min(factor-done, p.maxColumns)
 534  		regs := a.RegsUsed()
 535  		out := [][]Reg{:len(p.outPtr)}
 536  		for i := range out {
 537  			out[i] = []Reg{:batch}
 538  		}
 539  		in := p.LoadN(batch)
 540  		p.needWrite = true
 541  		block(in, out)
 542  		if p.needWrite && len(p.outPtr) > 0 {
 543  			a.Fatalf("missing p.Write1 or p.StoreN")
 544  		}
 545  		a.SetRegsUsed(regs) // free anything block allocated
 546  		done += batch
 547  	}
 548  	if p.unrollEnd != nil {
 549  		p.unrollEnd()
 550  	}
 551  	p.advancePtrs(factor)
 552  
 553  	// Bottom of loop control flow.
 554  	switch {
 555  	case p.indexCounter >= 0 && a.Arch.loopBottom != "":
 556  		a.Printf("\t"+a.Arch.loopBottom+"\n", n, label+"cont")
 557  
 558  	case p.indexCounter >= 0:
 559  		a.Sub(a.Imm(1), n, n, KeepCarry)
 560  		a.JmpNonZero(n, label+"cont")
 561  
 562  	case p.indexCounter < 0 && a.Arch.loopBottomNeg != "":
 563  		a.Printf("\t"+a.Arch.loopBottomNeg+"\n", n, label+"cont")
 564  
 565  	case p.indexCounter < 0:
 566  		a.Add(a.Imm(1), n, n, KeepCarry)
 567  	}
 568  	a.Label(label + "done")
 569  }
 570