wasmspawn.go raw

   1  package compiler
   2  
   3  // spawn_wasm.go — wasm spawn path.
   4  //
   5  // In wasm (GOOS=js GOARCH=wasm), spawn() creates a new Worker instead of
   6  // forking a process. Channels at spawn boundaries are SAB-backed; all other
   7  // channel operations use the regular runtime. Scalar args are packed into a
   8  // linear-memory buffer and deserialized by __spawn_entry on the child side.
   9  //
  10  // Architecture:
  11  //  Parent side (createWasmSpawn):
  12  //    1. For each chan arg: bridge.channel_create → int32 SAB handle stored in
  13  //       b.sabChannels keyed by the SSA channel value.
  14  //    2. Pack scalar args into a stack alloca.
  15  //    3. Pack SAB handles into a second alloca.
  16  //    4. Call bridge.spawn_domain(fnIdx, argPtr, argLen, chanHandlesPtr, nChans).
  17  //
  18  //  Child side (__spawn_entry, emitted at link time by emitWasmSpawnEntry):
  19  //    A switch over fnIdx dispatches to a per-target case that:
  20  //    1. Deserializes scalars from argPtr.
  21  //    2. Loads int32 handles from chanHandlesPtr, inttoptr → dataPtrType.
  22  //    3. Calls the target function directly.
  23  //    The target function's channel params are added to sabChannels in its
  24  //    builder (via setupWasmSpawnTargetChannels) so send/recv/close use bridge.
  25  
  26  import (
  27  	"fmt"
  28  	"go/types"
  29  	"strings"
  30  
  31  	"golang.org/x/tools/go/ssa"
  32  	"tinygo.org/x/go-llvm"
  33  )
  34  
  35  // wasmSpawnTarget records a spawn-eligible function and its dispatch index.
  36  type wasmSpawnTarget struct {
  37  	fn              *ssa.Function
  38  	fnIdx           int32
  39  	chanParamIdxs   []int // indices into fn.Params that are channel types
  40  	scalarParamIdxs []int // indices that are non-channel (scalar/struct)
  41  }
  42  
  43  // isWasmTarget returns true when building for GOOS=js GOARCH=wasm.
  44  func (b *builder) isWasmTarget() bool {
  45  	return strings.HasPrefix(b.Triple, "wasm") && b.GOOS == "js"
  46  }
  47  
  48  // scanWasmSpawnTargets pre-scans the package for spawn() calls, collecting
  49  // all statically-known target functions and assigning each a stable fnIdx.
  50  // Must be called after ssaPkg.Build() and before createPackage.
  51  func (c *compilerContext) scanWasmSpawnTargets(pkg *ssa.Package) {
  52  	if !strings.HasPrefix(c.Triple, "wasm") || c.GOOS != "js" {
  53  		return
  54  	}
  55  	seen := make(map[*ssa.Function]bool)
  56  	for _, mem := range pkg.Members {
  57  		fn, ok := mem.(*ssa.Function)
  58  		if !ok {
  59  			continue
  60  		}
  61  		c.scanFnForSpawn(fn, seen)
  62  	}
  63  }
  64  
  65  func (c *compilerContext) scanFnForSpawn(fn *ssa.Function, seen map[*ssa.Function]bool) {
  66  	for _, blk := range fn.Blocks {
  67  		for _, instr := range blk.Instrs {
  68  			call, ok := instr.(*ssa.Call)
  69  			if !ok {
  70  				continue
  71  			}
  72  			bi, ok := call.Call.Value.(*ssa.Builtin)
  73  			if !ok || bi.Name() != "spawn" {
  74  				continue
  75  			}
  76  			target := extractWasmSpawnTargetFn(call.Call.Args)
  77  			if target == nil || seen[target] {
  78  				continue
  79  			}
  80  			seen[target] = true
  81  			sig := target.Signature
  82  			var chanIdxs, scalarIdxs []int
  83  			for i := 0; i < sig.Params().Len(); i++ {
  84  				if _, isChan := sig.Params().At(i).Type().Underlying().(*types.Chan); isChan {
  85  					chanIdxs = append(chanIdxs, i)
  86  				} else {
  87  					scalarIdxs = append(scalarIdxs, i)
  88  				}
  89  			}
  90  			idx := int32(len(c.wasmSpawnTargets))
  91  			c.wasmSpawnTargets = append(c.wasmSpawnTargets, &wasmSpawnTarget{
  92  				fn:              target,
  93  				fnIdx:           idx,
  94  				chanParamIdxs:   chanIdxs,
  95  				scalarParamIdxs: scalarIdxs,
  96  			})
  97  			if c.wasmSpawnIndex == nil {
  98  				c.wasmSpawnIndex = make(map[*ssa.Function]int32)
  99  			}
 100  			c.wasmSpawnIndex[target] = idx
 101  		}
 102  	}
 103  }
 104  
 105  func extractWasmSpawnTargetFn(args []ssa.Value) *ssa.Function {
 106  	start := 0
 107  	if len(args) > 0 {
 108  		if _, ok := extractTransportString(args[0]); ok {
 109  			start = 1
 110  		}
 111  	}
 112  	if start >= len(args) {
 113  		return nil
 114  	}
 115  	switch v := args[start].(type) {
 116  	case *ssa.Function:
 117  		return v
 118  	case *ssa.MakeClosure:
 119  		if f, ok := v.Fn.(*ssa.Function); ok {
 120  			return f
 121  		}
 122  	}
 123  	return nil
 124  }
 125  
 126  // validateWasmSpawnArg checks that a type is safe across the wasm spawn
 127  // boundary. No Codec requirement. Rejects pointers, functions, interfaces.
 128  // Top-level channel types are allowed (SAB-backed); nested channels are not.
 129  func validateWasmSpawnArg(t types.Type, argIdx int) error {
 130  	if _, ok := t.Underlying().(*types.Chan); ok {
 131  		return nil
 132  	}
 133  	return validateWasmSpawnElem(t, argIdx)
 134  }
 135  
 136  func validateWasmSpawnElem(t types.Type, argIdx int) error {
 137  	switch ut := t.Underlying().(type) {
 138  	case *types.Basic:
 139  		return nil
 140  	case *types.Pointer:
 141  		return fmt.Errorf("spawn: argument %d is a pointer — not allowed across wasm spawn boundary", argIdx+1)
 142  	case *types.Signature:
 143  		return fmt.Errorf("spawn: argument %d is a function value — not allowed across wasm spawn boundary", argIdx+1)
 144  	case *types.Interface:
 145  		return fmt.Errorf("spawn: argument %d is an interface — not allowed across wasm spawn boundary", argIdx+1)
 146  	case *types.Chan:
 147  		return fmt.Errorf("spawn: argument %d contains a nested channel — channels must be top-level params", argIdx+1)
 148  	case *types.Slice:
 149  		if basic, ok := ut.Elem().Underlying().(*types.Basic); ok && basic.Kind() == types.Byte {
 150  			return nil // []byte / string
 151  		}
 152  		return validateWasmSpawnElem(ut.Elem(), argIdx)
 153  	case *types.Struct:
 154  		for i := 0; i < ut.NumFields(); i++ {
 155  			if err := validateWasmSpawnElem(ut.Field(i).Type(), argIdx); err != nil {
 156  				return err
 157  			}
 158  		}
 159  		return nil
 160  	case *types.Array:
 161  		return validateWasmSpawnElem(ut.Elem(), argIdx)
 162  	}
 163  	return fmt.Errorf("spawn: argument %d type %s not allowed across wasm spawn boundary", argIdx+1, t)
 164  }
 165  
 166  // createWasmSpawn handles spawn() in the wasm target. It emits:
 167  //   1. bridge.channel_create for each channel arg → SAB handle in b.sabChannels.
 168  //   2. A packed scalar arg buffer.
 169  //   3. A packed chan handle array.
 170  //   4. bridge.spawn_domain(fnIdx, argPtr, argLen, chanHandlesPtr, nChans).
 171  func (b *builder) createWasmSpawn(instr *ssa.CallCommon) (llvm.Value, error) {
 172  	argStart := 0
 173  	if len(instr.Args) > 0 {
 174  		if _, ok := extractTransportString(instr.Args[0]); ok {
 175  			argStart = 1
 176  		}
 177  	}
 178  	if argStart >= len(instr.Args) {
 179  		b.addError(instr.Pos(), "spawn: requires a function argument")
 180  		return llvm.Value{}, nil
 181  	}
 182  
 183  	fnArg := instr.Args[argStart]
 184  	var targetFn *ssa.Function
 185  	switch v := fnArg.(type) {
 186  	case *ssa.Function:
 187  		targetFn = v
 188  	case *ssa.MakeClosure:
 189  		if f, ok := v.Fn.(*ssa.Function); ok {
 190  			targetFn = f
 191  		}
 192  	}
 193  	if targetFn == nil {
 194  		b.addError(instr.Pos(), "spawn: first argument must be a static top-level function")
 195  		return llvm.Value{}, nil
 196  	}
 197  	if targetFn.Package() != nil && b.fn.Package() != nil &&
 198  		targetFn.Package().Pkg.Path() != b.fn.Package().Pkg.Path() {
 199  		b.addError(instr.Pos(), "spawn: wasm spawn target must be in the same package as the caller")
 200  		return llvm.Value{}, nil
 201  	}
 202  
 203  	fnIdx, ok := b.wasmSpawnIndex[targetFn]
 204  	if !ok {
 205  		b.addError(instr.Pos(), "spawn: internal — function not in wasm dispatch table (scan missed it)")
 206  		return llvm.Value{}, nil
 207  	}
 208  
 209  	concreteArgs := instr.Args[argStart+1:]
 210  	sig := targetFn.Signature
 211  	if len(concreteArgs) != sig.Params().Len() {
 212  		b.addError(instr.Pos(), fmt.Sprintf("spawn: %s expects %d args, got %d",
 213  			targetFn.Name(), sig.Params().Len(), len(concreteArgs)))
 214  		return llvm.Value{}, nil
 215  	}
 216  
 217  	for i, arg := range concreteArgs {
 218  		if err := validateWasmSpawnArg(sig.Params().At(i).Type(), i); err != nil {
 219  			b.addError(instr.Pos(), err.Error())
 220  			return llvm.Value{}, nil
 221  		}
 222  		if mc := traceToMakeChan(arg); mc != nil {
 223  			if cv, ok := mc.Size.(*ssa.Const); ok {
 224  				if cv.Value != nil {
 225  					// positive size = buffered; warn but allow
 226  					_ = cv
 227  				}
 228  			}
 229  		}
 230  	}
 231  
 232  	i32 := b.ctx.Int32Type()
 233  
 234  	// --- channel args: create SAB handles ---
 235  	var chanHandles []llvm.Value
 236  	for _, arg := range concreteArgs {
 237  		if _, isChan := arg.Type().Underlying().(*types.Chan); !isChan {
 238  			continue
 239  		}
 240  		elemType := arg.Type().Underlying().(*types.Chan).Elem()
 241  		elemSz := b.targetData.TypeAllocSize(b.getLLVMType(elemType))
 242  		slotSize := nextPow2u32(uint32(elemSz) + 12)
 243  		if slotSize < 64 {
 244  			slotSize = 64
 245  		}
 246  		slotCount := uint32(16)
 247  		handle := b.emitBridgeCall("channel_create", []llvm.Value{
 248  			llvm.ConstInt(i32, uint64(slotSize), false),
 249  			llvm.ConstInt(i32, uint64(slotCount), false),
 250  		}, i32)
 251  		chanHandles = append(chanHandles, handle)
 252  		if b.sabChannels == nil {
 253  			b.sabChannels = make(map[ssa.Value]llvm.Value)
 254  		}
 255  		b.sabChannels[arg] = handle
 256  	}
 257  
 258  	// --- scalar args: pack into a stack buffer ---
 259  	var scalarVals []llvm.Value
 260  	var scalarTypes []llvm.Type
 261  	var totalScalarBytes uint64
 262  	for i, arg := range concreteArgs {
 263  		if _, isChan := sig.Params().At(i).Type().Underlying().(*types.Chan); isChan {
 264  			continue
 265  		}
 266  		v := b.getValue(arg, instr.Pos())
 267  		llvmType := b.getLLVMType(arg.Type())
 268  		scalarVals = append(scalarVals, v)
 269  		scalarTypes = append(scalarTypes, llvmType)
 270  		totalScalarBytes += b.targetData.TypeAllocSize(llvmType)
 271  	}
 272  
 273  	argPtr := llvm.ConstNull(b.dataPtrType)
 274  	argLen := llvm.ConstInt(i32, 0, false)
 275  	if totalScalarBytes > 0 {
 276  		bufType := llvm.ArrayType(b.ctx.Int8Type(), int(totalScalarBytes))
 277  		bufAlloca, _ := b.createTemporaryAlloca(bufType, "spawn.args")
 278  		var off uint64
 279  		for idx, v := range scalarVals {
 280  			sz := b.targetData.TypeAllocSize(scalarTypes[idx])
 281  			gep := b.CreateGEP(bufType, bufAlloca, []llvm.Value{
 282  				llvm.ConstInt(i32, 0, false),
 283  				llvm.ConstInt(i32, off, false),
 284  			}, "spawn.arg.slot")
 285  			b.CreateStore(v, gep)
 286  			off += sz
 287  		}
 288  		argPtr = b.CreateGEP(bufType, bufAlloca, []llvm.Value{
 289  			llvm.ConstInt(i32, 0, false),
 290  			llvm.ConstInt(i32, 0, false),
 291  		}, "spawn.argptr")
 292  		argLen = llvm.ConstInt(i32, totalScalarBytes, false)
 293  	}
 294  
 295  	// --- channel handle array ---
 296  	nChans := len(chanHandles)
 297  	chanPtr := llvm.ConstNull(b.dataPtrType)
 298  	nChansVal := llvm.ConstInt(i32, uint64(nChans), false)
 299  	if nChans > 0 {
 300  		handleArrType := llvm.ArrayType(i32, nChans)
 301  		handleAlloca, _ := b.createTemporaryAlloca(handleArrType, "spawn.chans")
 302  		for hIdx, h := range chanHandles {
 303  			gep := b.CreateGEP(handleArrType, handleAlloca, []llvm.Value{
 304  				llvm.ConstInt(i32, 0, false),
 305  				llvm.ConstInt(i32, uint64(hIdx), false),
 306  			}, "")
 307  			b.CreateStore(h, gep)
 308  		}
 309  		chanPtr = b.CreateGEP(handleArrType, handleAlloca, []llvm.Value{
 310  			llvm.ConstInt(i32, 0, false),
 311  			llvm.ConstInt(i32, 0, false),
 312  		}, "spawn.chanptr")
 313  	}
 314  
 315  	// bridge.spawn_domain(fnIdx, argPtr, argLen, chanPtr, nChans)
 316  	b.emitBridgeCall("spawn_domain", []llvm.Value{
 317  		llvm.ConstInt(i32, uint64(fnIdx), false),
 318  		argPtr, argLen,
 319  		chanPtr, nChansVal,
 320  	}, b.ctx.VoidType())
 321  
 322  	return llvm.Undef(b.uintptrType), nil
 323  }
 324  
 325  // emitBridgeCall emits a call to a wasm import from the "bridge" module.
 326  // The function declaration is created on first use.
 327  func (b *builder) emitBridgeCall(name string, args []llvm.Value, retType llvm.Type) llvm.Value {
 328  	return emitBridgeCallCtx(b.compilerContext, b.Builder, name, args, retType)
 329  }
 330  
 331  func emitBridgeCallCtx(c *compilerContext, builder llvm.Builder, name string, args []llvm.Value, retType llvm.Type) llvm.Value {
 332  	internalName := "__bridge_" + name
 333  	argTypes := make([]llvm.Type, len(args))
 334  	for i, a := range args {
 335  		argTypes[i] = a.Type()
 336  	}
 337  	fnType := llvm.FunctionType(retType, argTypes, false)
 338  	fn := c.mod.NamedFunction(internalName)
 339  	if fn.IsNil() {
 340  		fn = llvm.AddFunction(c.mod, internalName, fnType)
 341  		fn.AddFunctionAttr(c.ctx.CreateStringAttribute("wasm-import-module", "bridge"))
 342  		fn.AddFunctionAttr(c.ctx.CreateStringAttribute("wasm-import-name", name))
 343  		fn.SetLinkage(llvm.ExternalLinkage)
 344  	}
 345  	if retType == c.ctx.VoidType() {
 346  		builder.CreateCall(fnType, fn, args, "")
 347  		return llvm.Value{}
 348  	}
 349  	return builder.CreateCall(fnType, fn, args, "bridge."+name)
 350  }
 351  
 352  // setupWasmSpawnTargetChannels pre-populates b.sabChannels for channel
 353  // parameters of spawn target functions. Called at the start of createFunction
 354  // for every function in wasm mode.
 355  func (b *builder) setupWasmSpawnTargetChannels() {
 356  	if !b.isWasmTarget() || b.fn == nil {
 357  		return
 358  	}
 359  	target := b.wasmSpawnTargetFor(b.fn)
 360  	if target == nil {
 361  		return
 362  	}
 363  	if b.sabChannels == nil {
 364  		b.sabChannels = make(map[ssa.Value]llvm.Value)
 365  	}
 366  	i32 := b.ctx.Int32Type()
 367  	for _, paramIdx := range target.chanParamIdxs {
 368  		if paramIdx >= len(b.fn.Params) {
 369  			continue
 370  		}
 371  		param := b.fn.Params[paramIdx]
 372  		llvmParam, ok := b.locals[param]
 373  		if !ok {
 374  			continue
 375  		}
 376  		// The param holds the int32 SAB handle stored as a pointer (inttoptr in __spawn_entry).
 377  		// ptrtoint recovers the int32 handle.
 378  		handle := b.CreatePtrToInt(llvmParam, i32, "sab.handle")
 379  		b.sabChannels[param] = handle
 380  	}
 381  }
 382  
 383  // wasmSpawnTargetFor returns the wasmSpawnTarget descriptor if fn is a spawn
 384  // target in the current program, nil otherwise.
 385  func (b *builder) wasmSpawnTargetFor(fn *ssa.Function) *wasmSpawnTarget {
 386  	if _, ok := b.wasmSpawnIndex[fn]; !ok {
 387  		return nil
 388  	}
 389  	for _, t := range b.wasmSpawnTargets {
 390  		if t.fn == fn {
 391  			return t
 392  		}
 393  	}
 394  	return nil
 395  }
 396  
 397  // emitWasmSpawnEntry emits the __spawn_entry export after all package
 398  // functions have been compiled. This function is the Worker entry point
 399  // called by wasm-worker-host.mjs instead of _start.
 400  //
 401  // Signature: void __spawn_entry(fnIdx i32, argPtr ptr, argLen i32, chanHandlesPtr ptr, nChans i32)
 402  func (c *compilerContext) emitWasmSpawnEntry(irbuilder llvm.Builder) {
 403  	if !strings.HasPrefix(c.Triple, "wasm") || c.GOOS != "js" {
 404  		return
 405  	}
 406  	// Only emit __spawn_entry in packages that actually contain spawn calls.
 407  	// Each package is compiled into a separate LLVM module; packages with no
 408  	// spawn calls must not define the symbol (link-time multiply-defined error).
 409  	if len(c.wasmSpawnTargets) == 0 {
 410  		return
 411  	}
 412  
 413  	i32 := c.ctx.Int32Type()
 414  	ptrType := c.dataPtrType
 415  	fnType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{i32, ptrType, i32, ptrType, i32}, false)
 416  
 417  	entryFn := llvm.AddFunction(c.mod, "__spawn_entry", fnType)
 418  	entryFn.AddFunctionAttr(c.ctx.CreateStringAttribute("wasm-export-name", "__spawn_entry"))
 419  	// wasm-export-name places the function in the wasm export table,
 420  	// which prevents DCE. No llvm.used needed.
 421  
 422  	entryBlock := c.ctx.AddBasicBlock(entryFn, "entry")
 423  	irbuilder.SetInsertPointAtEnd(entryBlock)
 424  
 425  	fnIdx := entryFn.Param(0)
 426  	argPtr := entryFn.Param(1)
 427  	argLen := entryFn.Param(2)
 428  	chanHandlesPtr := entryFn.Param(3)
 429  	nChans := entryFn.Param(4)
 430  	_, _ = argLen, nChans // used indirectly through GEPs
 431  
 432  	defaultBlock := c.ctx.AddBasicBlock(entryFn, "default")
 433  	irbuilder.SetInsertPointAtEnd(defaultBlock)
 434  	irbuilder.CreateUnreachable()
 435  	irbuilder.SetInsertPointAtEnd(entryBlock)
 436  
 437  	sw := irbuilder.CreateSwitch(fnIdx, defaultBlock, len(c.wasmSpawnTargets))
 438  
 439  	for _, target := range c.wasmSpawnTargets {
 440  		caseBlock := c.ctx.AddBasicBlock(entryFn, "case."+target.fn.Name())
 441  		irbuilder.SetInsertPointAtEnd(caseBlock)
 442  
 443  		_, targetLLVM := c.getFunction(target.fn)
 444  		targetFnType, _ := c.getFunction(target.fn)
 445  		sig := target.fn.Signature
 446  
 447  		var callArgs []llvm.Value
 448  		var scalarOffset uint64
 449  		chanIdx := 0
 450  
 451  		for i := 0; i < sig.Params().Len(); i++ {
 452  			param := sig.Params().At(i)
 453  			llvmType := c.getLLVMType(param.Type())
 454  
 455  			if _, isChan := param.Type().Underlying().(*types.Chan); isChan {
 456  				// Load int32 handle from chanHandlesPtr[chanIdx], inttoptr.
 457  				handleGEP := irbuilder.CreateGEP(i32, chanHandlesPtr, []llvm.Value{
 458  					llvm.ConstInt(i32, uint64(chanIdx), false),
 459  				}, "chan.handle.gep")
 460  				handle := irbuilder.CreateLoad(i32, handleGEP, "chan.handle")
 461  				chanPtr := irbuilder.CreateIntToPtr(handle, ptrType, "chan.ptr")
 462  				callArgs = append(callArgs, chanPtr)
 463  				chanIdx++
 464  			} else {
 465  				sz := c.targetData.TypeAllocSize(llvmType)
 466  				gep := irbuilder.CreateGEP(c.ctx.Int8Type(), argPtr, []llvm.Value{
 467  					llvm.ConstInt(i32, scalarOffset, false),
 468  				}, "arg.gep")
 469  				val := irbuilder.CreateLoad(llvmType, gep, fmt.Sprintf("arg%d", i))
 470  				callArgs = append(callArgs, val)
 471  				scalarOffset += sz
 472  			}
 473  		}
 474  
 475  		// Non-exported Moxie functions have a trailing context ptr parameter.
 476  		// Spawn targets are never closures, so pass undef context.
 477  		info := c.getFunctionInfo(target.fn)
 478  		if !info.exported {
 479  			callArgs = append(callArgs, llvm.Undef(ptrType))
 480  		}
 481  
 482  		irbuilder.CreateCall(targetFnType, targetLLVM, callArgs, "")
 483  		irbuilder.CreateRetVoid()
 484  
 485  		sw.AddCase(llvm.ConstInt(i32, uint64(target.fnIdx), false), caseBlock)
 486  	}
 487  }
 488  
 489  func nextPow2u32(n uint32) uint32 {
 490  	if n <= 1 {
 491  		return 1
 492  	}
 493  	n--
 494  	n |= n >> 1
 495  	n |= n >> 2
 496  	n |= n >> 4
 497  	n |= n >> 8
 498  	n |= n >> 16
 499  	return n + 1
 500  }
 501