visit.go raw

   1  // Copyright 2018 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package packages
   6  
   7  import (
   8  	"cmp"
   9  	"fmt"
  10  	"iter"
  11  	"os"
  12  	"slices"
  13  )
  14  
  15  // Visit visits all the packages in the import graph whose roots are
  16  // pkgs, calling the optional pre function the first time each package
  17  // is encountered (preorder), and the optional post function after a
  18  // package's dependencies have been visited (postorder).
  19  // The boolean result of pre(pkg) determines whether
  20  // the imports of package pkg are visited.
  21  //
  22  // Example:
  23  //
  24  //	pkgs, err := Load(...)
  25  //	if err != nil { ... }
  26  //	Visit(pkgs, nil, func(pkg *Package) {
  27  //		log.Println(pkg)
  28  //	})
  29  //
  30  // In most cases, it is more convenient to use [Postorder]:
  31  //
  32  //	for pkg := range Postorder(pkgs) {
  33  //		log.Println(pkg)
  34  //	}
  35  func Visit(pkgs []*Package, pre func(*Package) bool, post func(*Package)) {
  36  	seen := make(map[*Package]bool)
  37  	var visit func(*Package)
  38  	visit = func(pkg *Package) {
  39  		if !seen[pkg] {
  40  			seen[pkg] = true
  41  
  42  			if pre == nil || pre(pkg) {
  43  				for _, imp := range sorted(pkg.Imports) { // for determinism
  44  					visit(imp)
  45  				}
  46  			}
  47  
  48  			if post != nil {
  49  				post(pkg)
  50  			}
  51  		}
  52  	}
  53  	for _, pkg := range pkgs {
  54  		visit(pkg)
  55  	}
  56  }
  57  
  58  // PrintErrors prints to os.Stderr the accumulated errors of all
  59  // packages in the import graph rooted at pkgs, dependencies first.
  60  // PrintErrors returns the number of errors printed.
  61  func PrintErrors(pkgs []*Package) int {
  62  	var n int
  63  	errModules := make(map[*Module]bool)
  64  	for pkg := range Postorder(pkgs) {
  65  		for _, err := range pkg.Errors {
  66  			fmt.Fprintln(os.Stderr, err)
  67  			n++
  68  		}
  69  
  70  		// Print pkg.Module.Error once if present.
  71  		mod := pkg.Module
  72  		if mod != nil && mod.Error != nil && !errModules[mod] {
  73  			errModules[mod] = true
  74  			fmt.Fprintln(os.Stderr, mod.Error.Err)
  75  			n++
  76  		}
  77  	}
  78  	return n
  79  }
  80  
  81  // Postorder returns an iterator over the packages in
  82  // the import graph whose roots are pkg.
  83  // Packages are enumerated in dependencies-first order.
  84  func Postorder(pkgs []*Package) iter.Seq[*Package] {
  85  	return func(yield func(*Package) bool) {
  86  		seen := make(map[*Package]bool)
  87  		var visit func(*Package) bool
  88  		visit = func(pkg *Package) bool {
  89  			if !seen[pkg] {
  90  				seen[pkg] = true
  91  				for _, imp := range sorted(pkg.Imports) { // for determinism
  92  					if !visit(imp) {
  93  						return false
  94  					}
  95  				}
  96  				if !yield(pkg) {
  97  					return false
  98  				}
  99  			}
 100  			return true
 101  		}
 102  		for _, pkg := range pkgs {
 103  			if !visit(pkg) {
 104  				break
 105  			}
 106  		}
 107  	}
 108  }
 109  
 110  // -- copied from golang.org.x/tools/gopls/internal/util/moremaps --
 111  
 112  // sorted returns an iterator over the entries of m in key order.
 113  func sorted[M ~map[K]V, K cmp.Ordered, V any](m M) iter.Seq2[K, V] {
 114  	// TODO(adonovan): use maps.Sorted if proposal #68598 is accepted.
 115  	return func(yield func(K, V) bool) {
 116  		keys := keySlice(m)
 117  		slices.Sort(keys)
 118  		for _, k := range keys {
 119  			if !yield(k, m[k]) {
 120  				break
 121  			}
 122  		}
 123  	}
 124  }
 125  
 126  // KeySlice returns the keys of the map M, like slices.Collect(maps.Keys(m)).
 127  func keySlice[M ~map[K]V, K comparable, V any](m M) []K {
 128  	r := make([]K, 0, len(m))
 129  	for k := range m {
 130  		r = append(r, k)
 131  	}
 132  	return r
 133  }
 134