skl.go raw

   1  /*
   2   * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
   3   * SPDX-License-Identifier: Apache-2.0
   4   */
   5  
   6  /*
   7  Adapted from RocksDB inline skiplist.
   8  
   9  Key differences:
  10  - No optimization for sequential inserts (no "prev").
  11  - No custom comparator.
  12  - Support overwrites. This requires care when we see the same key when inserting.
  13    For RocksDB or LevelDB, overwrites are implemented as a newer sequence number in the key, so
  14  	there is no need for values. We don't intend to support versioning. In-place updates of values
  15  	would be more efficient.
  16  - We discard all non-concurrent code.
  17  - We do not support Splices. This simplifies the code a lot.
  18  - No AllocateNode or other pointer arithmetic.
  19  - We combine the findLessThan, findGreaterOrEqual, etc into one function.
  20  */
  21  
  22  package skl
  23  
  24  import (
  25  	"math"
  26  	"sync/atomic"
  27  	"unsafe"
  28  
  29  	"github.com/dgraph-io/badger/v4/y"
  30  	"github.com/dgraph-io/ristretto/v2/z"
  31  )
  32  
  33  const (
  34  	maxHeight      = 20
  35  	heightIncrease = math.MaxUint32 / 3
  36  )
  37  
  38  // MaxNodeSize is the memory footprint of a node of maximum height.
  39  const MaxNodeSize = int(unsafe.Sizeof(node{}))
  40  
  41  type node struct {
  42  	// Multiple parts of the value are encoded as a single uint64 so that it
  43  	// can be atomically loaded and stored:
  44  	//   value offset: uint32 (bits 0-31)
  45  	//   value size  : uint16 (bits 32-63)
  46  	value atomic.Uint64
  47  
  48  	// A byte slice is 24 bytes. We are trying to save space here.
  49  	keyOffset uint32 // Immutable. No need to lock to access key.
  50  	keySize   uint16 // Immutable. No need to lock to access key.
  51  
  52  	// Height of the tower.
  53  	height uint16
  54  
  55  	// Most nodes do not need to use the full height of the tower, since the
  56  	// probability of each successive level decreases exponentially. Because
  57  	// these elements are never accessed, they do not need to be allocated.
  58  	// Therefore, when a node is allocated in the arena, its memory footprint
  59  	// is deliberately truncated to not include unneeded tower elements.
  60  	//
  61  	// All accesses to elements should use CAS operations, with no need to lock.
  62  	tower [maxHeight]atomic.Uint32
  63  }
  64  
  65  type Skiplist struct {
  66  	height  atomic.Int32 // Current height. 1 <= height <= kMaxHeight. CAS.
  67  	head    *node
  68  	ref     atomic.Int32
  69  	arena   *Arena
  70  	OnClose func()
  71  }
  72  
  73  // IncrRef increases the refcount
  74  func (s *Skiplist) IncrRef() {
  75  	s.ref.Add(1)
  76  }
  77  
  78  // DecrRef decrements the refcount, deallocating the Skiplist when done using it
  79  func (s *Skiplist) DecrRef() {
  80  	newRef := s.ref.Add(-1)
  81  	if newRef > 0 {
  82  		return
  83  	}
  84  	if s.OnClose != nil {
  85  		s.OnClose()
  86  	}
  87  
  88  	// Indicate we are closed. Good for testing.  Also, lets GC reclaim memory. Race condition
  89  	// here would suggest we are accessing skiplist when we are supposed to have no reference!
  90  	s.arena = nil
  91  	// Since the head references the arena's buf, as long as the head is kept around
  92  	// GC can't release the buf.
  93  	s.head = nil
  94  }
  95  
  96  func newNode(arena *Arena, key []byte, v y.ValueStruct, height int) *node {
  97  	// The base level is already allocated in the node struct.
  98  	offset := arena.putNode(height)
  99  	node := arena.getNode(offset)
 100  	node.keyOffset = arena.putKey(key)
 101  	node.keySize = uint16(len(key))
 102  	node.height = uint16(height)
 103  	node.value.Store(encodeValue(arena.putVal(v), v.EncodedSize()))
 104  	return node
 105  }
 106  
 107  func encodeValue(valOffset uint32, valSize uint32) uint64 {
 108  	return uint64(valSize)<<32 | uint64(valOffset)
 109  }
 110  
 111  func decodeValue(value uint64) (valOffset uint32, valSize uint32) {
 112  	valOffset = uint32(value)
 113  	valSize = uint32(value >> 32)
 114  	return
 115  }
 116  
 117  // NewSkiplist makes a new empty skiplist, with a given arena size
 118  func NewSkiplist(arenaSize int64) *Skiplist {
 119  	arena := newArena(arenaSize)
 120  	head := newNode(arena, nil, y.ValueStruct{}, maxHeight)
 121  	s := &Skiplist{head: head, arena: arena}
 122  	s.height.Store(1)
 123  	s.ref.Store(1)
 124  	return s
 125  }
 126  
 127  func (s *node) getValueOffset() (uint32, uint32) {
 128  	value := s.value.Load()
 129  	return decodeValue(value)
 130  }
 131  
 132  func (s *node) key(arena *Arena) []byte {
 133  	return arena.getKey(s.keyOffset, s.keySize)
 134  }
 135  
 136  func (s *node) setValue(arena *Arena, v y.ValueStruct) {
 137  	valOffset := arena.putVal(v)
 138  	value := encodeValue(valOffset, v.EncodedSize())
 139  	s.value.Store(value)
 140  }
 141  
 142  func (s *node) getNextOffset(h int) uint32 {
 143  	return s.tower[h].Load()
 144  }
 145  
 146  func (s *node) casNextOffset(h int, old, val uint32) bool {
 147  	return s.tower[h].CompareAndSwap(old, val)
 148  }
 149  
 150  // Returns true if key is strictly > n.key.
 151  // If n is nil, this is an "end" marker and we return false.
 152  //func (s *Skiplist) keyIsAfterNode(key []byte, n *node) bool {
 153  //	y.AssertTrue(n != s.head)
 154  //	return n != nil && y.CompareKeys(key, n.key) > 0
 155  //}
 156  
 157  func (s *Skiplist) randomHeight() int {
 158  	h := 1
 159  	for h < maxHeight && z.FastRand() <= heightIncrease {
 160  		h++
 161  	}
 162  	return h
 163  }
 164  
 165  func (s *Skiplist) getNext(nd *node, height int) *node {
 166  	return s.arena.getNode(nd.getNextOffset(height))
 167  }
 168  
 169  // findNear finds the node near to key.
 170  // If less=true, it finds rightmost node such that node.key < key (if allowEqual=false) or
 171  // node.key <= key (if allowEqual=true).
 172  // If less=false, it finds leftmost node such that node.key > key (if allowEqual=false) or
 173  // node.key >= key (if allowEqual=true).
 174  // Returns the node found. The bool returned is true if the node has key equal to given key.
 175  func (s *Skiplist) findNear(key []byte, less bool, allowEqual bool) (*node, bool) {
 176  	x := s.head
 177  	level := int(s.getHeight() - 1)
 178  	for {
 179  		// Assume x.key < key.
 180  		next := s.getNext(x, level)
 181  		if next == nil {
 182  			// x.key < key < END OF LIST
 183  			if level > 0 {
 184  				// Can descend further to iterate closer to the end.
 185  				level--
 186  				continue
 187  			}
 188  			// Level=0. Cannot descend further. Let's return something that makes sense.
 189  			if !less {
 190  				return nil, false
 191  			}
 192  			// Try to return x. Make sure it is not a head node.
 193  			if x == s.head {
 194  				return nil, false
 195  			}
 196  			return x, false
 197  		}
 198  
 199  		nextKey := next.key(s.arena)
 200  		cmp := y.CompareKeys(key, nextKey)
 201  		if cmp > 0 {
 202  			// x.key < next.key < key. We can continue to move right.
 203  			x = next
 204  			continue
 205  		}
 206  		if cmp == 0 {
 207  			// x.key < key == next.key.
 208  			if allowEqual {
 209  				return next, true
 210  			}
 211  			if !less {
 212  				// We want >, so go to base level to grab the next bigger note.
 213  				return s.getNext(next, 0), false
 214  			}
 215  			// We want <. If not base level, we should go closer in the next level.
 216  			if level > 0 {
 217  				level--
 218  				continue
 219  			}
 220  			// On base level. Return x.
 221  			if x == s.head {
 222  				return nil, false
 223  			}
 224  			return x, false
 225  		}
 226  		// cmp < 0. In other words, x.key < key < next.
 227  		if level > 0 {
 228  			level--
 229  			continue
 230  		}
 231  		// At base level. Need to return something.
 232  		if !less {
 233  			return next, false
 234  		}
 235  		// Try to return x. Make sure it is not a head node.
 236  		if x == s.head {
 237  			return nil, false
 238  		}
 239  		return x, false
 240  	}
 241  }
 242  
 243  // findSpliceForLevel returns (outBefore, outAfter) with outBefore.key <= key <= outAfter.key.
 244  // The input "before" tells us where to start looking.
 245  // If we found a node with the same key, then we return outBefore = outAfter.
 246  // Otherwise, outBefore.key < key < outAfter.key.
 247  func (s *Skiplist) findSpliceForLevel(key []byte, before *node, level int) (*node, *node) {
 248  	for {
 249  		// Assume before.key < key.
 250  		next := s.getNext(before, level)
 251  		if next == nil {
 252  			return before, next
 253  		}
 254  		nextKey := next.key(s.arena)
 255  		cmp := y.CompareKeys(key, nextKey)
 256  		if cmp == 0 {
 257  			// Equality case.
 258  			return next, next
 259  		}
 260  		if cmp < 0 {
 261  			// before.key < key < next.key. We are done for this level.
 262  			return before, next
 263  		}
 264  		before = next // Keep moving right on this level.
 265  	}
 266  }
 267  
 268  func (s *Skiplist) getHeight() int32 {
 269  	return s.height.Load()
 270  }
 271  
 272  // Put inserts the key-value pair.
 273  func (s *Skiplist) Put(key []byte, v y.ValueStruct) {
 274  	// Since we allow overwrite, we may not need to create a new node. We might not even need to
 275  	// increase the height. Let's defer these actions.
 276  
 277  	listHeight := s.getHeight()
 278  	var prev [maxHeight + 1]*node
 279  	var next [maxHeight + 1]*node
 280  	prev[listHeight] = s.head
 281  	next[listHeight] = nil
 282  	for i := int(listHeight) - 1; i >= 0; i-- {
 283  		// Use higher level to speed up for current level.
 284  		prev[i], next[i] = s.findSpliceForLevel(key, prev[i+1], i)
 285  		if prev[i] == next[i] {
 286  			prev[i].setValue(s.arena, v)
 287  			return
 288  		}
 289  	}
 290  
 291  	// We do need to create a new node.
 292  	height := s.randomHeight()
 293  	x := newNode(s.arena, key, v, height)
 294  
 295  	// Try to increase s.height via CAS.
 296  	listHeight = s.getHeight()
 297  	for height > int(listHeight) {
 298  		if s.height.CompareAndSwap(listHeight, int32(height)) {
 299  			// Successfully increased skiplist.height.
 300  			break
 301  		}
 302  		listHeight = s.getHeight()
 303  	}
 304  
 305  	// We always insert from the base level and up. After you add a node in base level, we cannot
 306  	// create a node in the level above because it would have discovered the node in the base level.
 307  	for i := 0; i < height; i++ {
 308  		for {
 309  			if prev[i] == nil {
 310  				y.AssertTrue(i > 1) // This cannot happen in base level.
 311  				// We haven't computed prev, next for this level because height exceeds old listHeight.
 312  				// For these levels, we expect the lists to be sparse, so we can just search from head.
 313  				prev[i], next[i] = s.findSpliceForLevel(key, s.head, i)
 314  				// Someone adds the exact same key before we are able to do so. This can only happen on
 315  				// the base level. But we know we are not on the base level.
 316  				y.AssertTrue(prev[i] != next[i])
 317  			}
 318  			nextOffset := s.arena.getNodeOffset(next[i])
 319  			x.tower[i].Store(nextOffset)
 320  			if prev[i].casNextOffset(i, nextOffset, s.arena.getNodeOffset(x)) {
 321  				// Managed to insert x between prev[i] and next[i]. Go to the next level.
 322  				break
 323  			}
 324  			// CAS failed. We need to recompute prev and next.
 325  			// It is unlikely to be helpful to try to use a different level as we redo the search,
 326  			// because it is unlikely that lots of nodes are inserted between prev[i] and next[i].
 327  			prev[i], next[i] = s.findSpliceForLevel(key, prev[i], i)
 328  			if prev[i] == next[i] {
 329  				y.AssertTruef(i == 0, "Equality can happen only on base level: %d", i)
 330  				prev[i].setValue(s.arena, v)
 331  				return
 332  			}
 333  		}
 334  	}
 335  }
 336  
 337  // Empty returns if the Skiplist is empty.
 338  func (s *Skiplist) Empty() bool {
 339  	return s.findLast() == nil
 340  }
 341  
 342  // findLast returns the last element. If head (empty list), we return nil. All the find functions
 343  // will NEVER return the head nodes.
 344  func (s *Skiplist) findLast() *node {
 345  	n := s.head
 346  	level := int(s.getHeight()) - 1
 347  	for {
 348  		next := s.getNext(n, level)
 349  		if next != nil {
 350  			n = next
 351  			continue
 352  		}
 353  		if level == 0 {
 354  			if n == s.head {
 355  				return nil
 356  			}
 357  			return n
 358  		}
 359  		level--
 360  	}
 361  }
 362  
 363  // Get gets the value associated with the key. It returns a valid value if it finds equal or earlier
 364  // version of the same key.
 365  func (s *Skiplist) Get(key []byte) y.ValueStruct {
 366  	n, _ := s.findNear(key, false, true) // findGreaterOrEqual.
 367  	if n == nil {
 368  		return y.ValueStruct{}
 369  	}
 370  
 371  	nextKey := s.arena.getKey(n.keyOffset, n.keySize)
 372  	if !y.SameKey(key, nextKey) {
 373  		return y.ValueStruct{}
 374  	}
 375  
 376  	valOffset, valSize := n.getValueOffset()
 377  	vs := s.arena.getVal(valOffset, valSize)
 378  	vs.Version = y.ParseTs(nextKey)
 379  	return vs
 380  }
 381  
 382  // NewIterator returns a skiplist iterator.  You have to Close() the iterator.
 383  func (s *Skiplist) NewIterator() *Iterator {
 384  	s.IncrRef()
 385  	return &Iterator{list: s}
 386  }
 387  
 388  // MemSize returns the size of the Skiplist in terms of how much memory is used within its internal
 389  // arena.
 390  func (s *Skiplist) MemSize() int64 { return s.arena.size() }
 391  
 392  // Iterator is an iterator over skiplist object. For new objects, you just
 393  // need to initialize Iterator.list.
 394  type Iterator struct {
 395  	list *Skiplist
 396  	n    *node
 397  }
 398  
 399  // Close frees the resources held by the iterator
 400  func (s *Iterator) Close() error {
 401  	s.list.DecrRef()
 402  	return nil
 403  }
 404  
 405  // Valid returns true iff the iterator is positioned at a valid node.
 406  func (s *Iterator) Valid() bool { return s.n != nil }
 407  
 408  // Key returns the key at the current position.
 409  func (s *Iterator) Key() []byte {
 410  	return s.list.arena.getKey(s.n.keyOffset, s.n.keySize)
 411  }
 412  
 413  // Value returns value.
 414  func (s *Iterator) Value() y.ValueStruct {
 415  	valOffset, valSize := s.n.getValueOffset()
 416  	return s.list.arena.getVal(valOffset, valSize)
 417  }
 418  
 419  // ValueUint64 returns the uint64 value of the current node.
 420  func (s *Iterator) ValueUint64() uint64 {
 421  	return s.n.value.Load()
 422  }
 423  
 424  // Next advances to the next position.
 425  func (s *Iterator) Next() {
 426  	y.AssertTrue(s.Valid())
 427  	s.n = s.list.getNext(s.n, 0)
 428  }
 429  
 430  // Prev advances to the previous position.
 431  func (s *Iterator) Prev() {
 432  	y.AssertTrue(s.Valid())
 433  	s.n, _ = s.list.findNear(s.Key(), true, false) // find <. No equality allowed.
 434  }
 435  
 436  // Seek advances to the first entry with a key >= target.
 437  func (s *Iterator) Seek(target []byte) {
 438  	s.n, _ = s.list.findNear(target, false, true) // find >=.
 439  }
 440  
 441  // SeekForPrev finds an entry with key <= target.
 442  func (s *Iterator) SeekForPrev(target []byte) {
 443  	s.n, _ = s.list.findNear(target, true, true) // find <=.
 444  }
 445  
 446  // SeekToFirst seeks position at the first entry in list.
 447  // Final state of iterator is Valid() iff list is not empty.
 448  func (s *Iterator) SeekToFirst() {
 449  	s.n = s.list.getNext(s.list.head, 0)
 450  }
 451  
 452  // SeekToLast seeks position at the last entry in list.
 453  // Final state of iterator is Valid() iff list is not empty.
 454  func (s *Iterator) SeekToLast() {
 455  	s.n = s.list.findLast()
 456  }
 457  
 458  // UniIterator is a unidirectional memtable iterator. It is a thin wrapper around
 459  // Iterator. We like to keep Iterator as before, because it is more powerful and
 460  // we might support bidirectional iterators in the future.
 461  type UniIterator struct {
 462  	iter     *Iterator
 463  	reversed bool
 464  }
 465  
 466  // NewUniIterator returns a UniIterator.
 467  func (s *Skiplist) NewUniIterator(reversed bool) *UniIterator {
 468  	return &UniIterator{
 469  		iter:     s.NewIterator(),
 470  		reversed: reversed,
 471  	}
 472  }
 473  
 474  // Next implements y.Interface
 475  func (s *UniIterator) Next() {
 476  	if !s.reversed {
 477  		s.iter.Next()
 478  	} else {
 479  		s.iter.Prev()
 480  	}
 481  }
 482  
 483  // Rewind implements y.Interface
 484  func (s *UniIterator) Rewind() {
 485  	if !s.reversed {
 486  		s.iter.SeekToFirst()
 487  	} else {
 488  		s.iter.SeekToLast()
 489  	}
 490  }
 491  
 492  // Seek implements y.Interface
 493  func (s *UniIterator) Seek(key []byte) {
 494  	if !s.reversed {
 495  		s.iter.Seek(key)
 496  	} else {
 497  		s.iter.SeekForPrev(key)
 498  	}
 499  }
 500  
 501  // Key implements y.Interface
 502  func (s *UniIterator) Key() []byte { return s.iter.Key() }
 503  
 504  // Value implements y.Interface
 505  func (s *UniIterator) Value() y.ValueStruct { return s.iter.Value() }
 506  
 507  // Valid implements y.Interface
 508  func (s *UniIterator) Valid() bool { return s.iter.Valid() }
 509  
 510  // Close implements y.Interface (and frees up the iter's resources)
 511  func (s *UniIterator) Close() error { return s.iter.Close() }
 512