policy.go raw

   1  /*
   2   * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
   3   * SPDX-License-Identifier: Apache-2.0
   4   */
   5  
   6  package ristretto
   7  
   8  import (
   9  	"math"
  10  	"sync"
  11  	"sync/atomic"
  12  
  13  	"github.com/dgraph-io/ristretto/v2/z"
  14  )
  15  
  16  const (
  17  	// lfuSample is the number of items to sample when looking at eviction
  18  	// candidates. 5 seems to be the most optimal number [citation needed].
  19  	lfuSample = 5
  20  )
  21  
  22  func newPolicy[V any](numCounters, maxCost int64) *defaultPolicy[V] {
  23  	return newDefaultPolicy[V](numCounters, maxCost)
  24  }
  25  
  26  type defaultPolicy[V any] struct {
  27  	sync.Mutex
  28  	admit    *tinyLFU
  29  	evict    *sampledLFU
  30  	itemsCh  chan []uint64
  31  	stop     chan struct{}
  32  	done     chan struct{}
  33  	isClosed bool
  34  	metrics  *Metrics
  35  }
  36  
  37  func newDefaultPolicy[V any](numCounters, maxCost int64) *defaultPolicy[V] {
  38  	p := &defaultPolicy[V]{
  39  		admit:   newTinyLFU(numCounters),
  40  		evict:   newSampledLFU(maxCost),
  41  		itemsCh: make(chan []uint64, 3),
  42  		stop:    make(chan struct{}),
  43  		done:    make(chan struct{}),
  44  	}
  45  	go p.processItems()
  46  	return p
  47  }
  48  
  49  func (p *defaultPolicy[V]) CollectMetrics(metrics *Metrics) {
  50  	p.metrics = metrics
  51  	p.evict.metrics = metrics
  52  }
  53  
  54  type policyPair struct {
  55  	key  uint64
  56  	cost int64
  57  }
  58  
  59  func (p *defaultPolicy[V]) processItems() {
  60  	for {
  61  		select {
  62  		case items := <-p.itemsCh:
  63  			p.Lock()
  64  			p.admit.Push(items)
  65  			p.Unlock()
  66  		case <-p.stop:
  67  			p.done <- struct{}{}
  68  			return
  69  		}
  70  	}
  71  }
  72  
  73  func (p *defaultPolicy[V]) Push(keys []uint64) bool {
  74  	if p.isClosed {
  75  		return false
  76  	}
  77  
  78  	if len(keys) == 0 {
  79  		return true
  80  	}
  81  
  82  	select {
  83  	case p.itemsCh <- keys:
  84  		p.metrics.add(keepGets, keys[0], uint64(len(keys)))
  85  		return true
  86  	default:
  87  		p.metrics.add(dropGets, keys[0], uint64(len(keys)))
  88  		return false
  89  	}
  90  }
  91  
  92  // Add decides whether the item with the given key and cost should be accepted by
  93  // the policy. It returns the list of victims that have been evicted and a boolean
  94  // indicating whether the incoming item should be accepted.
  95  func (p *defaultPolicy[V]) Add(key uint64, cost int64) ([]*Item[V], bool) {
  96  	p.Lock()
  97  	defer p.Unlock()
  98  
  99  	// Cannot add an item bigger than entire cache.
 100  	if cost > p.evict.getMaxCost() {
 101  		return nil, false
 102  	}
 103  
 104  	// No need to go any further if the item is already in the cache.
 105  	if has := p.evict.updateIfHas(key, cost); has {
 106  		// An update does not count as an addition, so return false.
 107  		return nil, false
 108  	}
 109  
 110  	// If the execution reaches this point, the key doesn't exist in the cache.
 111  	// Calculate the remaining room in the cache (usually bytes).
 112  	room := p.evict.roomLeft(cost)
 113  	if room >= 0 {
 114  		// There's enough room in the cache to store the new item without
 115  		// overflowing. Do that now and stop here.
 116  		p.evict.add(key, cost)
 117  		p.metrics.add(costAdd, key, uint64(cost))
 118  		return nil, true
 119  	}
 120  
 121  	// incHits is the hit count for the incoming item.
 122  	incHits := p.admit.Estimate(key)
 123  	// sample is the eviction candidate pool to be filled via random sampling.
 124  	// TODO: perhaps we should use a min heap here. Right now our time
 125  	// complexity is N for finding the min. Min heap should bring it down to
 126  	// O(lg N).
 127  	sample := make([]*policyPair, 0, lfuSample)
 128  	// As items are evicted they will be appended to victims.
 129  	victims := make([]*Item[V], 0)
 130  
 131  	// Delete victims until there's enough space or a minKey is found that has
 132  	// more hits than incoming item.
 133  	for ; room < 0; room = p.evict.roomLeft(cost) {
 134  		// Fill up empty slots in sample.
 135  		sample = p.evict.fillSample(sample)
 136  
 137  		// Find minimally used item in sample.
 138  		minKey, minHits, minId, minCost := uint64(0), int64(math.MaxInt64), 0, int64(0)
 139  		for i, pair := range sample {
 140  			// Look up hit count for sample key.
 141  			if hits := p.admit.Estimate(pair.key); hits < minHits {
 142  				minKey, minHits, minId, minCost = pair.key, hits, i, pair.cost
 143  			}
 144  		}
 145  
 146  		// If the incoming item isn't worth keeping in the policy, reject.
 147  		if incHits < minHits {
 148  			p.metrics.add(rejectSets, key, 1)
 149  			return victims, false
 150  		}
 151  
 152  		// Delete the victim from metadata.
 153  		p.evict.del(minKey)
 154  
 155  		// Delete the victim from sample.
 156  		sample[minId] = sample[len(sample)-1]
 157  		sample = sample[:len(sample)-1]
 158  		// Store victim in evicted victims slice.
 159  		victims = append(victims, &Item[V]{
 160  			Key:      minKey,
 161  			Conflict: 0,
 162  			Cost:     minCost,
 163  		})
 164  	}
 165  
 166  	p.evict.add(key, cost)
 167  	p.metrics.add(costAdd, key, uint64(cost))
 168  	return victims, true
 169  }
 170  
 171  func (p *defaultPolicy[V]) Has(key uint64) bool {
 172  	p.Lock()
 173  	_, exists := p.evict.keyCosts[key]
 174  	p.Unlock()
 175  	return exists
 176  }
 177  
 178  func (p *defaultPolicy[V]) Del(key uint64) {
 179  	p.Lock()
 180  	p.evict.del(key)
 181  	p.Unlock()
 182  }
 183  
 184  func (p *defaultPolicy[V]) Cap() int64 {
 185  	p.Lock()
 186  	capacity := p.evict.getMaxCost() - p.evict.used
 187  	p.Unlock()
 188  	return capacity
 189  }
 190  
 191  func (p *defaultPolicy[V]) Update(key uint64, cost int64) {
 192  	p.Lock()
 193  	p.evict.updateIfHas(key, cost)
 194  	p.Unlock()
 195  }
 196  
 197  func (p *defaultPolicy[V]) Cost(key uint64) int64 {
 198  	p.Lock()
 199  	if cost, found := p.evict.keyCosts[key]; found {
 200  		p.Unlock()
 201  		return cost
 202  	}
 203  	p.Unlock()
 204  	return -1
 205  }
 206  
 207  func (p *defaultPolicy[V]) Clear() {
 208  	p.Lock()
 209  	p.admit.clear()
 210  	p.evict.clear()
 211  	p.Unlock()
 212  }
 213  
 214  func (p *defaultPolicy[V]) Close() {
 215  	if p.isClosed {
 216  		return
 217  	}
 218  
 219  	// Block until the p.processItems goroutine returns.
 220  	p.stop <- struct{}{}
 221  	<-p.done
 222  	close(p.stop)
 223  	close(p.done)
 224  	close(p.itemsCh)
 225  	p.isClosed = true
 226  }
 227  
 228  func (p *defaultPolicy[V]) MaxCost() int64 {
 229  	if p == nil || p.evict == nil {
 230  		return 0
 231  	}
 232  	return p.evict.getMaxCost()
 233  }
 234  
 235  func (p *defaultPolicy[V]) UpdateMaxCost(maxCost int64) {
 236  	if p == nil || p.evict == nil {
 237  		return
 238  	}
 239  	p.evict.updateMaxCost(maxCost)
 240  }
 241  
 242  // sampledLFU is an eviction helper storing key-cost pairs.
 243  type sampledLFU struct {
 244  	// NOTE: align maxCost to 64-bit boundary for use with atomic.
 245  	// As per https://golang.org/pkg/sync/atomic/: "On ARM, x86-32,
 246  	// and 32-bit MIPS, it is the caller’s responsibility to arrange
 247  	// for 64-bit alignment of 64-bit words accessed atomically.
 248  	// The first word in a variable or in an allocated struct, array,
 249  	// or slice can be relied upon to be 64-bit aligned."
 250  	maxCost  int64
 251  	used     int64
 252  	metrics  *Metrics
 253  	keyCosts map[uint64]int64
 254  }
 255  
 256  func newSampledLFU(maxCost int64) *sampledLFU {
 257  	return &sampledLFU{
 258  		keyCosts: make(map[uint64]int64),
 259  		maxCost:  maxCost,
 260  	}
 261  }
 262  
 263  func (p *sampledLFU) getMaxCost() int64 {
 264  	return atomic.LoadInt64(&p.maxCost)
 265  }
 266  
 267  func (p *sampledLFU) updateMaxCost(maxCost int64) {
 268  	atomic.StoreInt64(&p.maxCost, maxCost)
 269  }
 270  
 271  func (p *sampledLFU) roomLeft(cost int64) int64 {
 272  	return p.getMaxCost() - (p.used + cost)
 273  }
 274  
 275  func (p *sampledLFU) fillSample(in []*policyPair) []*policyPair {
 276  	if len(in) >= lfuSample {
 277  		return in
 278  	}
 279  	for key, cost := range p.keyCosts {
 280  		in = append(in, &policyPair{key, cost})
 281  		if len(in) >= lfuSample {
 282  			return in
 283  		}
 284  	}
 285  	return in
 286  }
 287  
 288  func (p *sampledLFU) del(key uint64) {
 289  	cost, ok := p.keyCosts[key]
 290  	if !ok {
 291  		return
 292  	}
 293  	p.used -= cost
 294  	delete(p.keyCosts, key)
 295  	p.metrics.add(costEvict, key, uint64(cost))
 296  	p.metrics.add(keyEvict, key, 1)
 297  }
 298  
 299  func (p *sampledLFU) add(key uint64, cost int64) {
 300  	p.keyCosts[key] = cost
 301  	p.used += cost
 302  }
 303  
 304  func (p *sampledLFU) updateIfHas(key uint64, cost int64) bool {
 305  	if prev, found := p.keyCosts[key]; found {
 306  		// Update the cost of an existing key, but don't worry about evicting.
 307  		// Evictions will be handled the next time a new item is added.
 308  		p.metrics.add(keyUpdate, key, 1)
 309  		if prev > cost {
 310  			diff := prev - cost
 311  			p.metrics.add(costAdd, key, ^(uint64(diff) - 1))
 312  		} else if cost > prev {
 313  			diff := cost - prev
 314  			p.metrics.add(costAdd, key, uint64(diff))
 315  		}
 316  		p.used += cost - prev
 317  		p.keyCosts[key] = cost
 318  		return true
 319  	}
 320  	return false
 321  }
 322  
 323  func (p *sampledLFU) clear() {
 324  	p.used = 0
 325  	p.keyCosts = make(map[uint64]int64)
 326  }
 327  
 328  // tinyLFU is an admission helper that keeps track of access frequency using
 329  // tiny (4-bit) counters in the form of a count-min sketch.
 330  // tinyLFU is NOT thread safe.
 331  type tinyLFU struct {
 332  	freq    *cmSketch
 333  	door    *z.Bloom
 334  	incrs   int64
 335  	resetAt int64
 336  }
 337  
 338  func newTinyLFU(numCounters int64) *tinyLFU {
 339  	return &tinyLFU{
 340  		freq:    newCmSketch(numCounters),
 341  		door:    z.NewBloomFilter(float64(numCounters), 0.01),
 342  		resetAt: numCounters,
 343  	}
 344  }
 345  
 346  func (p *tinyLFU) Push(keys []uint64) {
 347  	for _, key := range keys {
 348  		p.Increment(key)
 349  	}
 350  }
 351  
 352  func (p *tinyLFU) Estimate(key uint64) int64 {
 353  	hits := p.freq.Estimate(key)
 354  	if p.door.Has(key) {
 355  		hits++
 356  	}
 357  	return hits
 358  }
 359  
 360  func (p *tinyLFU) Increment(key uint64) {
 361  	// Flip doorkeeper bit if not already done.
 362  	if added := p.door.AddIfNotHas(key); !added {
 363  		// Increment count-min counter if doorkeeper bit is already set.
 364  		p.freq.Increment(key)
 365  	}
 366  	p.incrs++
 367  	if p.incrs >= p.resetAt {
 368  		p.reset()
 369  	}
 370  }
 371  
 372  func (p *tinyLFU) reset() {
 373  	// Zero out incrs.
 374  	p.incrs = 0
 375  	// clears doorkeeper bits
 376  	p.door.Clear()
 377  	// halves count-min counters
 378  	p.freq.Reset()
 379  }
 380  
 381  func (p *tinyLFU) clear() {
 382  	p.incrs = 0
 383  	p.door.Clear()
 384  	p.freq.Clear()
 385  }
 386