sketch.go raw
1 /*
2 * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6 package ristretto
7
8 import (
9 "fmt"
10 "math/rand"
11 "time"
12 )
13
14 // cmSketch is a Count-Min sketch implementation with 4-bit counters, heavily
15 // based on Damian Gryski's CM4 [1].
16 //
17 // [1]: https://github.com/dgryski/go-tinylfu/blob/master/cm4.go
18 type cmSketch struct {
19 rows [cmDepth]cmRow
20 seed [cmDepth]uint64
21 mask uint64
22 }
23
24 const (
25 // cmDepth is the number of counter copies to store (think of it as rows).
26 cmDepth = 4
27 )
28
29 func newCmSketch(numCounters int64) *cmSketch {
30 if numCounters == 0 {
31 panic("cmSketch: bad numCounters")
32 }
33 // Get the next power of 2 for better cache performance.
34 numCounters = next2Power(numCounters)
35 sketch := &cmSketch{mask: uint64(numCounters - 1)}
36 // Initialize rows of counters and seeds.
37 // Cryptographic precision not needed
38 source := rand.New(rand.NewSource(time.Now().UnixNano())) //nolint:gosec
39 for i := 0; i < cmDepth; i++ {
40 sketch.seed[i] = source.Uint64()
41 sketch.rows[i] = newCmRow(numCounters)
42 }
43 return sketch
44 }
45
46 // Increment increments the count(ers) for the specified key.
47 func (s *cmSketch) Increment(hashed uint64) {
48 for i := range s.rows {
49 s.rows[i].increment((hashed ^ s.seed[i]) & s.mask)
50 }
51 }
52
53 // Estimate returns the value of the specified key.
54 func (s *cmSketch) Estimate(hashed uint64) int64 {
55 min := byte(255)
56 for i := range s.rows {
57 val := s.rows[i].get((hashed ^ s.seed[i]) & s.mask)
58 if val < min {
59 min = val
60 }
61 }
62 return int64(min)
63 }
64
65 // Reset halves all counter values.
66 func (s *cmSketch) Reset() {
67 for _, r := range s.rows {
68 r.reset()
69 }
70 }
71
72 // Clear zeroes all counters.
73 func (s *cmSketch) Clear() {
74 for _, r := range s.rows {
75 r.clear()
76 }
77 }
78
79 // cmRow is a row of bytes, with each byte holding two counters.
80 type cmRow []byte
81
82 func newCmRow(numCounters int64) cmRow {
83 return make(cmRow, numCounters/2)
84 }
85
86 func (r cmRow) get(n uint64) byte {
87 return (r[n/2] >> ((n & 1) * 4)) & 0x0f
88 }
89
90 func (r cmRow) increment(n uint64) {
91 // Index of the counter.
92 i := n / 2
93 // Shift distance (even 0, odd 4).
94 s := (n & 1) * 4
95 // Counter value.
96 v := (r[i] >> s) & 0x0f
97 // Only increment if not max value (overflow wrap is bad for LFU).
98 if v < 15 {
99 r[i] += 1 << s
100 }
101 }
102
103 func (r cmRow) reset() {
104 // Halve each counter.
105 for i := range r {
106 r[i] = (r[i] >> 1) & 0x77
107 }
108 }
109
110 func (r cmRow) clear() {
111 // Zero each counter.
112 for i := range r {
113 r[i] = 0
114 }
115 }
116
117 func (r cmRow) string() string {
118 s := ""
119 for i := uint64(0); i < uint64(len(r)*2); i++ {
120 s += fmt.Sprintf("%02d ", (r[(i/2)]>>((i&1)*4))&0x0f)
121 }
122 s = s[:len(s)-1]
123 return s
124 }
125
126 // next2Power rounds x up to the next power of 2, if it's not already one.
127 func next2Power(x int64) int64 {
128 x--
129 x |= x >> 1
130 x |= x >> 2
131 x |= x >> 4
132 x |= x >> 8
133 x |= x >> 16
134 x |= x >> 32
135 x++
136 return x
137 }
138