callback.go raw

   1  // Copyright (C) 2019 Yasuhiro Matsumoto <mattn.jp@gmail.com>.
   2  //
   3  // Use of this source code is governed by an MIT-style
   4  // license that can be found in the LICENSE file.
   5  
   6  package sqlite3
   7  
   8  // You can't export a Go function to C and have definitions in the C
   9  // preamble in the same file, so we have to have callbackTrampoline in
  10  // its own file. Because we need a separate file anyway, the support
  11  // code for SQLite custom functions is in here.
  12  
  13  /*
  14  #ifndef USE_LIBSQLITE3
  15  #include "sqlite3-binding.h"
  16  #else
  17  #include <sqlite3.h>
  18  #endif
  19  #include <stdlib.h>
  20  
  21  void _sqlite3_result_text(sqlite3_context* ctx, const char* s);
  22  void _sqlite3_result_blob(sqlite3_context* ctx, const void* b, int l);
  23  */
  24  import "C"
  25  
  26  import (
  27  	"errors"
  28  	"fmt"
  29  	"math"
  30  	"reflect"
  31  	"sync"
  32  	"unsafe"
  33  )
  34  
  35  //export callbackTrampoline
  36  func callbackTrampoline(ctx *C.sqlite3_context, argc int, argv **C.sqlite3_value) {
  37  	args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:argc:argc]
  38  	fi := lookupHandle(C.sqlite3_user_data(ctx)).(*functionInfo)
  39  	fi.Call(ctx, args)
  40  }
  41  
  42  //export stepTrampoline
  43  func stepTrampoline(ctx *C.sqlite3_context, argc C.int, argv **C.sqlite3_value) {
  44  	args := (*[(math.MaxInt32 - 1) / unsafe.Sizeof((*C.sqlite3_value)(nil))]*C.sqlite3_value)(unsafe.Pointer(argv))[:int(argc):int(argc)]
  45  	ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
  46  	ai.Step(ctx, args)
  47  }
  48  
  49  //export doneTrampoline
  50  func doneTrampoline(ctx *C.sqlite3_context) {
  51  	ai := lookupHandle(C.sqlite3_user_data(ctx)).(*aggInfo)
  52  	ai.Done(ctx)
  53  }
  54  
  55  //export compareTrampoline
  56  func compareTrampoline(handlePtr unsafe.Pointer, la C.int, a *C.char, lb C.int, b *C.char) C.int {
  57  	cmp := lookupHandle(handlePtr).(func(string, string) int)
  58  	return C.int(cmp(C.GoStringN(a, la), C.GoStringN(b, lb)))
  59  }
  60  
  61  //export commitHookTrampoline
  62  func commitHookTrampoline(handle unsafe.Pointer) int {
  63  	callback := lookupHandle(handle).(func() int)
  64  	return callback()
  65  }
  66  
  67  //export rollbackHookTrampoline
  68  func rollbackHookTrampoline(handle unsafe.Pointer) {
  69  	callback := lookupHandle(handle).(func())
  70  	callback()
  71  }
  72  
  73  //export updateHookTrampoline
  74  func updateHookTrampoline(handle unsafe.Pointer, op int, db *C.char, table *C.char, rowid int64) {
  75  	callback := lookupHandle(handle).(func(int, string, string, int64))
  76  	callback(op, C.GoString(db), C.GoString(table), rowid)
  77  }
  78  
  79  //export authorizerTrampoline
  80  func authorizerTrampoline(handle unsafe.Pointer, op int, arg1 *C.char, arg2 *C.char, arg3 *C.char) int {
  81  	callback := lookupHandle(handle).(func(int, string, string, string) int)
  82  	return callback(op, C.GoString(arg1), C.GoString(arg2), C.GoString(arg3))
  83  }
  84  
  85  //export preUpdateHookTrampoline
  86  func preUpdateHookTrampoline(handle unsafe.Pointer, dbHandle uintptr, op int, db *C.char, table *C.char, oldrowid int64, newrowid int64) {
  87  	hval := lookupHandleVal(handle)
  88  	data := SQLitePreUpdateData{
  89  		Conn:         hval.db,
  90  		Op:           op,
  91  		DatabaseName: C.GoString(db),
  92  		TableName:    C.GoString(table),
  93  		OldRowID:     oldrowid,
  94  		NewRowID:     newrowid,
  95  	}
  96  	callback := hval.val.(func(SQLitePreUpdateData))
  97  	callback(data)
  98  }
  99  
 100  // Use handles to avoid passing Go pointers to C.
 101  type handleVal struct {
 102  	db  *SQLiteConn
 103  	val any
 104  }
 105  
 106  var handleLock sync.Mutex
 107  var handleVals = make(map[unsafe.Pointer]handleVal)
 108  
 109  func newHandle(db *SQLiteConn, v any) unsafe.Pointer {
 110  	handleLock.Lock()
 111  	defer handleLock.Unlock()
 112  	val := handleVal{db: db, val: v}
 113  	var p unsafe.Pointer = C.malloc(C.size_t(1))
 114  	if p == nil {
 115  		panic("can't allocate 'cgo-pointer hack index pointer': ptr == nil")
 116  	}
 117  	handleVals[p] = val
 118  	return p
 119  }
 120  
 121  func lookupHandleVal(handle unsafe.Pointer) handleVal {
 122  	handleLock.Lock()
 123  	defer handleLock.Unlock()
 124  	return handleVals[handle]
 125  }
 126  
 127  func lookupHandle(handle unsafe.Pointer) any {
 128  	return lookupHandleVal(handle).val
 129  }
 130  
 131  func deleteHandles(db *SQLiteConn) {
 132  	handleLock.Lock()
 133  	defer handleLock.Unlock()
 134  	for handle, val := range handleVals {
 135  		if val.db == db {
 136  			delete(handleVals, handle)
 137  			C.free(handle)
 138  		}
 139  	}
 140  }
 141  
 142  // This is only here so that tests can refer to it.
 143  type callbackArgRaw C.sqlite3_value
 144  
 145  type callbackArgConverter func(*C.sqlite3_value) (reflect.Value, error)
 146  
 147  type callbackArgCast struct {
 148  	f   callbackArgConverter
 149  	typ reflect.Type
 150  }
 151  
 152  func (c callbackArgCast) Run(v *C.sqlite3_value) (reflect.Value, error) {
 153  	val, err := c.f(v)
 154  	if err != nil {
 155  		return reflect.Value{}, err
 156  	}
 157  	if !val.Type().ConvertibleTo(c.typ) {
 158  		return reflect.Value{}, fmt.Errorf("cannot convert %s to %s", val.Type(), c.typ)
 159  	}
 160  	return val.Convert(c.typ), nil
 161  }
 162  
 163  func callbackArgInt64(v *C.sqlite3_value) (reflect.Value, error) {
 164  	if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
 165  		return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
 166  	}
 167  	return reflect.ValueOf(int64(C.sqlite3_value_int64(v))), nil
 168  }
 169  
 170  func callbackArgBool(v *C.sqlite3_value) (reflect.Value, error) {
 171  	if C.sqlite3_value_type(v) != C.SQLITE_INTEGER {
 172  		return reflect.Value{}, fmt.Errorf("argument must be an INTEGER")
 173  	}
 174  	i := int64(C.sqlite3_value_int64(v))
 175  	val := false
 176  	if i != 0 {
 177  		val = true
 178  	}
 179  	return reflect.ValueOf(val), nil
 180  }
 181  
 182  func callbackArgFloat64(v *C.sqlite3_value) (reflect.Value, error) {
 183  	if C.sqlite3_value_type(v) != C.SQLITE_FLOAT {
 184  		return reflect.Value{}, fmt.Errorf("argument must be a FLOAT")
 185  	}
 186  	return reflect.ValueOf(float64(C.sqlite3_value_double(v))), nil
 187  }
 188  
 189  func callbackArgBytes(v *C.sqlite3_value) (reflect.Value, error) {
 190  	switch C.sqlite3_value_type(v) {
 191  	case C.SQLITE_BLOB:
 192  		l := C.sqlite3_value_bytes(v)
 193  		p := C.sqlite3_value_blob(v)
 194  		return reflect.ValueOf(C.GoBytes(p, l)), nil
 195  	case C.SQLITE_TEXT:
 196  		l := C.sqlite3_value_bytes(v)
 197  		c := unsafe.Pointer(C.sqlite3_value_text(v))
 198  		return reflect.ValueOf(C.GoBytes(c, l)), nil
 199  	default:
 200  		return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
 201  	}
 202  }
 203  
 204  func callbackArgString(v *C.sqlite3_value) (reflect.Value, error) {
 205  	switch C.sqlite3_value_type(v) {
 206  	case C.SQLITE_BLOB:
 207  		l := C.sqlite3_value_bytes(v)
 208  		p := (*C.char)(C.sqlite3_value_blob(v))
 209  		return reflect.ValueOf(C.GoStringN(p, l)), nil
 210  	case C.SQLITE_TEXT:
 211  		c := (*C.char)(unsafe.Pointer(C.sqlite3_value_text(v)))
 212  		return reflect.ValueOf(C.GoString(c)), nil
 213  	default:
 214  		return reflect.Value{}, fmt.Errorf("argument must be BLOB or TEXT")
 215  	}
 216  }
 217  
 218  func callbackArgGeneric(v *C.sqlite3_value) (reflect.Value, error) {
 219  	switch C.sqlite3_value_type(v) {
 220  	case C.SQLITE_INTEGER:
 221  		return callbackArgInt64(v)
 222  	case C.SQLITE_FLOAT:
 223  		return callbackArgFloat64(v)
 224  	case C.SQLITE_TEXT:
 225  		return callbackArgString(v)
 226  	case C.SQLITE_BLOB:
 227  		return callbackArgBytes(v)
 228  	case C.SQLITE_NULL:
 229  		// Interpret NULL as a nil byte slice.
 230  		var ret []byte
 231  		return reflect.ValueOf(ret), nil
 232  	default:
 233  		panic("unreachable")
 234  	}
 235  }
 236  
 237  func callbackArg(typ reflect.Type) (callbackArgConverter, error) {
 238  	switch typ.Kind() {
 239  	case reflect.Interface:
 240  		if typ.NumMethod() != 0 {
 241  			return nil, errors.New("the only supported interface type is any")
 242  		}
 243  		return callbackArgGeneric, nil
 244  	case reflect.Slice:
 245  		if typ.Elem().Kind() != reflect.Uint8 {
 246  			return nil, errors.New("the only supported slice type is []byte")
 247  		}
 248  		return callbackArgBytes, nil
 249  	case reflect.String:
 250  		return callbackArgString, nil
 251  	case reflect.Bool:
 252  		return callbackArgBool, nil
 253  	case reflect.Int64:
 254  		return callbackArgInt64, nil
 255  	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
 256  		c := callbackArgCast{callbackArgInt64, typ}
 257  		return c.Run, nil
 258  	case reflect.Float64:
 259  		return callbackArgFloat64, nil
 260  	case reflect.Float32:
 261  		c := callbackArgCast{callbackArgFloat64, typ}
 262  		return c.Run, nil
 263  	default:
 264  		return nil, fmt.Errorf("don't know how to convert to %s", typ)
 265  	}
 266  }
 267  
 268  func callbackConvertArgs(argv []*C.sqlite3_value, converters []callbackArgConverter, variadic callbackArgConverter) ([]reflect.Value, error) {
 269  	var args []reflect.Value
 270  
 271  	if len(argv) < len(converters) {
 272  		return nil, fmt.Errorf("function requires at least %d arguments", len(converters))
 273  	}
 274  
 275  	for i, arg := range argv[:len(converters)] {
 276  		v, err := converters[i](arg)
 277  		if err != nil {
 278  			return nil, err
 279  		}
 280  		args = append(args, v)
 281  	}
 282  
 283  	if variadic != nil {
 284  		for _, arg := range argv[len(converters):] {
 285  			v, err := variadic(arg)
 286  			if err != nil {
 287  				return nil, err
 288  			}
 289  			args = append(args, v)
 290  		}
 291  	}
 292  	return args, nil
 293  }
 294  
 295  type callbackRetConverter func(*C.sqlite3_context, reflect.Value) error
 296  
 297  func callbackRetInteger(ctx *C.sqlite3_context, v reflect.Value) error {
 298  	switch v.Type().Kind() {
 299  	case reflect.Int64:
 300  	case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
 301  		v = v.Convert(reflect.TypeOf(int64(0)))
 302  	case reflect.Bool:
 303  		b := v.Interface().(bool)
 304  		if b {
 305  			v = reflect.ValueOf(int64(1))
 306  		} else {
 307  			v = reflect.ValueOf(int64(0))
 308  		}
 309  	default:
 310  		return fmt.Errorf("cannot convert %s to INTEGER", v.Type())
 311  	}
 312  
 313  	C.sqlite3_result_int64(ctx, C.sqlite3_int64(v.Interface().(int64)))
 314  	return nil
 315  }
 316  
 317  func callbackRetFloat(ctx *C.sqlite3_context, v reflect.Value) error {
 318  	switch v.Type().Kind() {
 319  	case reflect.Float64:
 320  	case reflect.Float32:
 321  		v = v.Convert(reflect.TypeOf(float64(0)))
 322  	default:
 323  		return fmt.Errorf("cannot convert %s to FLOAT", v.Type())
 324  	}
 325  
 326  	C.sqlite3_result_double(ctx, C.double(v.Interface().(float64)))
 327  	return nil
 328  }
 329  
 330  func callbackRetBlob(ctx *C.sqlite3_context, v reflect.Value) error {
 331  	if v.Type().Kind() != reflect.Slice || v.Type().Elem().Kind() != reflect.Uint8 {
 332  		return fmt.Errorf("cannot convert %s to BLOB", v.Type())
 333  	}
 334  	i := v.Interface()
 335  	if i == nil || len(i.([]byte)) == 0 {
 336  		C.sqlite3_result_null(ctx)
 337  	} else {
 338  		bs := i.([]byte)
 339  		C._sqlite3_result_blob(ctx, unsafe.Pointer(&bs[0]), C.int(len(bs)))
 340  	}
 341  	return nil
 342  }
 343  
 344  func callbackRetText(ctx *C.sqlite3_context, v reflect.Value) error {
 345  	if v.Type().Kind() != reflect.String {
 346  		return fmt.Errorf("cannot convert %s to TEXT", v.Type())
 347  	}
 348  	cstr := C.CString(v.Interface().(string))
 349  	C._sqlite3_result_text(ctx, cstr)
 350  	return nil
 351  }
 352  
 353  func callbackRetNil(ctx *C.sqlite3_context, v reflect.Value) error {
 354  	return nil
 355  }
 356  
 357  func callbackRetGeneric(ctx *C.sqlite3_context, v reflect.Value) error {
 358  	if v.IsNil() {
 359  		C.sqlite3_result_null(ctx)
 360  		return nil
 361  	}
 362  
 363  	cb, err := callbackRet(v.Elem().Type())
 364  	if err != nil {
 365  		return err
 366  	}
 367  
 368  	return cb(ctx, v.Elem())
 369  }
 370  
 371  func callbackRet(typ reflect.Type) (callbackRetConverter, error) {
 372  	switch typ.Kind() {
 373  	case reflect.Interface:
 374  		errorInterface := reflect.TypeOf((*error)(nil)).Elem()
 375  		if typ.Implements(errorInterface) {
 376  			return callbackRetNil, nil
 377  		}
 378  
 379  		if typ.NumMethod() == 0 {
 380  			return callbackRetGeneric, nil
 381  		}
 382  
 383  		fallthrough
 384  	case reflect.Slice:
 385  		if typ.Elem().Kind() != reflect.Uint8 {
 386  			return nil, errors.New("the only supported slice type is []byte")
 387  		}
 388  		return callbackRetBlob, nil
 389  	case reflect.String:
 390  		return callbackRetText, nil
 391  	case reflect.Bool, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Int, reflect.Uint:
 392  		return callbackRetInteger, nil
 393  	case reflect.Float32, reflect.Float64:
 394  		return callbackRetFloat, nil
 395  	default:
 396  		return nil, fmt.Errorf("don't know how to convert to %s", typ)
 397  	}
 398  }
 399  
 400  func callbackError(ctx *C.sqlite3_context, err error) {
 401  	cstr := C.CString(err.Error())
 402  	defer C.free(unsafe.Pointer(cstr))
 403  	C.sqlite3_result_error(ctx, cstr, C.int(-1))
 404  }
 405  
 406  // Test support code. Tests are not allowed to import "C", so we can't
 407  // declare any functions that use C.sqlite3_value.
 408  func callbackSyntheticForTests(v reflect.Value, err error) callbackArgConverter {
 409  	return func(*C.sqlite3_value) (reflect.Value, error) {
 410  		return v, err
 411  	}
 412  }
 413