trie.go raw

   1  /*
   2   * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
   3   * SPDX-License-Identifier: Apache-2.0
   4   */
   5  
   6  package trie
   7  
   8  import (
   9  	"fmt"
  10  	"strconv"
  11  	"strings"
  12  
  13  	"github.com/dgraph-io/badger/v4/pb"
  14  	"github.com/dgraph-io/badger/v4/y"
  15  )
  16  
  17  type node struct {
  18  	children map[byte]*node
  19  	ignore   *node
  20  	ids      []uint64
  21  }
  22  
  23  func (n *node) isEmpty() bool {
  24  	return len(n.children) == 0 && len(n.ids) == 0 && n.ignore == nil
  25  }
  26  
  27  func newNode() *node {
  28  	return &node{
  29  		children: make(map[byte]*node),
  30  		ids:      []uint64{},
  31  	}
  32  }
  33  
  34  // Trie datastructure.
  35  type Trie struct {
  36  	root *node
  37  }
  38  
  39  // NewTrie returns Trie.
  40  func NewTrie() *Trie {
  41  	return &Trie{
  42  		root: newNode(),
  43  	}
  44  }
  45  
  46  // parseIgnoreBytes would parse the ignore string, and convert it into a list of bools, where
  47  // bool[idx] = true implies that key[idx] can be ignored during comparison.
  48  func parseIgnoreBytes(ig string) ([]bool, error) {
  49  	var out []bool
  50  	if ig == "" {
  51  		return out, nil
  52  	}
  53  
  54  	for _, each := range strings.Split(strings.TrimSpace(ig), ",") {
  55  		r := strings.Split(strings.TrimSpace(each), "-")
  56  		if len(r) == 0 || len(r) > 2 {
  57  			return out, fmt.Errorf("Invalid range: %s", each)
  58  		}
  59  		start, end := -1, -1 //nolint:ineffassign
  60  		if len(r) == 2 {
  61  			idx, err := strconv.Atoi(strings.TrimSpace(r[1]))
  62  			if err != nil {
  63  				return out, err
  64  			}
  65  			end = idx
  66  		}
  67  		{
  68  			// Always consider r[0]
  69  			idx, err := strconv.Atoi(strings.TrimSpace(r[0]))
  70  			if err != nil {
  71  				return out, err
  72  			}
  73  			start = idx
  74  		}
  75  		if start == -1 {
  76  			return out, fmt.Errorf("Invalid range: %s", each)
  77  		}
  78  		for start >= len(out) {
  79  			out = append(out, false)
  80  		}
  81  		for end >= len(out) { // end could be -1, so do have the start loop above.
  82  			out = append(out, false)
  83  		}
  84  		if end == -1 {
  85  			out[start] = true
  86  		} else {
  87  			for i := start; i <= end; i++ {
  88  				out[i] = true
  89  			}
  90  		}
  91  	}
  92  	return out, nil
  93  }
  94  
  95  // Add adds the id in the trie for the given prefix path.
  96  func (t *Trie) Add(prefix []byte, id uint64) {
  97  	m := pb.Match{
  98  		Prefix: prefix,
  99  	}
 100  	y.Check(t.AddMatch(m, id))
 101  }
 102  
 103  // AddMatch allows you to send in a prefix match, with "holes" in the prefix. The holes are
 104  // specified via IgnoreBytes in a comma-separated list of indices starting from 0. A dash can be
 105  // used to denote a range. Valid example is "3, 5-8, 10, 12-15". Length of IgnoreBytes does not need
 106  // to match the length of the Prefix passed.
 107  //
 108  // Consider a prefix = "aaaa". If the IgnoreBytes is set to "0, 2", then along with key "aaaa...",
 109  // a key "baba..." would also match.
 110  func (t *Trie) AddMatch(m pb.Match, id uint64) error {
 111  	return t.fix(m, id, set)
 112  }
 113  
 114  const (
 115  	set = iota
 116  	del
 117  )
 118  
 119  func (t *Trie) fix(m pb.Match, id uint64, op int) error {
 120  	curNode := t.root
 121  
 122  	ignore, err := parseIgnoreBytes(m.IgnoreBytes)
 123  	if err != nil {
 124  		return fmt.Errorf( "while parsing ignore bytes: %s: %w", m.IgnoreBytes,err)
 125  	}
 126  	for len(ignore) < len(m.Prefix) {
 127  		ignore = append(ignore, false)
 128  	}
 129  	for idx, byt := range m.Prefix {
 130  		var child *node
 131  		if ignore[idx] {
 132  			child = curNode.ignore
 133  			if child == nil {
 134  				if op == del {
 135  					// No valid node found for delete operation. Return immediately.
 136  					return nil
 137  				}
 138  				child = newNode()
 139  				curNode.ignore = child
 140  			}
 141  		} else {
 142  			child = curNode.children[byt]
 143  			if child == nil {
 144  				if op == del {
 145  					// No valid node found for delete operation. Return immediately.
 146  					return nil
 147  				}
 148  				child = newNode()
 149  				curNode.children[byt] = child
 150  			}
 151  		}
 152  		curNode = child
 153  	}
 154  
 155  	// We only need to add the id to the last node of the given prefix.
 156  	if op == set {
 157  		curNode.ids = append(curNode.ids, id)
 158  
 159  	} else if op == del {
 160  		out := curNode.ids[:0]
 161  		for _, cid := range curNode.ids {
 162  			if id != cid {
 163  				out = append(out, cid)
 164  			}
 165  		}
 166  		curNode.ids = out
 167  	} else {
 168  		y.AssertTrue(false)
 169  	}
 170  	return nil
 171  }
 172  
 173  func (t *Trie) Get(key []byte) map[uint64]struct{} {
 174  	return t.get(t.root, key)
 175  }
 176  
 177  // Get returns prefix matched ids for the given key.
 178  func (t *Trie) get(curNode *node, key []byte) map[uint64]struct{} {
 179  	y.AssertTrue(curNode != nil)
 180  
 181  	out := make(map[uint64]struct{})
 182  	// If any node in the path of the key has ids, pick them up.
 183  	// This would also match nil prefixes.
 184  	for _, i := range curNode.ids {
 185  		out[i] = struct{}{}
 186  	}
 187  	if len(key) == 0 {
 188  		return out
 189  	}
 190  
 191  	// If we found an ignore node, traverse that path.
 192  	if curNode.ignore != nil {
 193  		res := t.get(curNode.ignore, key[1:])
 194  		for id := range res {
 195  			out[id] = struct{}{}
 196  		}
 197  	}
 198  
 199  	if child := curNode.children[key[0]]; child != nil {
 200  		res := t.get(child, key[1:])
 201  		for id := range res {
 202  			out[id] = struct{}{}
 203  		}
 204  	}
 205  	return out
 206  }
 207  
 208  func removeEmpty(curNode *node) bool {
 209  	// Go depth first.
 210  	if curNode.ignore != nil {
 211  		if empty := removeEmpty(curNode.ignore); empty {
 212  			curNode.ignore = nil
 213  		}
 214  	}
 215  
 216  	for byt, n := range curNode.children {
 217  		if empty := removeEmpty(n); empty {
 218  			delete(curNode.children, byt)
 219  		}
 220  	}
 221  
 222  	return curNode.isEmpty()
 223  }
 224  
 225  // Delete will delete the id if the id exist in the given index path.
 226  func (t *Trie) Delete(prefix []byte, id uint64) error {
 227  	return t.DeleteMatch(pb.Match{Prefix: prefix}, id)
 228  }
 229  
 230  func (t *Trie) DeleteMatch(m pb.Match, id uint64) error {
 231  	if err := t.fix(m, id, del); err != nil {
 232  		return err
 233  	}
 234  	// Would recursively delete empty nodes.
 235  	// Do not remove the t.root even if its empty.
 236  	removeEmpty(t.root)
 237  	return nil
 238  }
 239  
 240  func numNodes(curNode *node) int {
 241  	if curNode == nil {
 242  		return 0
 243  	}
 244  
 245  	num := numNodes(curNode.ignore)
 246  	for _, n := range curNode.children {
 247  		num += numNodes(n)
 248  	}
 249  	return num + 1
 250  }
 251