1 package transform
2 3 import (
4 "tinygo.org/x/go-llvm"
5 )
6 7 // This is somewhat ugly to access through the API.
8 // https://github.com/llvm/llvm-project/blob/94ebcfd16dac67486bae624f74e1c5c789448bae/llvm/include/llvm/Support/ModRef.h#L62
9 // https://github.com/llvm/llvm-project/blob/94ebcfd16dac67486bae624f74e1c5c789448bae/llvm/include/llvm/Support/ModRef.h#L87
10 const shiftExcludeArgMem = 2
11 12 // MakeGCStackSlots converts all calls to runtime.trackPointer to explicit
13 // stores to stack slots that are scannable by the GC.
14 func MakeGCStackSlots(mod llvm.Module) bool {
15 // Check whether there are allocations at all.
16 alloc := mod.NamedFunction("runtime.alloc")
17 if alloc.IsNil() {
18 // Nothing to. Make sure all remaining bits and pieces for stack
19 // chains are neutralized.
20 for _, call := range getUses(mod.NamedFunction("runtime.trackPointer")) {
21 call.EraseFromParentAsInstruction()
22 }
23 stackChainStart := mod.NamedGlobal("runtime.stackChainStart")
24 if !stackChainStart.IsNil() {
25 stackChainStart.SetLinkage(llvm.InternalLinkage)
26 stackChainStart.SetInitializer(llvm.ConstNull(stackChainStart.GlobalValueType()))
27 stackChainStart.SetGlobalConstant(true)
28 }
29 return false
30 }
31 32 trackPointer := mod.NamedFunction("runtime.trackPointer")
33 if trackPointer.IsNil() || trackPointer.FirstUse().IsNil() {
34 return false // nothing to do
35 }
36 37 ctx := mod.Context()
38 builder := ctx.NewBuilder()
39 defer builder.Dispose()
40 targetData := llvm.NewTargetData(mod.DataLayout())
41 defer targetData.Dispose()
42 uintptrType := ctx.IntType(targetData.PointerSize() * 8)
43 44 // All functions that call runtime.alloc needs stack objects.
45 trackFuncs := map[llvm.Value]struct{}{}
46 markParentFunctions(trackFuncs, alloc)
47 48 // External functions may indirectly suspend the goroutine or perform a heap allocation.
49 // Their callers should get stack objects.
50 memAttr := llvm.AttributeKindID("memory")
51 for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) {
52 if _, ok := trackFuncs[fn]; ok {
53 continue // already found
54 }
55 if !fn.FirstBasicBlock().IsNil() {
56 // This is not an external function.
57 continue
58 }
59 if fn == trackPointer {
60 // Manually exclude trackPointer.
61 continue
62 }
63 64 mem := fn.GetEnumFunctionAttribute(memAttr)
65 if !mem.IsNil() && mem.GetEnumValue()>>shiftExcludeArgMem == 0 {
66 // This does not access non-argument memory.
67 // Exclude it.
68 continue
69 }
70 71 // The callers need stack objects.
72 markParentFunctions(trackFuncs, fn)
73 }
74 75 // Look at all other functions to see whether they contain function pointer
76 // calls.
77 // This takes less than 5ms for ~100kB of WebAssembly but would perhaps be
78 // faster when written in C++ (to avoid the CGo overhead).
79 for fn := mod.FirstFunction(); !fn.IsNil(); fn = llvm.NextFunction(fn) {
80 if _, ok := trackFuncs[fn]; ok {
81 continue // already found
82 }
83 84 scanBody:
85 for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) {
86 for call := bb.FirstInstruction(); !call.IsNil(); call = llvm.NextInstruction(call) {
87 if call.IsACallInst().IsNil() {
88 continue // only looking at calls
89 }
90 called := call.CalledValue()
91 if !called.IsAFunction().IsNil() {
92 continue // only looking for function pointers
93 }
94 trackFuncs[fn] = struct{}{}
95 markParentFunctions(trackFuncs, fn)
96 break scanBody
97 }
98 }
99 }
100 101 // Collect some variables used below in the loop.
102 stackChainStart := mod.NamedGlobal("runtime.stackChainStart")
103 if stackChainStart.IsNil() {
104 // This may be reached in a weird scenario where we call runtime.alloc but the garbage collector is unreachable.
105 // This can be accomplished by allocating 0 bytes.
106 // There is no point in tracking anything.
107 for _, use := range getUses(trackPointer) {
108 use.EraseFromParentAsInstruction()
109 }
110 return false
111 }
112 stackChainStart.SetLinkage(llvm.InternalLinkage)
113 stackChainStartType := stackChainStart.GlobalValueType()
114 stackChainStart.SetInitializer(llvm.ConstNull(stackChainStartType))
115 116 // Iterate until runtime.trackPointer has no uses left.
117 for use := trackPointer.FirstUse(); !use.IsNil(); use = trackPointer.FirstUse() {
118 // Pick the first use of runtime.trackPointer.
119 call := use.User()
120 if call.IsACallInst().IsNil() {
121 panic("expected runtime.trackPointer use to be a call")
122 }
123 124 // Pick the parent function.
125 fn := call.InstructionParent().Parent()
126 127 if _, ok := trackFuncs[fn]; !ok {
128 // This function nor any of the functions it calls (recursively)
129 // allocate anything from the heap, so it will not trigger a garbage
130 // collection cycle. Thus, it does not need to track local pointer
131 // values.
132 // This is a useful optimization but not as big as you might guess,
133 // as described above (it avoids stack objects for ~12% of
134 // functions).
135 call.EraseFromParentAsInstruction()
136 continue
137 }
138 139 // Find all calls to runtime.trackPointer in this function.
140 var calls []llvm.Value
141 var returns []llvm.Value
142 for bb := fn.FirstBasicBlock(); !bb.IsNil(); bb = llvm.NextBasicBlock(bb) {
143 for inst := bb.FirstInstruction(); !inst.IsNil(); inst = llvm.NextInstruction(inst) {
144 switch inst.InstructionOpcode() {
145 case llvm.Call:
146 if inst.CalledValue() == trackPointer {
147 calls = append(calls, inst)
148 }
149 case llvm.Ret:
150 returns = append(returns, inst)
151 }
152 }
153 }
154 155 // Determine what to do with each call.
156 var pointers []llvm.Value
157 for _, call := range calls {
158 ptr := call.Operand(0)
159 call.EraseFromParentAsInstruction()
160 161 // Some trivial optimizations.
162 if ptr.IsAInstruction().IsNil() {
163 continue
164 }
165 switch ptr.InstructionOpcode() {
166 case llvm.GetElementPtr:
167 // Check for all zero offsets.
168 // Sometimes LLVM rewrites bitcasts to zero-index GEPs, and we still need to track the GEP.
169 n := ptr.OperandsCount()
170 var hasOffset bool
171 for i := 1; i < n; i++ {
172 offset := ptr.Operand(i)
173 if offset.IsAConstantInt().IsNil() || offset.ZExtValue() != 0 {
174 hasOffset = true
175 break
176 }
177 }
178 179 if hasOffset {
180 // These values do not create new values: the values already
181 // existed locally in this function so must have been tracked
182 // already.
183 continue
184 }
185 case llvm.PHI:
186 // While the value may have already been tracked, it may be overwritten in a loop.
187 // Therefore, a second copy must be created to ensure that it is tracked over the entirety of its lifetime.
188 case llvm.ExtractValue, llvm.BitCast:
189 // These instructions do not create new values, but their
190 // original value may not be tracked. So keep tracking them for
191 // now.
192 // With more analysis, it should be possible to optimize a
193 // significant chunk of these away.
194 case llvm.Call, llvm.Load, llvm.IntToPtr:
195 // These create new values so must be stored locally. But
196 // perhaps some of these can be fused when they actually refer
197 // to the same value.
198 default:
199 // Ambiguous. These instructions are uncommon, but perhaps could
200 // be optimized if needed.
201 }
202 203 if ptr := stripPointerCasts(ptr); !ptr.IsAAllocaInst().IsNil() {
204 // Allocas don't need to be tracked because they are allocated
205 // on the C stack which is scanned separately.
206 continue
207 }
208 pointers = append(pointers, ptr)
209 }
210 211 if len(pointers) == 0 {
212 // This function does not need to keep track of stack pointers.
213 continue
214 }
215 216 // Determine the type of the required stack slot.
217 fields := []llvm.Type{
218 stackChainStartType, // Pointer to parent frame.
219 uintptrType, // Number of elements in this frame.
220 }
221 for _, ptr := range pointers {
222 fields = append(fields, ptr.Type())
223 }
224 stackObjectType := ctx.StructType(fields, false)
225 226 // Create the stack object at the function entry.
227 builder.SetInsertPointBefore(fn.EntryBasicBlock().FirstInstruction())
228 stackObject := builder.CreateAlloca(stackObjectType, "gc.stackobject")
229 initialStackObject := llvm.ConstNull(stackObjectType)
230 numSlots := (targetData.TypeAllocSize(stackObjectType) - uint64(targetData.PointerSize())*2) / uint64(targetData.ABITypeAlignment(uintptrType))
231 numSlotsValue := llvm.ConstInt(uintptrType, numSlots, false)
232 initialStackObject = builder.CreateInsertValue(initialStackObject, numSlotsValue, 1, "")
233 builder.CreateStore(initialStackObject, stackObject)
234 235 // Update stack start.
236 parent := builder.CreateLoad(stackChainStartType, stackChainStart, "")
237 gep := builder.CreateGEP(stackObjectType, stackObject, []llvm.Value{
238 llvm.ConstInt(ctx.Int32Type(), 0, false),
239 llvm.ConstInt(ctx.Int32Type(), 0, false),
240 }, "")
241 builder.CreateStore(parent, gep)
242 builder.CreateStore(stackObject, stackChainStart)
243 244 // Do a store to the stack object after each new pointer that is created.
245 pointerStores := make(map[llvm.Value]struct{})
246 for i, ptr := range pointers {
247 // Insert the store after the pointer value is created.
248 insertionPoint := llvm.NextInstruction(ptr)
249 for !insertionPoint.IsAPHINode().IsNil() {
250 // PHI nodes are required to be at the start of the block.
251 // Insert after the last PHI node.
252 insertionPoint = llvm.NextInstruction(insertionPoint)
253 }
254 builder.SetInsertPointBefore(insertionPoint)
255 256 // Extract a pointer to the appropriate section of the stack object.
257 gep := builder.CreateGEP(stackObjectType, stackObject, []llvm.Value{
258 llvm.ConstInt(ctx.Int32Type(), 0, false),
259 llvm.ConstInt(ctx.Int32Type(), uint64(2+i), false),
260 }, "")
261 262 // Store the pointer into the stack slot.
263 store := builder.CreateStore(ptr, gep)
264 pointerStores[store] = struct{}{}
265 }
266 267 // Make sure this stack object is popped from the linked list of stack
268 // objects at return.
269 for _, ret := range returns {
270 // Check for any tail calls at this return.
271 prev := llvm.PrevInstruction(ret)
272 if !prev.IsNil() && !prev.IsABitCastInst().IsNil() {
273 // A bitcast can appear before a tail call, so skip backwards more.
274 prev = llvm.PrevInstruction(prev)
275 }
276 if !prev.IsNil() && !prev.IsACallInst().IsNil() {
277 // This is no longer a tail call.
278 prev.SetTailCall(false)
279 }
280 builder.SetInsertPointBefore(ret)
281 builder.CreateStore(parent, stackChainStart)
282 }
283 }
284 285 return true
286 }
287 288 // markParentFunctions traverses all parent function calls (recursively) and
289 // adds them to the set of marked functions. It only considers function calls:
290 // any other uses of such a function is ignored.
291 func markParentFunctions(marked map[llvm.Value]struct{}, fn llvm.Value) {
292 worklist := []llvm.Value{fn}
293 for len(worklist) != 0 {
294 fn := worklist[len(worklist)-1]
295 worklist = worklist[:len(worklist)-1]
296 for _, use := range getUses(fn) {
297 if use.IsACallInst().IsNil() || use.CalledValue() != fn {
298 // Not the parent function.
299 continue
300 }
301 parent := use.InstructionParent().Parent()
302 if _, ok := marked[parent]; !ok {
303 marked[parent] = struct{}{}
304 worklist = append(worklist, parent)
305 }
306 }
307 }
308 }
309