sqlite.go raw

   1  // The sqlite package defines an extensible sqlite3 store for Nostr events.
   2  package sqlite
   3  
   4  import (
   5  	"context"
   6  	"database/sql"
   7  	"encoding/json"
   8  	"errors"
   9  	"fmt"
  10  	"strings"
  11  	"sync/atomic"
  12  
  13  	_ "github.com/mattn/go-sqlite3"
  14  	"github.com/nbd-wtf/go-nostr"
  15  )
  16  
  17  var (
  18  	ErrInvalidAddressableEvent = errors.New("addressable event doesn't have a 'd' tag")
  19  	ErrInvalidReplacement      = errors.New("called Replace on a non-replaceable event")
  20  	ErrInternalQuery           = errors.New("internal query error")
  21  )
  22  
  23  const schema = `
  24  	CREATE TABLE IF NOT EXISTS events (
  25         id TEXT PRIMARY KEY,
  26         pubkey TEXT NOT NULL,
  27         created_at INTEGER NOT NULL,
  28         kind INTEGER NOT NULL,
  29         tags JSONB NOT NULL,
  30         content TEXT NOT NULL,
  31         sig TEXT NOT NULL
  32  	);
  33  
  34  	CREATE INDEX IF NOT EXISTS time_idx ON events(created_at DESC, id ASC);
  35  	CREATE INDEX IF NOT EXISTS kind_sorted_idx ON events(kind, created_at DESC, id ASC);
  36  	CREATE INDEX IF NOT EXISTS pubkey_kind_sorted_idx ON events(pubkey, kind, created_at DESC, id ASC);
  37  	
  38  	CREATE TABLE IF NOT EXISTS tags (
  39  		event_id TEXT NOT NULL,
  40  		key TEXT NOT NULL,
  41  		value TEXT NOT NULL,
  42  		
  43  		PRIMARY KEY (key, value, event_id),
  44  		FOREIGN KEY (event_id) REFERENCES events(id) ON DELETE CASCADE
  45  	);
  46  
  47  	CREATE INDEX IF NOT EXISTS tags_event_id_idx ON tags(event_id);
  48  
  49  	CREATE TRIGGER IF NOT EXISTS d_tags_ai AFTER INSERT ON events
  50  	WHEN NEW.kind BETWEEN 30000 AND 39999 
  51  	BEGIN
  52  	INSERT INTO tags (event_id, key, value)
  53  		SELECT NEW.id, 'd', json_extract(value, '$[1]')
  54  		FROM json_each(NEW.tags)
  55  		WHERE json_type(value) = 'array' AND json_array_length(value) > 1 AND json_extract(value, '$[0]') = 'd'
  56  		LIMIT 1;
  57  	END;`
  58  
  59  // Store of Nostr events that uses an sqlite3 database.
  60  // It embeds the *sql.DB connection for direct interaction and can use
  61  // custom filter/event policies and query builders.
  62  //
  63  // All methods are safe for concurrent use. However, due to the file-based
  64  // architecture of SQLite, there can only be one writer at the time.
  65  //
  66  // This limitation remains even after applying the recommended concurrency optimisations,
  67  // such as journal_mode=WAL and PRAGMA busy_timeout=1s, which are applied by default in this implementation.
  68  //
  69  // Therefore it remains possible that methods return the error [sqlite3.ErrBusy].
  70  // To reduce the likelihood of this happening, you can:
  71  //   - increase the busy_timeout with the option [WithBusyTimeout] (default is 1s).
  72  //   - provide synchronisation, for example with a mutex or channel(s). This however won't
  73  //     help if there are other programs writing to the same sqlite file.
  74  //
  75  // If instead you want to handle all [sqlite3.ErrBusy] in your application,
  76  // use WithBusyTimeout(0) to make blocked writers return immediatly.
  77  //
  78  // More about WAL mode and concurrency: https://sqlite.org/wal.html
  79  type Store struct {
  80  	*sql.DB
  81  
  82  	optimizeEvery int32        // the threshold of writes that trigger PRAGMA optimize
  83  	writeCount    atomic.Int32 // successful writes since last PRAGMA optimize
  84  
  85  	filterPolicy FilterPolicy
  86  	eventPolicy  EventPolicy
  87  
  88  	queryBuilder QueryBuilder
  89  	countBuilder QueryBuilder
  90  }
  91  
  92  // New returns an sqlite3 store connected to the sqlite file located at the provided
  93  // file path, after applying the base schema, and the provided options.
  94  func New(path string, opts ...Option) (*Store, error) {
  95  	DB, err := sql.Open("sqlite3", path)
  96  	if err != nil {
  97  		return nil, fmt.Errorf("failed to connect to sqlite3 at %s: %w", path, err)
  98  	}
  99  
 100  	if _, err := DB.Exec(schema); err != nil {
 101  		return nil, fmt.Errorf("failed to apply base schema: %w", err)
 102  	}
 103  
 104  	if _, err := DB.Exec("PRAGMA journal_mode = WAL;"); err != nil {
 105  		return nil, fmt.Errorf("failed to set WAL mode: %w", err)
 106  	}
 107  
 108  	if _, err := DB.Exec("PRAGMA busy_timeout = 1000;"); err != nil {
 109  		return nil, fmt.Errorf("failed to set busy timeout: %w", err)
 110  	}
 111  
 112  	store := &Store{
 113  		DB:            DB,
 114  		optimizeEvery: 5000,
 115  		filterPolicy:  defaultFilterPolicy,
 116  		eventPolicy:   defaultEventPolicy,
 117  		queryBuilder:  DefaultQueryBuilder,
 118  		countBuilder:  DefaultCountBuilder,
 119  	}
 120  
 121  	for _, opt := range opts {
 122  		if err := opt(store); err != nil {
 123  			return nil, err
 124  		}
 125  	}
 126  
 127  	// run full optimize after options, to inform the query planner about new indexes (if any).
 128  	if _, err := DB.Exec("PRAGMA optimize=0x10002;"); err != nil {
 129  		return nil, fmt.Errorf("failed to PRAGMA optimize: %w", err)
 130  	}
 131  	return store, nil
 132  }
 133  
 134  // QueryBuilder converts multiple nostr filters into one or more sqlite queries and lists of arguments.
 135  // Not all filters can be combined into a single query, but many can.
 136  // Filters passed to the query builder have been previously validated by the [Store.filterPolicy].
 137  //
 138  // It's useful to specify custom query/count builders to leverage additional schemas that have been
 139  // provided in the [New] constructor.
 140  //
 141  // For examples, check out the [DefaultQueryBuilder] and [DefaultCountBuilder]
 142  type QueryBuilder func(filters ...nostr.Filter) (queries []Query, err error)
 143  
 144  type Query struct {
 145  	SQL  string
 146  	Args []any
 147  }
 148  
 149  // Size returns the number of events stored in the "events" table.
 150  func (s *Store) Size(ctx context.Context) (size int, err error) {
 151  	row := s.DB.QueryRowContext(ctx, "SELECT COUNT(*) FROM events;")
 152  	if err = row.Scan(&size); err != nil {
 153  		return -1, err
 154  	}
 155  	return size, nil
 156  }
 157  
 158  // Optimize runs "PRAGMA optimize", which updates the statistics and heuristics
 159  // of the query planner, which should result in improved read performance.
 160  // The Store's write methods call Optimize (roughly) every [Store.optimizeEvery]
 161  // successful insertions or deletions.
 162  func (s *Store) Optimize(ctx context.Context) error {
 163  	_, err := s.DB.ExecContext(ctx, "PRAGMA optimize;")
 164  	return err
 165  }
 166  
 167  func (s *Store) checkOptimize(ctx context.Context) error {
 168  	tot := s.writeCount.Add(1)
 169  	if tot < s.optimizeEvery {
 170  		return nil
 171  	}
 172  
 173  	if s.writeCount.CompareAndSwap(tot, 0) {
 174  		return s.Optimize(ctx)
 175  	}
 176  	return nil
 177  }
 178  
 179  // Close the underlying database connection, committing all temporary data to disk.
 180  func (s *Store) Close() error {
 181  	return s.DB.Close()
 182  }
 183  
 184  // Save the event in the store. Save is idempotent, meaning successful calls to Save
 185  // with the same event are no-ops.
 186  //
 187  // Save returns true if the event has been saved, false in case of errors or if the event was already present.
 188  // For replaceable/addressable events, it is recommended to call [Store.Replace] instead.
 189  func (s *Store) Save(ctx context.Context, e *nostr.Event) (bool, error) {
 190  	if err := s.eventPolicy(e); err != nil {
 191  		return false, err
 192  	}
 193  
 194  	tags, err := json.Marshal(e.Tags)
 195  	if err != nil {
 196  		return false, fmt.Errorf("failed to marshal the event tags: %w", err)
 197  	}
 198  
 199  	res, err := s.DB.ExecContext(ctx, `INSERT OR IGNORE INTO events (id, pubkey, created_at, kind, tags, content, sig)
 200          VALUES ($1, $2, $3, $4, $5, $6, $7)`, e.ID, e.PubKey, e.CreatedAt, e.Kind, tags, e.Content, e.Sig)
 201  
 202  	if err != nil {
 203  		return false, fmt.Errorf("failed to execute: %w", err)
 204  	}
 205  
 206  	rows, err := res.RowsAffected()
 207  	if err != nil {
 208  		return false, fmt.Errorf("failed to check for rows affected: %w", err)
 209  	}
 210  
 211  	if rows > 0 {
 212  		s.checkOptimize(ctx)
 213  		return true, nil
 214  	}
 215  	return false, nil
 216  }
 217  
 218  // Delete the event with the provided id. If the event is not found, nothing happens and nil is returned.
 219  // Delete returns true if the event was deleted, false in case of errors or if the event
 220  // was never present.
 221  func (s *Store) Delete(ctx context.Context, id string) (bool, error) {
 222  	res, err := s.DB.ExecContext(ctx, "DELETE FROM events WHERE id = $1", id)
 223  	if err != nil {
 224  		return false, fmt.Errorf("failed to execute: %w", err)
 225  	}
 226  
 227  	rows, err := res.RowsAffected()
 228  	if err != nil {
 229  		return false, fmt.Errorf("failed to check for rows affected: %w", err)
 230  	}
 231  
 232  	if rows > 0 {
 233  		s.checkOptimize(ctx)
 234  		return true, nil
 235  	}
 236  	return false, nil
 237  }
 238  
 239  // Replace an old event with the new one according to NIP-01.
 240  //
 241  // The replacement happens if the event is strictly newer than the stored event
 242  // within the same 'category' (kind, pubkey, and d-tag if addressable).
 243  // If no such stored event exists, and the event is a replaceable/addressable kind, it is simply saved.
 244  //
 245  // Calling Replace on a non-replaceable/addressable event returns [ErrInvalidReplacement]
 246  //
 247  // Replace returns true if the event has been saved/superseded a previous one,
 248  // false in case of errors or if a stored event in the same 'category' is newer or equal.
 249  //
 250  // More info here: https://github.com/nostr-protocol/nips/blob/master/01.md#kinds
 251  func (s *Store) Replace(ctx context.Context, event *nostr.Event) (bool, error) {
 252  	if err := s.eventPolicy(event); err != nil {
 253  		return false, err
 254  	}
 255  
 256  	var query Query
 257  	switch {
 258  	case nostr.IsReplaceableKind(event.Kind):
 259  		query = Query{
 260  			SQL:  "SELECT id, created_at FROM events WHERE kind = $1 AND pubkey = $2",
 261  			Args: []any{event.Kind, event.PubKey},
 262  		}
 263  
 264  	case nostr.IsAddressableKind(event.Kind):
 265  		dTag := event.Tags.GetD()
 266  		if dTag == "" {
 267  			return false, ErrInvalidAddressableEvent
 268  		}
 269  
 270  		query = Query{
 271  			SQL: `SELECT e.id, e.created_at FROM events AS e 
 272  				JOIN tags AS t ON e.id = t.event_id 
 273  				WHERE e.kind = $1 AND e.pubkey = $2 AND t.key = 'd' AND t.value = $3;`,
 274  			Args: []any{event.Kind, event.PubKey, dTag},
 275  		}
 276  
 277  	default:
 278  		return false, ErrInvalidReplacement
 279  	}
 280  
 281  	replaced, err := s.replace(ctx, event, query)
 282  	if err != nil {
 283  		return false, fmt.Errorf("failed to replace: %w", err)
 284  	}
 285  	return replaced, nil
 286  }
 287  
 288  // replace the event with the provided id with the new event.
 289  // It's an atomic version of Save(ctx, new) + Delete(ctx, id)
 290  func (s *Store) replace(ctx context.Context, event *nostr.Event, query Query) (bool, error) {
 291  	tags, err := json.Marshal(event.Tags)
 292  	if err != nil {
 293  		return false, fmt.Errorf("failed to marshal the event tags: %w", err)
 294  	}
 295  
 296  	tx, err := s.DB.BeginTx(ctx, nil)
 297  	if err != nil {
 298  		return false, fmt.Errorf("failed to begin the transaction: %w", err)
 299  	}
 300  	defer tx.Rollback()
 301  
 302  	var oldID string
 303  	var oldCreatedAt nostr.Timestamp
 304  
 305  	row := tx.QueryRowContext(ctx, query.SQL, query.Args...)
 306  	err = row.Scan(&oldID, &oldCreatedAt)
 307  
 308  	if errors.Is(err, sql.ErrNoRows) {
 309  		// no old event found, insert the new one
 310  		_, err := tx.ExecContext(ctx, `INSERT OR IGNORE INTO events (id, pubkey, created_at, kind, tags, content, sig)
 311          VALUES ($1, $2, $3, $4, $5, $6, $7)`, event.ID, event.PubKey, event.CreatedAt, event.Kind, tags, event.Content, event.Sig)
 312  
 313  		if err != nil {
 314  			return false, err
 315  		}
 316  
 317  		if err = tx.Commit(); err != nil {
 318  			return false, err
 319  		}
 320  
 321  		s.checkOptimize(ctx)
 322  		return true, nil
 323  	}
 324  
 325  	if err != nil {
 326  		// fail on different non nil errors
 327  		return false, fmt.Errorf("failed to query for old event: %w", err)
 328  	}
 329  
 330  	if oldCreatedAt >= event.CreatedAt {
 331  		return false, nil
 332  	}
 333  
 334  	if _, err = tx.ExecContext(ctx, `INSERT OR IGNORE INTO events (id, pubkey, created_at, kind, tags, content, sig)
 335  	VALUES ($1, $2, $3, $4, $5, $6, $7)`, event.ID, event.PubKey, event.CreatedAt, event.Kind, tags, event.Content, event.Sig); err != nil {
 336  		return false, err
 337  	}
 338  
 339  	if _, err = tx.ExecContext(ctx, "DELETE FROM events WHERE id = $1", oldID); err != nil {
 340  		return false, err
 341  	}
 342  
 343  	if err := tx.Commit(); err != nil {
 344  		return false, err
 345  	}
 346  	return true, nil
 347  }
 348  
 349  // Query stored events matching the provided filters.
 350  func (s *Store) Query(ctx context.Context, filters ...nostr.Filter) ([]nostr.Event, error) {
 351  	return s.QueryWithBuilder(ctx, s.queryBuilder, filters...)
 352  }
 353  
 354  // QueryWithBuilder generates an sqlite query for the filters with the provided [QueryBuilder], and executes it.
 355  func (s *Store) QueryWithBuilder(ctx context.Context, build QueryBuilder, filters ...nostr.Filter) ([]nostr.Event, error) {
 356  	filters, err := s.filterPolicy(filters...)
 357  	if err != nil {
 358  		return nil, err
 359  	}
 360  
 361  	queries, err := build(filters...)
 362  	if err != nil {
 363  		return nil, fmt.Errorf("failed to build query: %w", err)
 364  	}
 365  
 366  	var events []nostr.Event
 367  	for i, query := range queries {
 368  		rows, err := s.DB.QueryContext(ctx, query.SQL, query.Args...)
 369  		if errors.Is(err, sql.ErrNoRows) {
 370  			continue
 371  		}
 372  		if err != nil {
 373  			return nil, fmt.Errorf("failed to fetch events with query %s: %w", queries[i], err)
 374  		}
 375  		defer rows.Close()
 376  
 377  		for rows.Next() {
 378  			var event nostr.Event
 379  			err = rows.Scan(&event.ID, &event.PubKey, &event.CreatedAt, &event.Kind, &event.Tags, &event.Content, &event.Sig)
 380  			if err != nil {
 381  				return events, fmt.Errorf("%w: failed to scan event row: %w", ErrInternalQuery, err)
 382  			}
 383  
 384  			events = append(events, event)
 385  		}
 386  
 387  		if err := rows.Err(); err != nil {
 388  			return events, fmt.Errorf("%w: failed to scan event row: %w", ErrInternalQuery, err)
 389  		}
 390  	}
 391  	return events, nil
 392  }
 393  
 394  // Count stored events matching the provided filters.
 395  func (s *Store) Count(ctx context.Context, filters ...nostr.Filter) (int64, error) {
 396  	return s.CountWithBuilder(ctx, s.countBuilder, filters...)
 397  }
 398  
 399  // CountWithBuilder generates an sqlite query for the filters with the provided [QueryBuilder], and executes it.
 400  func (s *Store) CountWithBuilder(ctx context.Context, build QueryBuilder, filters ...nostr.Filter) (int64, error) {
 401  	queries, err := build(filters...)
 402  	if err != nil {
 403  		return 0, fmt.Errorf("failed to build count query: %w", err)
 404  	}
 405  
 406  	var total int64
 407  	for i, query := range queries {
 408  		var count int64
 409  		row := s.DB.QueryRowContext(ctx, query.SQL, query.Args...)
 410  		err := row.Scan(&count)
 411  		if err != nil {
 412  			return 0, fmt.Errorf("failed to count events with query %s: %w", queries[i], err)
 413  		}
 414  
 415  		total += count
 416  	}
 417  	return total, nil
 418  }
 419  
 420  func DefaultQueryBuilder(filters ...nostr.Filter) ([]Query, error) {
 421  	switch len(filters) {
 422  	case 0:
 423  		return nil, nil
 424  
 425  	case 1:
 426  		query, args := buildQuery(filters[0])
 427  		query += " ORDER BY e.created_at DESC, e.id ASC LIMIT ?"
 428  		args = append(args, filters[0].Limit)
 429  		return []Query{{SQL: query, Args: args}}, nil
 430  
 431  	default:
 432  		subQueries := make([]string, 0, len(filters))
 433  		allArgs := make([]any, 0, len(filters))
 434  		limit := 0
 435  
 436  		for _, filter := range filters {
 437  			query, args := buildQuery(filter)
 438  			subQueries = append(subQueries, query)
 439  			allArgs = append(allArgs, args...)
 440  			limit += filter.Limit
 441  		}
 442  
 443  		query := "SELECT * FROM (" + strings.Join(subQueries, " UNION ALL ") + ")" +
 444  			" GROUP BY id ORDER BY created_at DESC, id ASC LIMIT ?"
 445  		allArgs = append(allArgs, limit)
 446  		return []Query{{SQL: query, Args: allArgs}}, nil
 447  	}
 448  }
 449  
 450  func DefaultCountBuilder(filters ...nostr.Filter) ([]Query, error) {
 451  	switch len(filters) {
 452  	case 0:
 453  		return nil, nil
 454  
 455  	case 1:
 456  		query, args := buildCount(filters[0])
 457  		return []Query{{SQL: query, Args: args}}, nil
 458  
 459  	default:
 460  		subQueries := make([]string, 0, len(filters))
 461  		allArgs := make([]any, 0, len(filters))
 462  
 463  		for _, filter := range filters {
 464  			query, args := buildCount(filter)
 465  			subQueries = append(subQueries, "("+query+")")
 466  			allArgs = append(allArgs, args...)
 467  		}
 468  
 469  		// TODO: we are summing all counts together, without any deduplication
 470  		query := "SELECT (" + strings.Join(subQueries, " + ") + ")"
 471  		return []Query{{SQL: query, Args: allArgs}}, nil
 472  	}
 473  }
 474  
 475  func buildQuery(filter nostr.Filter) (string, []any) {
 476  	sql := toSql(filter)
 477  	if sql.JoinTags {
 478  		query := "SELECT e.* FROM events AS e JOIN tags AS t ON t.event_id = e.id" +
 479  			" WHERE " + strings.Join(sql.Conditions, " AND ") + " GROUP BY e.id"
 480  		return query, sql.Args
 481  	}
 482  
 483  	query := "SELECT e.* FROM events AS e"
 484  	if len(sql.Conditions) > 0 {
 485  		query += " WHERE " + strings.Join(sql.Conditions, " AND ")
 486  	}
 487  	return query, sql.Args
 488  }
 489  
 490  func buildCount(filter nostr.Filter) (string, []any) {
 491  	sql := toSql(filter)
 492  	if sql.JoinTags {
 493  		query := "SELECT COUNT(DISTINCT e.id) FROM events AS e JOIN tags AS t ON t.event_id = e.id" +
 494  			" WHERE " + strings.Join(sql.Conditions, " AND ")
 495  		return query, sql.Args
 496  	}
 497  
 498  	query := "SELECT COUNT(e.id) FROM events AS e"
 499  	if len(sql.Conditions) > 0 {
 500  		query += " WHERE " + strings.Join(sql.Conditions, " AND ")
 501  	}
 502  	return query, sql.Args
 503  }
 504  
 505  type sqlFilter struct {
 506  	Conditions []string
 507  	Args       []any
 508  	JoinTags   bool
 509  }
 510  
 511  func toSql(filter nostr.Filter) sqlFilter {
 512  	s := sqlFilter{}
 513  	if len(filter.IDs) > 0 {
 514  		s.Conditions = append(s.Conditions, "e.id"+equalityClause(filter.IDs))
 515  		for _, id := range filter.IDs {
 516  			s.Args = append(s.Args, id)
 517  		}
 518  	}
 519  
 520  	if len(filter.Kinds) > 0 {
 521  		s.Conditions = append(s.Conditions, "e.kind"+equalityClause(filter.Kinds))
 522  		for _, kind := range filter.Kinds {
 523  			s.Args = append(s.Args, kind)
 524  		}
 525  	}
 526  
 527  	if len(filter.Authors) > 0 {
 528  		s.Conditions = append(s.Conditions, "e.pubkey"+equalityClause(filter.Authors))
 529  		for _, pk := range filter.Authors {
 530  			s.Args = append(s.Args, pk)
 531  		}
 532  	}
 533  
 534  	if filter.Until != nil {
 535  		s.Conditions = append(s.Conditions, "e.created_at <= ?")
 536  		s.Args = append(s.Args, filter.Until.Time().Unix())
 537  	}
 538  
 539  	if filter.Since != nil {
 540  		s.Conditions = append(s.Conditions, "e.created_at >= ?")
 541  		s.Args = append(s.Args, filter.Since.Time().Unix())
 542  	}
 543  
 544  	if len(filter.Tags) > 0 {
 545  		conds := make([]string, 0, len(filter.Tags))
 546  		args := make([]any, 0, len(filter.Tags))
 547  
 548  		for key, vals := range filter.Tags {
 549  			if len(vals) == 0 {
 550  				continue
 551  			}
 552  
 553  			conds = append(conds, "(t.key = ? AND t.value"+equalityClause(vals)+")")
 554  			args = append(args, key)
 555  			for _, v := range vals {
 556  				args = append(args, v)
 557  			}
 558  		}
 559  
 560  		if len(conds) > 0 {
 561  			s.JoinTags = true
 562  			s.Conditions = append(s.Conditions, strings.Join(conds, " OR "))
 563  			s.Args = append(s.Args, args...)
 564  		}
 565  	}
 566  	return s
 567  }
 568  
 569  // equalityClause returns the appropriate SQL comparison operator and placeholder(s)
 570  // for use in a WHERE clause, based on the number of values provided.
 571  // If the slice contains one value, it returns " = ?".
 572  // If it contains multiple values, it returns " IN (?, ?, ... )" with the correct number of placeholders.
 573  // It panics is vals is nil or empty.
 574  func equalityClause[T any](vals []T) string {
 575  	if len(vals) == 1 {
 576  		return " = ?"
 577  	}
 578  	return " IN (?" + strings.Repeat(",?", len(vals)-1) + ")"
 579  }
 580