tree.mx raw

   1  // Copyright 2025 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package token
   6  
   7  // tree is a self-balancing AVL tree; see
   8  // Lewis & Denenberg, Data Structures and Their Algorithms.
   9  //
  10  // An AVL tree is a binary tree in which the difference between the
  11  // heights of a node's two subtrees--the node's "balance factor"--is
  12  // at most one. It is more strictly balanced than a red/black tree,
  13  // and thus favors lookups at the expense of updates, which is the
  14  // appropriate trade-off for FileSet.
  15  //
  16  // Insertion at a node may cause its ancestors' balance factors to
  17  // temporarily reach ±2, requiring rebalancing of each such ancestor
  18  // by a rotation.
  19  //
  20  // Each key is the pos-end range of a single File.
  21  // All Files in the tree must have disjoint ranges.
  22  //
  23  // The implementation is simplified from Russ Cox's github.com/rsc/omap.
  24  
  25  import (
  26  	"fmt"
  27  	"iter"
  28  )
  29  
  30  // A tree is a tree-based ordered map:
  31  // each value is a *File, keyed by its Pos range.
  32  // All map entries cover disjoint ranges.
  33  //
  34  // The zero value of tree is an empty map ready to use.
  35  type tree struct {
  36  	root *node
  37  }
  38  
  39  type node struct {
  40  	// We use the notation (parent left right) in many comments.
  41  	parent  *node
  42  	left    *node
  43  	right   *node
  44  	file    *File
  45  	key     key   // = file.key(), but improves locality (25% faster)
  46  	balance int32 // at most ±2
  47  	height  int32
  48  }
  49  
  50  // A key represents the Pos range of a File.
  51  type key struct{ start, end int }
  52  
  53  func (f *File) key() key {
  54  	return key{f.base, f.base + f.size}
  55  }
  56  
  57  // compareKey reports whether x is before y (-1),
  58  // after y (+1), or overlapping y (0).
  59  // This is a total order so long as all
  60  // files in the tree have disjoint ranges.
  61  //
  62  // All files are separated by at least one unit.
  63  // This allows us to use strict < comparisons.
  64  // Use key{p, p} to search for a zero-width position
  65  // even at the start or end of a file.
  66  func compareKey(x, y key) int {
  67  	switch {
  68  	case x.end < y.start:
  69  		return -1
  70  	case y.end < x.start:
  71  		return +1
  72  	}
  73  	return 0
  74  }
  75  
  76  // check asserts that each node's height, subtree, and parent link is
  77  // correct.
  78  func (n *node) check(parent *node) {
  79  	const debugging = false
  80  	if debugging {
  81  		if n == nil {
  82  			return
  83  		}
  84  		if n.parent != parent {
  85  			panic("bad parent")
  86  		}
  87  		n.left.check(n)
  88  		n.right.check(n)
  89  		n.checkBalance()
  90  	}
  91  }
  92  
  93  func (n *node) checkBalance() {
  94  	lheight, rheight := n.left.safeHeight(), n.right.safeHeight()
  95  	balance := rheight - lheight
  96  	if balance != n.balance {
  97  		panic("bad node.balance")
  98  	}
  99  	if !(-2 <= balance && balance <= +2) {
 100  		panic(fmt.Sprintf("node.balance out of range: %d", balance))
 101  	}
 102  	h := 1 + max(lheight, rheight)
 103  	if h != n.height {
 104  		panic("bad node.height")
 105  	}
 106  }
 107  
 108  // locate returns a pointer to the variable that holds the node
 109  // identified by k, along with its parent, if any. If the key is not
 110  // present, it returns a pointer to the node where the key should be
 111  // inserted by a subsequent call to [tree.set].
 112  func (t *tree) locate(k key) (pos **node, parent *node) {
 113  	pos, x := &t.root, t.root
 114  	for x != nil {
 115  		sign := compareKey(k, x.key)
 116  		if sign < 0 {
 117  			pos, x, parent = &x.left, x.left, x
 118  		} else if sign > 0 {
 119  			pos, x, parent = &x.right, x.right, x
 120  		} else {
 121  			break
 122  		}
 123  	}
 124  	return pos, parent
 125  }
 126  
 127  // all returns an iterator over the tree t.
 128  // If t is modified during the iteration,
 129  // some files may not be visited.
 130  // No file will be visited multiple times.
 131  func (t *tree) all() iter.Seq[*File] {
 132  	return func(yield func(*File) bool) {
 133  		if t == nil {
 134  			return
 135  		}
 136  		x := t.root
 137  		if x != nil {
 138  			for x.left != nil {
 139  				x = x.left
 140  			}
 141  		}
 142  		for x != nil && yield(x.file) {
 143  			if x.height >= 0 {
 144  				// still in tree
 145  				x = x.next()
 146  			} else {
 147  				// deleted
 148  				x = t.nextAfter(t.locate(x.key))
 149  			}
 150  		}
 151  	}
 152  }
 153  
 154  // nextAfter returns the node in the key sequence following
 155  // (pos, parent), a result pair from [tree.locate].
 156  func (t *tree) nextAfter(pos **node, parent *node) *node {
 157  	switch {
 158  	case *pos != nil:
 159  		return (*pos).next()
 160  	case parent == nil:
 161  		return nil
 162  	case pos == &parent.left:
 163  		return parent
 164  	default:
 165  		return parent.next()
 166  	}
 167  }
 168  
 169  func (x *node) next() *node {
 170  	if x.right == nil {
 171  		for x.parent != nil && x.parent.right == x {
 172  			x = x.parent
 173  		}
 174  		return x.parent
 175  	}
 176  	x = x.right
 177  	for x.left != nil {
 178  		x = x.left
 179  	}
 180  	return x
 181  }
 182  
 183  func (t *tree) setRoot(x *node) {
 184  	t.root = x
 185  	if x != nil {
 186  		x.parent = nil
 187  	}
 188  }
 189  
 190  func (x *node) setLeft(y *node) {
 191  	x.left = y
 192  	if y != nil {
 193  		y.parent = x
 194  	}
 195  }
 196  
 197  func (x *node) setRight(y *node) {
 198  	x.right = y
 199  	if y != nil {
 200  		y.parent = x
 201  	}
 202  }
 203  
 204  func (n *node) safeHeight() int32 {
 205  	if n == nil {
 206  		return -1
 207  	}
 208  	return n.height
 209  }
 210  
 211  func (n *node) update() {
 212  	lheight, rheight := n.left.safeHeight(), n.right.safeHeight()
 213  	n.height = max(lheight, rheight) + 1
 214  	n.balance = rheight - lheight
 215  }
 216  
 217  func (t *tree) replaceChild(parent, old, new *node) {
 218  	switch {
 219  	case parent == nil:
 220  		if t.root != old {
 221  			panic("corrupt tree")
 222  		}
 223  		t.setRoot(new)
 224  	case parent.left == old:
 225  		parent.setLeft(new)
 226  	case parent.right == old:
 227  		parent.setRight(new)
 228  	default:
 229  		panic("corrupt tree")
 230  	}
 231  }
 232  
 233  // rebalanceUp visits each excessively unbalanced ancestor
 234  // of x, restoring balance by rotating it.
 235  //
 236  // x is a node that has just been mutated, and so the height and
 237  // balance of x and its ancestors may be stale, but the children of x
 238  // must be in a valid state.
 239  func (t *tree) rebalanceUp(x *node) {
 240  	for x != nil {
 241  		h := x.height
 242  		x.update()
 243  		switch x.balance {
 244  		case -2:
 245  			if x.left.balance == 1 {
 246  				t.rotateLeft(x.left)
 247  			}
 248  			x = t.rotateRight(x)
 249  
 250  		case +2:
 251  			if x.right.balance == -1 {
 252  				t.rotateRight(x.right)
 253  			}
 254  			x = t.rotateLeft(x)
 255  		}
 256  		if x.height == h {
 257  			// x's height has not changed, so the height
 258  			// and balance of its ancestors have not changed;
 259  			// no further rebalancing is required.
 260  			return
 261  		}
 262  		x = x.parent
 263  	}
 264  }
 265  
 266  // rotateRight rotates the subtree rooted at node y.
 267  // turning (y (x a b) c) into (x a (y b c)).
 268  func (t *tree) rotateRight(y *node) *node {
 269  	// p -> (y (x a b) c)
 270  	p := y.parent
 271  	x := y.left
 272  	b := x.right
 273  
 274  	x.checkBalance()
 275  	y.checkBalance()
 276  
 277  	x.setRight(y)
 278  	y.setLeft(b)
 279  	t.replaceChild(p, y, x)
 280  
 281  	y.update()
 282  	x.update()
 283  	return x
 284  }
 285  
 286  // rotateLeft rotates the subtree rooted at node x.
 287  // turning (x a (y b c)) into (y (x a b) c).
 288  func (t *tree) rotateLeft(x *node) *node {
 289  	// p -> (x a (y b c))
 290  	p := x.parent
 291  	y := x.right
 292  	b := y.left
 293  
 294  	x.checkBalance()
 295  	y.checkBalance()
 296  
 297  	y.setLeft(x)
 298  	x.setRight(b)
 299  	t.replaceChild(p, x, y)
 300  
 301  	x.update()
 302  	y.update()
 303  	return y
 304  }
 305  
 306  // add inserts file into the tree, if not present.
 307  // It panics if file overlaps with another.
 308  func (t *tree) add(file *File) {
 309  	pos, parent := t.locate(file.key())
 310  	if *pos == nil {
 311  		t.set(file, pos, parent) // missing; insert
 312  		return
 313  	}
 314  	if prev := (*pos).file; prev != file {
 315  		panic(fmt.Sprintf("file %s (%d-%d) overlaps with file %s (%d-%d)",
 316  			prev.Name(), prev.Base(), prev.Base()+prev.Size(),
 317  			file.Name(), file.Base(), file.Base()+file.Size()))
 318  	}
 319  }
 320  
 321  // set updates the existing node at (pos, parent) if present, or
 322  // inserts a new node if not, so that it refers to file.
 323  func (t *tree) set(file *File, pos **node, parent *node) {
 324  	if x := *pos; x != nil {
 325  		// This code path isn't currently needed
 326  		// because FileSet never updates an existing entry.
 327  		// Remove this assertion if things change.
 328  		if true {
 329  			panic("unreachable according to current FileSet requirements")
 330  		}
 331  		x.file = file
 332  		return
 333  	}
 334  	x := &node{file: file, key: file.key(), parent: parent, height: -1}
 335  	*pos = x
 336  	t.rebalanceUp(x)
 337  }
 338  
 339  // delete deletes the node at pos.
 340  func (t *tree) delete(pos **node) {
 341  	t.root.check(nil)
 342  
 343  	x := *pos
 344  	switch {
 345  	case x == nil:
 346  		// This code path isn't currently needed because FileSet
 347  		// only calls delete after a positive locate.
 348  		// Remove this assertion if things change.
 349  		if true {
 350  			panic("unreachable according to current FileSet requirements")
 351  		}
 352  		return
 353  
 354  	case x.left == nil:
 355  		if *pos = x.right; *pos != nil {
 356  			(*pos).parent = x.parent
 357  		}
 358  		t.rebalanceUp(x.parent)
 359  
 360  	case x.right == nil:
 361  		*pos = x.left
 362  		x.left.parent = x.parent
 363  		t.rebalanceUp(x.parent)
 364  
 365  	default:
 366  		t.deleteSwap(pos)
 367  	}
 368  
 369  	x.balance = -100
 370  	x.parent = nil
 371  	x.left = nil
 372  	x.right = nil
 373  	x.height = -1
 374  	t.root.check(nil)
 375  }
 376  
 377  // deleteSwap deletes a node that has two children by replacing
 378  // it by its in-order successor, then triggers a rebalance.
 379  func (t *tree) deleteSwap(pos **node) {
 380  	x := *pos
 381  	z := t.deleteMin(&x.right)
 382  
 383  	*pos = z
 384  	unbalanced := z.parent // lowest potentially unbalanced node
 385  	if unbalanced == x {
 386  		unbalanced = z // (x a (z nil b)) -> (z a b)
 387  	}
 388  	z.parent = x.parent
 389  	z.height = x.height
 390  	z.balance = x.balance
 391  	z.setLeft(x.left)
 392  	z.setRight(x.right)
 393  
 394  	t.rebalanceUp(unbalanced)
 395  }
 396  
 397  // deleteMin updates the subtree rooted at *zpos to delete its minimum
 398  // (leftmost) element, which may be *zpos itself. It returns the
 399  // deleted node.
 400  func (t *tree) deleteMin(zpos **node) (z *node) {
 401  	for (*zpos).left != nil {
 402  		zpos = &(*zpos).left
 403  	}
 404  	z = *zpos
 405  	*zpos = z.right
 406  	if *zpos != nil {
 407  		(*zpos).parent = z.parent
 408  	}
 409  	return z
 410  }
 411