iter.go raw

   1  // Copyright 2024 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  //go:build go1.23
   6  
   7  package inspector
   8  
   9  import (
  10  	"go/ast"
  11  	"iter"
  12  )
  13  
  14  // PreorderSeq returns an iterator that visits all the
  15  // nodes of the files supplied to New in depth-first order.
  16  // It visits each node n before n's children.
  17  // The complete traversal sequence is determined by ast.Inspect.
  18  //
  19  // The types argument, if non-empty, enables type-based
  20  // filtering of events: only nodes whose type matches an
  21  // element of the types slice are included in the sequence.
  22  func (in *Inspector) PreorderSeq(types ...ast.Node) iter.Seq[ast.Node] {
  23  
  24  	// This implementation is identical to Preorder,
  25  	// except that it supports breaking out of the loop.
  26  
  27  	return func(yield func(ast.Node) bool) {
  28  		mask := maskOf(types)
  29  		for i := int32(0); i < int32(len(in.events)); {
  30  			ev := in.events[i]
  31  			if ev.index > i {
  32  				// push
  33  				if ev.typ&mask != 0 {
  34  					if !yield(ev.node) {
  35  						break
  36  					}
  37  				}
  38  				pop := ev.index
  39  				if in.events[pop].typ&mask == 0 {
  40  					// Subtrees do not contain types: skip them and pop.
  41  					i = pop + 1
  42  					continue
  43  				}
  44  			}
  45  			i++
  46  		}
  47  	}
  48  }
  49  
  50  // All[N] returns an iterator over all the nodes of type N.
  51  // N must be a pointer-to-struct type that implements ast.Node.
  52  //
  53  // Example:
  54  //
  55  //	for call := range All[*ast.CallExpr](in) { ... }
  56  func All[N interface {
  57  	*S
  58  	ast.Node
  59  }, S any](in *Inspector) iter.Seq[N] {
  60  
  61  	// To avoid additional dynamic call overheads,
  62  	// we duplicate rather than call the logic of PreorderSeq.
  63  
  64  	mask := typeOf((N)(nil))
  65  	return func(yield func(N) bool) {
  66  		for i := int32(0); i < int32(len(in.events)); {
  67  			ev := in.events[i]
  68  			if ev.index > i {
  69  				// push
  70  				if ev.typ&mask != 0 {
  71  					if !yield(ev.node.(N)) {
  72  						break
  73  					}
  74  				}
  75  				pop := ev.index
  76  				if in.events[pop].typ&mask == 0 {
  77  					// Subtrees do not contain types: skip them and pop.
  78  					i = pop + 1
  79  					continue
  80  				}
  81  			}
  82  			i++
  83  		}
  84  	}
  85  }
  86