trie.go raw
1 /*
2 * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6 package trie
7
8 import (
9 "fmt"
10 "strconv"
11 "strings"
12
13 "github.com/dgraph-io/badger/v4/pb"
14 "github.com/dgraph-io/badger/v4/y"
15 )
16
17 type node struct {
18 children map[byte]*node
19 ignore *node
20 ids []uint64
21 }
22
23 func (n *node) isEmpty() bool {
24 return len(n.children) == 0 && len(n.ids) == 0 && n.ignore == nil
25 }
26
27 func newNode() *node {
28 return &node{
29 children: make(map[byte]*node),
30 ids: []uint64{},
31 }
32 }
33
34 // Trie datastructure.
35 type Trie struct {
36 root *node
37 }
38
39 // NewTrie returns Trie.
40 func NewTrie() *Trie {
41 return &Trie{
42 root: newNode(),
43 }
44 }
45
46 // parseIgnoreBytes would parse the ignore string, and convert it into a list of bools, where
47 // bool[idx] = true implies that key[idx] can be ignored during comparison.
48 func parseIgnoreBytes(ig string) ([]bool, error) {
49 var out []bool
50 if ig == "" {
51 return out, nil
52 }
53
54 for _, each := range strings.Split(strings.TrimSpace(ig), ",") {
55 r := strings.Split(strings.TrimSpace(each), "-")
56 if len(r) == 0 || len(r) > 2 {
57 return out, fmt.Errorf("Invalid range: %s", each)
58 }
59 start, end := -1, -1 //nolint:ineffassign
60 if len(r) == 2 {
61 idx, err := strconv.Atoi(strings.TrimSpace(r[1]))
62 if err != nil {
63 return out, err
64 }
65 end = idx
66 }
67 {
68 // Always consider r[0]
69 idx, err := strconv.Atoi(strings.TrimSpace(r[0]))
70 if err != nil {
71 return out, err
72 }
73 start = idx
74 }
75 if start == -1 {
76 return out, fmt.Errorf("Invalid range: %s", each)
77 }
78 for start >= len(out) {
79 out = append(out, false)
80 }
81 for end >= len(out) { // end could be -1, so do have the start loop above.
82 out = append(out, false)
83 }
84 if end == -1 {
85 out[start] = true
86 } else {
87 for i := start; i <= end; i++ {
88 out[i] = true
89 }
90 }
91 }
92 return out, nil
93 }
94
95 // Add adds the id in the trie for the given prefix path.
96 func (t *Trie) Add(prefix []byte, id uint64) {
97 m := pb.Match{
98 Prefix: prefix,
99 }
100 y.Check(t.AddMatch(m, id))
101 }
102
103 // AddMatch allows you to send in a prefix match, with "holes" in the prefix. The holes are
104 // specified via IgnoreBytes in a comma-separated list of indices starting from 0. A dash can be
105 // used to denote a range. Valid example is "3, 5-8, 10, 12-15". Length of IgnoreBytes does not need
106 // to match the length of the Prefix passed.
107 //
108 // Consider a prefix = "aaaa". If the IgnoreBytes is set to "0, 2", then along with key "aaaa...",
109 // a key "baba..." would also match.
110 func (t *Trie) AddMatch(m pb.Match, id uint64) error {
111 return t.fix(m, id, set)
112 }
113
114 const (
115 set = iota
116 del
117 )
118
119 func (t *Trie) fix(m pb.Match, id uint64, op int) error {
120 curNode := t.root
121
122 ignore, err := parseIgnoreBytes(m.IgnoreBytes)
123 if err != nil {
124 return fmt.Errorf( "while parsing ignore bytes: %s: %w", m.IgnoreBytes,err)
125 }
126 for len(ignore) < len(m.Prefix) {
127 ignore = append(ignore, false)
128 }
129 for idx, byt := range m.Prefix {
130 var child *node
131 if ignore[idx] {
132 child = curNode.ignore
133 if child == nil {
134 if op == del {
135 // No valid node found for delete operation. Return immediately.
136 return nil
137 }
138 child = newNode()
139 curNode.ignore = child
140 }
141 } else {
142 child = curNode.children[byt]
143 if child == nil {
144 if op == del {
145 // No valid node found for delete operation. Return immediately.
146 return nil
147 }
148 child = newNode()
149 curNode.children[byt] = child
150 }
151 }
152 curNode = child
153 }
154
155 // We only need to add the id to the last node of the given prefix.
156 if op == set {
157 curNode.ids = append(curNode.ids, id)
158
159 } else if op == del {
160 out := curNode.ids[:0]
161 for _, cid := range curNode.ids {
162 if id != cid {
163 out = append(out, cid)
164 }
165 }
166 curNode.ids = out
167 } else {
168 y.AssertTrue(false)
169 }
170 return nil
171 }
172
173 func (t *Trie) Get(key []byte) map[uint64]struct{} {
174 return t.get(t.root, key)
175 }
176
177 // Get returns prefix matched ids for the given key.
178 func (t *Trie) get(curNode *node, key []byte) map[uint64]struct{} {
179 y.AssertTrue(curNode != nil)
180
181 out := make(map[uint64]struct{})
182 // If any node in the path of the key has ids, pick them up.
183 // This would also match nil prefixes.
184 for _, i := range curNode.ids {
185 out[i] = struct{}{}
186 }
187 if len(key) == 0 {
188 return out
189 }
190
191 // If we found an ignore node, traverse that path.
192 if curNode.ignore != nil {
193 res := t.get(curNode.ignore, key[1:])
194 for id := range res {
195 out[id] = struct{}{}
196 }
197 }
198
199 if child := curNode.children[key[0]]; child != nil {
200 res := t.get(child, key[1:])
201 for id := range res {
202 out[id] = struct{}{}
203 }
204 }
205 return out
206 }
207
208 func removeEmpty(curNode *node) bool {
209 // Go depth first.
210 if curNode.ignore != nil {
211 if empty := removeEmpty(curNode.ignore); empty {
212 curNode.ignore = nil
213 }
214 }
215
216 for byt, n := range curNode.children {
217 if empty := removeEmpty(n); empty {
218 delete(curNode.children, byt)
219 }
220 }
221
222 return curNode.isEmpty()
223 }
224
225 // Delete will delete the id if the id exist in the given index path.
226 func (t *Trie) Delete(prefix []byte, id uint64) error {
227 return t.DeleteMatch(pb.Match{Prefix: prefix}, id)
228 }
229
230 func (t *Trie) DeleteMatch(m pb.Match, id uint64) error {
231 if err := t.fix(m, id, del); err != nil {
232 return err
233 }
234 // Would recursively delete empty nodes.
235 // Do not remove the t.root even if its empty.
236 removeEmpty(t.root)
237 return nil
238 }
239
240 func numNodes(curNode *node) int {
241 if curNode == nil {
242 return 0
243 }
244
245 num := numNodes(curNode.ignore)
246 for _, n := range curNode.children {
247 num += numNodes(n)
248 }
249 return num + 1
250 }
251