gcs.go raw

   1  package gcs
   2  
   3  import (
   4  	"bytes"
   5  	"fmt"
   6  	"io"
   7  	"sort"
   8  	
   9  	"github.com/aead/siphash"
  10  	"github.com/kkdai/bstream"
  11  	
  12  	"github.com/p9c/p9/pkg/wire"
  13  )
  14  
  15  // Inspired by https://github.com/rasky/gcs
  16  var (
  17  	// ErrNTooBig signifies that the filter can't handle N items.
  18  	ErrNTooBig = fmt.Errorf("N is too big to fit in uint32")
  19  	// ErrPTooBig signifies that the filter can't handle `1/2**P` collision probability.
  20  	ErrPTooBig = fmt.Errorf("P is too big to fit in uint32")
  21  )
  22  
  23  const (
  24  	// KeySize is the size of the byte array required for key material for the
  25  	// SipHash keyed hash function.
  26  	KeySize = 16
  27  	// varIntProtoVer is the protocol version to use for serializing N as a VarInt.
  28  	varIntProtoVer uint32 = 0
  29  )
  30  
  31  // fastReduction calculates a mapping that's more ore less equivalent to: x mod
  32  // N. However, instead of using a mod operation, which using a non-power of two
  33  // will lead to slowness on many processors due to unnecessary division, we
  34  // instead use a "multiply-and-shift" trick which eliminates all divisions,
  35  // described in:
  36  // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
  37  //  * v * N  >> log_2(N)
  38  //
  39  // In our case, using 64-bit integers, log_2 is 64. As most processors don't
  40  // support 128-bit arithmetic natively, we'll be super portable and unfold the
  41  // operation into several operations with 64-bit arithmetic. As inputs, we the
  42  // number to reduce, and our modulus N divided into its high 32-bits and lower
  43  // 32-bits.
  44  func fastReduction(v, nHi, nLo uint64) uint64 {
  45  	// First, we'll spit the item we need to reduce into its higher and lower bits.
  46  	vhi := v >> 32
  47  	vlo := uint64(uint32(v))
  48  	// Then, we distribute multiplication over each part.
  49  	vnphi := vhi * nHi
  50  	vnpmid := vhi * nLo
  51  	npvmid := nHi * vlo
  52  	vnplo := vlo * nLo
  53  	// We calculate the carry bit.
  54  	carry := (uint64(uint32(vnpmid)) + uint64(uint32(npvmid)) +
  55  		(vnplo >> 32)) >> 32
  56  	// Last, we add the high bits, the middle bits, and the carry.
  57  	v = vnphi + (vnpmid >> 32) + (npvmid >> 32) + carry
  58  	return v
  59  }
  60  
  61  // Filter describes an immutable filter that can be built from a set of data
  62  // elements, serialized, deserialized, and queried in a thread-safe manner. The
  63  // serialized form is compressed as a Golomb Coded Set (GCS), but does not
  64  // include N or P to allow the user to encode the metadata separately if
  65  // necessary. The hash function used is SipHash, a keyed function; the key used
  66  // in building the filter is required in order to match filter values and is not
  67  // included in the serialized form.
  68  type Filter struct {
  69  	n          uint32
  70  	p          uint8
  71  	modulusNP  uint64
  72  	filterData []byte
  73  }
  74  
  75  // BuildGCSFilter builds a new GCS filter with the collision probability of
  76  // `1/(2**P)`, key `key`, and including every `[]byte` in `data` as a member of
  77  // the set.
  78  func BuildGCSFilter(P uint8, M uint64, key [KeySize]byte, data [][]byte) (*Filter, error) {
  79  	// Some initial parameter checks: make sure we have data from which to podbuild the
  80  	// filter, and make sure our parameters will fit the hash function we're using.
  81  	if uint64(len(data)) >= (1 << 32) {
  82  		return nil, ErrNTooBig
  83  	}
  84  	if P > 32 {
  85  		return nil, ErrPTooBig
  86  	}
  87  	// Create the filter object and insert metadata.
  88  	f := Filter{
  89  		n: uint32(len(data)),
  90  		p: P,
  91  	}
  92  	// First we'll compute the value of m, which is the modulus we use within our
  93  	// finite field. We want to compute: mScalar * 2^P. We use math.Round in order
  94  	// to round the value up, rather than down.
  95  	f.modulusNP = uint64(f.n) * M
  96  	// Shortcut if the filter is empty.
  97  	if f.n == 0 {
  98  		return &f, nil
  99  	}
 100  	// Build the filter.
 101  	values := make(uint64Slice, 0, len(data))
 102  	b := bstream.NewBStreamWriter(0)
 103  	// Insert the hash (fast-ranged over a space of N*P) of each data element into a
 104  	// slice and txsort the slice. This can be greatly optimized with native 128-bit
 105  	// multiplication, but we're going to be fully portable for now.
 106  	//
 107  	// First, we cache the high and low bits of modulusNP for the multiplication of
 108  	// 2 64-bit integers into a 128-bit integer.
 109  	nphi := f.modulusNP >> 32
 110  	nplo := uint64(uint32(f.modulusNP))
 111  	for _, d := range data {
 112  		// For each datum, we assign the initial hash to a uint64.
 113  		v := siphash.Sum64(d, &key)
 114  		v = fastReduction(v, nphi, nplo)
 115  		values = append(values, v)
 116  	}
 117  	sort.Sort(values)
 118  	// Write the sorted list of values into the filter bitstream, compressing it
 119  	// using Golomb coding.
 120  	var value, lastValue, remainder uint64
 121  	for _, v := range values {
 122  		// Calculate the difference between this value and the last, modulo P.
 123  		remainder = (v - lastValue) & ((uint64(1) << f.p) - 1)
 124  		// Calculate the difference between this value and the last, divided by P.
 125  		value = (v - lastValue - remainder) >> f.p
 126  		lastValue = v
 127  		// Write the P multiple into the bitstream in unary; the average should be
 128  		// around 1 (2 bits - 0b10).
 129  		for value > 0 {
 130  			b.WriteBit(true)
 131  			value--
 132  		}
 133  		b.WriteBit(false)
 134  		// Write the remainder as a big-endian integer with enough bits to represent the
 135  		// appropriate collision probability.
 136  		b.WriteBits(remainder, int(f.p))
 137  	}
 138  	// Copy the bitstream into the filter object and return the object.
 139  	f.filterData = b.Bytes()
 140  	return &f, nil
 141  }
 142  
 143  // FromBytes deserializes a GCS filter from a known N, P, and serialized filter
 144  // as returned by Hash().
 145  func FromBytes(N uint32, P uint8, M uint64, d []byte) (*Filter, error) {
 146  	// Basic sanity check.
 147  	if P > 32 {
 148  		return nil, ErrPTooBig
 149  	}
 150  	// Create the filter object and insert metadata.
 151  	f := &Filter{
 152  		n: N,
 153  		p: P,
 154  	}
 155  	// First we'll compute the value of m, which is the modulus we use within our
 156  	// finite field. We want to compute: mScalar * 2^P. We use math.Round in order
 157  	// to round the value up, rather than down.
 158  	f.modulusNP = uint64(f.n) * M
 159  	// Copy the filter.
 160  	f.filterData = make([]byte, len(d))
 161  	copy(f.filterData, d)
 162  	return f, nil
 163  }
 164  
 165  // FromNBytes deserializes a GCS filter from a known P, and serialized N and
 166  // filter as returned by NBytes().
 167  func FromNBytes(P uint8, M uint64, d []byte) (*Filter, error) {
 168  	buffer := bytes.NewBuffer(d)
 169  	N, e := wire.ReadVarInt(buffer, varIntProtoVer)
 170  	if e != nil {
 171  		E.Ln(e)
 172  		return nil, e
 173  	}
 174  	if N >= (1 << 32) {
 175  		return nil, ErrNTooBig
 176  	}
 177  	return FromBytes(uint32(N), P, M, buffer.Bytes())
 178  }
 179  
 180  // Bytes returns the serialized format of the GCS filter, which does not include
 181  // N or P (returned by separate methods) or the key used by SipHash.
 182  func (f *Filter) Bytes() ([]byte, error) {
 183  	filterData := make([]byte, len(f.filterData))
 184  	copy(filterData, f.filterData)
 185  	return filterData, nil
 186  }
 187  
 188  // NBytes returns the serialized format of the GCS filter with N, which does not
 189  // include P (returned by a separate method) or the key used by SipHash.
 190  func (f *Filter) NBytes() ([]byte, error) {
 191  	var buffer bytes.Buffer
 192  	buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + len(f.filterData))
 193  	e := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
 194  	if e != nil {
 195  		E.Ln(e)
 196  		return nil, e
 197  	}
 198  	_, e = buffer.Write(f.filterData)
 199  	if e != nil {
 200  		E.Ln(e)
 201  		return nil, e
 202  	}
 203  	return buffer.Bytes(), nil
 204  }
 205  
 206  // PBytes returns the serialized format of the GCS filter with P, which does not
 207  // include N (returned by a separate method) or the key used by SipHash.
 208  func (f *Filter) PBytes() ([]byte, error) {
 209  	filterData := make([]byte, len(f.filterData)+1)
 210  	filterData[0] = f.p
 211  	copy(filterData[1:], f.filterData)
 212  	return filterData, nil
 213  }
 214  
 215  // NPBytes returns the serialized format of the GCS filter with N and P, which does not include the key used by SipHash.
 216  func (f *Filter) NPBytes() ([]byte, error) {
 217  	var buffer bytes.Buffer
 218  	buffer.Grow(wire.VarIntSerializeSize(uint64(f.n)) + 1 + len(f.filterData))
 219  	e := wire.WriteVarInt(&buffer, varIntProtoVer, uint64(f.n))
 220  	if e != nil {
 221  		E.Ln(e)
 222  		return nil, e
 223  	}
 224  	e = buffer.WriteByte(f.p)
 225  	if e != nil {
 226  		E.Ln(e)
 227  		return nil, e
 228  	}
 229  	_, e = buffer.Write(f.filterData)
 230  	if e != nil {
 231  		E.Ln(e)
 232  		return nil, e
 233  	}
 234  	return buffer.Bytes(), nil
 235  }
 236  
 237  // P returns the filter's collision probability as a negative power of 2 (that is, a collision probability of `1/2**20`
 238  // is represented as 20).
 239  func (f *Filter) P() uint8 {
 240  	return f.p
 241  }
 242  
 243  // N returns the size of the data set used to podbuild the filter.
 244  func (f *Filter) N() uint32 {
 245  	return f.n
 246  }
 247  
 248  // Match checks whether a []byte value is likely (within collision probability) to be a member of the set represented by
 249  // the filter.
 250  func (f *Filter) Match(key [KeySize]byte, data []byte) (yn bool, e error) {
 251  	// Create a filter bitstream.
 252  	var filterData []byte
 253  	if filterData, e = f.Bytes(); E.Chk(e) {
 254  		return false, e
 255  	}
 256  	b := bstream.NewBStreamReader(filterData)
 257  	// We take the high and low bits of modulusNP for the multiplication of 2 64-bit integers into a 128-bit integer.
 258  	nphi := f.modulusNP >> 32
 259  	nplo := uint64(uint32(f.modulusNP))
 260  	// Then we hash our search term with the same parameters as the filter.
 261  	term := siphash.Sum64(data, &key)
 262  	term = fastReduction(term, nphi, nplo)
 263  	// Go through the search filter and look for the desired value.
 264  	var lastValue uint64
 265  	for lastValue < term {
 266  		// Read the difference between previous and new value from bitstream.
 267  		value, e := f.readFullUint64(b)
 268  		if e != nil {
 269  			E.Ln(e)
 270  			if e == io.EOF {
 271  				return false, nil
 272  			}
 273  			return false, e
 274  		}
 275  		// Add the previous value to it.
 276  		value += lastValue
 277  		if value == term {
 278  			return true, nil
 279  		}
 280  		lastValue = value
 281  	}
 282  	return false, nil
 283  }
 284  
 285  // MatchAny returns checks whether any []byte value is likely (within collision probability) to be a member of the set
 286  // represented by the filter faster than calling Match() for each value individually.
 287  func (f *Filter) MatchAny(key [KeySize]byte, data [][]byte) (bool, error) {
 288  	// Basic sanity check.
 289  	if len(data) == 0 {
 290  		return false, nil
 291  	}
 292  	// Create a filter bitstream.
 293  	filterData, e := f.Bytes()
 294  	if e != nil {
 295  		E.Ln(e)
 296  		return false, e
 297  	}
 298  	b := bstream.NewBStreamReader(filterData)
 299  	// Create an uncompressed filter of the search values.
 300  	values := make(uint64Slice, 0, len(data))
 301  	// First, we cache the high and low bits of modulusNP for the multiplication of 2 64-bit integers into a 128-bit
 302  	// integer.
 303  	nphi := f.modulusNP >> 32
 304  	nplo := uint64(uint32(f.modulusNP))
 305  	for _, d := range data {
 306  		// For each datum, we assign the initial hash to a uint64.
 307  		v := siphash.Sum64(d, &key)
 308  		// We'll then reduce the value down to the range of our modulus.
 309  		v = fastReduction(v, nphi, nplo)
 310  		values = append(values, v)
 311  	}
 312  	sort.Sort(values)
 313  	// Zip down the filters, comparing values until we either run out of values to compare in one of the filters or we
 314  	// reach a matching value.
 315  	var lastValue1, lastValue2 uint64
 316  	lastValue2 = values[0]
 317  	i := 1
 318  	for lastValue1 != lastValue2 {
 319  		// Chk which filter to advance to make sure we're comparing the right values.
 320  		switch {
 321  		case lastValue1 > lastValue2:
 322  			// Advance filter created from search terms or return false if we're at the end because nothing matched.
 323  			if i < len(values) {
 324  				lastValue2 = values[i]
 325  				i++
 326  			} else {
 327  				return false, nil
 328  			}
 329  		case lastValue2 > lastValue1:
 330  			// Advance filter we're searching or return false if we're at the end because nothing matched.
 331  			value, e := f.readFullUint64(b)
 332  			if e != nil {
 333  				// E.Ln(e)
 334  				if e == io.EOF {
 335  					return false, nil
 336  				}
 337  				return false, e
 338  			}
 339  			lastValue1 += value
 340  		}
 341  	}
 342  	// If we've made it this far, an element matched between filters so we return true.
 343  	return true, nil
 344  }
 345  
 346  // readFullUint64 reads a value represented by the sum of a unary multiple of the filter's P modulus (`2**P`) and a
 347  // big-endian P-bit remainder.
 348  func (f *Filter) readFullUint64(b *bstream.BStream) (rv uint64, e error) {
 349  	var quotient uint64
 350  	// Count the 1s until we reach a 0.
 351  	c, e := b.ReadBit()
 352  	if e != nil {
 353  		T.Ln(e)
 354  		return 0, e
 355  	}
 356  	for c {
 357  		quotient++
 358  		c, e = b.ReadBit()
 359  		if e != nil {
 360  			D.Ln(e)
 361  			return 0, e
 362  		}
 363  	}
 364  	// Read P bits.
 365  	var remainder uint64
 366  	if remainder, e = b.ReadBits(int(f.p)); T.Chk(e) {
 367  		D.Ln(e)
 368  		return 0, e
 369  	}
 370  	// Add the multiple and the remainder.
 371  	v := (quotient << f.p) + remainder
 372  	return v, nil
 373  }
 374