tree_math.mx raw
1 package mls
2
3 // MLS tree math (RFC 9420 Appendix C).
4 // Array-based complete balanced binary tree representation.
5
6 // numLeaves exposes operations on a tree with a given number of leaves.
7 // RFC 9420 §7 bounds trees to 2^32 leaves, so width is always representable in uint32.
8 type numLeaves uint32
9
10 func numLeavesFromWidth(w uint) numLeaves {
11 if w == 0 {
12 return 0
13 }
14 return numLeaves((w-1)/2 + 1)
15 }
16
17 // width returns the number of nodes (minimum array length).
18 func (n numLeaves) width() uint {
19 if n == 0 {
20 return 0
21 }
22 return 2*(uint(n)-1) + 1
23 }
24
25 // root returns the index of the root node.
26 func (n numLeaves) root() nodeIndex {
27 return nodeIndex((1 << log2(n.width())) - 1)
28 }
29
30 // parent returns the parent node index. Returns false for the root.
31 func (n numLeaves) parent(x nodeIndex) (nodeIndex, bool) {
32 if x == n.root() {
33 return 0, false
34 }
35 lvl := nodeIndex(x.level())
36 b := (x >> (lvl + 1)) & 1
37 p := (x | (1 << lvl)) ^ (b << (lvl + 1))
38 return p, true
39 }
40
41 // sibling returns the other child of the node's parent.
42 func (n numLeaves) sibling(x nodeIndex) (nodeIndex, bool) {
43 p, ok := n.parent(x)
44 if !ok {
45 return 0, false
46 }
47 if x < p {
48 return p.right()
49 } else {
50 return p.left()
51 }
52 }
53
54 // directPath returns the path from x to the root (excluding x, excluding root).
55 func (n numLeaves) directPath(x nodeIndex) []nodeIndex {
56 var path []nodeIndex
57 for {
58 p, ok := n.parent(x)
59 if !ok {
60 break
61 }
62 path = append(path, p)
63 x = p
64 }
65 return path
66 }
67
68 // copath returns the copath (siblings along the direct path).
69 func (n numLeaves) copath(x nodeIndex) []nodeIndex {
70 path := n.directPath(x)
71 if len(path) == 0 {
72 return nil
73 }
74 path = append([]nodeIndex{x}, path...)
75 path = path[:len(path)-1]
76
77 var copath []nodeIndex
78 for _, y := range path {
79 s, ok := n.sibling(y)
80 if !ok {
81 panic("unreachable")
82 }
83 copath = append(copath, s)
84 }
85 return copath
86 }
87
88 // --- nodeIndex ---
89
90 // nodeIndex addresses any node in the tree (leaf or parent).
91 // Internally 2*leafIndex for leaves, odd for parents; RFC 9420 keeps this in u32.
92 type nodeIndex uint32
93
94 func (x nodeIndex) isLeaf() bool {
95 return x%2 == 0
96 }
97
98 func (x nodeIndex) leafIndex() (leafIndex, bool) {
99 if !x.isLeaf() {
100 return 0, false
101 }
102 return leafIndex(x) >> 1, true
103 }
104
105 func (x nodeIndex) left() (nodeIndex, bool) {
106 lvl := x.level()
107 if lvl == 0 {
108 return 0, false
109 }
110 return x ^ (1 << (nodeIndex(lvl) - 1)), true
111 }
112
113 func (x nodeIndex) right() (nodeIndex, bool) {
114 lvl := x.level()
115 if lvl == 0 {
116 return 0, false
117 }
118 return x ^ (3 << (nodeIndex(lvl) - 1)), true
119 }
120
121 func (x nodeIndex) children() (left, right nodeIndex, ok bool) {
122 l, ok := x.left()
123 if !ok {
124 return 0, 0, false
125 }
126 r, _ := x.right()
127 return l, r, true
128 }
129
130 func (x nodeIndex) level() uint {
131 if x&1 == 0 {
132 return 0
133 }
134 lvl := uint(0)
135 for (x>>lvl)&1 == 1 {
136 lvl++
137 }
138 return lvl
139 }
140
141 // commonAncestor returns the lowest node in both direct paths.
142 func commonAncestor(x, y nodeIndex) nodeIndex {
143 lx, ly := x.level()+1, y.level()+1
144 if lx <= ly && x>>ly == y>>ly {
145 return y
146 } else if ly <= lx && x>>lx == y>>lx {
147 return x
148 }
149
150 xn, yn := x, y
151 k := uint(0)
152 for xn != yn {
153 xn, yn = xn>>1, yn>>1
154 k++
155 }
156 return (xn << k) + (1 << (k - 1)) - 1
157 }
158
159 // --- leafIndex ---
160
161 // leafIndex addresses a member leaf in the ratchet tree.
162 // Wire-encoded as uint32 (RFC 9420 §7).
163 type leafIndex uint32
164
165 func (li leafIndex) nodeIndex() nodeIndex {
166 return nodeIndex(2 * li)
167 }
168
169 // --- Helpers ---
170
171 func log2(x uint) uint {
172 if x == 0 {
173 return 0
174 }
175 k := uint(0)
176 for x>>k > 0 {
177 k++
178 }
179 return k - 1
180 }
181
182 func isPowerOf2(x uint) bool {
183 return x != 0 && x&(x-1) == 0
184 }
185