package mls // MLS tree math (RFC 9420 Appendix C). // Array-based complete balanced binary tree representation. // numLeaves exposes operations on a tree with a given number of leaves. // RFC 9420 §7 bounds trees to 2^32 leaves, so width is always representable in uint32. type numLeaves uint32 func numLeavesFromWidth(w uint) numLeaves { if w == 0 { return 0 } return numLeaves((w-1)/2 + 1) } // width returns the number of nodes (minimum array length). func (n numLeaves) width() uint { if n == 0 { return 0 } return 2*(uint(n)-1) + 1 } // root returns the index of the root node. func (n numLeaves) root() nodeIndex { return nodeIndex((1 << log2(n.width())) - 1) } // parent returns the parent node index. Returns false for the root. func (n numLeaves) parent(x nodeIndex) (nodeIndex, bool) { if x == n.root() { return 0, false } lvl := nodeIndex(x.level()) b := (x >> (lvl + 1)) & 1 p := (x | (1 << lvl)) ^ (b << (lvl + 1)) return p, true } // sibling returns the other child of the node's parent. func (n numLeaves) sibling(x nodeIndex) (nodeIndex, bool) { p, ok := n.parent(x) if !ok { return 0, false } if x < p { return p.right() } else { return p.left() } } // directPath returns the path from x to the root (excluding x, excluding root). func (n numLeaves) directPath(x nodeIndex) []nodeIndex { var path []nodeIndex for { p, ok := n.parent(x) if !ok { break } path = append(path, p) x = p } return path } // copath returns the copath (siblings along the direct path). func (n numLeaves) copath(x nodeIndex) []nodeIndex { path := n.directPath(x) if len(path) == 0 { return nil } path = append([]nodeIndex{x}, path...) path = path[:len(path)-1] var copath []nodeIndex for _, y := range path { s, ok := n.sibling(y) if !ok { panic("unreachable") } copath = append(copath, s) } return copath } // --- nodeIndex --- // nodeIndex addresses any node in the tree (leaf or parent). // Internally 2*leafIndex for leaves, odd for parents; RFC 9420 keeps this in u32. type nodeIndex uint32 func (x nodeIndex) isLeaf() bool { return x%2 == 0 } func (x nodeIndex) leafIndex() (leafIndex, bool) { if !x.isLeaf() { return 0, false } return leafIndex(x) >> 1, true } func (x nodeIndex) left() (nodeIndex, bool) { lvl := x.level() if lvl == 0 { return 0, false } return x ^ (1 << (nodeIndex(lvl) - 1)), true } func (x nodeIndex) right() (nodeIndex, bool) { lvl := x.level() if lvl == 0 { return 0, false } return x ^ (3 << (nodeIndex(lvl) - 1)), true } func (x nodeIndex) children() (left, right nodeIndex, ok bool) { l, ok := x.left() if !ok { return 0, 0, false } r, _ := x.right() return l, r, true } func (x nodeIndex) level() uint { if x&1 == 0 { return 0 } lvl := uint(0) for (x>>lvl)&1 == 1 { lvl++ } return lvl } // commonAncestor returns the lowest node in both direct paths. func commonAncestor(x, y nodeIndex) nodeIndex { lx, ly := x.level()+1, y.level()+1 if lx <= ly && x>>ly == y>>ly { return y } else if ly <= lx && x>>lx == y>>lx { return x } xn, yn := x, y k := uint(0) for xn != yn { xn, yn = xn>>1, yn>>1 k++ } return (xn << k) + (1 << (k - 1)) - 1 } // --- leafIndex --- // leafIndex addresses a member leaf in the ratchet tree. // Wire-encoded as uint32 (RFC 9420 §7). type leafIndex uint32 func (li leafIndex) nodeIndex() nodeIndex { return nodeIndex(2 * li) } // --- Helpers --- func log2(x uint) uint { if x == 0 { return 0 } k := uint(0) for x>>k > 0 { k++ } return k - 1 } func isPowerOf2(x uint) bool { return x != 0 && x&(x-1) == 0 }