tree_math.mx raw

   1  package mls
   2  
   3  // MLS tree math (RFC 9420 Appendix C).
   4  // Array-based complete balanced binary tree representation.
   5  
   6  // numLeaves exposes operations on a tree with a given number of leaves.
   7  // RFC 9420 §7 bounds trees to 2^32 leaves, so width is always representable in uint32.
   8  type numLeaves uint32
   9  
  10  func numLeavesFromWidth(w uint) numLeaves {
  11  	if w == 0 {
  12  		return 0
  13  	}
  14  	return numLeaves((w-1)/2 + 1)
  15  }
  16  
  17  // width returns the number of nodes (minimum array length).
  18  func (n numLeaves) width() uint {
  19  	if n == 0 {
  20  		return 0
  21  	}
  22  	return 2*(uint(n)-1) + 1
  23  }
  24  
  25  // root returns the index of the root node.
  26  func (n numLeaves) root() nodeIndex {
  27  	return nodeIndex((1 << log2(n.width())) - 1)
  28  }
  29  
  30  // parent returns the parent node index. Returns false for the root.
  31  func (n numLeaves) parent(x nodeIndex) (nodeIndex, bool) {
  32  	if x == n.root() {
  33  		return 0, false
  34  	}
  35  	lvl := nodeIndex(x.level())
  36  	b := (x >> (lvl + 1)) & 1
  37  	p := (x | (1 << lvl)) ^ (b << (lvl + 1))
  38  	return p, true
  39  }
  40  
  41  // sibling returns the other child of the node's parent.
  42  func (n numLeaves) sibling(x nodeIndex) (nodeIndex, bool) {
  43  	p, ok := n.parent(x)
  44  	if !ok {
  45  		return 0, false
  46  	}
  47  	if x < p {
  48  		return p.right()
  49  	} else {
  50  		return p.left()
  51  	}
  52  }
  53  
  54  // directPath returns the path from x to the root (excluding x, excluding root).
  55  func (n numLeaves) directPath(x nodeIndex) []nodeIndex {
  56  	var path []nodeIndex
  57  	for {
  58  		p, ok := n.parent(x)
  59  		if !ok {
  60  			break
  61  		}
  62  		path = append(path, p)
  63  		x = p
  64  	}
  65  	return path
  66  }
  67  
  68  // copath returns the copath (siblings along the direct path).
  69  func (n numLeaves) copath(x nodeIndex) []nodeIndex {
  70  	path := n.directPath(x)
  71  	if len(path) == 0 {
  72  		return nil
  73  	}
  74  	path = append([]nodeIndex{x}, path...)
  75  	path = path[:len(path)-1]
  76  
  77  	var copath []nodeIndex
  78  	for _, y := range path {
  79  		s, ok := n.sibling(y)
  80  		if !ok {
  81  			panic("unreachable")
  82  		}
  83  		copath = append(copath, s)
  84  	}
  85  	return copath
  86  }
  87  
  88  // --- nodeIndex ---
  89  
  90  // nodeIndex addresses any node in the tree (leaf or parent).
  91  // Internally 2*leafIndex for leaves, odd for parents; RFC 9420 keeps this in u32.
  92  type nodeIndex uint32
  93  
  94  func (x nodeIndex) isLeaf() bool {
  95  	return x%2 == 0
  96  }
  97  
  98  func (x nodeIndex) leafIndex() (leafIndex, bool) {
  99  	if !x.isLeaf() {
 100  		return 0, false
 101  	}
 102  	return leafIndex(x) >> 1, true
 103  }
 104  
 105  func (x nodeIndex) left() (nodeIndex, bool) {
 106  	lvl := x.level()
 107  	if lvl == 0 {
 108  		return 0, false
 109  	}
 110  	return x ^ (1 << (nodeIndex(lvl) - 1)), true
 111  }
 112  
 113  func (x nodeIndex) right() (nodeIndex, bool) {
 114  	lvl := x.level()
 115  	if lvl == 0 {
 116  		return 0, false
 117  	}
 118  	return x ^ (3 << (nodeIndex(lvl) - 1)), true
 119  }
 120  
 121  func (x nodeIndex) children() (left, right nodeIndex, ok bool) {
 122  	l, ok := x.left()
 123  	if !ok {
 124  		return 0, 0, false
 125  	}
 126  	r, _ := x.right()
 127  	return l, r, true
 128  }
 129  
 130  func (x nodeIndex) level() uint {
 131  	if x&1 == 0 {
 132  		return 0
 133  	}
 134  	lvl := uint(0)
 135  	for (x>>lvl)&1 == 1 {
 136  		lvl++
 137  	}
 138  	return lvl
 139  }
 140  
 141  // commonAncestor returns the lowest node in both direct paths.
 142  func commonAncestor(x, y nodeIndex) nodeIndex {
 143  	lx, ly := x.level()+1, y.level()+1
 144  	if lx <= ly && x>>ly == y>>ly {
 145  		return y
 146  	} else if ly <= lx && x>>lx == y>>lx {
 147  		return x
 148  	}
 149  
 150  	xn, yn := x, y
 151  	k := uint(0)
 152  	for xn != yn {
 153  		xn, yn = xn>>1, yn>>1
 154  		k++
 155  	}
 156  	return (xn << k) + (1 << (k - 1)) - 1
 157  }
 158  
 159  // --- leafIndex ---
 160  
 161  // leafIndex addresses a member leaf in the ratchet tree.
 162  // Wire-encoded as uint32 (RFC 9420 §7).
 163  type leafIndex uint32
 164  
 165  func (li leafIndex) nodeIndex() nodeIndex {
 166  	return nodeIndex(2 * li)
 167  }
 168  
 169  // --- Helpers ---
 170  
 171  func log2(x uint) uint {
 172  	if x == 0 {
 173  		return 0
 174  	}
 175  	k := uint(0)
 176  	for x>>k > 0 {
 177  		k++
 178  	}
 179  	return k - 1
 180  }
 181  
 182  func isPowerOf2(x uint) bool {
 183  	return x != 0 && x&(x-1) == 0
 184  }
 185