sqlite.go raw
1 // The sqlite package defines an extensible sqlite3 store for Nostr events.
2 package sqlite
3
4 import (
5 "context"
6 "database/sql"
7 "encoding/json"
8 "errors"
9 "fmt"
10 "strings"
11 "sync/atomic"
12
13 _ "github.com/mattn/go-sqlite3"
14 "github.com/nbd-wtf/go-nostr"
15 )
16
17 var (
18 ErrInvalidAddressableEvent = errors.New("addressable event doesn't have a 'd' tag")
19 ErrInvalidReplacement = errors.New("called Replace on a non-replaceable event")
20 ErrInternalQuery = errors.New("internal query error")
21 )
22
23 const schema = `
24 CREATE TABLE IF NOT EXISTS events (
25 id TEXT PRIMARY KEY,
26 pubkey TEXT NOT NULL,
27 created_at INTEGER NOT NULL,
28 kind INTEGER NOT NULL,
29 tags JSONB NOT NULL,
30 content TEXT NOT NULL,
31 sig TEXT NOT NULL
32 );
33
34 CREATE INDEX IF NOT EXISTS time_idx ON events(created_at DESC, id ASC);
35 CREATE INDEX IF NOT EXISTS kind_sorted_idx ON events(kind, created_at DESC, id ASC);
36 CREATE INDEX IF NOT EXISTS pubkey_kind_sorted_idx ON events(pubkey, kind, created_at DESC, id ASC);
37
38 CREATE TABLE IF NOT EXISTS tags (
39 event_id TEXT NOT NULL,
40 key TEXT NOT NULL,
41 value TEXT NOT NULL,
42
43 PRIMARY KEY (key, value, event_id),
44 FOREIGN KEY (event_id) REFERENCES events(id) ON DELETE CASCADE
45 );
46
47 CREATE INDEX IF NOT EXISTS tags_event_id_idx ON tags(event_id);
48
49 CREATE TRIGGER IF NOT EXISTS d_tags_ai AFTER INSERT ON events
50 WHEN NEW.kind BETWEEN 30000 AND 39999
51 BEGIN
52 INSERT INTO tags (event_id, key, value)
53 SELECT NEW.id, 'd', json_extract(value, '$[1]')
54 FROM json_each(NEW.tags)
55 WHERE json_type(value) = 'array' AND json_array_length(value) > 1 AND json_extract(value, '$[0]') = 'd'
56 LIMIT 1;
57 END;`
58
59 // Store of Nostr events that uses an sqlite3 database.
60 // It embeds the *sql.DB connection for direct interaction and can use
61 // custom filter/event policies and query builders.
62 //
63 // All methods are safe for concurrent use. However, due to the file-based
64 // architecture of SQLite, there can only be one writer at the time.
65 //
66 // This limitation remains even after applying the recommended concurrency optimisations,
67 // such as journal_mode=WAL and PRAGMA busy_timeout=1s, which are applied by default in this implementation.
68 //
69 // Therefore it remains possible that methods return the error [sqlite3.ErrBusy].
70 // To reduce the likelihood of this happening, you can:
71 // - increase the busy_timeout with the option [WithBusyTimeout] (default is 1s).
72 // - provide synchronisation, for example with a mutex or channel(s). This however won't
73 // help if there are other programs writing to the same sqlite file.
74 //
75 // If instead you want to handle all [sqlite3.ErrBusy] in your application,
76 // use WithBusyTimeout(0) to make blocked writers return immediatly.
77 //
78 // More about WAL mode and concurrency: https://sqlite.org/wal.html
79 type Store struct {
80 *sql.DB
81
82 optimizeEvery int32 // the threshold of writes that trigger PRAGMA optimize
83 writeCount atomic.Int32 // successful writes since last PRAGMA optimize
84
85 filterPolicy FilterPolicy
86 eventPolicy EventPolicy
87
88 queryBuilder QueryBuilder
89 countBuilder QueryBuilder
90 }
91
92 // New returns an sqlite3 store connected to the sqlite file located at the provided
93 // file path, after applying the base schema, and the provided options.
94 func New(path string, opts ...Option) (*Store, error) {
95 DB, err := sql.Open("sqlite3", path)
96 if err != nil {
97 return nil, fmt.Errorf("failed to connect to sqlite3 at %s: %w", path, err)
98 }
99
100 if _, err := DB.Exec(schema); err != nil {
101 return nil, fmt.Errorf("failed to apply base schema: %w", err)
102 }
103
104 if _, err := DB.Exec("PRAGMA journal_mode = WAL;"); err != nil {
105 return nil, fmt.Errorf("failed to set WAL mode: %w", err)
106 }
107
108 if _, err := DB.Exec("PRAGMA busy_timeout = 1000;"); err != nil {
109 return nil, fmt.Errorf("failed to set busy timeout: %w", err)
110 }
111
112 store := &Store{
113 DB: DB,
114 optimizeEvery: 5000,
115 filterPolicy: defaultFilterPolicy,
116 eventPolicy: defaultEventPolicy,
117 queryBuilder: DefaultQueryBuilder,
118 countBuilder: DefaultCountBuilder,
119 }
120
121 for _, opt := range opts {
122 if err := opt(store); err != nil {
123 return nil, err
124 }
125 }
126
127 // run full optimize after options, to inform the query planner about new indexes (if any).
128 if _, err := DB.Exec("PRAGMA optimize=0x10002;"); err != nil {
129 return nil, fmt.Errorf("failed to PRAGMA optimize: %w", err)
130 }
131 return store, nil
132 }
133
134 // QueryBuilder converts multiple nostr filters into one or more sqlite queries and lists of arguments.
135 // Not all filters can be combined into a single query, but many can.
136 // Filters passed to the query builder have been previously validated by the [Store.filterPolicy].
137 //
138 // It's useful to specify custom query/count builders to leverage additional schemas that have been
139 // provided in the [New] constructor.
140 //
141 // For examples, check out the [DefaultQueryBuilder] and [DefaultCountBuilder]
142 type QueryBuilder func(filters ...nostr.Filter) (queries []Query, err error)
143
144 type Query struct {
145 SQL string
146 Args []any
147 }
148
149 // Size returns the number of events stored in the "events" table.
150 func (s *Store) Size(ctx context.Context) (size int, err error) {
151 row := s.DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM events;")
152 if err = row.Scan(&size); err != nil {
153 return -1, err
154 }
155 return size, nil
156 }
157
158 // Optimize runs "PRAGMA optimize", which updates the statistics and heuristics
159 // of the query planner, which should result in improved read performance.
160 // The Store's write methods call Optimize (roughly) every [Store.optimizeEvery]
161 // successful insertions or deletions.
162 func (s *Store) Optimize(ctx context.Context) error {
163 _, err := s.DB.ExecContext(ctx, "PRAGMA optimize;")
164 return err
165 }
166
167 func (s *Store) checkOptimize(ctx context.Context) error {
168 tot := s.writeCount.Add(1)
169 if tot < s.optimizeEvery {
170 return nil
171 }
172
173 if s.writeCount.CompareAndSwap(tot, 0) {
174 return s.Optimize(ctx)
175 }
176 return nil
177 }
178
179 // Close the underlying database connection, committing all temporary data to disk.
180 func (s *Store) Close() error {
181 return s.DB.Close()
182 }
183
184 // Save the event in the store. Save is idempotent, meaning successful calls to Save
185 // with the same event are no-ops.
186 //
187 // Save returns true if the event has been saved, false in case of errors or if the event was already present.
188 // For replaceable/addressable events, it is recommended to call [Store.Replace] instead.
189 func (s *Store) Save(ctx context.Context, e *nostr.Event) (bool, error) {
190 if err := s.eventPolicy(e); err != nil {
191 return false, err
192 }
193
194 tags, err := json.Marshal(e.Tags)
195 if err != nil {
196 return false, fmt.Errorf("failed to marshal the event tags: %w", err)
197 }
198
199 res, err := s.DB.ExecContext(ctx, `INSERT OR IGNORE INTO events (id, pubkey, created_at, kind, tags, content, sig)
200 VALUES ($1, $2, $3, $4, $5, $6, $7)`, e.ID, e.PubKey, e.CreatedAt, e.Kind, tags, e.Content, e.Sig)
201
202 if err != nil {
203 return false, fmt.Errorf("failed to execute: %w", err)
204 }
205
206 rows, err := res.RowsAffected()
207 if err != nil {
208 return false, fmt.Errorf("failed to check for rows affected: %w", err)
209 }
210
211 if rows > 0 {
212 s.checkOptimize(ctx)
213 return true, nil
214 }
215 return false, nil
216 }
217
218 // Delete the event with the provided id. If the event is not found, nothing happens and nil is returned.
219 // Delete returns true if the event was deleted, false in case of errors or if the event
220 // was never present.
221 func (s *Store) Delete(ctx context.Context, id string) (bool, error) {
222 res, err := s.DB.ExecContext(ctx, "DELETE FROM events WHERE id = $1", id)
223 if err != nil {
224 return false, fmt.Errorf("failed to execute: %w", err)
225 }
226
227 rows, err := res.RowsAffected()
228 if err != nil {
229 return false, fmt.Errorf("failed to check for rows affected: %w", err)
230 }
231
232 if rows > 0 {
233 s.checkOptimize(ctx)
234 return true, nil
235 }
236 return false, nil
237 }
238
239 // Replace an old event with the new one according to NIP-01.
240 //
241 // The replacement happens if the event is strictly newer than the stored event
242 // within the same 'category' (kind, pubkey, and d-tag if addressable).
243 // If no such stored event exists, and the event is a replaceable/addressable kind, it is simply saved.
244 //
245 // Calling Replace on a non-replaceable/addressable event returns [ErrInvalidReplacement]
246 //
247 // Replace returns true if the event has been saved/superseded a previous one,
248 // false in case of errors or if a stored event in the same 'category' is newer or equal.
249 //
250 // More info here: https://github.com/nostr-protocol/nips/blob/master/01.md#kinds
251 func (s *Store) Replace(ctx context.Context, event *nostr.Event) (bool, error) {
252 if err := s.eventPolicy(event); err != nil {
253 return false, err
254 }
255
256 var query Query
257 switch {
258 case nostr.IsReplaceableKind(event.Kind):
259 query = Query{
260 SQL: "SELECT id, created_at FROM events WHERE kind = $1 AND pubkey = $2",
261 Args: []any{event.Kind, event.PubKey},
262 }
263
264 case nostr.IsAddressableKind(event.Kind):
265 dTag := event.Tags.GetD()
266 if dTag == "" {
267 return false, ErrInvalidAddressableEvent
268 }
269
270 query = Query{
271 SQL: `SELECT e.id, e.created_at FROM events AS e
272 JOIN tags AS t ON e.id = t.event_id
273 WHERE e.kind = $1 AND e.pubkey = $2 AND t.key = 'd' AND t.value = $3;`,
274 Args: []any{event.Kind, event.PubKey, dTag},
275 }
276
277 default:
278 return false, ErrInvalidReplacement
279 }
280
281 replaced, err := s.replace(ctx, event, query)
282 if err != nil {
283 return false, fmt.Errorf("failed to replace: %w", err)
284 }
285 return replaced, nil
286 }
287
288 // replace the event with the provided id with the new event.
289 // It's an atomic version of Save(ctx, new) + Delete(ctx, id)
290 func (s *Store) replace(ctx context.Context, event *nostr.Event, query Query) (bool, error) {
291 tags, err := json.Marshal(event.Tags)
292 if err != nil {
293 return false, fmt.Errorf("failed to marshal the event tags: %w", err)
294 }
295
296 tx, err := s.DB.BeginTx(ctx, nil)
297 if err != nil {
298 return false, fmt.Errorf("failed to begin the transaction: %w", err)
299 }
300 defer tx.Rollback()
301
302 var oldID string
303 var oldCreatedAt nostr.Timestamp
304
305 row := tx.QueryRowContext(ctx, query.SQL, query.Args...)
306 err = row.Scan(&oldID, &oldCreatedAt)
307
308 if errors.Is(err, sql.ErrNoRows) {
309 // no old event found, insert the new one
310 _, err := tx.ExecContext(ctx, `INSERT OR IGNORE INTO events (id, pubkey, created_at, kind, tags, content, sig)
311 VALUES ($1, $2, $3, $4, $5, $6, $7)`, event.ID, event.PubKey, event.CreatedAt, event.Kind, tags, event.Content, event.Sig)
312
313 if err != nil {
314 return false, err
315 }
316
317 if err = tx.Commit(); err != nil {
318 return false, err
319 }
320
321 s.checkOptimize(ctx)
322 return true, nil
323 }
324
325 if err != nil {
326 // fail on different non nil errors
327 return false, fmt.Errorf("failed to query for old event: %w", err)
328 }
329
330 if oldCreatedAt >= event.CreatedAt {
331 return false, nil
332 }
333
334 if _, err = tx.ExecContext(ctx, `INSERT OR IGNORE INTO events (id, pubkey, created_at, kind, tags, content, sig)
335 VALUES ($1, $2, $3, $4, $5, $6, $7)`, event.ID, event.PubKey, event.CreatedAt, event.Kind, tags, event.Content, event.Sig); err != nil {
336 return false, err
337 }
338
339 if _, err = tx.ExecContext(ctx, "DELETE FROM events WHERE id = $1", oldID); err != nil {
340 return false, err
341 }
342
343 if err := tx.Commit(); err != nil {
344 return false, err
345 }
346 return true, nil
347 }
348
349 // Query stored events matching the provided filters.
350 func (s *Store) Query(ctx context.Context, filters ...nostr.Filter) ([]nostr.Event, error) {
351 return s.QueryWithBuilder(ctx, s.queryBuilder, filters...)
352 }
353
354 // QueryWithBuilder generates an sqlite query for the filters with the provided [QueryBuilder], and executes it.
355 func (s *Store) QueryWithBuilder(ctx context.Context, build QueryBuilder, filters ...nostr.Filter) ([]nostr.Event, error) {
356 filters, err := s.filterPolicy(filters...)
357 if err != nil {
358 return nil, err
359 }
360
361 queries, err := build(filters...)
362 if err != nil {
363 return nil, fmt.Errorf("failed to build query: %w", err)
364 }
365
366 var events []nostr.Event
367 for i, query := range queries {
368 rows, err := s.DB.QueryContext(ctx, query.SQL, query.Args...)
369 if errors.Is(err, sql.ErrNoRows) {
370 continue
371 }
372 if err != nil {
373 return nil, fmt.Errorf("failed to fetch events with query %s: %w", queries[i], err)
374 }
375 defer rows.Close()
376
377 for rows.Next() {
378 var event nostr.Event
379 err = rows.Scan(&event.ID, &event.PubKey, &event.CreatedAt, &event.Kind, &event.Tags, &event.Content, &event.Sig)
380 if err != nil {
381 return events, fmt.Errorf("%w: failed to scan event row: %w", ErrInternalQuery, err)
382 }
383
384 events = append(events, event)
385 }
386
387 if err := rows.Err(); err != nil {
388 return events, fmt.Errorf("%w: failed to scan event row: %w", ErrInternalQuery, err)
389 }
390 }
391 return events, nil
392 }
393
394 // Count stored events matching the provided filters.
395 func (s *Store) Count(ctx context.Context, filters ...nostr.Filter) (int64, error) {
396 return s.CountWithBuilder(ctx, s.countBuilder, filters...)
397 }
398
399 // CountWithBuilder generates an sqlite query for the filters with the provided [QueryBuilder], and executes it.
400 func (s *Store) CountWithBuilder(ctx context.Context, build QueryBuilder, filters ...nostr.Filter) (int64, error) {
401 queries, err := build(filters...)
402 if err != nil {
403 return 0, fmt.Errorf("failed to build count query: %w", err)
404 }
405
406 var total int64
407 for i, query := range queries {
408 var count int64
409 row := s.DB.QueryRowContext(ctx, query.SQL, query.Args...)
410 err := row.Scan(&count)
411 if err != nil {
412 return 0, fmt.Errorf("failed to count events with query %s: %w", queries[i], err)
413 }
414
415 total += count
416 }
417 return total, nil
418 }
419
420 func DefaultQueryBuilder(filters ...nostr.Filter) ([]Query, error) {
421 switch len(filters) {
422 case 0:
423 return nil, nil
424
425 case 1:
426 query, args := buildQuery(filters[0])
427 query += " ORDER BY e.created_at DESC, e.id ASC LIMIT ?"
428 args = append(args, filters[0].Limit)
429 return []Query{{SQL: query, Args: args}}, nil
430
431 default:
432 subQueries := make([]string, 0, len(filters))
433 allArgs := make([]any, 0, len(filters))
434 limit := 0
435
436 for _, filter := range filters {
437 query, args := buildQuery(filter)
438 subQueries = append(subQueries, query)
439 allArgs = append(allArgs, args...)
440 limit += filter.Limit
441 }
442
443 query := "SELECT * FROM (" + strings.Join(subQueries, " UNION ALL ") + ")" +
444 " GROUP BY id ORDER BY created_at DESC, id ASC LIMIT ?"
445 allArgs = append(allArgs, limit)
446 return []Query{{SQL: query, Args: allArgs}}, nil
447 }
448 }
449
450 func DefaultCountBuilder(filters ...nostr.Filter) ([]Query, error) {
451 switch len(filters) {
452 case 0:
453 return nil, nil
454
455 case 1:
456 query, args := buildCount(filters[0])
457 return []Query{{SQL: query, Args: args}}, nil
458
459 default:
460 subQueries := make([]string, 0, len(filters))
461 allArgs := make([]any, 0, len(filters))
462
463 for _, filter := range filters {
464 query, args := buildCount(filter)
465 subQueries = append(subQueries, "("+query+")")
466 allArgs = append(allArgs, args...)
467 }
468
469 // TODO: we are summing all counts together, without any deduplication
470 query := "SELECT (" + strings.Join(subQueries, " + ") + ")"
471 return []Query{{SQL: query, Args: allArgs}}, nil
472 }
473 }
474
475 func buildQuery(filter nostr.Filter) (string, []any) {
476 sql := toSql(filter)
477 if sql.JoinTags {
478 query := "SELECT e.* FROM events AS e JOIN tags AS t ON t.event_id = e.id" +
479 " WHERE " + strings.Join(sql.Conditions, " AND ") + " GROUP BY e.id"
480 return query, sql.Args
481 }
482
483 query := "SELECT e.* FROM events AS e"
484 if len(sql.Conditions) > 0 {
485 query += " WHERE " + strings.Join(sql.Conditions, " AND ")
486 }
487 return query, sql.Args
488 }
489
490 func buildCount(filter nostr.Filter) (string, []any) {
491 sql := toSql(filter)
492 if sql.JoinTags {
493 query := "SELECT COUNT(DISTINCT e.id) FROM events AS e JOIN tags AS t ON t.event_id = e.id" +
494 " WHERE " + strings.Join(sql.Conditions, " AND ")
495 return query, sql.Args
496 }
497
498 query := "SELECT COUNT(e.id) FROM events AS e"
499 if len(sql.Conditions) > 0 {
500 query += " WHERE " + strings.Join(sql.Conditions, " AND ")
501 }
502 return query, sql.Args
503 }
504
505 type sqlFilter struct {
506 Conditions []string
507 Args []any
508 JoinTags bool
509 }
510
511 func toSql(filter nostr.Filter) sqlFilter {
512 s := sqlFilter{}
513 if len(filter.IDs) > 0 {
514 s.Conditions = append(s.Conditions, "e.id"+equalityClause(filter.IDs))
515 for _, id := range filter.IDs {
516 s.Args = append(s.Args, id)
517 }
518 }
519
520 if len(filter.Kinds) > 0 {
521 s.Conditions = append(s.Conditions, "e.kind"+equalityClause(filter.Kinds))
522 for _, kind := range filter.Kinds {
523 s.Args = append(s.Args, kind)
524 }
525 }
526
527 if len(filter.Authors) > 0 {
528 s.Conditions = append(s.Conditions, "e.pubkey"+equalityClause(filter.Authors))
529 for _, pk := range filter.Authors {
530 s.Args = append(s.Args, pk)
531 }
532 }
533
534 if filter.Until != nil {
535 s.Conditions = append(s.Conditions, "e.created_at <= ?")
536 s.Args = append(s.Args, filter.Until.Time().Unix())
537 }
538
539 if filter.Since != nil {
540 s.Conditions = append(s.Conditions, "e.created_at >= ?")
541 s.Args = append(s.Args, filter.Since.Time().Unix())
542 }
543
544 if len(filter.Tags) > 0 {
545 conds := make([]string, 0, len(filter.Tags))
546 args := make([]any, 0, len(filter.Tags))
547
548 for key, vals := range filter.Tags {
549 if len(vals) == 0 {
550 continue
551 }
552
553 conds = append(conds, "(t.key = ? AND t.value"+equalityClause(vals)+")")
554 args = append(args, key)
555 for _, v := range vals {
556 args = append(args, v)
557 }
558 }
559
560 if len(conds) > 0 {
561 s.JoinTags = true
562 s.Conditions = append(s.Conditions, strings.Join(conds, " OR "))
563 s.Args = append(s.Args, args...)
564 }
565 }
566 return s
567 }
568
569 // equalityClause returns the appropriate SQL comparison operator and placeholder(s)
570 // for use in a WHERE clause, based on the number of values provided.
571 // If the slice contains one value, it returns " = ?".
572 // If it contains multiple values, it returns " IN (?, ?, ... )" with the correct number of placeholders.
573 // It panics is vals is nil or empty.
574 func equalityClause[T any](vals []T) string {
575 if len(vals) == 1 {
576 return " = ?"
577 }
578 return " IN (?" + strings.Repeat(",?", len(vals)-1) + ")"
579 }
580