store.go raw

   1  /*
   2   * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
   3   * SPDX-License-Identifier: Apache-2.0
   4   */
   5  
   6  package ristretto
   7  
   8  import (
   9  	"sync"
  10  	"time"
  11  )
  12  
  13  type updateFn[V any] func(cur, prev V) bool
  14  
  15  // TODO: Do we need this to be a separate struct from Item?
  16  type storeItem[V any] struct {
  17  	key        uint64
  18  	conflict   uint64
  19  	value      V
  20  	expiration time.Time
  21  }
  22  
  23  // store is the interface fulfilled by all hash map implementations in this
  24  // file. Some hash map implementations are better suited for certain data
  25  // distributions than others, so this allows us to abstract that out for use
  26  // in Ristretto.
  27  //
  28  // Every store is safe for concurrent usage.
  29  type store[V any] interface {
  30  	// Get returns the value associated with the key parameter.
  31  	Get(uint64, uint64) (V, bool)
  32  	// Expiration returns the expiration time for this key.
  33  	Expiration(uint64) time.Time
  34  	// Set adds the key-value pair to the Map or updates the value if it's
  35  	// already present. The key-value pair is passed as a pointer to an
  36  	// item object.
  37  	Set(*Item[V])
  38  	// Del deletes the key-value pair from the Map.
  39  	Del(uint64, uint64) (uint64, V)
  40  	// Update attempts to update the key with a new value and returns true if
  41  	// successful.
  42  	Update(*Item[V]) (V, bool)
  43  	// Cleanup removes items that have an expired TTL.
  44  	Cleanup(policy *defaultPolicy[V], onEvict func(item *Item[V]))
  45  	// Clear clears all contents of the store.
  46  	Clear(onEvict func(item *Item[V]))
  47  	SetShouldUpdateFn(f updateFn[V])
  48  }
  49  
  50  // newStore returns the default store implementation.
  51  func newStore[V any]() store[V] {
  52  	return newShardedMap[V]()
  53  }
  54  
  55  const numShards uint64 = 256
  56  
  57  type shardedMap[V any] struct {
  58  	shards    []*lockedMap[V]
  59  	expiryMap *expirationMap[V]
  60  }
  61  
  62  func newShardedMap[V any]() *shardedMap[V] {
  63  	sm := &shardedMap[V]{
  64  		shards:    make([]*lockedMap[V], int(numShards)),
  65  		expiryMap: newExpirationMap[V](),
  66  	}
  67  	for i := range sm.shards {
  68  		sm.shards[i] = newLockedMap[V](sm.expiryMap)
  69  	}
  70  	return sm
  71  }
  72  
  73  func (m *shardedMap[V]) SetShouldUpdateFn(f updateFn[V]) {
  74  	for i := range m.shards {
  75  		m.shards[i].setShouldUpdateFn(f)
  76  	}
  77  }
  78  
  79  func (sm *shardedMap[V]) Get(key, conflict uint64) (V, bool) {
  80  	return sm.shards[key%numShards].get(key, conflict)
  81  }
  82  
  83  func (sm *shardedMap[V]) Expiration(key uint64) time.Time {
  84  	return sm.shards[key%numShards].Expiration(key)
  85  }
  86  
  87  func (sm *shardedMap[V]) Set(i *Item[V]) {
  88  	if i == nil {
  89  		// If item is nil make this Set a no-op.
  90  		return
  91  	}
  92  
  93  	sm.shards[i.Key%numShards].Set(i)
  94  }
  95  
  96  func (sm *shardedMap[V]) Del(key, conflict uint64) (uint64, V) {
  97  	return sm.shards[key%numShards].Del(key, conflict)
  98  }
  99  
 100  func (sm *shardedMap[V]) Update(newItem *Item[V]) (V, bool) {
 101  	return sm.shards[newItem.Key%numShards].Update(newItem)
 102  }
 103  
 104  func (sm *shardedMap[V]) Cleanup(policy *defaultPolicy[V], onEvict func(item *Item[V])) {
 105  	sm.expiryMap.cleanup(sm, policy, onEvict)
 106  }
 107  
 108  func (sm *shardedMap[V]) Clear(onEvict func(item *Item[V])) {
 109  	for i := uint64(0); i < numShards; i++ {
 110  		sm.shards[i].Clear(onEvict)
 111  	}
 112  	sm.expiryMap.clear()
 113  }
 114  
 115  type lockedMap[V any] struct {
 116  	sync.RWMutex
 117  	data         map[uint64]storeItem[V]
 118  	em           *expirationMap[V]
 119  	shouldUpdate updateFn[V]
 120  }
 121  
 122  func newLockedMap[V any](em *expirationMap[V]) *lockedMap[V] {
 123  	return &lockedMap[V]{
 124  		data: make(map[uint64]storeItem[V]),
 125  		em:   em,
 126  		shouldUpdate: func(cur, prev V) bool {
 127  			return true
 128  		},
 129  	}
 130  }
 131  
 132  func (m *lockedMap[V]) setShouldUpdateFn(f updateFn[V]) {
 133  	m.shouldUpdate = f
 134  }
 135  
 136  func (m *lockedMap[V]) get(key, conflict uint64) (V, bool) {
 137  	m.RLock()
 138  	item, ok := m.data[key]
 139  	m.RUnlock()
 140  	if !ok {
 141  		return zeroValue[V](), false
 142  	}
 143  	if conflict != 0 && (conflict != item.conflict) {
 144  		return zeroValue[V](), false
 145  	}
 146  
 147  	// Handle expired items.
 148  	if !item.expiration.IsZero() && time.Now().After(item.expiration) {
 149  		return zeroValue[V](), false
 150  	}
 151  	return item.value, true
 152  }
 153  
 154  func (m *lockedMap[V]) Expiration(key uint64) time.Time {
 155  	m.RLock()
 156  	defer m.RUnlock()
 157  	return m.data[key].expiration
 158  }
 159  
 160  func (m *lockedMap[V]) Set(i *Item[V]) {
 161  	if i == nil {
 162  		// If the item is nil make this Set a no-op.
 163  		return
 164  	}
 165  
 166  	m.Lock()
 167  	defer m.Unlock()
 168  	item, ok := m.data[i.Key]
 169  
 170  	if ok {
 171  		// The item existed already. We need to check the conflict key and reject the
 172  		// update if they do not match. Only after that the expiration map is updated.
 173  		if i.Conflict != 0 && (i.Conflict != item.conflict) {
 174  			return
 175  		}
 176  		if m.shouldUpdate != nil && !m.shouldUpdate(i.Value, item.value) {
 177  			return
 178  		}
 179  		m.em.update(i.Key, i.Conflict, item.expiration, i.Expiration)
 180  	} else {
 181  		// The value is not in the map already. There's no need to return anything.
 182  		// Simply add the expiration map.
 183  		m.em.add(i.Key, i.Conflict, i.Expiration)
 184  	}
 185  
 186  	m.data[i.Key] = storeItem[V]{
 187  		key:        i.Key,
 188  		conflict:   i.Conflict,
 189  		value:      i.Value,
 190  		expiration: i.Expiration,
 191  	}
 192  }
 193  
 194  func (m *lockedMap[V]) Del(key, conflict uint64) (uint64, V) {
 195  	m.Lock()
 196  	defer m.Unlock()
 197  	item, ok := m.data[key]
 198  	if !ok {
 199  		return 0, zeroValue[V]()
 200  	}
 201  	if conflict != 0 && (conflict != item.conflict) {
 202  		return 0, zeroValue[V]()
 203  	}
 204  
 205  	if !item.expiration.IsZero() {
 206  		m.em.del(key, item.expiration)
 207  	}
 208  
 209  	delete(m.data, key)
 210  	return item.conflict, item.value
 211  }
 212  
 213  func (m *lockedMap[V]) Update(newItem *Item[V]) (V, bool) {
 214  	m.Lock()
 215  	defer m.Unlock()
 216  	item, ok := m.data[newItem.Key]
 217  	if !ok {
 218  		return zeroValue[V](), false
 219  	}
 220  	if newItem.Conflict != 0 && (newItem.Conflict != item.conflict) {
 221  		return zeroValue[V](), false
 222  	}
 223  	if m.shouldUpdate != nil && !m.shouldUpdate(newItem.Value, item.value) {
 224  		return item.value, false
 225  	}
 226  
 227  	m.em.update(newItem.Key, newItem.Conflict, item.expiration, newItem.Expiration)
 228  	m.data[newItem.Key] = storeItem[V]{
 229  		key:        newItem.Key,
 230  		conflict:   newItem.Conflict,
 231  		value:      newItem.Value,
 232  		expiration: newItem.Expiration,
 233  	}
 234  
 235  	return item.value, true
 236  }
 237  
 238  func (m *lockedMap[V]) Clear(onEvict func(item *Item[V])) {
 239  	m.Lock()
 240  	defer m.Unlock()
 241  	i := &Item[V]{}
 242  	if onEvict != nil {
 243  		for _, si := range m.data {
 244  			i.Key = si.key
 245  			i.Conflict = si.conflict
 246  			i.Value = si.value
 247  			onEvict(i)
 248  		}
 249  	}
 250  	m.data = make(map[uint64]storeItem[V])
 251  }
 252