1 package gcs
2 3 import (
4 "bytes"
5 "fmt"
6 "io"
7 "sort"
8 9 "github.com/aead/siphash"
10 "github.com/kkdai/bstream"
11 12 "github.com/p9c/p9/pkg/wire"
13 )
14 15 // Inspired by https://github.com/rasky/gcs
16 var (
17 // ErrNTooBig signifies that the filter can't handle N items.
18 ErrNTooBig = fmt.Errorf("N is too big to fit in uint32")
19 // ErrPTooBig signifies that the filter can't handle `1/2**P` collision probability.
20 ErrPTooBig = fmt.Errorf("P is too big to fit in uint32")
21 )
22 23 const (
24 // KeySize is the size of the byte array required for key material for the
25 // SipHash keyed hash function.
26 KeySize = 16
27 // varIntProtoVer is the protocol version to use for serializing N as a VarInt.
28 varIntProtoVer uint32 = 0
29 )
30 31 // fastReduction calculates a mapping that's more ore less equivalent to: x mod
32 // N. However, instead of using a mod operation, which using a non-power of two
33 // will lead to slowness on many processors due to unnecessary division, we
34 // instead use a "multiply-and-shift" trick which eliminates all divisions,
35 // described in:
36 // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
37 // * v * N >> log_2(N)
38 //
39 // In our case, using 64-bit integers, log_2 is 64. As most processors don't
40 // support 128-bit arithmetic natively, we'll be super portable and unfold the
41 // operation into several operations with 64-bit arithmetic. As inputs, we the
42 // number to reduce, and our modulus N divided into its high 32-bits and lower
43 // 32-bits.
44 func fastReduction(v, nHi, nLo uint64) uint64 {
45 // First, we'll spit the item we need to reduce into its higher and lower bits.
46 vhi := v >> 32
47 vlo := uint64(uint32(v))
48 // Then, we distribute multiplication over each part.
49 vnphi := vhi * nHi
50 vnpmid := vhi * nLo
51 npvmid := nHi * vlo
52 vnplo := vlo * nLo
53 // We calculate the carry bit.
54 carry := (uint64(uint32(vnpmid)) + uint64(uint32(npvmid)) +
55 (vnplo >> 32)) >> 32
56 // Last, we add the high bits, the middle bits, and the carry.
57 v = vnphi + (vnpmid >> 32) + (npvmid >> 32) + carry
58 return v
59 }
60 61 // Filter describes an immutable filter that can be built from a set of data
62 // elements, serialized, deserialized, and queried in a thread-safe manner. The
63 // serialized form is compressed as a Golomb Coded Set (GCS), but does not
64 // include N or P to allow the user to encode the metadata separately if
65 // necessary. The hash function used is SipHash, a keyed function; the key used
66 // in building the filter is required in order to match filter values and is not
67 // included in the serialized form.
68 type Filter struct {
69 n uint32
70 p uint8
71 modulusNP uint64
72 filterData []byte
73 }
74 75 // BuildGCSFilter builds a new GCS filter with the collision probability of
76 // `1/(2**P)`, key `key`, and including every `[]byte` in `data` as a member of
77 // the set.
78 func BuildGCSFilter(P uint8, M uint64, key [KeySize]byte, data [][]byte) (*Filter, error) {
79 // Some initial parameter checks: make sure we have data from which to podbuild the
80 // filter, and make sure our parameters will fit the hash function we're using.
81 if uint64(len(data)) >= (1 << 32) {
82 return nil, ErrNTooBig
83 }
84 if P > 32 {
85 return nil, ErrPTooBig
86 }
87 // Create the filter object and insert metadata.
88 f := Filter{
89 n: uint32(len(data)),
90 p: P,
91 }
92 // First we'll compute the value of m, which is the modulus we use within our
93 // finite field. We want to compute: mScalar * 2^P. We use math.Round in order
94 // to round the value up, rather than down.
95 f.modulusNP = uint64(f.n) * M
96 // Shortcut if the filter is empty.
97 if f.n == 0 {
98 return &f, nil
99 }
100 // Build the filter.
101 values := make(uint64Slice, 0, len(data))
102 b := bstream.NewBStreamWriter(0)
103 // Insert the hash (fast-ranged over a space of N*P) of each data element into a
104 // slice and txsort the slice. This can be greatly optimized with native 128-bit
105 // multiplication, but we're going to be fully portable for now.
106 //
107 // First, we cache the high and low bits of modulusNP for the multiplication of
108 // 2 64-bit integers into a 128-bit integer.
109 nphi := f.modulusNP >> 32
110 nplo := uint64(uint32(f.modulusNP))
111 for _, d := range data {
112 // For each datum, we assign the initial hash to a uint64.
113 v := siphash.Sum64(d, &key)
114 v = fastReduction(v, nphi, nplo)
115 values = append(values, v)
116 }
117 sort.Sort(values)
118 // Write the sorted list of values into the filter bitstream, compressing it
119 // using Golomb coding.
120 var value, lastValue, remainder uint64
121 for _, v := range values {
122 // Calculate the difference between this value and the last, modulo P.
123 remainder = (v - lastValue) & ((uint64(1) << f.p) - 1)
124 // Calculate the difference between this value and the last, divided by P.
125 value = (v - lastValue - remainder) >> f.p
126 lastValue = v
127 // Write the P multiple into the bitstream in unary; the average should be
128 // around 1 (2 bits - 0b10).
129 for value > 0 {
130 b.WriteBit(true)
131 value--
132 }
133 b.WriteBit(false)
134 // Write the remainder as a big-endian integer with enough bits to represent the
135 // appropriate collision probability.
136 b.WriteBits(remainder, int(f.p))
137 }
138 // Copy the bitstream into the filter object and return the object.
139 f.filterData = b.Bytes()
140 return &f, nil
141 }
142 143 // FromBytes deserializes a GCS filter from a known N, P, and serialized filter
144 // as returned by Hash().
145 func FromBytes(N uint32, P uint8, M uint64, d []byte) (*Filter, error) {
146 // Basic sanity check.
147 if P > 32 {
148 return nil, ErrPTooBig
149 }
150 // Create the filter object and insert metadata.
151 f := &Filter{
152 n: N,
153 p: P,
154 }
155 // First we'll compute the value of m, which is the modulus we use within our
156 // finite field. We want to compute: mScalar * 2^P. We use math.Round in order
157 // to round the value up, rather than down.
158 f.modulusNP = uint64(f.n) * M
159 // Copy the filter.
160 f.filterData = make([]byte, len(d))
161 copy(f.filterData, d)
162 return f, nil
163 }
164 165 // FromNBytes deserializes a GCS filter from a known P, and serialized N and
166 // filter as returned by NBytes().
167 func FromNBytes(P uint8, M uint64, d []byte) (*Filter, error) {
168 buffer := bytes.NewBuffer(d)
169 N, e := wire.ReadVarInt(buffer, varIntProtoVer)
170 if e != nil {
171 E.Ln(e)
172 return nil, e
173 }
174 if N >= (1 << 32) {
175 return nil, ErrNTooBig
176 }
177 return FromBytes(uint32(N), P, M, buffer.Bytes())
178 }
179 180 // Bytes returns the serialized format of the GCS filter, which does not include
181 // N or P (returned by separate methods) or the key used by SipHash.
182 func (f *Filter) Bytes() ([]byte, error) {
183 filterData := make([]byte, len(f.filterData))
184 copy(filterData, f.filterData)
185 return filterData, nil
186 }
187 188 // NBytes returns the serialized format of the GCS filter with N, which does not
189 // include P (returned by a separate method) or the key used by SipHash.
190 func (f *Filter) NBytes() ([]byte, error) {
191 var buffer bytes.Buffer
192 buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + len(f.filterData))
193 e := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
194 if e != nil {
195 E.Ln(e)
196 return nil, e
197 }
198 _, e = buffer.Write(f.filterData)
199 if e != nil {
200 E.Ln(e)
201 return nil, e
202 }
203 return buffer.Bytes(), nil
204 }
205 206 // PBytes returns the serialized format of the GCS filter with P, which does not
207 // include N (returned by a separate method) or the key used by SipHash.
208 func (f *Filter) PBytes() ([]byte, error) {
209 filterData := make([]byte, len(f.filterData)+1)
210 filterData[0] = f.p
211 copy(filterData[1:], f.filterData)
212 return filterData, nil
213 }
214 215 // NPBytes returns the serialized format of the GCS filter with N and P, which does not include the key used by SipHash.
216 func (f *Filter) NPBytes() ([]byte, error) {
217 var buffer bytes.Buffer
218 buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + 1 + len(f.filterData))
219 e := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
220 if e != nil {
221 E.Ln(e)
222 return nil, e
223 }
224 e = buffer.WriteByte(f.p)
225 if e != nil {
226 E.Ln(e)
227 return nil, e
228 }
229 _, e = buffer.Write(f.filterData)
230 if e != nil {
231 E.Ln(e)
232 return nil, e
233 }
234 return buffer.Bytes(), nil
235 }
236 237 // P returns the filter's collision probability as a negative power of 2 (that is, a collision probability of `1/2**20`
238 // is represented as 20).
239 func (f *Filter) P() uint8 {
240 return f.p
241 }
242 243 // N returns the size of the data set used to podbuild the filter.
244 func (f *Filter) N() uint32 {
245 return f.n
246 }
247 248 // Match checks whether a []byte value is likely (within collision probability) to be a member of the set represented by
249 // the filter.
250 func (f *Filter) Match(key [KeySize]byte, data []byte) (yn bool, e error) {
251 // Create a filter bitstream.
252 var filterData []byte
253 if filterData, e = f.Bytes(); E.Chk(e) {
254 return false, e
255 }
256 b := bstream.NewBStreamReader(filterData)
257 // We take the high and low bits of modulusNP for the multiplication of 2 64-bit integers into a 128-bit integer.
258 nphi := f.modulusNP >> 32
259 nplo := uint64(uint32(f.modulusNP))
260 // Then we hash our search term with the same parameters as the filter.
261 term := siphash.Sum64(data, &key)
262 term = fastReduction(term, nphi, nplo)
263 // Go through the search filter and look for the desired value.
264 var lastValue uint64
265 for lastValue < term {
266 // Read the difference between previous and new value from bitstream.
267 value, e := f.readFullUint64(b)
268 if e != nil {
269 E.Ln(e)
270 if e == io.EOF {
271 return false, nil
272 }
273 return false, e
274 }
275 // Add the previous value to it.
276 value += lastValue
277 if value == term {
278 return true, nil
279 }
280 lastValue = value
281 }
282 return false, nil
283 }
284 285 // MatchAny returns checks whether any []byte value is likely (within collision probability) to be a member of the set
286 // represented by the filter faster than calling Match() for each value individually.
287 func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
288 // Basic sanity check.
289 if len(data) == 0 {
290 return false, nil
291 }
292 // Create a filter bitstream.
293 filterData, e := f.Bytes()
294 if e != nil {
295 E.Ln(e)
296 return false, e
297 }
298 b := bstream.NewBStreamReader(filterData)
299 // Create an uncompressed filter of the search values.
300 values := make(uint64Slice, 0, len(data))
301 // First, we cache the high and low bits of modulusNP for the multiplication of 2 64-bit integers into a 128-bit
302 // integer.
303 nphi := f.modulusNP >> 32
304 nplo := uint64(uint32(f.modulusNP))
305 for _, d := range data {
306 // For each datum, we assign the initial hash to a uint64.
307 v := siphash.Sum64(d, &key)
308 // We'll then reduce the value down to the range of our modulus.
309 v = fastReduction(v, nphi, nplo)
310 values = append(values, v)
311 }
312 sort.Sort(values)
313 // Zip down the filters, comparing values until we either run out of values to compare in one of the filters or we
314 // reach a matching value.
315 var lastValue1, lastValue2 uint64
316 lastValue2 = values[0]
317 i := 1
318 for lastValue1 != lastValue2 {
319 // Chk which filter to advance to make sure we're comparing the right values.
320 switch {
321 case lastValue1 > lastValue2:
322 // Advance filter created from search terms or return false if we're at the end because nothing matched.
323 if i < len(values) {
324 lastValue2 = values[i]
325 i++
326 } else {
327 return false, nil
328 }
329 case lastValue2 > lastValue1:
330 // Advance filter we're searching or return false if we're at the end because nothing matched.
331 value, e := f.readFullUint64(b)
332 if e != nil {
333 // E.Ln(e)
334 if e == io.EOF {
335 return false, nil
336 }
337 return false, e
338 }
339 lastValue1 += value
340 }
341 }
342 // If we've made it this far, an element matched between filters so we return true.
343 return true, nil
344 }
345 346 // readFullUint64 reads a value represented by the sum of a unary multiple of the filter's P modulus (`2**P`) and a
347 // big-endian P-bit remainder.
348 func (f *Filter) readFullUint64(b *bstream.BStream) (rv uint64, e error) {
349 var quotient uint64
350 // Count the 1s until we reach a 0.
351 c, e := b.ReadBit()
352 if e != nil {
353 T.Ln(e)
354 return 0, e
355 }
356 for c {
357 quotient++
358 c, e = b.ReadBit()
359 if e != nil {
360 D.Ln(e)
361 return 0, e
362 }
363 }
364 // Read P bits.
365 var remainder uint64
366 if remainder, e = b.ReadBits(int(f.p)); T.Chk(e) {
367 D.Ln(e)
368 return 0, e
369 }
370 // Add the multiple and the remainder.
371 v := (quotient << f.p) + remainder
372 return v, nil
373 }
374