store.go raw
1 /*
2 * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6 package ristretto
7
8 import (
9 "sync"
10 "time"
11 )
12
13 type updateFn[V any] func(cur, prev V) bool
14
15 // TODO: Do we need this to be a separate struct from Item?
16 type storeItem[V any] struct {
17 key uint64
18 conflict uint64
19 value V
20 expiration time.Time
21 }
22
23 // store is the interface fulfilled by all hash map implementations in this
24 // file. Some hash map implementations are better suited for certain data
25 // distributions than others, so this allows us to abstract that out for use
26 // in Ristretto.
27 //
28 // Every store is safe for concurrent usage.
29 type store[V any] interface {
30 // Get returns the value associated with the key parameter.
31 Get(uint64, uint64) (V, bool)
32 // Expiration returns the expiration time for this key.
33 Expiration(uint64) time.Time
34 // Set adds the key-value pair to the Map or updates the value if it's
35 // already present. The key-value pair is passed as a pointer to an
36 // item object.
37 Set(*Item[V])
38 // Del deletes the key-value pair from the Map.
39 Del(uint64, uint64) (uint64, V)
40 // Update attempts to update the key with a new value and returns true if
41 // successful.
42 Update(*Item[V]) (V, bool)
43 // Cleanup removes items that have an expired TTL.
44 Cleanup(policy *defaultPolicy[V], onEvict func(item *Item[V]))
45 // Clear clears all contents of the store.
46 Clear(onEvict func(item *Item[V]))
47 SetShouldUpdateFn(f updateFn[V])
48 }
49
50 // newStore returns the default store implementation.
51 func newStore[V any]() store[V] {
52 return newShardedMap[V]()
53 }
54
55 const numShards uint64 = 256
56
57 type shardedMap[V any] struct {
58 shards []*lockedMap[V]
59 expiryMap *expirationMap[V]
60 }
61
62 func newShardedMap[V any]() *shardedMap[V] {
63 sm := &shardedMap[V]{
64 shards: make([]*lockedMap[V], int(numShards)),
65 expiryMap: newExpirationMap[V](),
66 }
67 for i := range sm.shards {
68 sm.shards[i] = newLockedMap[V](sm.expiryMap)
69 }
70 return sm
71 }
72
73 func (m *shardedMap[V]) SetShouldUpdateFn(f updateFn[V]) {
74 for i := range m.shards {
75 m.shards[i].setShouldUpdateFn(f)
76 }
77 }
78
79 func (sm *shardedMap[V]) Get(key, conflict uint64) (V, bool) {
80 return sm.shards[key%numShards].get(key, conflict)
81 }
82
83 func (sm *shardedMap[V]) Expiration(key uint64) time.Time {
84 return sm.shards[key%numShards].Expiration(key)
85 }
86
87 func (sm *shardedMap[V]) Set(i *Item[V]) {
88 if i == nil {
89 // If item is nil make this Set a no-op.
90 return
91 }
92
93 sm.shards[i.Key%numShards].Set(i)
94 }
95
96 func (sm *shardedMap[V]) Del(key, conflict uint64) (uint64, V) {
97 return sm.shards[key%numShards].Del(key, conflict)
98 }
99
100 func (sm *shardedMap[V]) Update(newItem *Item[V]) (V, bool) {
101 return sm.shards[newItem.Key%numShards].Update(newItem)
102 }
103
104 func (sm *shardedMap[V]) Cleanup(policy *defaultPolicy[V], onEvict func(item *Item[V])) {
105 sm.expiryMap.cleanup(sm, policy, onEvict)
106 }
107
108 func (sm *shardedMap[V]) Clear(onEvict func(item *Item[V])) {
109 for i := uint64(0); i < numShards; i++ {
110 sm.shards[i].Clear(onEvict)
111 }
112 sm.expiryMap.clear()
113 }
114
115 type lockedMap[V any] struct {
116 sync.RWMutex
117 data map[uint64]storeItem[V]
118 em *expirationMap[V]
119 shouldUpdate updateFn[V]
120 }
121
122 func newLockedMap[V any](em *expirationMap[V]) *lockedMap[V] {
123 return &lockedMap[V]{
124 data: make(map[uint64]storeItem[V]),
125 em: em,
126 shouldUpdate: func(cur, prev V) bool {
127 return true
128 },
129 }
130 }
131
132 func (m *lockedMap[V]) setShouldUpdateFn(f updateFn[V]) {
133 m.shouldUpdate = f
134 }
135
136 func (m *lockedMap[V]) get(key, conflict uint64) (V, bool) {
137 m.RLock()
138 item, ok := m.data[key]
139 m.RUnlock()
140 if !ok {
141 return zeroValue[V](), false
142 }
143 if conflict != 0 && (conflict != item.conflict) {
144 return zeroValue[V](), false
145 }
146
147 // Handle expired items.
148 if !item.expiration.IsZero() && time.Now().After(item.expiration) {
149 return zeroValue[V](), false
150 }
151 return item.value, true
152 }
153
154 func (m *lockedMap[V]) Expiration(key uint64) time.Time {
155 m.RLock()
156 defer m.RUnlock()
157 return m.data[key].expiration
158 }
159
160 func (m *lockedMap[V]) Set(i *Item[V]) {
161 if i == nil {
162 // If the item is nil make this Set a no-op.
163 return
164 }
165
166 m.Lock()
167 defer m.Unlock()
168 item, ok := m.data[i.Key]
169
170 if ok {
171 // The item existed already. We need to check the conflict key and reject the
172 // update if they do not match. Only after that the expiration map is updated.
173 if i.Conflict != 0 && (i.Conflict != item.conflict) {
174 return
175 }
176 if m.shouldUpdate != nil && !m.shouldUpdate(i.Value, item.value) {
177 return
178 }
179 m.em.update(i.Key, i.Conflict, item.expiration, i.Expiration)
180 } else {
181 // The value is not in the map already. There's no need to return anything.
182 // Simply add the expiration map.
183 m.em.add(i.Key, i.Conflict, i.Expiration)
184 }
185
186 m.data[i.Key] = storeItem[V]{
187 key: i.Key,
188 conflict: i.Conflict,
189 value: i.Value,
190 expiration: i.Expiration,
191 }
192 }
193
194 func (m *lockedMap[V]) Del(key, conflict uint64) (uint64, V) {
195 m.Lock()
196 defer m.Unlock()
197 item, ok := m.data[key]
198 if !ok {
199 return 0, zeroValue[V]()
200 }
201 if conflict != 0 && (conflict != item.conflict) {
202 return 0, zeroValue[V]()
203 }
204
205 if !item.expiration.IsZero() {
206 m.em.del(key, item.expiration)
207 }
208
209 delete(m.data, key)
210 return item.conflict, item.value
211 }
212
213 func (m *lockedMap[V]) Update(newItem *Item[V]) (V, bool) {
214 m.Lock()
215 defer m.Unlock()
216 item, ok := m.data[newItem.Key]
217 if !ok {
218 return zeroValue[V](), false
219 }
220 if newItem.Conflict != 0 && (newItem.Conflict != item.conflict) {
221 return zeroValue[V](), false
222 }
223 if m.shouldUpdate != nil && !m.shouldUpdate(newItem.Value, item.value) {
224 return item.value, false
225 }
226
227 m.em.update(newItem.Key, newItem.Conflict, item.expiration, newItem.Expiration)
228 m.data[newItem.Key] = storeItem[V]{
229 key: newItem.Key,
230 conflict: newItem.Conflict,
231 value: newItem.Value,
232 expiration: newItem.Expiration,
233 }
234
235 return item.value, true
236 }
237
238 func (m *lockedMap[V]) Clear(onEvict func(item *Item[V])) {
239 m.Lock()
240 defer m.Unlock()
241 i := &Item[V]{}
242 if onEvict != nil {
243 for _, si := range m.data {
244 i.Key = si.key
245 i.Conflict = si.conflict
246 i.Value = si.value
247 onEvict(i)
248 }
249 }
250 m.data = make(map[uint64]storeItem[V])
251 }
252