cursor.go raw

   1  // Copyright 2025 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package inspector
   6  
   7  import (
   8  	"fmt"
   9  	"go/ast"
  10  	"go/token"
  11  	"iter"
  12  	"reflect"
  13  
  14  	"golang.org/x/tools/go/ast/edge"
  15  )
  16  
  17  // A Cursor represents an [ast.Node]. It is immutable.
  18  //
  19  // Two Cursors compare equal if they represent the same node.
  20  //
  21  // The zero value of Cursor is not valid.
  22  //
  23  // Call [Inspector.Root] to obtain a cursor for the virtual root node
  24  // of the traversal. This is the sole valid cursor for which [Cursor.Node]
  25  // returns nil.
  26  //
  27  // Use the following methods to navigate efficiently around the tree:
  28  //   - for ancestors, use [Cursor.Parent] and [Cursor.Enclosing];
  29  //   - for children, use [Cursor.Child], [Cursor.Children],
  30  //     [Cursor.FirstChild], and [Cursor.LastChild];
  31  //   - for siblings, use [Cursor.PrevSibling] and [Cursor.NextSibling];
  32  //   - for descendants, use [Cursor.FindByPos], [Cursor.FindNode],
  33  //     [Cursor.Inspect], and [Cursor.Preorder].
  34  //
  35  // Use the [Cursor.ChildAt] and [Cursor.ParentEdge] methods for
  36  // information about the edges in a tree: which field (and slice
  37  // element) of the parent node holds the child.
  38  type Cursor struct {
  39  	in    *Inspector
  40  	index int32 // index of push node; -1 for virtual root node
  41  }
  42  
  43  // Root returns a valid cursor for the virtual root node,
  44  // whose children are the files provided to [New].
  45  //
  46  // Its [Cursor.Node] method return nil.
  47  func (in *Inspector) Root() Cursor {
  48  	return Cursor{in, -1}
  49  }
  50  
  51  // At returns the cursor at the specified index in the traversal,
  52  // which must have been obtained from [Cursor.Index] on a Cursor
  53  // belonging to the same Inspector (see [Cursor.Inspector]).
  54  func (in *Inspector) At(index int32) Cursor {
  55  	if index < 0 {
  56  		panic("negative index")
  57  	}
  58  	if int(index) >= len(in.events) {
  59  		panic("index out of range for this inspector")
  60  	}
  61  	if in.events[index].index < index {
  62  		panic("invalid index") // (a push, not a pop)
  63  	}
  64  	return Cursor{in, index}
  65  }
  66  
  67  // Valid reports whether the cursor is valid.
  68  // The zero value of cursor is invalid.
  69  // Unless otherwise documented, it is not safe to call
  70  // any other method on an invalid cursor.
  71  func (c Cursor) Valid() bool {
  72  	return c.in != nil
  73  }
  74  
  75  // Inspector returns the cursor's Inspector.
  76  // It returns nil if the Cursor is not valid.
  77  func (c Cursor) Inspector() *Inspector { return c.in }
  78  
  79  // Index returns the index of this cursor position within the package.
  80  //
  81  // Clients should not assume anything about the numeric Index value
  82  // except that it increases monotonically throughout the traversal.
  83  // It is provided for use with [Inspector.At].
  84  //
  85  // Index must not be called on the Root node.
  86  func (c Cursor) Index() int32 {
  87  	if c.index < 0 {
  88  		panic("Index called on Root node")
  89  	}
  90  	return c.index
  91  }
  92  
  93  // Node returns the node at the current cursor position,
  94  // or nil for the cursor returned by [Inspector.Root].
  95  func (c Cursor) Node() ast.Node {
  96  	if c.index < 0 {
  97  		return nil
  98  	}
  99  	return c.in.events[c.index].node
 100  }
 101  
 102  // String returns information about the cursor's node, if any.
 103  func (c Cursor) String() string {
 104  	if !c.Valid() {
 105  		return "(invalid)"
 106  	}
 107  	if c.index < 0 {
 108  		return "(root)"
 109  	}
 110  	return reflect.TypeOf(c.Node()).String()
 111  }
 112  
 113  // indices return the [start, end) half-open interval of event indices.
 114  func (c Cursor) indices() (int32, int32) {
 115  	if c.index < 0 {
 116  		return 0, int32(len(c.in.events)) // root: all events
 117  	} else {
 118  		return c.index, c.in.events[c.index].index + 1 // just one subtree
 119  	}
 120  }
 121  
 122  // Preorder returns an iterator over the nodes of the subtree
 123  // represented by c in depth-first order. Each node in the sequence is
 124  // represented by a Cursor that allows access to the Node, but may
 125  // also be used to start a new traversal, or to obtain the stack of
 126  // nodes enclosing the cursor.
 127  //
 128  // The traversal sequence is determined by [ast.Inspect]. The types
 129  // argument, if non-empty, enables type-based filtering of events. The
 130  // function f if is called only for nodes whose type matches an
 131  // element of the types slice.
 132  //
 133  // If you need control over descent into subtrees,
 134  // or need both pre- and post-order notifications, use [Cursor.Inspect]
 135  func (c Cursor) Preorder(types ...ast.Node) iter.Seq[Cursor] {
 136  	mask := maskOf(types)
 137  
 138  	return func(yield func(Cursor) bool) {
 139  		events := c.in.events
 140  
 141  		for i, limit := c.indices(); i < limit; {
 142  			ev := events[i]
 143  			if ev.index > i { // push?
 144  				if ev.typ&mask != 0 && !yield(Cursor{c.in, i}) {
 145  					break
 146  				}
 147  				pop := ev.index
 148  				if events[pop].typ&mask == 0 {
 149  					// Subtree does not contain types: skip.
 150  					i = pop + 1
 151  					continue
 152  				}
 153  			}
 154  			i++
 155  		}
 156  	}
 157  }
 158  
 159  // Inspect visits the nodes of the subtree represented by c in
 160  // depth-first order. It calls f(n) for each node n before it
 161  // visits n's children. If f returns true, Inspect invokes f
 162  // recursively for each of the non-nil children of the node.
 163  //
 164  // Each node is represented by a Cursor that allows access to the
 165  // Node, but may also be used to start a new traversal, or to obtain
 166  // the stack of nodes enclosing the cursor.
 167  //
 168  // The complete traversal sequence is determined by [ast.Inspect].
 169  // The types argument, if non-empty, enables type-based filtering of
 170  // events. The function f if is called only for nodes whose type
 171  // matches an element of the types slice.
 172  func (c Cursor) Inspect(types []ast.Node, f func(c Cursor) (descend bool)) {
 173  	mask := maskOf(types)
 174  	events := c.in.events
 175  	for i, limit := c.indices(); i < limit; {
 176  		ev := events[i]
 177  		if ev.index > i {
 178  			// push
 179  			pop := ev.index
 180  			if ev.typ&mask != 0 && !f(Cursor{c.in, i}) ||
 181  				events[pop].typ&mask == 0 {
 182  				// The user opted not to descend, or the
 183  				// subtree does not contain types:
 184  				// skip past the pop.
 185  				i = pop + 1
 186  				continue
 187  			}
 188  		}
 189  		i++
 190  	}
 191  }
 192  
 193  // Enclosing returns an iterator over the nodes enclosing the current
 194  // current node, starting with the Cursor itself.
 195  //
 196  // Enclosing must not be called on the Root node (whose [Cursor.Node] returns nil).
 197  //
 198  // The types argument, if non-empty, enables type-based filtering of
 199  // events: the sequence includes only enclosing nodes whose type
 200  // matches an element of the types slice.
 201  func (c Cursor) Enclosing(types ...ast.Node) iter.Seq[Cursor] {
 202  	if c.index < 0 {
 203  		panic("Cursor.Enclosing called on Root node")
 204  	}
 205  
 206  	mask := maskOf(types)
 207  
 208  	return func(yield func(Cursor) bool) {
 209  		events := c.in.events
 210  		for i := c.index; i >= 0; i = events[i].parent {
 211  			if events[i].typ&mask != 0 && !yield(Cursor{c.in, i}) {
 212  				break
 213  			}
 214  		}
 215  	}
 216  }
 217  
 218  // Parent returns the parent of the current node.
 219  //
 220  // Parent must not be called on the Root node (whose [Cursor.Node] returns nil).
 221  func (c Cursor) Parent() Cursor {
 222  	if c.index < 0 {
 223  		panic("Cursor.Parent called on Root node")
 224  	}
 225  
 226  	return Cursor{c.in, c.in.events[c.index].parent}
 227  }
 228  
 229  // ParentEdge returns the identity of the field in the parent node
 230  // that holds this cursor's node, and if it is a list, the index within it.
 231  //
 232  // For example, f(x, y) is a CallExpr whose three children are Idents.
 233  // f has edge kind [edge.CallExpr_Fun] and index -1.
 234  // x and y have kind [edge.CallExpr_Args] and indices 0 and 1, respectively.
 235  //
 236  // If called on a child of the Root node, it returns ([edge.Invalid], -1).
 237  //
 238  // ParentEdge must not be called on the Root node (whose [Cursor.Node] returns nil).
 239  func (c Cursor) ParentEdge() (edge.Kind, int) {
 240  	if c.index < 0 {
 241  		panic("Cursor.ParentEdge called on Root node")
 242  	}
 243  	events := c.in.events
 244  	pop := events[c.index].index
 245  	return unpackEdgeKindAndIndex(events[pop].parent)
 246  }
 247  
 248  // ParentEdgeKind returns the kind component of the result of [Cursor.ParentEdge].
 249  func (c Cursor) ParentEdgeKind() edge.Kind {
 250  	ek, _ := c.ParentEdge()
 251  	return ek
 252  }
 253  
 254  // ParentEdgeIndex returns the index component of the result of [Cursor.ParentEdge].
 255  func (c Cursor) ParentEdgeIndex() int {
 256  	_, index := c.ParentEdge()
 257  	return index
 258  }
 259  
 260  // ChildAt returns the cursor for the child of the
 261  // current node identified by its edge and index.
 262  // The index must be -1 if the edge.Kind is not a slice.
 263  // The indicated child node must exist.
 264  //
 265  // ChildAt must not be called on the Root node (whose [Cursor.Node] returns nil).
 266  //
 267  // Invariant: c.Parent().ChildAt(c.ParentEdge()) == c.
 268  func (c Cursor) ChildAt(k edge.Kind, idx int) Cursor {
 269  	target := packEdgeKindAndIndex(k, idx)
 270  
 271  	// Unfortunately there's no shortcut to looping.
 272  	events := c.in.events
 273  	i := c.index + 1
 274  	for {
 275  		pop := events[i].index
 276  		if pop < i {
 277  			break
 278  		}
 279  		if events[pop].parent == target {
 280  			return Cursor{c.in, i}
 281  		}
 282  		i = pop + 1
 283  	}
 284  	panic(fmt.Sprintf("ChildAt(%v, %d): no such child of %v", k, idx, c))
 285  }
 286  
 287  // Child returns the cursor for n, which must be a direct child of c's Node.
 288  //
 289  // Child must not be called on the Root node (whose [Cursor.Node] returns nil).
 290  func (c Cursor) Child(n ast.Node) Cursor {
 291  	if c.index < 0 {
 292  		panic("Cursor.Child called on Root node")
 293  	}
 294  
 295  	if false {
 296  		// reference implementation
 297  		for child := range c.Children() {
 298  			if child.Node() == n {
 299  				return child
 300  			}
 301  		}
 302  
 303  	} else {
 304  		// optimized implementation
 305  		events := c.in.events
 306  		for i := c.index + 1; events[i].index > i; i = events[i].index + 1 {
 307  			if events[i].node == n {
 308  				return Cursor{c.in, i}
 309  			}
 310  		}
 311  	}
 312  	panic(fmt.Sprintf("Child(%T): not a child of %v", n, c))
 313  }
 314  
 315  // NextSibling returns the cursor for the next sibling node in the same list
 316  // (for example, of files, decls, specs, statements, fields, or expressions) as
 317  // the current node. It returns (zero, false) if the node is the last node in
 318  // the list, or is not part of a list.
 319  //
 320  // NextSibling must not be called on the Root node.
 321  //
 322  // See note at [Cursor.Children].
 323  func (c Cursor) NextSibling() (Cursor, bool) {
 324  	if c.index < 0 {
 325  		panic("Cursor.NextSibling called on Root node")
 326  	}
 327  
 328  	events := c.in.events
 329  	i := events[c.index].index + 1 // after corresponding pop
 330  	if i < int32(len(events)) {
 331  		if events[i].index > i { // push?
 332  			return Cursor{c.in, i}, true
 333  		}
 334  	}
 335  	return Cursor{}, false
 336  }
 337  
 338  // PrevSibling returns the cursor for the previous sibling node in the
 339  // same list (for example, of files, decls, specs, statements, fields,
 340  // or expressions) as the current node. It returns zero if the node is
 341  // the first node in the list, or is not part of a list.
 342  //
 343  // It must not be called on the Root node.
 344  //
 345  // See note at [Cursor.Children].
 346  func (c Cursor) PrevSibling() (Cursor, bool) {
 347  	if c.index < 0 {
 348  		panic("Cursor.PrevSibling called on Root node")
 349  	}
 350  
 351  	events := c.in.events
 352  	i := c.index - 1
 353  	if i >= 0 {
 354  		if j := events[i].index; j < i { // pop?
 355  			return Cursor{c.in, j}, true
 356  		}
 357  	}
 358  	return Cursor{}, false
 359  }
 360  
 361  // FirstChild returns the first direct child of the current node,
 362  // or zero if it has no children.
 363  func (c Cursor) FirstChild() (Cursor, bool) {
 364  	events := c.in.events
 365  	i := c.index + 1                                   // i=0 if c is root
 366  	if i < int32(len(events)) && events[i].index > i { // push?
 367  		return Cursor{c.in, i}, true
 368  	}
 369  	return Cursor{}, false
 370  }
 371  
 372  // LastChild returns the last direct child of the current node,
 373  // or zero if it has no children.
 374  func (c Cursor) LastChild() (Cursor, bool) {
 375  	events := c.in.events
 376  	if c.index < 0 { // root?
 377  		if len(events) > 0 {
 378  			// return push of final event (a pop)
 379  			return Cursor{c.in, events[len(events)-1].index}, true
 380  		}
 381  	} else {
 382  		j := events[c.index].index - 1 // before corresponding pop
 383  		// Inv: j == c.index if c has no children
 384  		//  or  j is last child's pop.
 385  		if j > c.index { // c has children
 386  			return Cursor{c.in, events[j].index}, true
 387  		}
 388  	}
 389  	return Cursor{}, false
 390  }
 391  
 392  // Children returns an iterator over the direct children of the
 393  // current node, if any.
 394  //
 395  // When using Children, NextChild, and PrevChild, bear in mind that a
 396  // Node's children may come from different fields, some of which may
 397  // be lists of nodes without a distinguished intervening container
 398  // such as [ast.BlockStmt].
 399  //
 400  // For example, [ast.CaseClause] has a field List of expressions and a
 401  // field Body of statements, so the children of a CaseClause are a mix
 402  // of expressions and statements. Other nodes that have "uncontained"
 403  // list fields include:
 404  //
 405  //   - [ast.ValueSpec] (Names, Values)
 406  //   - [ast.CompositeLit] (Type, Elts)
 407  //   - [ast.IndexListExpr] (X, Indices)
 408  //   - [ast.CallExpr] (Fun, Args)
 409  //   - [ast.AssignStmt] (Lhs, Rhs)
 410  //
 411  // So, do not assume that the previous sibling of an ast.Stmt is also
 412  // an ast.Stmt, or if it is, that they are executed sequentially,
 413  // unless you have established that, say, its parent is a BlockStmt
 414  // or its [Cursor.ParentEdge] is [edge.BlockStmt_List].
 415  // For example, given "for S1; ; S2 {}", the predecessor of S2 is S1,
 416  // even though they are not executed in sequence.
 417  func (c Cursor) Children() iter.Seq[Cursor] {
 418  	return func(yield func(Cursor) bool) {
 419  		c, ok := c.FirstChild()
 420  		for ok && yield(c) {
 421  			c, ok = c.NextSibling()
 422  		}
 423  	}
 424  }
 425  
 426  // Contains reports whether c contains or is equal to c2.
 427  //
 428  // Both Cursors must belong to the same [Inspector];
 429  // neither may be its Root node.
 430  func (c Cursor) Contains(c2 Cursor) bool {
 431  	if c.in != c2.in {
 432  		panic("different inspectors")
 433  	}
 434  	events := c.in.events
 435  	return c.index <= c2.index && events[c2.index].index <= events[c.index].index
 436  }
 437  
 438  // FindNode returns the cursor for node n if it belongs to the subtree
 439  // rooted at c. It returns zero if n is not found.
 440  func (c Cursor) FindNode(n ast.Node) (Cursor, bool) {
 441  
 442  	// FindNode is equivalent to this code,
 443  	// but more convenient and 15-20% faster:
 444  	if false {
 445  		for candidate := range c.Preorder(n) {
 446  			if candidate.Node() == n {
 447  				return candidate, true
 448  			}
 449  		}
 450  		return Cursor{}, false
 451  	}
 452  
 453  	// TODO(adonovan): opt: should we assume Node.Pos is accurate
 454  	// and combine type-based filtering with position filtering
 455  	// like FindByPos?
 456  
 457  	mask := maskOf([]ast.Node{n})
 458  	events := c.in.events
 459  
 460  	for i, limit := c.indices(); i < limit; i++ {
 461  		ev := events[i]
 462  		if ev.index > i { // push?
 463  			if ev.typ&mask != 0 && ev.node == n {
 464  				return Cursor{c.in, i}, true
 465  			}
 466  			pop := ev.index
 467  			if events[pop].typ&mask == 0 {
 468  				// Subtree does not contain type of n: skip.
 469  				i = pop
 470  			}
 471  		}
 472  	}
 473  	return Cursor{}, false
 474  }
 475  
 476  // FindByPos returns the cursor for the innermost node n in the tree
 477  // rooted at c such that n.Pos() <= start && end <= n.End().
 478  // (For an *ast.File, it uses the bounds n.FileStart-n.FileEnd.)
 479  //
 480  // An empty range (start == end) between two adjacent nodes is
 481  // considered to belong to the first node.
 482  //
 483  // It returns zero if none is found.
 484  // Precondition: start <= end.
 485  //
 486  // See also [astutil.PathEnclosingInterval], which
 487  // tolerates adjoining whitespace.
 488  func (c Cursor) FindByPos(start, end token.Pos) (Cursor, bool) {
 489  	if end < start {
 490  		panic("end < start")
 491  	}
 492  	events := c.in.events
 493  
 494  	// This algorithm could be implemented using c.Inspect,
 495  	// but it is about 2.5x slower.
 496  
 497  	// best is the push-index of the latest (=innermost) node containing range.
 498  	// (Beware: latest is not always innermost because FuncDecl.{Name,Type} overlap.)
 499  	best := int32(-1)
 500  	for i, limit := c.indices(); i < limit; i++ {
 501  		ev := events[i]
 502  		if ev.index > i { // push?
 503  			n := ev.node
 504  			var nodeEnd token.Pos
 505  			if file, ok := n.(*ast.File); ok {
 506  				nodeEnd = file.FileEnd
 507  				// Note: files may be out of Pos order.
 508  				if file.FileStart > start {
 509  					i = ev.index // disjoint, after; skip to next file
 510  					continue
 511  				}
 512  			} else {
 513  				// Edge case: FuncDecl.Name and .Type overlap:
 514  				// Don't update best from Name to FuncDecl.Type.
 515  				//
 516  				// The condition can be read as:
 517  				// - n is FuncType
 518  				// - n.parent is FuncDecl
 519  				// - best is strictly beneath the FuncDecl
 520  				if ev.typ == 1<<nFuncType &&
 521  					events[ev.parent].typ == 1<<nFuncDecl &&
 522  					best > ev.parent {
 523  					continue
 524  				}
 525  
 526  				nodeEnd = n.End()
 527  				if n.Pos() > start {
 528  					break // disjoint, after; stop
 529  				}
 530  			}
 531  
 532  			// Inv: node.{Pos,FileStart} <= start
 533  			if end <= nodeEnd {
 534  				// node fully contains target range
 535  				best = i
 536  
 537  				// Don't search beyond end of the first match.
 538  				// This is important only for an empty range (start=end)
 539  				// between two adjoining nodes, which would otherwise
 540  				// match both nodes; we want to match only the first.
 541  				limit = ev.index
 542  			} else if nodeEnd < start {
 543  				i = ev.index // disjoint, before; skip forward
 544  			}
 545  		}
 546  	}
 547  	if best >= 0 {
 548  		return Cursor{c.in, best}, true
 549  	}
 550  	return Cursor{}, false
 551  }
 552