1 // Copyright (C) 2019 G.J.R. Timmer <gjr.timmer@gmail.com>.
2 // Copyright (C) 2018 segment.com <friends@segment.com>
3 //
4 // Use of this source code is governed by an MIT-style
5 // license that can be found in the LICENSE file.
6 7 //go:build sqlite_preupdate_hook
8 // +build sqlite_preupdate_hook
9 10 package sqlite3
11 12 /*
13 #cgo CFLAGS: -DSQLITE_ENABLE_PREUPDATE_HOOK
14 #cgo LDFLAGS: -lm
15 16 #ifndef USE_LIBSQLITE3
17 #include "sqlite3-binding.h"
18 #else
19 #include <sqlite3.h>
20 #endif
21 #include <stdlib.h>
22 #include <string.h>
23 24 void preUpdateHookTrampoline(void*, sqlite3 *, int, char *, char *, sqlite3_int64, sqlite3_int64);
25 */
26 import "C"
27 import (
28 "errors"
29 "unsafe"
30 )
31 32 // RegisterPreUpdateHook sets the pre-update hook for a connection.
33 //
34 // The callback is passed a SQLitePreUpdateData struct with the data for
35 // the update, as well as methods for fetching copies of impacted data.
36 //
37 // If there is an existing preupdate hook for this connection, it will be
38 // removed. If callback is nil the existing hook (if any) will be removed
39 // without creating a new one.
40 func (c *SQLiteConn) RegisterPreUpdateHook(callback func(SQLitePreUpdateData)) {
41 if callback == nil {
42 C.sqlite3_preupdate_hook(c.db, nil, nil)
43 } else {
44 C.sqlite3_preupdate_hook(c.db, (*[0]byte)(unsafe.Pointer(C.preUpdateHookTrampoline)), unsafe.Pointer(newHandle(c, callback)))
45 }
46 }
47 48 // Depth returns the source path of the write, see sqlite3_preupdate_depth()
49 func (d *SQLitePreUpdateData) Depth() int {
50 return int(C.sqlite3_preupdate_depth(d.Conn.db))
51 }
52 53 // Count returns the number of columns in the row
54 func (d *SQLitePreUpdateData) Count() int {
55 return int(C.sqlite3_preupdate_count(d.Conn.db))
56 }
57 58 func (d *SQLitePreUpdateData) row(dest []any, new bool) error {
59 for i := 0; i < d.Count() && i < len(dest); i++ {
60 var val *C.sqlite3_value
61 var src any
62 63 // Initially I tried making this just a function pointer argument, but
64 // it's absurdly complicated to pass C function pointers.
65 if new {
66 C.sqlite3_preupdate_new(d.Conn.db, C.int(i), &val)
67 } else {
68 C.sqlite3_preupdate_old(d.Conn.db, C.int(i), &val)
69 }
70 71 switch C.sqlite3_value_type(val) {
72 case C.SQLITE_INTEGER:
73 src = int64(C.sqlite3_value_int64(val))
74 case C.SQLITE_FLOAT:
75 src = float64(C.sqlite3_value_double(val))
76 case C.SQLITE_BLOB:
77 len := C.sqlite3_value_bytes(val)
78 blobptr := C.sqlite3_value_blob(val)
79 src = C.GoBytes(blobptr, len)
80 case C.SQLITE_TEXT:
81 len := C.sqlite3_value_bytes(val)
82 cstrptr := unsafe.Pointer(C.sqlite3_value_text(val))
83 src = C.GoBytes(cstrptr, len)
84 case C.SQLITE_NULL:
85 src = nil
86 }
87 88 err := convertAssign(&dest[i], src)
89 if err != nil {
90 return err
91 }
92 }
93 94 return nil
95 }
96 97 // Old populates dest with the row data to be replaced. This works similar to
98 // database/sql's Rows.Scan()
99 func (d *SQLitePreUpdateData) Old(dest ...any) error {
100 if d.Op == SQLITE_INSERT {
101 return errors.New("There is no old row for INSERT operations")
102 }
103 return d.row(dest, false)
104 }
105 106 // New populates dest with the replacement row data. This works similar to
107 // database/sql's Rows.Scan()
108 func (d *SQLitePreUpdateData) New(dest ...any) error {
109 if d.Op == SQLITE_DELETE {
110 return errors.New("There is no new row for DELETE operations")
111 }
112 return d.row(dest, true)
113 }
114