sketch.go raw

   1  /*
   2   * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
   3   * SPDX-License-Identifier: Apache-2.0
   4   */
   5  
   6  package ristretto
   7  
   8  import (
   9  	"fmt"
  10  	"math/rand"
  11  	"time"
  12  )
  13  
  14  // cmSketch is a Count-Min sketch implementation with 4-bit counters, heavily
  15  // based on Damian Gryski's CM4 [1].
  16  //
  17  // [1]: https://github.com/dgryski/go-tinylfu/blob/master/cm4.go
  18  type cmSketch struct {
  19  	rows [cmDepth]cmRow
  20  	seed [cmDepth]uint64
  21  	mask uint64
  22  }
  23  
  24  const (
  25  	// cmDepth is the number of counter copies to store (think of it as rows).
  26  	cmDepth = 4
  27  )
  28  
  29  func newCmSketch(numCounters int64) *cmSketch {
  30  	if numCounters == 0 {
  31  		panic("cmSketch: bad numCounters")
  32  	}
  33  	// Get the next power of 2 for better cache performance.
  34  	numCounters = next2Power(numCounters)
  35  	sketch := &cmSketch{mask: uint64(numCounters - 1)}
  36  	// Initialize rows of counters and seeds.
  37  	// Cryptographic precision not needed
  38  	source := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec
  39  	for i := 0; i < cmDepth; i++ {
  40  		sketch.seed[i] = source.Uint64()
  41  		sketch.rows[i] = newCmRow(numCounters)
  42  	}
  43  	return sketch
  44  }
  45  
  46  // Increment increments the count(ers) for the specified key.
  47  func (s *cmSketch) Increment(hashed uint64) {
  48  	for i := range s.rows {
  49  		s.rows[i].increment((hashed ^ s.seed[i]) & s.mask)
  50  	}
  51  }
  52  
  53  // Estimate returns the value of the specified key.
  54  func (s *cmSketch) Estimate(hashed uint64) int64 {
  55  	min := byte(255)
  56  	for i := range s.rows {
  57  		val := s.rows[i].get((hashed ^ s.seed[i]) & s.mask)
  58  		if val < min {
  59  			min = val
  60  		}
  61  	}
  62  	return int64(min)
  63  }
  64  
  65  // Reset halves all counter values.
  66  func (s *cmSketch) Reset() {
  67  	for _, r := range s.rows {
  68  		r.reset()
  69  	}
  70  }
  71  
  72  // Clear zeroes all counters.
  73  func (s *cmSketch) Clear() {
  74  	for _, r := range s.rows {
  75  		r.clear()
  76  	}
  77  }
  78  
  79  // cmRow is a row of bytes, with each byte holding two counters.
  80  type cmRow []byte
  81  
  82  func newCmRow(numCounters int64) cmRow {
  83  	return make(cmRow, numCounters/2)
  84  }
  85  
  86  func (r cmRow) get(n uint64) byte {
  87  	return (r[n/2] >> ((n & 1) * 4)) & 0x0f
  88  }
  89  
  90  func (r cmRow) increment(n uint64) {
  91  	// Index of the counter.
  92  	i := n / 2
  93  	// Shift distance (even 0, odd 4).
  94  	s := (n & 1) * 4
  95  	// Counter value.
  96  	v := (r[i] >> s) & 0x0f
  97  	// Only increment if not max value (overflow wrap is bad for LFU).
  98  	if v < 15 {
  99  		r[i] += 1 << s
 100  	}
 101  }
 102  
 103  func (r cmRow) reset() {
 104  	// Halve each counter.
 105  	for i := range r {
 106  		r[i] = (r[i] >> 1) & 0x77
 107  	}
 108  }
 109  
 110  func (r cmRow) clear() {
 111  	// Zero each counter.
 112  	for i := range r {
 113  		r[i] = 0
 114  	}
 115  }
 116  
 117  func (r cmRow) string() string {
 118  	s := ""
 119  	for i := uint64(0); i < uint64(len(r)*2); i++ {
 120  		s += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f)
 121  	}
 122  	s = s[:len(s)-1]
 123  	return s
 124  }
 125  
 126  // next2Power rounds x up to the next power of 2, if it's not already one.
 127  func next2Power(x int64) int64 {
 128  	x--
 129  	x |= x >> 1
 130  	x |= x >> 2
 131  	x |= x >> 4
 132  	x |= x >> 8
 133  	x |= x >> 16
 134  	x |= x >> 32
 135  	x++
 136  	return x
 137  }
 138