negentropy.mx raw

   1  // Package negentropy implements the negentropy set reconciliation protocol
   2  // for efficient event sync between relays.
   3  //
   4  // The protocol works by exchanging XOR fingerprints over sorted ranges of
   5  // (timestamp, id) pairs. Matching fingerprints mean the range is identical;
   6  // mismatched ranges are subdivided recursively until individual items are
   7  // exchanged.
   8  package negentropy
   9  
  10  import (
  11  	"bytes"
  12  	"crypto/sha256"
  13  	"math"
  14  	"sort"
  15  
  16  	"smesh.lol/pkg/nostr/event"
  17  	"smesh.lol/pkg/nostr/filter"
  18  	"smesh.lol/pkg/store"
  19  )
  20  
  21  // Item is a (timestamp, id) pair for set reconciliation.
  22  type Item struct {
  23  	Timestamp int64
  24  	ID        []byte // 32 bytes
  25  }
  26  
  27  func compareItems(a, b Item) int {
  28  	if a.Timestamp != b.Timestamp {
  29  		if a.Timestamp < b.Timestamp {
  30  			return -1
  31  		}
  32  		return 1
  33  	}
  34  	return bytes.Compare(a.ID, b.ID)
  35  }
  36  
  37  // ItemsFromEvents extracts sorted items from events.
  38  func ItemsFromEvents(events []*event.E) []Item {
  39  	items := []Item{:len(events)}
  40  	for i, ev := range events {
  41  		id := []byte{:32}
  42  		copy(id, ev.ID)
  43  		items[i] = Item{Timestamp: ev.CreatedAt, ID: id}
  44  	}
  45  	sortItems(items)
  46  	return items
  47  }
  48  
  49  func sortItems(items []Item) {
  50  	sort.Slice(items, func(i, j int) bool {
  51  		return compareItems(items[i], items[j]) < 0
  52  	})
  53  }
  54  
  55  // Fingerprint computes an XOR fingerprint over a range of items.
  56  func Fingerprint(items []Item) [32]byte {
  57  	var fp [32]byte
  58  	for _, it := range items {
  59  		for j := 0; j < 32 && j < len(it.ID); j++ {
  60  			fp[j] ^= it.ID[j]
  61  		}
  62  	}
  63  	return fp
  64  }
  65  
  66  // Diff computes set differences between two sorted item lists.
  67  // Returns have (items in local but not remote) and need (items in remote but not local).
  68  func Diff(local, remote []Item) (have, need []Item) {
  69  	i, j := 0, 0
  70  	for i < len(local) && j < len(remote) {
  71  		cmp := compareItems(local[i], remote[j])
  72  		if cmp < 0 {
  73  			have = append(have, local[i])
  74  			i++
  75  		} else if cmp > 0 {
  76  			need = append(need, remote[j])
  77  			j++
  78  		} else {
  79  			i++
  80  			j++
  81  		}
  82  	}
  83  	for ; i < len(local); i++ {
  84  		have = append(have, local[i])
  85  	}
  86  	for ; j < len(remote); j++ {
  87  		need = append(need, remote[j])
  88  	}
  89  	return
  90  }
  91  
  92  // Reconciler performs fingerprint-based reconciliation.
  93  type Reconciler struct {
  94  	items []Item
  95  }
  96  
  97  // NewReconciler creates a reconciler from a sorted item list.
  98  func NewReconciler(items []Item) *Reconciler {
  99  	sortItems(items)
 100  	return &Reconciler{items: items}
 101  }
 102  
 103  // Range represents a segment with a fingerprint.
 104  type Range struct {
 105  	UpperTimestamp int64
 106  	UpperID       []byte
 107  	Fingerprint   [32]byte
 108  	Count         int
 109  }
 110  
 111  // Split divides the item set into n ranges with fingerprints.
 112  func (r *Reconciler) Split(n int) []Range {
 113  	if n <= 0 || len(r.items) == 0 {
 114  		return nil
 115  	}
 116  	if n > len(r.items) {
 117  		n = len(r.items)
 118  	}
 119  	segSize := len(r.items) / n
 120  	ranges := []Range{:0:n}
 121  	for i := 0; i < n; i++ {
 122  		lo := i * segSize
 123  		hi := (i + 1) * segSize
 124  		if i == n-1 {
 125  			hi = len(r.items)
 126  		}
 127  		upper := r.items[hi-1]
 128  		ranges = append(ranges, Range{
 129  			UpperTimestamp: upper.Timestamp,
 130  			UpperID:       upper.ID,
 131  			Fingerprint:   Fingerprint(r.items[lo:hi]),
 132  			Count:         hi - lo,
 133  		})
 134  	}
 135  	return ranges
 136  }
 137  
 138  // FindMismatches compares local ranges against remote ranges and
 139  // returns indices of ranges that differ.
 140  func FindMismatches(local, remote []Range) []int {
 141  	n := len(local)
 142  	if len(remote) < n {
 143  		n = len(remote)
 144  	}
 145  	var mismatches []int
 146  	for i := 0; i < n; i++ {
 147  		if local[i].Fingerprint != remote[i].Fingerprint {
 148  			mismatches = append(mismatches, i)
 149  		}
 150  	}
 151  	return mismatches
 152  }
 153  
 154  // Syncer orchestrates sync between local store and a remote item set.
 155  type Syncer struct {
 156  	store *store.Engine
 157  }
 158  
 159  // NewSyncer creates a syncer.
 160  func NewSyncer(s *store.Engine) *Syncer { return &Syncer{store: s} }
 161  
 162  // LocalItems returns all (timestamp, id) pairs from the local store.
 163  func (s *Syncer) LocalItems() []Item {
 164  	f := &filter.F{}
 165  	events, err := s.store.QueryEvents(f)
 166  	if err != nil {
 167  		return nil
 168  	}
 169  	return ItemsFromEvents(events)
 170  }
 171  
 172  // FindNeeded compares local items against remote items and returns
 173  // IDs of events we need to fetch.
 174  func (s *Syncer) FindNeeded(remoteItems []Item) [][]byte {
 175  	local := s.LocalItems()
 176  	_, need := Diff(local, remoteItems)
 177  	ids := [][]byte{:len(need)}
 178  	for i, it := range need {
 179  		ids[i] = it.ID
 180  	}
 181  	return ids
 182  }
 183  
 184  // FindHave compares local items against remote items and returns
 185  // events we have that the remote doesn't.
 186  func (s *Syncer) FindHave(remoteItems []Item) []*event.E {
 187  	local := s.LocalItems()
 188  	have, _ := Diff(local, remoteItems)
 189  	var events []*event.E
 190  	for _, it := range have {
 191  		ev, err := s.store.GetByID(it.ID)
 192  		if err == nil && ev != nil {
 193  			events = append(events, ev)
 194  		}
 195  	}
 196  	return events
 197  }
 198  
 199  func sha256Hash(data []byte) []byte {
 200  	h := sha256.Sum256(data)
 201  	return h[:]
 202  }
 203  
 204  // EstimateRanges returns the optimal number of ranges for a given item count.
 205  func EstimateRanges(n int) int {
 206  	if n <= 0 {
 207  		return 0
 208  	}
 209  	r := int(math.Ceil(math.Sqrt(float64(n))))
 210  	if r < 2 {
 211  		return 2
 212  	}
 213  	if r > 128 {
 214  		return 128
 215  	}
 216  	return r
 217  }
 218