1 // Copyright 2013 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 // Package astutil contains common utilities for working with the Go AST.
6 package astutil // import "golang.org/x/tools/go/ast/astutil"
7 8 import (
9 "fmt"
10 "go/ast"
11 "go/token"
12 "reflect"
13 "slices"
14 "strconv"
15 "strings"
16 )
17 18 // AddImport adds the import path to the file f, if absent.
19 func AddImport(fset *token.FileSet, f *ast.File, path string) (added bool) {
20 return AddNamedImport(fset, f, "", path)
21 }
22 23 // AddNamedImport adds the import with the given name and path to the file f, if absent.
24 // If name is not empty, it is used to rename the import.
25 //
26 // For example, calling
27 //
28 // AddNamedImport(fset, f, "pathpkg", "path")
29 //
30 // adds
31 //
32 // import pathpkg "path"
33 func AddNamedImport(fset *token.FileSet, f *ast.File, name, path string) (added bool) {
34 if imports(f, name, path) {
35 return false
36 }
37 38 newImport := &ast.ImportSpec{
39 Path: &ast.BasicLit{
40 Kind: token.STRING,
41 Value: strconv.Quote(path),
42 },
43 }
44 if name != "" {
45 newImport.Name = &ast.Ident{Name: name}
46 }
47 48 // Find an import decl to add to.
49 // The goal is to find an existing import
50 // whose import path has the longest shared
51 // prefix with path.
52 var (
53 bestMatch = -1 // length of longest shared prefix
54 lastImport = -1 // index in f.Decls of the file's final import decl
55 impDecl *ast.GenDecl // import decl containing the best match
56 impIndex = -1 // spec index in impDecl containing the best match
57 58 isThirdPartyPath = isThirdParty(path)
59 )
60 for i, decl := range f.Decls {
61 gen, ok := decl.(*ast.GenDecl)
62 if ok && gen.Tok == token.IMPORT {
63 lastImport = i
64 // Do not add to import "C", to avoid disrupting the
65 // association with its doc comment, breaking cgo.
66 if declImports(gen, "C") {
67 continue
68 }
69 70 // Match an empty import decl if that's all that is available.
71 if len(gen.Specs) == 0 && bestMatch == -1 {
72 impDecl = gen
73 }
74 75 // Compute longest shared prefix with imports in this group and find best
76 // matched import spec.
77 // 1. Always prefer import spec with longest shared prefix.
78 // 2. While match length is 0,
79 // - for stdlib package: prefer first import spec.
80 // - for third party package: prefer first third party import spec.
81 // We cannot use last import spec as best match for third party package
82 // because grouped imports are usually placed last by goimports -local
83 // flag.
84 // See issue #19190.
85 seenAnyThirdParty := false
86 for j, spec := range gen.Specs {
87 impspec := spec.(*ast.ImportSpec)
88 p := importPath(impspec)
89 n := matchLen(p, path)
90 if n > bestMatch || (bestMatch == 0 && !seenAnyThirdParty && isThirdPartyPath) {
91 bestMatch = n
92 impDecl = gen
93 impIndex = j
94 }
95 seenAnyThirdParty = seenAnyThirdParty || isThirdParty(p)
96 }
97 }
98 }
99 100 // If no import decl found, add one after the last import.
101 if impDecl == nil {
102 impDecl = &ast.GenDecl{
103 Tok: token.IMPORT,
104 }
105 if lastImport >= 0 {
106 impDecl.TokPos = f.Decls[lastImport].End()
107 } else {
108 // There are no existing imports.
109 // Our new import, preceded by a blank line, goes after the package declaration
110 // and after the comment, if any, that starts on the same line as the
111 // package declaration.
112 impDecl.TokPos = f.Package
113 114 file := fset.File(f.Package)
115 pkgLine := file.Line(f.Package)
116 for _, c := range f.Comments {
117 if file.Line(c.Pos()) > pkgLine {
118 break
119 }
120 // +2 for a blank line
121 impDecl.TokPos = c.End() + 2
122 }
123 }
124 f.Decls = append(f.Decls, nil)
125 copy(f.Decls[lastImport+2:], f.Decls[lastImport+1:])
126 f.Decls[lastImport+1] = impDecl
127 }
128 129 // Insert new import at insertAt.
130 insertAt := 0
131 if impIndex >= 0 {
132 // insert after the found import
133 insertAt = impIndex + 1
134 }
135 impDecl.Specs = append(impDecl.Specs, nil)
136 copy(impDecl.Specs[insertAt+1:], impDecl.Specs[insertAt:])
137 impDecl.Specs[insertAt] = newImport
138 pos := impDecl.Pos()
139 if insertAt > 0 {
140 // If there is a comment after an existing import, preserve the comment
141 // position by adding the new import after the comment.
142 if spec, ok := impDecl.Specs[insertAt-1].(*ast.ImportSpec); ok && spec.Comment != nil {
143 pos = spec.Comment.End()
144 } else {
145 // Assign same position as the previous import,
146 // so that the sorter sees it as being in the same block.
147 pos = impDecl.Specs[insertAt-1].Pos()
148 }
149 }
150 if newImport.Name != nil {
151 newImport.Name.NamePos = pos
152 }
153 updateBasicLitPos(newImport.Path, pos)
154 newImport.EndPos = pos
155 156 // Clean up parens. impDecl contains at least one spec.
157 if len(impDecl.Specs) == 1 {
158 // Remove unneeded parens.
159 impDecl.Lparen = token.NoPos
160 } else if !impDecl.Lparen.IsValid() {
161 // impDecl needs parens added.
162 impDecl.Lparen = impDecl.Specs[0].Pos()
163 }
164 165 f.Imports = append(f.Imports, newImport)
166 167 if len(f.Decls) <= 1 {
168 return true
169 }
170 171 // Merge all the import declarations into the first one.
172 var first *ast.GenDecl
173 for i := 0; i < len(f.Decls); i++ {
174 decl := f.Decls[i]
175 gen, ok := decl.(*ast.GenDecl)
176 if !ok || gen.Tok != token.IMPORT || declImports(gen, "C") {
177 continue
178 }
179 if first == nil {
180 first = gen
181 continue // Don't touch the first one.
182 }
183 // We now know there is more than one package in this import
184 // declaration. Ensure that it ends up parenthesized.
185 first.Lparen = first.Pos()
186 // Move the imports of the other import declaration to the first one.
187 for _, spec := range gen.Specs {
188 updateBasicLitPos(spec.(*ast.ImportSpec).Path, first.Pos())
189 first.Specs = append(first.Specs, spec)
190 }
191 f.Decls = slices.Delete(f.Decls, i, i+1)
192 i--
193 }
194 195 return true
196 }
197 198 func isThirdParty(importPath string) bool {
199 // Third party package import path usually contains "." (".com", ".org", ...)
200 // This logic is taken from golang.org/x/tools/imports package.
201 return strings.Contains(importPath, ".")
202 }
203 204 // DeleteImport deletes the import path from the file f, if present.
205 // If there are duplicate import declarations, all matching ones are deleted.
206 func DeleteImport(fset *token.FileSet, f *ast.File, path string) (deleted bool) {
207 return DeleteNamedImport(fset, f, "", path)
208 }
209 210 // DeleteNamedImport deletes the import with the given name and path from the file f, if present.
211 // If there are duplicate import declarations, all matching ones are deleted.
212 func DeleteNamedImport(fset *token.FileSet, f *ast.File, name, path string) (deleted bool) {
213 var (
214 delspecs = make(map[*ast.ImportSpec]bool)
215 delcomments = make(map[*ast.CommentGroup]bool)
216 )
217 218 // Find the import nodes that import path, if any.
219 for i := 0; i < len(f.Decls); i++ {
220 gen, ok := f.Decls[i].(*ast.GenDecl)
221 if !ok || gen.Tok != token.IMPORT {
222 continue
223 }
224 for j := 0; j < len(gen.Specs); j++ {
225 impspec := gen.Specs[j].(*ast.ImportSpec)
226 if importName(impspec) != name || importPath(impspec) != path {
227 continue
228 }
229 230 // We found an import spec that imports path.
231 // Delete it.
232 delspecs[impspec] = true
233 deleted = true
234 gen.Specs = slices.Delete(gen.Specs, j, j+1)
235 236 // If this was the last import spec in this decl,
237 // delete the decl, too.
238 if len(gen.Specs) == 0 {
239 f.Decls = slices.Delete(f.Decls, i, i+1)
240 i--
241 break
242 } else if len(gen.Specs) == 1 {
243 if impspec.Doc != nil {
244 delcomments[impspec.Doc] = true
245 }
246 if impspec.Comment != nil {
247 delcomments[impspec.Comment] = true
248 }
249 for _, cg := range f.Comments {
250 // Found comment on the same line as the import spec.
251 if cg.End() < impspec.Pos() && fset.Position(cg.End()).Line == fset.Position(impspec.Pos()).Line {
252 delcomments[cg] = true
253 break
254 }
255 }
256 257 spec := gen.Specs[0].(*ast.ImportSpec)
258 259 // Move the documentation right after the import decl.
260 if spec.Doc != nil {
261 for fset.Position(gen.TokPos).Line+1 < fset.Position(spec.Doc.Pos()).Line {
262 fset.File(gen.TokPos).MergeLine(fset.Position(gen.TokPos).Line)
263 }
264 }
265 for _, cg := range f.Comments {
266 if cg.End() < spec.Pos() && fset.Position(cg.End()).Line == fset.Position(spec.Pos()).Line {
267 for fset.Position(gen.TokPos).Line+1 < fset.Position(spec.Pos()).Line {
268 fset.File(gen.TokPos).MergeLine(fset.Position(gen.TokPos).Line)
269 }
270 break
271 }
272 }
273 }
274 if j > 0 {
275 lastImpspec := gen.Specs[j-1].(*ast.ImportSpec)
276 lastLine := fset.PositionFor(lastImpspec.Path.ValuePos, false).Line
277 line := fset.PositionFor(impspec.Path.ValuePos, false).Line
278 279 // We deleted an entry but now there may be
280 // a blank line-sized hole where the import was.
281 if line-lastLine > 1 || !gen.Rparen.IsValid() {
282 // There was a blank line immediately preceding the deleted import,
283 // so there's no need to close the hole. The right parenthesis is
284 // invalid after AddImport to an import statement without parenthesis.
285 // Do nothing.
286 } else if line != fset.File(gen.Rparen).LineCount() {
287 // There was no blank line. Close the hole.
288 fset.File(gen.Rparen).MergeLine(line)
289 }
290 }
291 j--
292 }
293 }
294 295 // Delete imports from f.Imports.
296 before := len(f.Imports)
297 f.Imports = slices.DeleteFunc(f.Imports, func(imp *ast.ImportSpec) bool {
298 _, ok := delspecs[imp]
299 return ok
300 })
301 if len(f.Imports)+len(delspecs) != before {
302 // This can happen when the AST is invalid (i.e. imports differ between f.Decls and f.Imports).
303 panic(fmt.Sprintf("deleted specs from Decls but not Imports: %v", delspecs))
304 }
305 306 // Delete comments from f.Comments.
307 f.Comments = slices.DeleteFunc(f.Comments, func(cg *ast.CommentGroup) bool {
308 _, ok := delcomments[cg]
309 return ok
310 })
311 312 return
313 }
314 315 // RewriteImport rewrites any import of path oldPath to path newPath.
316 func RewriteImport(fset *token.FileSet, f *ast.File, oldPath, newPath string) (rewrote bool) {
317 for _, imp := range f.Imports {
318 if importPath(imp) == oldPath {
319 rewrote = true
320 // record old End, because the default is to compute
321 // it using the length of imp.Path.Value.
322 imp.EndPos = imp.End()
323 imp.Path.Value = strconv.Quote(newPath)
324 }
325 }
326 return
327 }
328 329 // UsesImport reports whether a given import is used.
330 // The provided File must have been parsed with syntactic object resolution
331 // (not using go/parser.SkipObjectResolution).
332 func UsesImport(f *ast.File, path string) (used bool) {
333 if f.Scope == nil {
334 panic("file f was not parsed with syntactic object resolution")
335 }
336 spec := importSpec(f, path)
337 if spec == nil {
338 return
339 }
340 341 name := spec.Name.String()
342 switch name {
343 case "<nil>":
344 // If the package name is not explicitly specified,
345 // make an educated guess. This is not guaranteed to be correct.
346 lastSlash := strings.LastIndex(path, "/")
347 if lastSlash == -1 {
348 name = path
349 } else {
350 name = path[lastSlash+1:]
351 }
352 case "_", ".":
353 // Not sure if this import is used - err on the side of caution.
354 return true
355 }
356 357 ast.Walk(visitFn(func(n ast.Node) {
358 sel, ok := n.(*ast.SelectorExpr)
359 if ok && isTopName(sel.X, name) {
360 used = true
361 }
362 }), f)
363 364 return
365 }
366 367 type visitFn func(node ast.Node)
368 369 func (fn visitFn) Visit(node ast.Node) ast.Visitor {
370 fn(node)
371 return fn
372 }
373 374 // imports reports whether f has an import with the specified name and path.
375 func imports(f *ast.File, name, path string) bool {
376 for _, s := range f.Imports {
377 if importName(s) == name && importPath(s) == path {
378 return true
379 }
380 }
381 return false
382 }
383 384 // importSpec returns the import spec if f imports path,
385 // or nil otherwise.
386 func importSpec(f *ast.File, path string) *ast.ImportSpec {
387 for _, s := range f.Imports {
388 if importPath(s) == path {
389 return s
390 }
391 }
392 return nil
393 }
394 395 // importName returns the name of s,
396 // or "" if the import is not named.
397 func importName(s *ast.ImportSpec) string {
398 if s.Name == nil {
399 return ""
400 }
401 return s.Name.Name
402 }
403 404 // importPath returns the unquoted import path of s,
405 // or "" if the path is not properly quoted.
406 func importPath(s *ast.ImportSpec) string {
407 t, err := strconv.Unquote(s.Path.Value)
408 if err != nil {
409 return ""
410 }
411 return t
412 }
413 414 // declImports reports whether gen contains an import of path.
415 func declImports(gen *ast.GenDecl, path string) bool {
416 if gen.Tok != token.IMPORT {
417 return false
418 }
419 for _, spec := range gen.Specs {
420 impspec := spec.(*ast.ImportSpec)
421 if importPath(impspec) == path {
422 return true
423 }
424 }
425 return false
426 }
427 428 // matchLen returns the length of the longest path segment prefix shared by x and y.
429 func matchLen(x, y string) int {
430 n := 0
431 for i := 0; i < len(x) && i < len(y) && x[i] == y[i]; i++ {
432 if x[i] == '/' {
433 n++
434 }
435 }
436 return n
437 }
438 439 // isTopName returns true if n is a top-level unresolved identifier with the given name.
440 func isTopName(n ast.Expr, name string) bool {
441 id, ok := n.(*ast.Ident)
442 return ok && id.Name == name && id.Obj == nil
443 }
444 445 // Imports returns the file imports grouped by paragraph.
446 func Imports(fset *token.FileSet, f *ast.File) [][]*ast.ImportSpec {
447 var groups [][]*ast.ImportSpec
448 449 for _, decl := range f.Decls {
450 genDecl, ok := decl.(*ast.GenDecl)
451 if !ok || genDecl.Tok != token.IMPORT {
452 break
453 }
454 455 group := []*ast.ImportSpec{}
456 457 var lastLine int
458 for _, spec := range genDecl.Specs {
459 importSpec := spec.(*ast.ImportSpec)
460 pos := importSpec.Path.ValuePos
461 line := fset.Position(pos).Line
462 if lastLine > 0 && pos > 0 && line-lastLine > 1 {
463 groups = append(groups, group)
464 group = []*ast.ImportSpec{}
465 }
466 group = append(group, importSpec)
467 lastLine = line
468 }
469 groups = append(groups, group)
470 }
471 472 return groups
473 }
474 475 // updateBasicLitPos updates lit.Pos,
476 // ensuring that lit.End (if set) is displaced by the same amount.
477 // (See https://go.dev/issue/76395.)
478 func updateBasicLitPos(lit *ast.BasicLit, pos token.Pos) {
479 len := lit.End() - lit.Pos()
480 lit.ValuePos = pos
481 // TODO(adonovan): after go1.26, simplify to:
482 // lit.ValueEnd = pos + len
483 v := reflect.ValueOf(lit).Elem().FieldByName("ValueEnd")
484 if v.IsValid() && v.Int() != 0 {
485 v.SetInt(int64(pos + len))
486 }
487 }
488