wasmspawn.go raw
1 package compiler
2
3 // spawn_wasm.go — wasm spawn path.
4 //
5 // In wasm (GOOS=js GOARCH=wasm), spawn() creates a new Worker instead of
6 // forking a process. Channels at spawn boundaries are SAB-backed; all other
7 // channel operations use the regular runtime. Scalar args are packed into a
8 // linear-memory buffer and deserialized by __spawn_entry on the child side.
9 //
10 // Architecture:
11 // Parent side (createWasmSpawn):
12 // 1. For each chan arg: bridge.channel_create → int32 SAB handle stored in
13 // b.sabChannels keyed by the SSA channel value.
14 // 2. Pack scalar args into a stack alloca.
15 // 3. Pack SAB handles into a second alloca.
16 // 4. Call bridge.spawn_domain(fnIdx, argPtr, argLen, chanHandlesPtr, nChans).
17 //
18 // Child side (__spawn_entry, emitted at link time by emitWasmSpawnEntry):
19 // A switch over fnIdx dispatches to a per-target case that:
20 // 1. Deserializes scalars from argPtr.
21 // 2. Loads int32 handles from chanHandlesPtr, inttoptr → dataPtrType.
22 // 3. Calls the target function directly.
23 // The target function's channel params are added to sabChannels in its
24 // builder (via setupWasmSpawnTargetChannels) so send/recv/close use bridge.
25
26 import (
27 "fmt"
28 "go/types"
29 "strings"
30
31 "golang.org/x/tools/go/ssa"
32 "tinygo.org/x/go-llvm"
33 )
34
35 // wasmSpawnTarget records a spawn-eligible function and its dispatch index.
36 type wasmSpawnTarget struct {
37 fn *ssa.Function
38 fnIdx int32
39 chanParamIdxs []int // indices into fn.Params that are channel types
40 scalarParamIdxs []int // indices that are non-channel (scalar/struct)
41 }
42
43 // isWasmTarget returns true when building for GOOS=js GOARCH=wasm.
44 func (b *builder) isWasmTarget() bool {
45 return strings.HasPrefix(b.Triple, "wasm") && b.GOOS == "js"
46 }
47
48 // scanWasmSpawnTargets pre-scans the package for spawn() calls, collecting
49 // all statically-known target functions and assigning each a stable fnIdx.
50 // Must be called after ssaPkg.Build() and before createPackage.
51 func (c *compilerContext) scanWasmSpawnTargets(pkg *ssa.Package) {
52 if !strings.HasPrefix(c.Triple, "wasm") || c.GOOS != "js" {
53 return
54 }
55 seen := make(map[*ssa.Function]bool)
56 for _, mem := range pkg.Members {
57 fn, ok := mem.(*ssa.Function)
58 if !ok {
59 continue
60 }
61 c.scanFnForSpawn(fn, seen)
62 }
63 }
64
65 func (c *compilerContext) scanFnForSpawn(fn *ssa.Function, seen map[*ssa.Function]bool) {
66 for _, blk := range fn.Blocks {
67 for _, instr := range blk.Instrs {
68 call, ok := instr.(*ssa.Call)
69 if !ok {
70 continue
71 }
72 bi, ok := call.Call.Value.(*ssa.Builtin)
73 if !ok || bi.Name() != "spawn" {
74 continue
75 }
76 target := extractWasmSpawnTargetFn(call.Call.Args)
77 if target == nil || seen[target] {
78 continue
79 }
80 seen[target] = true
81 sig := target.Signature
82 var chanIdxs, scalarIdxs []int
83 for i := 0; i < sig.Params().Len(); i++ {
84 if _, isChan := sig.Params().At(i).Type().Underlying().(*types.Chan); isChan {
85 chanIdxs = append(chanIdxs, i)
86 } else {
87 scalarIdxs = append(scalarIdxs, i)
88 }
89 }
90 idx := int32(len(c.wasmSpawnTargets))
91 c.wasmSpawnTargets = append(c.wasmSpawnTargets, &wasmSpawnTarget{
92 fn: target,
93 fnIdx: idx,
94 chanParamIdxs: chanIdxs,
95 scalarParamIdxs: scalarIdxs,
96 })
97 if c.wasmSpawnIndex == nil {
98 c.wasmSpawnIndex = make(map[*ssa.Function]int32)
99 }
100 c.wasmSpawnIndex[target] = idx
101 }
102 }
103 }
104
105 func extractWasmSpawnTargetFn(args []ssa.Value) *ssa.Function {
106 start := 0
107 if len(args) > 0 {
108 if _, ok := extractTransportString(args[0]); ok {
109 start = 1
110 }
111 }
112 if start >= len(args) {
113 return nil
114 }
115 switch v := args[start].(type) {
116 case *ssa.Function:
117 return v
118 case *ssa.MakeClosure:
119 if f, ok := v.Fn.(*ssa.Function); ok {
120 return f
121 }
122 }
123 return nil
124 }
125
126 // validateWasmSpawnArg checks that a type is safe across the wasm spawn
127 // boundary. No Codec requirement. Rejects pointers, functions, interfaces.
128 // Top-level channel types are allowed (SAB-backed); nested channels are not.
129 func validateWasmSpawnArg(t types.Type, argIdx int) error {
130 if _, ok := t.Underlying().(*types.Chan); ok {
131 return nil
132 }
133 return validateWasmSpawnElem(t, argIdx)
134 }
135
136 func validateWasmSpawnElem(t types.Type, argIdx int) error {
137 switch ut := t.Underlying().(type) {
138 case *types.Basic:
139 return nil
140 case *types.Pointer:
141 return fmt.Errorf("spawn: argument %d is a pointer — not allowed across wasm spawn boundary", argIdx+1)
142 case *types.Signature:
143 return fmt.Errorf("spawn: argument %d is a function value — not allowed across wasm spawn boundary", argIdx+1)
144 case *types.Interface:
145 return fmt.Errorf("spawn: argument %d is an interface — not allowed across wasm spawn boundary", argIdx+1)
146 case *types.Chan:
147 return fmt.Errorf("spawn: argument %d contains a nested channel — channels must be top-level params", argIdx+1)
148 case *types.Slice:
149 if basic, ok := ut.Elem().Underlying().(*types.Basic); ok && basic.Kind() == types.Byte {
150 return nil // []byte / string
151 }
152 return validateWasmSpawnElem(ut.Elem(), argIdx)
153 case *types.Struct:
154 for i := 0; i < ut.NumFields(); i++ {
155 if err := validateWasmSpawnElem(ut.Field(i).Type(), argIdx); err != nil {
156 return err
157 }
158 }
159 return nil
160 case *types.Array:
161 return validateWasmSpawnElem(ut.Elem(), argIdx)
162 }
163 return fmt.Errorf("spawn: argument %d type %s not allowed across wasm spawn boundary", argIdx+1, t)
164 }
165
166 // createWasmSpawn handles spawn() in the wasm target. It emits:
167 // 1. bridge.channel_create for each channel arg → SAB handle in b.sabChannels.
168 // 2. A packed scalar arg buffer.
169 // 3. A packed chan handle array.
170 // 4. bridge.spawn_domain(fnIdx, argPtr, argLen, chanHandlesPtr, nChans).
171 func (b *builder) createWasmSpawn(instr *ssa.CallCommon) (llvm.Value, error) {
172 argStart := 0
173 if len(instr.Args) > 0 {
174 if _, ok := extractTransportString(instr.Args[0]); ok {
175 argStart = 1
176 }
177 }
178 if argStart >= len(instr.Args) {
179 b.addError(instr.Pos(), "spawn: requires a function argument")
180 return llvm.Value{}, nil
181 }
182
183 fnArg := instr.Args[argStart]
184 var targetFn *ssa.Function
185 switch v := fnArg.(type) {
186 case *ssa.Function:
187 targetFn = v
188 case *ssa.MakeClosure:
189 if f, ok := v.Fn.(*ssa.Function); ok {
190 targetFn = f
191 }
192 }
193 if targetFn == nil {
194 b.addError(instr.Pos(), "spawn: first argument must be a static top-level function")
195 return llvm.Value{}, nil
196 }
197 if targetFn.Package() != nil && b.fn.Package() != nil &&
198 targetFn.Package().Pkg.Path() != b.fn.Package().Pkg.Path() {
199 b.addError(instr.Pos(), "spawn: wasm spawn target must be in the same package as the caller")
200 return llvm.Value{}, nil
201 }
202
203 fnIdx, ok := b.wasmSpawnIndex[targetFn]
204 if !ok {
205 b.addError(instr.Pos(), "spawn: internal — function not in wasm dispatch table (scan missed it)")
206 return llvm.Value{}, nil
207 }
208
209 concreteArgs := instr.Args[argStart+1:]
210 sig := targetFn.Signature
211 if len(concreteArgs) != sig.Params().Len() {
212 b.addError(instr.Pos(), fmt.Sprintf("spawn: %s expects %d args, got %d",
213 targetFn.Name(), sig.Params().Len(), len(concreteArgs)))
214 return llvm.Value{}, nil
215 }
216
217 for i, arg := range concreteArgs {
218 if err := validateWasmSpawnArg(sig.Params().At(i).Type(), i); err != nil {
219 b.addError(instr.Pos(), err.Error())
220 return llvm.Value{}, nil
221 }
222 if mc := traceToMakeChan(arg); mc != nil {
223 if cv, ok := mc.Size.(*ssa.Const); ok {
224 if cv.Value != nil {
225 // positive size = buffered; warn but allow
226 _ = cv
227 }
228 }
229 }
230 }
231
232 i32 := b.ctx.Int32Type()
233
234 // --- channel args: create SAB handles ---
235 var chanHandles []llvm.Value
236 for _, arg := range concreteArgs {
237 if _, isChan := arg.Type().Underlying().(*types.Chan); !isChan {
238 continue
239 }
240 elemType := arg.Type().Underlying().(*types.Chan).Elem()
241 elemSz := b.targetData.TypeAllocSize(b.getLLVMType(elemType))
242 slotSize := nextPow2u32(uint32(elemSz) + 12)
243 if slotSize < 64 {
244 slotSize = 64
245 }
246 slotCount := uint32(16)
247 handle := b.emitBridgeCall("channel_create", []llvm.Value{
248 llvm.ConstInt(i32, uint64(slotSize), false),
249 llvm.ConstInt(i32, uint64(slotCount), false),
250 }, i32)
251 chanHandles = append(chanHandles, handle)
252 if b.sabChannels == nil {
253 b.sabChannels = make(map[ssa.Value]llvm.Value)
254 }
255 b.sabChannels[arg] = handle
256 }
257
258 // --- scalar args: pack into a stack buffer ---
259 var scalarVals []llvm.Value
260 var scalarTypes []llvm.Type
261 var totalScalarBytes uint64
262 for i, arg := range concreteArgs {
263 if _, isChan := sig.Params().At(i).Type().Underlying().(*types.Chan); isChan {
264 continue
265 }
266 v := b.getValue(arg, instr.Pos())
267 llvmType := b.getLLVMType(arg.Type())
268 scalarVals = append(scalarVals, v)
269 scalarTypes = append(scalarTypes, llvmType)
270 totalScalarBytes += b.targetData.TypeAllocSize(llvmType)
271 }
272
273 argPtr := llvm.ConstNull(b.dataPtrType)
274 argLen := llvm.ConstInt(i32, 0, false)
275 if totalScalarBytes > 0 {
276 bufType := llvm.ArrayType(b.ctx.Int8Type(), int(totalScalarBytes))
277 bufAlloca, _ := b.createTemporaryAlloca(bufType, "spawn.args")
278 var off uint64
279 for idx, v := range scalarVals {
280 sz := b.targetData.TypeAllocSize(scalarTypes[idx])
281 gep := b.CreateGEP(bufType, bufAlloca, []llvm.Value{
282 llvm.ConstInt(i32, 0, false),
283 llvm.ConstInt(i32, off, false),
284 }, "spawn.arg.slot")
285 b.CreateStore(v, gep)
286 off += sz
287 }
288 argPtr = b.CreateGEP(bufType, bufAlloca, []llvm.Value{
289 llvm.ConstInt(i32, 0, false),
290 llvm.ConstInt(i32, 0, false),
291 }, "spawn.argptr")
292 argLen = llvm.ConstInt(i32, totalScalarBytes, false)
293 }
294
295 // --- channel handle array ---
296 nChans := len(chanHandles)
297 chanPtr := llvm.ConstNull(b.dataPtrType)
298 nChansVal := llvm.ConstInt(i32, uint64(nChans), false)
299 if nChans > 0 {
300 handleArrType := llvm.ArrayType(i32, nChans)
301 handleAlloca, _ := b.createTemporaryAlloca(handleArrType, "spawn.chans")
302 for hIdx, h := range chanHandles {
303 gep := b.CreateGEP(handleArrType, handleAlloca, []llvm.Value{
304 llvm.ConstInt(i32, 0, false),
305 llvm.ConstInt(i32, uint64(hIdx), false),
306 }, "")
307 b.CreateStore(h, gep)
308 }
309 chanPtr = b.CreateGEP(handleArrType, handleAlloca, []llvm.Value{
310 llvm.ConstInt(i32, 0, false),
311 llvm.ConstInt(i32, 0, false),
312 }, "spawn.chanptr")
313 }
314
315 // bridge.spawn_domain(fnIdx, argPtr, argLen, chanPtr, nChans)
316 b.emitBridgeCall("spawn_domain", []llvm.Value{
317 llvm.ConstInt(i32, uint64(fnIdx), false),
318 argPtr, argLen,
319 chanPtr, nChansVal,
320 }, b.ctx.VoidType())
321
322 return llvm.Undef(b.uintptrType), nil
323 }
324
325 // emitBridgeCall emits a call to a wasm import from the "bridge" module.
326 // The function declaration is created on first use.
327 func (b *builder) emitBridgeCall(name string, args []llvm.Value, retType llvm.Type) llvm.Value {
328 return emitBridgeCallCtx(b.compilerContext, b.Builder, name, args, retType)
329 }
330
331 func emitBridgeCallCtx(c *compilerContext, builder llvm.Builder, name string, args []llvm.Value, retType llvm.Type) llvm.Value {
332 internalName := "__bridge_" + name
333 argTypes := make([]llvm.Type, len(args))
334 for i, a := range args {
335 argTypes[i] = a.Type()
336 }
337 fnType := llvm.FunctionType(retType, argTypes, false)
338 fn := c.mod.NamedFunction(internalName)
339 if fn.IsNil() {
340 fn = llvm.AddFunction(c.mod, internalName, fnType)
341 fn.AddFunctionAttr(c.ctx.CreateStringAttribute("wasm-import-module", "bridge"))
342 fn.AddFunctionAttr(c.ctx.CreateStringAttribute("wasm-import-name", name))
343 fn.SetLinkage(llvm.ExternalLinkage)
344 }
345 if retType == c.ctx.VoidType() {
346 builder.CreateCall(fnType, fn, args, "")
347 return llvm.Value{}
348 }
349 return builder.CreateCall(fnType, fn, args, "bridge."+name)
350 }
351
352 // setupWasmSpawnTargetChannels pre-populates b.sabChannels for channel
353 // parameters of spawn target functions. Called at the start of createFunction
354 // for every function in wasm mode.
355 func (b *builder) setupWasmSpawnTargetChannels() {
356 if !b.isWasmTarget() || b.fn == nil {
357 return
358 }
359 target := b.wasmSpawnTargetFor(b.fn)
360 if target == nil {
361 return
362 }
363 if b.sabChannels == nil {
364 b.sabChannels = make(map[ssa.Value]llvm.Value)
365 }
366 i32 := b.ctx.Int32Type()
367 for _, paramIdx := range target.chanParamIdxs {
368 if paramIdx >= len(b.fn.Params) {
369 continue
370 }
371 param := b.fn.Params[paramIdx]
372 llvmParam, ok := b.locals[param]
373 if !ok {
374 continue
375 }
376 // The param holds the int32 SAB handle stored as a pointer (inttoptr in __spawn_entry).
377 // ptrtoint recovers the int32 handle.
378 handle := b.CreatePtrToInt(llvmParam, i32, "sab.handle")
379 b.sabChannels[param] = handle
380 }
381 }
382
383 // wasmSpawnTargetFor returns the wasmSpawnTarget descriptor if fn is a spawn
384 // target in the current program, nil otherwise.
385 func (b *builder) wasmSpawnTargetFor(fn *ssa.Function) *wasmSpawnTarget {
386 if _, ok := b.wasmSpawnIndex[fn]; !ok {
387 return nil
388 }
389 for _, t := range b.wasmSpawnTargets {
390 if t.fn == fn {
391 return t
392 }
393 }
394 return nil
395 }
396
397 // emitWasmSpawnEntry emits the __spawn_entry export after all package
398 // functions have been compiled. This function is the Worker entry point
399 // called by wasm-worker-host.mjs instead of _start.
400 //
401 // Signature: void __spawn_entry(fnIdx i32, argPtr ptr, argLen i32, chanHandlesPtr ptr, nChans i32)
402 func (c *compilerContext) emitWasmSpawnEntry(irbuilder llvm.Builder) {
403 if !strings.HasPrefix(c.Triple, "wasm") || c.GOOS != "js" {
404 return
405 }
406 // Only emit __spawn_entry in packages that actually contain spawn calls.
407 // Each package is compiled into a separate LLVM module; packages with no
408 // spawn calls must not define the symbol (link-time multiply-defined error).
409 if len(c.wasmSpawnTargets) == 0 {
410 return
411 }
412
413 i32 := c.ctx.Int32Type()
414 ptrType := c.dataPtrType
415 fnType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{i32, ptrType, i32, ptrType, i32}, false)
416
417 entryFn := llvm.AddFunction(c.mod, "__spawn_entry", fnType)
418 entryFn.AddFunctionAttr(c.ctx.CreateStringAttribute("wasm-export-name", "__spawn_entry"))
419 // wasm-export-name places the function in the wasm export table,
420 // which prevents DCE. No llvm.used needed.
421
422 entryBlock := c.ctx.AddBasicBlock(entryFn, "entry")
423 irbuilder.SetInsertPointAtEnd(entryBlock)
424
425 fnIdx := entryFn.Param(0)
426 argPtr := entryFn.Param(1)
427 argLen := entryFn.Param(2)
428 chanHandlesPtr := entryFn.Param(3)
429 nChans := entryFn.Param(4)
430 _, _ = argLen, nChans // used indirectly through GEPs
431
432 defaultBlock := c.ctx.AddBasicBlock(entryFn, "default")
433 irbuilder.SetInsertPointAtEnd(defaultBlock)
434 irbuilder.CreateUnreachable()
435 irbuilder.SetInsertPointAtEnd(entryBlock)
436
437 sw := irbuilder.CreateSwitch(fnIdx, defaultBlock, len(c.wasmSpawnTargets))
438
439 for _, target := range c.wasmSpawnTargets {
440 caseBlock := c.ctx.AddBasicBlock(entryFn, "case."+target.fn.Name())
441 irbuilder.SetInsertPointAtEnd(caseBlock)
442
443 _, targetLLVM := c.getFunction(target.fn)
444 targetFnType, _ := c.getFunction(target.fn)
445 sig := target.fn.Signature
446
447 var callArgs []llvm.Value
448 var scalarOffset uint64
449 chanIdx := 0
450
451 for i := 0; i < sig.Params().Len(); i++ {
452 param := sig.Params().At(i)
453 llvmType := c.getLLVMType(param.Type())
454
455 if _, isChan := param.Type().Underlying().(*types.Chan); isChan {
456 // Load int32 handle from chanHandlesPtr[chanIdx], inttoptr.
457 handleGEP := irbuilder.CreateGEP(i32, chanHandlesPtr, []llvm.Value{
458 llvm.ConstInt(i32, uint64(chanIdx), false),
459 }, "chan.handle.gep")
460 handle := irbuilder.CreateLoad(i32, handleGEP, "chan.handle")
461 chanPtr := irbuilder.CreateIntToPtr(handle, ptrType, "chan.ptr")
462 callArgs = append(callArgs, chanPtr)
463 chanIdx++
464 } else {
465 sz := c.targetData.TypeAllocSize(llvmType)
466 gep := irbuilder.CreateGEP(c.ctx.Int8Type(), argPtr, []llvm.Value{
467 llvm.ConstInt(i32, scalarOffset, false),
468 }, "arg.gep")
469 val := irbuilder.CreateLoad(llvmType, gep, fmt.Sprintf("arg%d", i))
470 callArgs = append(callArgs, val)
471 scalarOffset += sz
472 }
473 }
474
475 // Non-exported Moxie functions have a trailing context ptr parameter.
476 // Spawn targets are never closures, so pass undef context.
477 info := c.getFunctionInfo(target.fn)
478 if !info.exported {
479 callArgs = append(callArgs, llvm.Undef(ptrType))
480 }
481
482 irbuilder.CreateCall(targetFnType, targetLLVM, callArgs, "")
483 irbuilder.CreateRetVoid()
484
485 sw.AddCase(llvm.ConstInt(i32, uint64(target.fnIdx), false), caseBlock)
486 }
487 }
488
489 func nextPow2u32(n uint32) uint32 {
490 if n <= 1 {
491 return 1
492 }
493 n--
494 n |= n >> 1
495 n |= n >> 2
496 n |= n >> 4
497 n |= n >> 8
498 n |= n >> 16
499 return n + 1
500 }
501