btree.go raw
1 /*
2 * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6 package z
7
8 import (
9 "fmt"
10 "math"
11 "os"
12 "reflect"
13 "strings"
14 "unsafe"
15
16 "github.com/dgraph-io/ristretto/v2/z/simd"
17 )
18
19 var (
20 pageSize = os.Getpagesize()
21 maxKeys = (pageSize / 16) - 1
22 //nolint:unused
23 oneThird = int(float64(maxKeys) / 3)
24 )
25
26 const (
27 absoluteMax = uint64(math.MaxUint64 - 1)
28 minSize = 1 << 20
29 )
30
31 // Tree represents the structure for custom mmaped B+ tree.
32 // It supports keys in range [1, math.MaxUint64-1] and values [1, math.Uint64].
33 type Tree struct {
34 buffer *Buffer
35 data []byte
36 nextPage uint64
37 freePage uint64
38 stats TreeStats
39 }
40
41 func (t *Tree) initRootNode() {
42 // This is the root node.
43 t.newNode(0)
44 // This acts as the rightmost pointer (all the keys are <= this key).
45 t.Set(absoluteMax, 0)
46 }
47
48 // NewTree returns an in-memory B+ tree.
49 func NewTree(tag string) *Tree {
50 const defaultTag = "tree"
51 if tag == "" {
52 tag = defaultTag
53 }
54 t := &Tree{buffer: NewBuffer(minSize, tag)}
55 t.Reset()
56 return t
57 }
58
59 // NewTree returns a persistent on-disk B+ tree.
60 func NewTreePersistent(path string) (*Tree, error) {
61 t := &Tree{}
62 var err error
63
64 // Open the buffer from disk and set it to the maximum allocated size.
65 t.buffer, err = NewBufferPersistent(path, minSize)
66 if err != nil {
67 return nil, err
68 }
69 t.buffer.offset = uint64(len(t.buffer.buf))
70 t.data = t.buffer.Bytes()
71
72 // pageID can never be 0 if the tree has been initialized.
73 root := t.node(1)
74 isInitialized := root.pageID() != 0
75
76 if !isInitialized {
77 t.nextPage = 1
78 t.freePage = 0
79 t.initRootNode()
80 } else {
81 t.reinit()
82 }
83
84 return t, nil
85 }
86
87 // reinit sets the internal variables of a Tree, which are normally stored
88 // in-memory, but are lost when loading from disk.
89 func (t *Tree) reinit() {
90 // Calculate t.nextPage by finding the first node whose pageID is not set.
91 t.nextPage = 1
92 for int(t.nextPage)*pageSize < len(t.data) {
93 n := t.node(t.nextPage)
94 if n.pageID() == 0 {
95 break
96 }
97 t.nextPage++
98 }
99 maxPageId := t.nextPage - 1
100
101 // Calculate t.freePage by finding the page to which no other page points.
102 // This would be the head of the page linked list.
103 // tailPages[i] is true if pageId i+1 is not the head of the list.
104 tailPages := make([]bool, maxPageId)
105 // Mark all pages containing nodes as tail pages.
106 t.Iterate(func(n node) {
107 i := n.pageID() - 1
108 tailPages[i] = true
109 // If this is a leaf node, increment the stats.
110 if n.isLeaf() {
111 t.stats.NumLeafKeys += n.numKeys()
112 }
113 })
114 // pointedPages is a list of page IDs that the tail pages point to.
115 pointedPages := make([]uint64, 0)
116 for i, isTail := range tailPages {
117 if !isTail {
118 pageId := uint64(i) + 1
119 // Skip if nextPageId = 0, as that is equivalent to null page.
120 if nextPageId := t.node(pageId).uint64(0); nextPageId != 0 {
121 pointedPages = append(pointedPages, nextPageId)
122 }
123 t.stats.NumPagesFree++
124 }
125 }
126
127 // Mark all pages being pointed to as tail pages.
128 for _, pageId := range pointedPages {
129 i := pageId - 1
130 tailPages[i] = true
131 }
132 // There should only be one head page left.
133 for i, isTail := range tailPages {
134 if !isTail {
135 pageId := uint64(i) + 1
136 t.freePage = pageId
137 break
138 }
139 }
140 }
141
142 // Reset resets the tree and truncates it to maxSz.
143 func (t *Tree) Reset() {
144 // Tree relies on uninitialized data being zeroed out, so we need to Memclr
145 // the data before using it again.
146 Memclr(t.buffer.buf)
147 t.buffer.Reset()
148 t.buffer.AllocateOffset(minSize)
149 t.data = t.buffer.Bytes()
150 t.stats = TreeStats{}
151 t.nextPage = 1
152 t.freePage = 0
153 t.initRootNode()
154 }
155
156 // Close releases the memory used by the tree.
157 func (t *Tree) Close() error {
158 if t == nil {
159 return nil
160 }
161 return t.buffer.Release()
162 }
163
164 type TreeStats struct {
165 Allocated int // Derived.
166 Bytes int // Derived.
167 NumLeafKeys int // Calculated.
168 NumPages int // Derived.
169 NumPagesFree int // Calculated.
170 Occupancy float64 // Derived.
171 PageSize int // Derived.
172 }
173
174 // Stats returns stats about the tree.
175 func (t *Tree) Stats() TreeStats {
176 numPages := int(t.nextPage - 1)
177 out := TreeStats{
178 Bytes: numPages * pageSize,
179 Allocated: len(t.data),
180 NumLeafKeys: t.stats.NumLeafKeys,
181 NumPages: numPages,
182 NumPagesFree: t.stats.NumPagesFree,
183 PageSize: pageSize,
184 }
185 out.Occupancy = 100.0 * float64(out.NumLeafKeys) / float64(maxKeys*numPages)
186 return out
187 }
188
189 // BytesToUint64Slice converts a byte slice to a uint64 slice.
190 func BytesToUint64Slice(b []byte) []uint64 {
191 if len(b) == 0 {
192 return nil
193 }
194 var u64s []uint64
195 hdr := (*reflect.SliceHeader)(unsafe.Pointer(&u64s))
196 hdr.Len = len(b) / 8
197 hdr.Cap = hdr.Len
198 hdr.Data = uintptr(unsafe.Pointer(&b[0]))
199 return u64s
200 }
201
202 func (t *Tree) newNode(bit uint64) node {
203 var pageId uint64
204 if t.freePage > 0 {
205 pageId = t.freePage
206 t.stats.NumPagesFree--
207 } else {
208 pageId = t.nextPage
209 t.nextPage++
210 offset := int(pageId) * pageSize
211 reqSize := offset + pageSize
212 if reqSize > len(t.data) {
213 t.buffer.AllocateOffset(reqSize - len(t.data))
214 t.data = t.buffer.Bytes()
215 }
216 }
217 n := t.node(pageId)
218 if t.freePage > 0 {
219 t.freePage = n.uint64(0)
220 }
221 zeroOut(n)
222 n.setBit(bit)
223 n.setAt(keyOffset(maxKeys), pageId)
224 return n
225 }
226
227 func getNode(data []byte) node {
228 return node(BytesToUint64Slice(data))
229 }
230
231 func zeroOut(data []uint64) {
232 for i := 0; i < len(data); i++ {
233 data[i] = 0
234 }
235 }
236
237 func (t *Tree) node(pid uint64) node {
238 // page does not exist
239 if pid == 0 {
240 return nil
241 }
242 start := pageSize * int(pid)
243 return getNode(t.data[start : start+pageSize])
244 }
245
246 // Set sets the key-value pair in the tree.
247 func (t *Tree) Set(k, v uint64) {
248 if k == math.MaxUint64 || k == 0 {
249 panic("Error setting zero or MaxUint64")
250 }
251 root := t.set(1, k, v)
252 if root.isFull() {
253 right := t.split(1)
254 left := t.newNode(root.bits())
255 // Re-read the root as the underlying buffer for tree might have changed during split.
256 root = t.node(1)
257 copy(left[:keyOffset(maxKeys)], root)
258 left.setNumKeys(root.numKeys())
259
260 // reset the root node.
261 zeroOut(root[:keyOffset(maxKeys)])
262 root.setNumKeys(0)
263
264 // set the pointers for left and right child in the root node.
265 root.set(left.maxKey(), left.pageID())
266 root.set(right.maxKey(), right.pageID())
267 }
268 }
269
270 // For internal nodes, they contain <key, ptr>.
271 // where all entries <= key are stored in the corresponding ptr.
272 func (t *Tree) set(pid, k, v uint64) node {
273 n := t.node(pid)
274 if n.isLeaf() {
275 t.stats.NumLeafKeys += n.set(k, v)
276 return n
277 }
278
279 // This is an internal node.
280 idx := n.search(k)
281 if idx >= maxKeys {
282 panic("search returned index >= maxKeys")
283 }
284 // If no key at idx.
285 if n.key(idx) == 0 {
286 n.setAt(keyOffset(idx), k)
287 n.setNumKeys(n.numKeys() + 1)
288 }
289 child := t.node(n.val(idx))
290 if child == nil {
291 child = t.newNode(bitLeaf)
292 n = t.node(pid)
293 n.setAt(valOffset(idx), child.pageID())
294 }
295 child = t.set(child.pageID(), k, v)
296 // Re-read n as the underlying buffer for tree might have changed during set.
297 n = t.node(pid)
298 if child.isFull() {
299 // Just consider the left sibling for simplicity.
300 // if t.shareWithSibling(n, idx) {
301 // return n
302 // }
303
304 nn := t.split(child.pageID())
305 // Re-read n and child as the underlying buffer for tree might have changed during split.
306 n = t.node(pid)
307 child = t.node(n.uint64(valOffset(idx)))
308 // Set child pointers in the node n.
309 // Note that key for right node (nn) already exist in node n, but the
310 // pointer is updated.
311 n.set(child.maxKey(), child.pageID())
312 n.set(nn.maxKey(), nn.pageID())
313 }
314 return n
315 }
316
317 // Get looks for key and returns the corresponding value.
318 // If key is not found, 0 is returned.
319 func (t *Tree) Get(k uint64) uint64 {
320 if k == math.MaxUint64 || k == 0 {
321 panic("Does not support getting MaxUint64/Zero")
322 }
323 root := t.node(1)
324 return t.get(root, k)
325 }
326
327 func (t *Tree) get(n node, k uint64) uint64 {
328 if n.isLeaf() {
329 return n.get(k)
330 }
331 // This is internal node
332 idx := n.search(k)
333 if idx == n.numKeys() || n.key(idx) == 0 {
334 return 0
335 }
336 child := t.node(n.uint64(valOffset(idx)))
337 assert(child != nil)
338 return t.get(child, k)
339 }
340
341 // DeleteBelow deletes all keys with value under ts.
342 func (t *Tree) DeleteBelow(ts uint64) {
343 root := t.node(1)
344 t.stats.NumLeafKeys = 0
345 t.compact(root, ts)
346 assert(root.numKeys() >= 1)
347 }
348
349 func (t *Tree) compact(n node, ts uint64) int {
350 if n.isLeaf() {
351 numKeys := n.compact(ts)
352 t.stats.NumLeafKeys += n.numKeys()
353 return numKeys
354 }
355 // Not leaf.
356 N := n.numKeys()
357 for i := 0; i < N; i++ {
358 assert(n.key(i) > 0)
359 childID := n.uint64(valOffset(i))
360 child := t.node(childID)
361 if rem := t.compact(child, ts); rem == 0 && i < N-1 {
362 // If no valid key is remaining we can drop this child. However, don't do that if this
363 // is the max key.
364 t.stats.NumLeafKeys -= child.numKeys()
365 child.setAt(0, t.freePage)
366 t.freePage = childID
367 n.setAt(valOffset(i), 0)
368 t.stats.NumPagesFree++
369 }
370 }
371 // We use ts=1 here because we want to delete all the keys whose value is 0, which means they no
372 // longer have a valid page for that key.
373 return n.compact(1)
374 }
375
376 func (t *Tree) iterate(n node, fn func(node)) {
377 fn(n)
378 if n.isLeaf() {
379 return
380 }
381 // Explore children.
382 for i := 0; i < maxKeys; i++ {
383 if n.key(i) == 0 {
384 return
385 }
386 childID := n.uint64(valOffset(i))
387 assert(childID > 0)
388
389 child := t.node(childID)
390 t.iterate(child, fn)
391 }
392 }
393
394 // Iterate iterates over the tree and executes the fn on each node.
395 func (t *Tree) Iterate(fn func(node)) {
396 root := t.node(1)
397 t.iterate(root, fn)
398 }
399
400 // IterateKV iterates through all keys and values in the tree.
401 // If newVal is non-zero, it will be set in the tree.
402 func (t *Tree) IterateKV(f func(key, val uint64) (newVal uint64)) {
403 t.Iterate(func(n node) {
404 // Only leaf nodes contain keys.
405 if !n.isLeaf() {
406 return
407 }
408
409 for i := 0; i < n.numKeys(); i++ {
410 key := n.key(i)
411 val := n.val(i)
412
413 // A zero value here means that this is a bogus entry.
414 if val == 0 {
415 continue
416 }
417
418 newVal := f(key, val)
419 if newVal != 0 {
420 n.setAt(valOffset(i), newVal)
421 }
422 }
423 })
424 }
425
426 func (t *Tree) print(n node, parentID uint64) {
427 n.print(parentID)
428 if n.isLeaf() {
429 return
430 }
431 pid := n.pageID()
432 for i := 0; i < maxKeys; i++ {
433 if n.key(i) == 0 {
434 return
435 }
436 childID := n.uint64(valOffset(i))
437 child := t.node(childID)
438 t.print(child, pid)
439 }
440 }
441
442 // Print iterates over the tree and prints all valid KVs.
443 func (t *Tree) Print() {
444 root := t.node(1)
445 t.print(root, 0)
446 }
447
448 // Splits the node into two. It moves right half of the keys from the original node to a newly
449 // created right node. It returns the right node.
450 func (t *Tree) split(pid uint64) node {
451 n := t.node(pid)
452 if !n.isFull() {
453 panic("This should be called only when n is full")
454 }
455
456 // Create a new node nn, copy over half the keys from n, and set the parent to n's parent.
457 nn := t.newNode(n.bits())
458 // Re-read n as the underlying buffer for tree might have changed during newNode.
459 n = t.node(pid)
460 rightHalf := n[keyOffset(maxKeys/2):keyOffset(maxKeys)]
461 copy(nn, rightHalf)
462 nn.setNumKeys(maxKeys - maxKeys/2)
463
464 // Remove entries from node n.
465 zeroOut(rightHalf)
466 n.setNumKeys(maxKeys / 2)
467 return nn
468 }
469
470 // shareWithSiblingXXX is unused for now. The idea is to move some keys to
471 // sibling when a node is full. But, I don't see any special benefits in our
472 // access pattern. It doesn't result in better occupancy ratios.
473 //
474 //nolint:unused
475 func (t *Tree) shareWithSiblingXXX(n node, idx int) bool {
476 if idx == 0 {
477 return false
478 }
479 left := t.node(n.val(idx - 1))
480 ns := left.numKeys()
481 if ns >= maxKeys/2 {
482 // Sibling is already getting full.
483 return false
484 }
485
486 right := t.node(n.val(idx))
487 // Copy over keys from right child to left child.
488 copied := copy(left[keyOffset(ns):], right[:keyOffset(oneThird)])
489 copied /= 2 // Considering that key-val constitute one key.
490 left.setNumKeys(ns + copied)
491
492 // Update the max key in parent node n for the left sibling.
493 n.setAt(keyOffset(idx-1), left.maxKey())
494
495 // Now move keys to left for the right sibling.
496 until := copy(right, right[keyOffset(oneThird):keyOffset(maxKeys)])
497 right.setNumKeys(until / 2)
498 zeroOut(right[until:keyOffset(maxKeys)])
499 return true
500 }
501
502 // Each node in the node is of size pageSize. Two kinds of nodes. Leaf nodes and internal nodes.
503 // Leaf nodes only contain the data. Internal nodes would contain the key and the offset to the
504 // child node.
505 // Internal node would have first entry as
506 // <0 offset to child>, <1000 offset>, <5000 offset>, and so on...
507 // Leaf nodes would just have: <key, value>, <key, value>, and so on...
508 // Last 16 bytes of the node are off limits.
509 // | pageID (8 bytes) | metaBits (1 byte) | 3 free bytes | numKeys (4 bytes) |
510 type node []uint64
511
512 func (n node) uint64(start int) uint64 { return n[start] }
513
514 // func (n node) uint32(start int) uint32 { return *(*uint32)(unsafe.Pointer(&n[start])) }
515
516 func keyOffset(i int) int { return 2 * i }
517 func valOffset(i int) int { return 2*i + 1 }
518 func (n node) numKeys() int { return int(n.uint64(valOffset(maxKeys)) & 0xFFFFFFFF) }
519 func (n node) pageID() uint64 { return n.uint64(keyOffset(maxKeys)) }
520 func (n node) key(i int) uint64 { return n.uint64(keyOffset(i)) }
521 func (n node) val(i int) uint64 { return n.uint64(valOffset(i)) }
522 func (n node) data(i int) []uint64 { return n[keyOffset(i):keyOffset(i+1)] }
523
524 func (n node) setAt(start int, k uint64) {
525 n[start] = k
526 }
527
528 func (n node) setNumKeys(num int) {
529 idx := valOffset(maxKeys)
530 val := n[idx]
531 val &= 0xFFFFFFFF00000000
532 val |= uint64(num)
533 n[idx] = val
534 }
535
536 func (n node) moveRight(lo int) {
537 hi := n.numKeys()
538 assert(hi != maxKeys)
539 // copy works despite of overlap in src and dst.
540 // See https://golang.org/pkg/builtin/#copy
541 copy(n[keyOffset(lo+1):keyOffset(hi+1)], n[keyOffset(lo):keyOffset(hi)])
542 }
543
544 const (
545 bitLeaf = uint64(1 << 63)
546 )
547
548 func (n node) setBit(b uint64) {
549 vo := valOffset(maxKeys)
550 val := n[vo]
551 val &= 0xFFFFFFFF
552 val |= b
553 n[vo] = val
554 }
555 func (n node) bits() uint64 {
556 return n.val(maxKeys) & 0xFF00000000000000
557 }
558 func (n node) isLeaf() bool {
559 return n.bits()&bitLeaf > 0
560 }
561
562 // isFull checks that the node is already full.
563 func (n node) isFull() bool {
564 return n.numKeys() == maxKeys
565 }
566
567 // Search returns the index of a smallest key >= k in a node.
568 func (n node) search(k uint64) int {
569 N := n.numKeys()
570 if N < 4 {
571 for i := 0; i < N; i++ {
572 if ki := n.key(i); ki >= k {
573 return i
574 }
575 }
576 return N
577 }
578 return int(simd.Search(n[:2*N], k))
579 // lo, hi := 0, N
580 // // Reduce the search space using binary seach and then do linear search.
581 // for hi-lo > 32 {
582 // mid := (hi + lo) / 2
583 // km := n.key(mid)
584 // if k == km {
585 // return mid
586 // }
587 // if k > km {
588 // // key is greater than the key at mid, so move right.
589 // lo = mid + 1
590 // } else {
591 // // else move left.
592 // hi = mid
593 // }
594 // }
595 // for i := lo; i <= hi; i++ {
596 // if ki := n.key(i); ki >= k {
597 // return i
598 // }
599 // }
600 // return N
601 }
602 func (n node) maxKey() uint64 {
603 idx := n.numKeys()
604 // idx points to the first key which is zero.
605 if idx > 0 {
606 idx--
607 }
608 return n.key(idx)
609 }
610
611 // compacts the node i.e., remove all the kvs with value < lo. It returns the remaining number of
612 // keys.
613 func (n node) compact(lo uint64) int {
614 N := n.numKeys()
615 mk := n.maxKey()
616 var left, right int
617 for right = 0; right < N; right++ {
618 if n.val(right) < lo && n.key(right) < mk {
619 // Skip over this key. Don't copy it.
620 continue
621 }
622 // Valid data. Copy it from right to left. Advance left.
623 if left != right {
624 copy(n.data(left), n.data(right))
625 }
626 left++
627 }
628 // zero out rest of the kv pairs.
629 zeroOut(n[keyOffset(left):keyOffset(right)])
630 n.setNumKeys(left)
631
632 // If the only key we have is the max key, and its value is less than lo, then we can indicate
633 // to the caller by returning a zero that it's OK to drop the node.
634 if left == 1 && n.key(0) == mk && n.val(0) < lo {
635 return 0
636 }
637 return left
638 }
639
640 func (n node) get(k uint64) uint64 {
641 idx := n.search(k)
642 // key is not found
643 if idx == n.numKeys() {
644 return 0
645 }
646 if ki := n.key(idx); ki == k {
647 return n.val(idx)
648 }
649 return 0
650 }
651
652 // set returns true if it added a new key.
653 func (n node) set(k, v uint64) (numAdded int) {
654 idx := n.search(k)
655 ki := n.key(idx)
656 if n.numKeys() == maxKeys {
657 // This happens during split of non-root node, when we are updating the child pointer of
658 // right node. Hence, the key should already exist.
659 assert(ki == k)
660 }
661 if ki > k {
662 // Found the first entry which is greater than k. So, we need to fit k
663 // just before it. For that, we should move the rest of the data in the
664 // node to the right to make space for k.
665 n.moveRight(idx)
666 }
667 // If the k does not exist already, increment the number of keys.
668 if ki != k {
669 n.setNumKeys(n.numKeys() + 1)
670 numAdded = 1
671 }
672 if ki == 0 || ki >= k {
673 n.setAt(keyOffset(idx), k)
674 n.setAt(valOffset(idx), v)
675 return
676 }
677 panic("shouldn't reach here")
678 }
679
680 func (n node) iterate(fn func(node, int)) {
681 for i := 0; i < maxKeys; i++ {
682 if k := n.key(i); k > 0 {
683 fn(n, i)
684 } else {
685 break
686 }
687 }
688 }
689
690 func (n node) print(parentID uint64) {
691 var keys []string
692 n.iterate(func(n node, i int) {
693 keys = append(keys, fmt.Sprintf("%d", n.key(i)))
694 })
695 if len(keys) > 8 {
696 copy(keys[4:], keys[len(keys)-4:])
697 keys[3] = "..."
698 keys = keys[:8]
699 }
700 fmt.Printf("%d Child of: %d num keys: %d keys: %s\n",
701 n.pageID(), parentID, n.numKeys(), strings.Join(keys, " "))
702 }
703