gc.go raw

   1  package transform
   2  
   3  import (
   4  	"tinygo.org/x/go-llvm"
   5  )
   6  
   7  // This is somewhat ugly to access through the API.
   8  // https://github.com/llvm/llvm-project/blob/94ebcfd16dac67486bae624f74e1c5c789448bae/llvm/include/llvm/Support/ModRef.h#L62
   9  // https://github.com/llvm/llvm-project/blob/94ebcfd16dac67486bae624f74e1c5c789448bae/llvm/include/llvm/Support/ModRef.h#L87
  10  const shiftExcludeArgMem = 2
  11  
  12  // MakeGCStackSlots converts all calls to runtime.trackPointer to explicit
  13  // stores to stack slots that are scannable by the GC.
  14  func MakeGCStackSlots(mod llvm.Module) bool {
  15  	// Check whether there are allocations at all.
  16  	alloc := mod.NamedFunction("runtime.alloc")
  17  	if alloc.IsNil() {
  18  		// Nothing to. Make sure all remaining bits and pieces for stack
  19  		// chains are neutralized.
  20  		for _, call := range getUses(mod.NamedFunction("runtime.trackPointer")) {
  21  			call.EraseFromParentAsInstruction()
  22  		}
  23  		stackChainStart := mod.NamedGlobal("runtime.stackChainStart")
  24  		if !stackChainStart.IsNil() {
  25  			stackChainStart.SetLinkage(llvm.InternalLinkage)
  26  			stackChainStart.SetInitializer(llvm.ConstNull(stackChainStart.GlobalValueType()))
  27  			stackChainStart.SetGlobalConstant(true)
  28  		}
  29  		return false
  30  	}
  31  
  32  	trackPointer := mod.NamedFunction("runtime.trackPointer")
  33  	if trackPointer.IsNil() || trackPointer.FirstUse().IsNil() {
  34  		return false // nothing to do
  35  	}
  36  
  37  	ctx := mod.Context()
  38  	builder := ctx.NewBuilder()
  39  	defer builder.Dispose()
  40  	targetData := llvm.NewTargetData(mod.DataLayout())
  41  	defer targetData.Dispose()
  42  	uintptrType := ctx.IntType(targetData.PointerSize() * 8)
  43  
  44  	// All functions that call runtime.alloc needs stack objects.
  45  	trackFuncs := map[llvm.Value]struct{}{}
  46  	markParentFunctions(trackFuncs, alloc)
  47  
  48  	// External functions may indirectly suspend the goroutine or perform a heap allocation.
  49  	// Their callers should get stack objects.
  50  	memAttr := llvm.AttributeKindID("memory")
  51  	for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) {
  52  		if _, ok := trackFuncs[fn]; ok {
  53  			continue // already found
  54  		}
  55  		if !fn.FirstBasicBlock().IsNil() {
  56  			// This is not an external function.
  57  			continue
  58  		}
  59  		if fn == trackPointer {
  60  			// Manually exclude trackPointer.
  61  			continue
  62  		}
  63  
  64  		mem := fn.GetEnumFunctionAttribute(memAttr)
  65  		if !mem.IsNil() && mem.GetEnumValue()>>shiftExcludeArgMem == 0 {
  66  			// This does not access non-argument memory.
  67  			// Exclude it.
  68  			continue
  69  		}
  70  
  71  		// The callers need stack objects.
  72  		markParentFunctions(trackFuncs, fn)
  73  	}
  74  
  75  	// Look at all other functions to see whether they contain function pointer
  76  	// calls.
  77  	// This takes less than 5ms for ~100kB of WebAssembly but would perhaps be
  78  	// faster when written in C++ (to avoid the CGo overhead).
  79  	for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) {
  80  		if _, ok := trackFuncs[fn]; ok {
  81  			continue // already found
  82  		}
  83  
  84  	scanBody:
  85  		for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) {
  86  			for call := bb.FirstInstruction(); !call.IsNil(); call = llvm.NextInstruction(call) {
  87  				if call.IsACallInst().IsNil() {
  88  					continue // only looking at calls
  89  				}
  90  				called := call.CalledValue()
  91  				if !called.IsAFunction().IsNil() {
  92  					continue // only looking for function pointers
  93  				}
  94  				trackFuncs[fn] = struct{}{}
  95  				markParentFunctions(trackFuncs, fn)
  96  				break scanBody
  97  			}
  98  		}
  99  	}
 100  
 101  	// Collect some variables used below in the loop.
 102  	stackChainStart := mod.NamedGlobal("runtime.stackChainStart")
 103  	if stackChainStart.IsNil() {
 104  		// This may be reached in a weird scenario where we call runtime.alloc but the garbage collector is unreachable.
 105  		// This can be accomplished by allocating 0 bytes.
 106  		// There is no point in tracking anything.
 107  		for _, use := range getUses(trackPointer) {
 108  			use.EraseFromParentAsInstruction()
 109  		}
 110  		return false
 111  	}
 112  	stackChainStart.SetLinkage(llvm.InternalLinkage)
 113  	stackChainStartType := stackChainStart.GlobalValueType()
 114  	stackChainStart.SetInitializer(llvm.ConstNull(stackChainStartType))
 115  
 116  	// Iterate until runtime.trackPointer has no uses left.
 117  	for use := trackPointer.FirstUse(); !use.IsNil(); use = trackPointer.FirstUse() {
 118  		// Pick the first use of runtime.trackPointer.
 119  		call := use.User()
 120  		if call.IsACallInst().IsNil() {
 121  			panic("expected runtime.trackPointer use to be a call")
 122  		}
 123  
 124  		// Pick the parent function.
 125  		fn := call.InstructionParent().Parent()
 126  
 127  		if _, ok := trackFuncs[fn]; !ok {
 128  			// This function nor any of the functions it calls (recursively)
 129  			// allocate anything from the heap, so it will not trigger a garbage
 130  			// collection cycle. Thus, it does not need to track local pointer
 131  			// values.
 132  			// This is a useful optimization but not as big as you might guess,
 133  			// as described above (it avoids stack objects for ~12% of
 134  			// functions).
 135  			call.EraseFromParentAsInstruction()
 136  			continue
 137  		}
 138  
 139  		// Find all calls to runtime.trackPointer in this function.
 140  		var calls []llvm.Value
 141  		var returns []llvm.Value
 142  		for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) {
 143  			for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) {
 144  				switch inst.InstructionOpcode() {
 145  				case llvm.Call:
 146  					if inst.CalledValue() == trackPointer {
 147  						calls = append(calls, inst)
 148  					}
 149  				case llvm.Ret:
 150  					returns = append(returns, inst)
 151  				}
 152  			}
 153  		}
 154  
 155  		// Determine what to do with each call.
 156  		var pointers []llvm.Value
 157  		for _, call := range calls {
 158  			ptr := call.Operand(0)
 159  			call.EraseFromParentAsInstruction()
 160  
 161  			// Some trivial optimizations.
 162  			if ptr.IsAInstruction().IsNil() {
 163  				continue
 164  			}
 165  			switch ptr.InstructionOpcode() {
 166  			case llvm.GetElementPtr:
 167  				// Check for all zero offsets.
 168  				// Sometimes LLVM rewrites bitcasts to zero-index GEPs, and we still need to track the GEP.
 169  				n := ptr.OperandsCount()
 170  				var hasOffset bool
 171  				for i := 1; i < n; i++ {
 172  					offset := ptr.Operand(i)
 173  					if offset.IsAConstantInt().IsNil() || offset.ZExtValue() != 0 {
 174  						hasOffset = true
 175  						break
 176  					}
 177  				}
 178  
 179  				if hasOffset {
 180  					// These values do not create new values: the values already
 181  					// existed locally in this function so must have been tracked
 182  					// already.
 183  					continue
 184  				}
 185  			case llvm.PHI:
 186  				// While the value may have already been tracked, it may be overwritten in a loop.
 187  				// Therefore, a second copy must be created to ensure that it is tracked over the entirety of its lifetime.
 188  			case llvm.ExtractValue, llvm.BitCast:
 189  				// These instructions do not create new values, but their
 190  				// original value may not be tracked. So keep tracking them for
 191  				// now.
 192  				// With more analysis, it should be possible to optimize a
 193  				// significant chunk of these away.
 194  			case llvm.Call, llvm.Load, llvm.IntToPtr:
 195  				// These create new values so must be stored locally. But
 196  				// perhaps some of these can be fused when they actually refer
 197  				// to the same value.
 198  			default:
 199  				// Ambiguous. These instructions are uncommon, but perhaps could
 200  				// be optimized if needed.
 201  			}
 202  
 203  			if ptr := stripPointerCasts(ptr); !ptr.IsAAllocaInst().IsNil() {
 204  				// Allocas don't need to be tracked because they are allocated
 205  				// on the C stack which is scanned separately.
 206  				continue
 207  			}
 208  			pointers = append(pointers, ptr)
 209  		}
 210  
 211  		if len(pointers) == 0 {
 212  			// This function does not need to keep track of stack pointers.
 213  			continue
 214  		}
 215  
 216  		// Determine the type of the required stack slot.
 217  		fields := []llvm.Type{
 218  			stackChainStartType, // Pointer to parent frame.
 219  			uintptrType,         // Number of elements in this frame.
 220  		}
 221  		for _, ptr := range pointers {
 222  			fields = append(fields, ptr.Type())
 223  		}
 224  		stackObjectType := ctx.StructType(fields, false)
 225  
 226  		// Create the stack object at the function entry.
 227  		builder.SetInsertPointBefore(fn.EntryBasicBlock().FirstInstruction())
 228  		stackObject := builder.CreateAlloca(stackObjectType, "gc.stackobject")
 229  		initialStackObject := llvm.ConstNull(stackObjectType)
 230  		numSlots := (targetData.TypeAllocSize(stackObjectType) - uint64(targetData.PointerSize())*2) / uint64(targetData.ABITypeAlignment(uintptrType))
 231  		numSlotsValue := llvm.ConstInt(uintptrType, numSlots, false)
 232  		initialStackObject = builder.CreateInsertValue(initialStackObject, numSlotsValue, 1, "")
 233  		builder.CreateStore(initialStackObject, stackObject)
 234  
 235  		// Update stack start.
 236  		parent := builder.CreateLoad(stackChainStartType, stackChainStart, "")
 237  		gep := builder.CreateGEP(stackObjectType, stackObject, []llvm.Value{
 238  			llvm.ConstInt(ctx.Int32Type(), 0, false),
 239  			llvm.ConstInt(ctx.Int32Type(), 0, false),
 240  		}, "")
 241  		builder.CreateStore(parent, gep)
 242  		builder.CreateStore(stackObject, stackChainStart)
 243  
 244  		// Do a store to the stack object after each new pointer that is created.
 245  		pointerStores := make(map[llvm.Value]struct{})
 246  		for i, ptr := range pointers {
 247  			// Insert the store after the pointer value is created.
 248  			insertionPoint := llvm.NextInstruction(ptr)
 249  			for !insertionPoint.IsAPHINode().IsNil() {
 250  				// PHI nodes are required to be at the start of the block.
 251  				// Insert after the last PHI node.
 252  				insertionPoint = llvm.NextInstruction(insertionPoint)
 253  			}
 254  			builder.SetInsertPointBefore(insertionPoint)
 255  
 256  			// Extract a pointer to the appropriate section of the stack object.
 257  			gep := builder.CreateGEP(stackObjectType, stackObject, []llvm.Value{
 258  				llvm.ConstInt(ctx.Int32Type(), 0, false),
 259  				llvm.ConstInt(ctx.Int32Type(), uint64(2+i), false),
 260  			}, "")
 261  
 262  			// Store the pointer into the stack slot.
 263  			store := builder.CreateStore(ptr, gep)
 264  			pointerStores[store] = struct{}{}
 265  		}
 266  
 267  		// Make sure this stack object is popped from the linked list of stack
 268  		// objects at return.
 269  		for _, ret := range returns {
 270  			// Check for any tail calls at this return.
 271  			prev := llvm.PrevInstruction(ret)
 272  			if !prev.IsNil() && !prev.IsABitCastInst().IsNil() {
 273  				// A bitcast can appear before a tail call, so skip backwards more.
 274  				prev = llvm.PrevInstruction(prev)
 275  			}
 276  			if !prev.IsNil() && !prev.IsACallInst().IsNil() {
 277  				// This is no longer a tail call.
 278  				prev.SetTailCall(false)
 279  			}
 280  			builder.SetInsertPointBefore(ret)
 281  			builder.CreateStore(parent, stackChainStart)
 282  		}
 283  	}
 284  
 285  	return true
 286  }
 287  
 288  // markParentFunctions traverses all parent function calls (recursively) and
 289  // adds them to the set of marked functions. It only considers function calls:
 290  // any other uses of such a function is ignored.
 291  func markParentFunctions(marked map[llvm.Value]struct{}, fn llvm.Value) {
 292  	worklist := []llvm.Value{fn}
 293  	for len(worklist) != 0 {
 294  		fn := worklist[len(worklist)-1]
 295  		worklist = worklist[:len(worklist)-1]
 296  		for _, use := range getUses(fn) {
 297  			if use.IsACallInst().IsNil() || use.CalledValue() != fn {
 298  				// Not the parent function.
 299  				continue
 300  			}
 301  			parent := use.InstructionParent().Parent()
 302  			if _, ok := marked[parent]; !ok {
 303  				marked[parent] = struct{}{}
 304  				worklist = append(worklist, parent)
 305  			}
 306  		}
 307  	}
 308  }
 309