intrinsics.go raw

   1  package compiler
   2  
   3  // This file contains helper functions to create calls to LLVM intrinsics.
   4  
   5  import (
   6  	"go/token"
   7  	"strconv"
   8  	"strings"
   9  
  10  	"tinygo.org/x/go-llvm"
  11  )
  12  
  13  // Define unimplemented intrinsic functions.
  14  //
  15  // Some functions are either normally implemented in Go assembly (like
  16  // sync/atomic functions) or intentionally left undefined to be implemented
  17  // directly in the compiler (like runtime/volatile functions). Either way, look
  18  // for these and implement them if this is the case.
  19  func (b *builder) defineIntrinsicFunction() {
  20  	name := b.fn.RelString(nil)
  21  	switch {
  22  	case name == "runtime.memcpy" || name == "runtime.memmove":
  23  		b.createMemoryCopyImpl()
  24  	case name == "runtime.memzero":
  25  		b.createMemoryZeroImpl()
  26  	case name == "runtime.stacksave":
  27  		b.createStackSaveImpl()
  28  	case name == "runtime.KeepAlive":
  29  		b.createKeepAliveImpl()
  30  	case name == "machine.keepAliveNoEscape":
  31  		b.createMachineKeepAliveImpl()
  32  	case strings.HasPrefix(name, "runtime/volatile.Load"):
  33  		b.createVolatileLoad()
  34  	case strings.HasPrefix(name, "runtime/volatile.Store"):
  35  		b.createVolatileStore()
  36  	case strings.HasPrefix(name, "sync/atomic.") && token.IsExported(b.fn.Name()):
  37  		b.createFunctionStart(true)
  38  		returnValue := b.createAtomicOp(b.fn.Name())
  39  		if !returnValue.IsNil() {
  40  			b.CreateRet(returnValue)
  41  		} else {
  42  			b.CreateRetVoid()
  43  		}
  44  	}
  45  }
  46  
  47  // createMemoryCopyImpl creates a call to a builtin LLVM memcpy or memmove
  48  // function, declaring this function if needed. These calls are treated
  49  // specially by optimization passes possibly resulting in better generated code,
  50  // and will otherwise be lowered to regular libc memcpy/memmove calls.
  51  func (b *builder) createMemoryCopyImpl() {
  52  	b.createFunctionStart(true)
  53  	fnName := "llvm." + b.fn.Name() + ".p0.p0.i" + strconv.Itoa(b.uintptrType.IntTypeWidth())
  54  	llvmFn := b.mod.NamedFunction(fnName)
  55  	if llvmFn.IsNil() {
  56  		fnType := llvm.FunctionType(b.ctx.VoidType(), []llvm.Type{b.dataPtrType, b.dataPtrType, b.uintptrType, b.ctx.Int1Type()}, false)
  57  		llvmFn = llvm.AddFunction(b.mod, fnName, fnType)
  58  	}
  59  	var params []llvm.Value
  60  	for _, param := range b.fn.Params {
  61  		params = append(params, b.getValue(param, getPos(b.fn)))
  62  	}
  63  	params = append(params, llvm.ConstInt(b.ctx.Int1Type(), 0, false))
  64  	b.CreateCall(llvmFn.GlobalValueType(), llvmFn, params, "")
  65  	b.CreateRetVoid()
  66  }
  67  
  68  // createMemoryZeroImpl creates calls to llvm.memset.* to zero a block of
  69  // memory, declaring the function if needed. These calls will be lowered to
  70  // regular libc memset calls if they aren't optimized out in a different way.
  71  func (b *builder) createMemoryZeroImpl() {
  72  	b.createFunctionStart(true)
  73  	llvmFn := b.getMemsetFunc()
  74  	params := []llvm.Value{
  75  		b.getValue(b.fn.Params[0], getPos(b.fn)),
  76  		llvm.ConstInt(b.ctx.Int8Type(), 0, false),
  77  		b.getValue(b.fn.Params[1], getPos(b.fn)),
  78  		llvm.ConstInt(b.ctx.Int1Type(), 0, false),
  79  	}
  80  	b.CreateCall(llvmFn.GlobalValueType(), llvmFn, params, "")
  81  	b.CreateRetVoid()
  82  }
  83  
  84  // createStackSaveImpl creates a call to llvm.stacksave.p0 to read the current
  85  // stack pointer.
  86  func (b *builder) createStackSaveImpl() {
  87  	b.createFunctionStart(true)
  88  	sp := b.readStackPointer()
  89  	b.CreateRet(sp)
  90  }
  91  
  92  // Return the llvm.memset.p0.i8 function declaration.
  93  func (c *compilerContext) getMemsetFunc() llvm.Value {
  94  	fnName := "llvm.memset.p0.i" + strconv.Itoa(c.uintptrType.IntTypeWidth())
  95  	llvmFn := c.mod.NamedFunction(fnName)
  96  	if llvmFn.IsNil() {
  97  		fnType := llvm.FunctionType(c.ctx.VoidType(), []llvm.Type{c.dataPtrType, c.ctx.Int8Type(), c.uintptrType, c.ctx.Int1Type()}, false)
  98  		llvmFn = llvm.AddFunction(c.mod, fnName, fnType)
  99  	}
 100  	return llvmFn
 101  }
 102  
 103  // createKeepAlive creates the runtime.KeepAlive function. It is implemented
 104  // using inline assembly.
 105  func (b *builder) createKeepAliveImpl() {
 106  	b.createFunctionStart(true)
 107  
 108  	// Get the underlying value of the interface value.
 109  	interfaceValue := b.getValue(b.fn.Params[0], getPos(b.fn))
 110  	pointerValue := b.CreateExtractValue(interfaceValue, 1, "")
 111  
 112  	// Create an equivalent of the following C code, which is basically just a
 113  	// nop but ensures the pointerValue is kept alive:
 114  	//
 115  	//     __asm__ __volatile__("" : : "r"(pointerValue))
 116  	//
 117  	// It should be portable to basically everything as the "r" register type
 118  	// exists basically everywhere.
 119  	asmType := llvm.FunctionType(b.ctx.VoidType(), []llvm.Type{b.dataPtrType}, false)
 120  	asmFn := llvm.InlineAsm(asmType, "", "r", true, false, 0, false)
 121  	b.createCall(asmType, asmFn, []llvm.Value{pointerValue}, "")
 122  
 123  	b.CreateRetVoid()
 124  }
 125  
 126  // createAbiEscapeImpl implements the generic internal/abi.Escape function. It
 127  // currently only supports pointer types.
 128  func (b *builder) createAbiEscapeImpl() {
 129  	b.createFunctionStart(true)
 130  
 131  	// The first parameter is assumed to be a pointer. This is checked at the
 132  	// call site of createAbiEscapeImpl.
 133  	pointerValue := b.getValue(b.fn.Params[0], getPos(b.fn))
 134  
 135  	// Create an equivalent of the following C code, which is basically just a
 136  	// nop but ensures the pointerValue is kept alive:
 137  	//
 138  	//     __asm__ __volatile__("" : : "r"(pointerValue))
 139  	//
 140  	// It should be portable to basically everything as the "r" register type
 141  	// exists basically everywhere.
 142  	asmType := llvm.FunctionType(b.dataPtrType, []llvm.Type{b.dataPtrType}, false)
 143  	asmFn := llvm.InlineAsm(asmType, "", "=r,0", true, false, 0, false)
 144  	result := b.createCall(asmType, asmFn, []llvm.Value{pointerValue}, "")
 145  
 146  	b.CreateRet(result)
 147  }
 148  
 149  // Implement machine.keepAliveNoEscape, which makes sure the compiler keeps the
 150  // pointer parameter alive until this point (for GC).
 151  func (b *builder) createMachineKeepAliveImpl() {
 152  	b.createFunctionStart(true)
 153  	pointerValue := b.getValue(b.fn.Params[0], getPos(b.fn))
 154  
 155  	// See createKeepAliveImpl for details.
 156  	asmType := llvm.FunctionType(b.ctx.VoidType(), []llvm.Type{b.dataPtrType}, false)
 157  	asmFn := llvm.InlineAsm(asmType, "", "r", true, false, 0, false)
 158  	b.createCall(asmType, asmFn, []llvm.Value{pointerValue}, "")
 159  
 160  	b.CreateRetVoid()
 161  }
 162  
 163  var mathToLLVMMapping = map[string]string{
 164  	"math.Ceil":  "llvm.ceil.f64",
 165  	"math.Exp":   "llvm.exp.f64",
 166  	"math.Exp2":  "llvm.exp2.f64",
 167  	"math.Floor": "llvm.floor.f64",
 168  	"math.Log":   "llvm.log.f64",
 169  	"math.Sqrt":  "llvm.sqrt.f64",
 170  	"math.Trunc": "llvm.trunc.f64",
 171  }
 172  
 173  // defineMathOp defines a math function body as a call to a LLVM intrinsic,
 174  // instead of the regular Go implementation. This allows LLVM to reason about
 175  // the math operation and (depending on the architecture) allows it to lower the
 176  // operation to very fast floating point instructions. If this is not possible,
 177  // LLVM will emit a call to a libm function that implements the same operation.
 178  //
 179  // One example of an optimization that LLVM can do is to convert
 180  // float32(math.Sqrt(float64(v))) to a 32-bit floating point operation, which is
 181  // beneficial on architectures where 64-bit floating point operations are (much)
 182  // more expensive than 32-bit ones.
 183  func (b *builder) defineMathOp() {
 184  	b.createFunctionStart(true)
 185  	llvmName := mathToLLVMMapping[b.fn.RelString(nil)]
 186  	if llvmName == "" {
 187  		panic("unreachable: unknown math operation") // sanity check
 188  	}
 189  	llvmFn := b.mod.NamedFunction(llvmName)
 190  	if llvmFn.IsNil() {
 191  		// The intrinsic doesn't exist yet, so declare it.
 192  		// At the moment, all supported intrinsics have the form "double
 193  		// foo(double %x)" so we can hardcode the signature here.
 194  		llvmType := llvm.FunctionType(b.ctx.DoubleType(), []llvm.Type{b.ctx.DoubleType()}, false)
 195  		llvmFn = llvm.AddFunction(b.mod, llvmName, llvmType)
 196  	}
 197  	// Create a call to the intrinsic.
 198  	args := make([]llvm.Value, len(b.fn.Params))
 199  	for i, param := range b.fn.Params {
 200  		args[i] = b.getValue(param, getPos(b.fn))
 201  	}
 202  	result := b.CreateCall(llvmFn.GlobalValueType(), llvmFn, args, "")
 203  	b.CreateRet(result)
 204  }
 205  
 206  // Implement most math/bits functions.
 207  //
 208  // This implements all the functions that operate on bits. It does not yet
 209  // implement the arithmetic functions (like bits.Add), which also have LLVM
 210  // intrinsics.
 211  func (b *builder) defineMathBitsIntrinsic() bool {
 212  	if b.fn.Pkg.Pkg.Path() != "math/bits" {
 213  		return false
 214  	}
 215  	name := b.fn.Name()
 216  	switch name {
 217  	case "LeadingZeros", "LeadingZeros8", "LeadingZeros16", "LeadingZeros32", "LeadingZeros64",
 218  		"TrailingZeros", "TrailingZeros8", "TrailingZeros16", "TrailingZeros32", "TrailingZeros64":
 219  		b.createFunctionStart(true)
 220  		param := b.getValue(b.fn.Params[0], b.fn.Pos())
 221  		valueType := param.Type()
 222  		var intrinsicName string
 223  		if strings.HasPrefix(name, "Leading") { // LeadingZeros
 224  			intrinsicName = "llvm.ctlz.i" + strconv.Itoa(valueType.IntTypeWidth())
 225  		} else { // TrailingZeros
 226  			intrinsicName = "llvm.cttz.i" + strconv.Itoa(valueType.IntTypeWidth())
 227  		}
 228  		llvmFn := b.mod.NamedFunction(intrinsicName)
 229  		llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false)
 230  		if llvmFn.IsNil() {
 231  			llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
 232  		}
 233  		result := b.createCall(llvmFnType, llvmFn, []llvm.Value{
 234  			param,
 235  			llvm.ConstInt(b.ctx.Int1Type(), 0, false),
 236  		}, "")
 237  		result = b.createZExtOrTrunc(result, b.intType)
 238  		b.CreateRet(result)
 239  		return true
 240  	case "Len", "Len8", "Len16", "Len32", "Len64":
 241  		// bits.Len can be implemented as:
 242  		//     (unsafe.Sizeof(v) * 8) -  bits.LeadingZeros(n)
 243  		// Not sure why this isn't already done in the standard library, as it
 244  		// is much simpler than a lookup table.
 245  		b.createFunctionStart(true)
 246  		param := b.getValue(b.fn.Params[0], b.fn.Pos())
 247  		valueType := param.Type()
 248  		valueBits := valueType.IntTypeWidth()
 249  		intrinsicName := "llvm.ctlz.i" + strconv.Itoa(valueBits)
 250  		llvmFn := b.mod.NamedFunction(intrinsicName)
 251  		llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, b.ctx.Int1Type()}, false)
 252  		if llvmFn.IsNil() {
 253  			llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
 254  		}
 255  		result := b.createCall(llvmFnType, llvmFn, []llvm.Value{
 256  			param,
 257  			llvm.ConstInt(b.ctx.Int1Type(), 0, false),
 258  		}, "")
 259  		result = b.createZExtOrTrunc(result, b.intType)
 260  		maxLen := llvm.ConstInt(b.intType, uint64(valueBits), false) // number of bits in the value
 261  		result = b.CreateSub(maxLen, result, "")
 262  		b.CreateRet(result)
 263  		return true
 264  	case "OnesCount", "OnesCount8", "OnesCount16", "OnesCount32", "OnesCount64":
 265  		b.createFunctionStart(true)
 266  		param := b.getValue(b.fn.Params[0], b.fn.Pos())
 267  		valueType := param.Type()
 268  		intrinsicName := "llvm.ctpop.i" + strconv.Itoa(valueType.IntTypeWidth())
 269  		llvmFn := b.mod.NamedFunction(intrinsicName)
 270  		llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false)
 271  		if llvmFn.IsNil() {
 272  			llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
 273  		}
 274  		result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "")
 275  		result = b.createZExtOrTrunc(result, b.intType)
 276  		b.CreateRet(result)
 277  		return true
 278  	case "Reverse", "Reverse8", "Reverse16", "Reverse32", "Reverse64",
 279  		"ReverseBytes", "ReverseBytes16", "ReverseBytes32", "ReverseBytes64":
 280  		b.createFunctionStart(true)
 281  		param := b.getValue(b.fn.Params[0], b.fn.Pos())
 282  		valueType := param.Type()
 283  		var intrinsicName string
 284  		if strings.HasPrefix(name, "ReverseBytes") {
 285  			intrinsicName = "llvm.bswap.i" + strconv.Itoa(valueType.IntTypeWidth())
 286  		} else { // Reverse
 287  			intrinsicName = "llvm.bitreverse.i" + strconv.Itoa(valueType.IntTypeWidth())
 288  		}
 289  		llvmFn := b.mod.NamedFunction(intrinsicName)
 290  		llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType}, false)
 291  		if llvmFn.IsNil() {
 292  			llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
 293  		}
 294  		result := b.createCall(llvmFnType, llvmFn, []llvm.Value{param}, "")
 295  		b.CreateRet(result)
 296  		return true
 297  	case "RotateLeft", "RotateLeft8", "RotateLeft16", "RotateLeft32", "RotateLeft64":
 298  		// Warning: the documentation says these functions must be constant time.
 299  		// I do not think LLVM guarantees this, but there's a good chance LLVM
 300  		// already recognized the rotate instruction so it probably won't get
 301  		// any _worse_ by implementing these rotate functions.
 302  		b.createFunctionStart(true)
 303  		x := b.getValue(b.fn.Params[0], b.fn.Pos())
 304  		k := b.getValue(b.fn.Params[1], b.fn.Pos())
 305  		valueType := x.Type()
 306  		intrinsicName := "llvm.fshl.i" + strconv.Itoa(valueType.IntTypeWidth())
 307  		llvmFn := b.mod.NamedFunction(intrinsicName)
 308  		llvmFnType := llvm.FunctionType(valueType, []llvm.Type{valueType, valueType, valueType}, false)
 309  		if llvmFn.IsNil() {
 310  			llvmFn = llvm.AddFunction(b.mod, intrinsicName, llvmFnType)
 311  		}
 312  		k = b.createZExtOrTrunc(k, valueType)
 313  		result := b.createCall(llvmFnType, llvmFn, []llvm.Value{x, x, k}, "")
 314  		b.CreateRet(result)
 315  		return true
 316  	default:
 317  		return false
 318  	}
 319  }
 320