1 // Copyright 2017 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 package astutil
6 7 import (
8 "fmt"
9 "go/ast"
10 "reflect"
11 "sort"
12 )
13 14 // An ApplyFunc is invoked by Apply for each node n, even if n is nil,
15 // before and/or after the node's children, using a Cursor describing
16 // the current node and providing operations on it.
17 //
18 // The return value of ApplyFunc controls the syntax tree traversal.
19 // See Apply for details.
20 type ApplyFunc func(*Cursor) bool
21 22 // Apply traverses a syntax tree recursively, starting with root,
23 // and calling pre and post for each node as described below.
24 // Apply returns the syntax tree, possibly modified.
25 //
26 // If pre is not nil, it is called for each node before the node's
27 // children are traversed (pre-order). If pre returns false, no
28 // children are traversed, and post is not called for that node.
29 //
30 // If post is not nil, and a prior call of pre didn't return false,
31 // post is called for each node after its children are traversed
32 // (post-order). If post returns false, traversal is terminated and
33 // Apply returns immediately.
34 //
35 // Only fields that refer to AST nodes are considered children;
36 // i.e., token.Pos, Scopes, Objects, and fields of basic types
37 // (strings, etc.) are ignored.
38 //
39 // Children are traversed in the order in which they appear in the
40 // respective node's struct definition. A package's files are
41 // traversed in the filenames' alphabetical order.
42 func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
43 parent := &struct{ ast.Node }{root}
44 defer func() {
45 if r := recover(); r != nil && r != abort {
46 panic(r)
47 }
48 result = parent.Node
49 }()
50 a := &application{pre: pre, post: post}
51 a.apply(parent, "Node", nil, root)
52 return
53 }
54 55 var abort = new(int) // singleton, to signal termination of Apply
56 57 // A Cursor describes a node encountered during Apply.
58 // Information about the node and its parent is available
59 // from the Node, Parent, Name, and Index methods.
60 //
61 // If p is a variable of type and value of the current parent node
62 // c.Parent(), and f is the field identifier with name c.Name(),
63 // the following invariants hold:
64 //
65 // p.f == c.Node() if c.Index() < 0
66 // p.f[c.Index()] == c.Node() if c.Index() >= 0
67 //
68 // The methods Replace, Delete, InsertBefore, and InsertAfter
69 // can be used to change the AST without disrupting Apply.
70 //
71 // This type is not to be confused with [inspector.Cursor] from
72 // package [golang.org/x/tools/go/ast/inspector], which provides
73 // stateless navigation of immutable syntax trees.
74 type Cursor struct {
75 parent ast.Node
76 name string
77 iter *iterator // valid if non-nil
78 node ast.Node
79 }
80 81 // Node returns the current Node.
82 func (c *Cursor) Node() ast.Node { return c.node }
83 84 // Parent returns the parent of the current Node.
85 func (c *Cursor) Parent() ast.Node { return c.parent }
86 87 // Name returns the name of the parent Node field that contains the current Node.
88 // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
89 // the filename for the current Node.
90 func (c *Cursor) Name() string { return c.name }
91 92 // Index reports the index >= 0 of the current Node in the slice of Nodes that
93 // contains it, or a value < 0 if the current Node is not part of a slice.
94 // The index of the current node changes if InsertBefore is called while
95 // processing the current node.
96 func (c *Cursor) Index() int {
97 if c.iter != nil {
98 return c.iter.index
99 }
100 return -1
101 }
102 103 // field returns the current node's parent field value.
104 func (c *Cursor) field() reflect.Value {
105 return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
106 }
107 108 // Replace replaces the current Node with n.
109 // The replacement node is not walked by Apply.
110 func (c *Cursor) Replace(n ast.Node) {
111 if _, ok := c.node.(*ast.File); ok {
112 file, ok := n.(*ast.File)
113 if !ok {
114 panic("attempt to replace *ast.File with non-*ast.File")
115 }
116 c.parent.(*ast.Package).Files[c.name] = file
117 return
118 }
119 120 v := c.field()
121 if i := c.Index(); i >= 0 {
122 v = v.Index(i)
123 }
124 v.Set(reflect.ValueOf(n))
125 }
126 127 // Delete deletes the current Node from its containing slice.
128 // If the current Node is not part of a slice, Delete panics.
129 // As a special case, if the current node is a package file,
130 // Delete removes it from the package's Files map.
131 func (c *Cursor) Delete() {
132 if _, ok := c.node.(*ast.File); ok {
133 delete(c.parent.(*ast.Package).Files, c.name)
134 return
135 }
136 137 i := c.Index()
138 if i < 0 {
139 panic("Delete node not contained in slice")
140 }
141 v := c.field()
142 l := v.Len()
143 reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
144 v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
145 v.SetLen(l - 1)
146 c.iter.step--
147 }
148 149 // InsertAfter inserts n after the current Node in its containing slice.
150 // If the current Node is not part of a slice, InsertAfter panics.
151 // Apply does not walk n.
152 func (c *Cursor) InsertAfter(n ast.Node) {
153 i := c.Index()
154 if i < 0 {
155 panic("InsertAfter node not contained in slice")
156 }
157 v := c.field()
158 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
159 l := v.Len()
160 reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
161 v.Index(i + 1).Set(reflect.ValueOf(n))
162 c.iter.step++
163 }
164 165 // InsertBefore inserts n before the current Node in its containing slice.
166 // If the current Node is not part of a slice, InsertBefore panics.
167 // Apply will not walk n.
168 func (c *Cursor) InsertBefore(n ast.Node) {
169 i := c.Index()
170 if i < 0 {
171 panic("InsertBefore node not contained in slice")
172 }
173 v := c.field()
174 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
175 l := v.Len()
176 reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
177 v.Index(i).Set(reflect.ValueOf(n))
178 c.iter.index++
179 }
180 181 // application carries all the shared data so we can pass it around cheaply.
182 type application struct {
183 pre, post ApplyFunc
184 cursor Cursor
185 iter iterator
186 }
187 188 func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
189 // convert typed nil into untyped nil
190 if v := reflect.ValueOf(n); v.Kind() == reflect.Pointer && v.IsNil() {
191 n = nil
192 }
193 194 // avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
195 saved := a.cursor
196 a.cursor.parent = parent
197 a.cursor.name = name
198 a.cursor.iter = iter
199 a.cursor.node = n
200 201 if a.pre != nil && !a.pre(&a.cursor) {
202 a.cursor = saved
203 return
204 }
205 206 // walk children
207 // (the order of the cases matches the order of the corresponding node types in go/ast)
208 switch n := n.(type) {
209 case nil:
210 // nothing to do
211 212 // Comments and fields
213 case *ast.Comment:
214 // nothing to do
215 216 case *ast.CommentGroup:
217 if n != nil {
218 a.applyList(n, "List")
219 }
220 221 case *ast.Field:
222 a.apply(n, "Doc", nil, n.Doc)
223 a.applyList(n, "Names")
224 a.apply(n, "Type", nil, n.Type)
225 a.apply(n, "Tag", nil, n.Tag)
226 a.apply(n, "Comment", nil, n.Comment)
227 228 case *ast.FieldList:
229 a.applyList(n, "List")
230 231 // Expressions
232 case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
233 // nothing to do
234 235 case *ast.Ellipsis:
236 a.apply(n, "Elt", nil, n.Elt)
237 238 case *ast.FuncLit:
239 a.apply(n, "Type", nil, n.Type)
240 a.apply(n, "Body", nil, n.Body)
241 242 case *ast.CompositeLit:
243 a.apply(n, "Type", nil, n.Type)
244 a.applyList(n, "Elts")
245 246 case *ast.ParenExpr:
247 a.apply(n, "X", nil, n.X)
248 249 case *ast.SelectorExpr:
250 a.apply(n, "X", nil, n.X)
251 a.apply(n, "Sel", nil, n.Sel)
252 253 case *ast.IndexExpr:
254 a.apply(n, "X", nil, n.X)
255 a.apply(n, "Index", nil, n.Index)
256 257 case *ast.IndexListExpr:
258 a.apply(n, "X", nil, n.X)
259 a.applyList(n, "Indices")
260 261 case *ast.SliceExpr:
262 a.apply(n, "X", nil, n.X)
263 a.apply(n, "Low", nil, n.Low)
264 a.apply(n, "High", nil, n.High)
265 a.apply(n, "Max", nil, n.Max)
266 267 case *ast.TypeAssertExpr:
268 a.apply(n, "X", nil, n.X)
269 a.apply(n, "Type", nil, n.Type)
270 271 case *ast.CallExpr:
272 a.apply(n, "Fun", nil, n.Fun)
273 a.applyList(n, "Args")
274 275 case *ast.StarExpr:
276 a.apply(n, "X", nil, n.X)
277 278 case *ast.UnaryExpr:
279 a.apply(n, "X", nil, n.X)
280 281 case *ast.BinaryExpr:
282 a.apply(n, "X", nil, n.X)
283 a.apply(n, "Y", nil, n.Y)
284 285 case *ast.KeyValueExpr:
286 a.apply(n, "Key", nil, n.Key)
287 a.apply(n, "Value", nil, n.Value)
288 289 // Types
290 case *ast.ArrayType:
291 a.apply(n, "Len", nil, n.Len)
292 a.apply(n, "Elt", nil, n.Elt)
293 294 case *ast.StructType:
295 a.apply(n, "Fields", nil, n.Fields)
296 297 case *ast.FuncType:
298 if tparams := n.TypeParams; tparams != nil {
299 a.apply(n, "TypeParams", nil, tparams)
300 }
301 a.apply(n, "Params", nil, n.Params)
302 a.apply(n, "Results", nil, n.Results)
303 304 case *ast.InterfaceType:
305 a.apply(n, "Methods", nil, n.Methods)
306 307 case *ast.MapType:
308 a.apply(n, "Key", nil, n.Key)
309 a.apply(n, "Value", nil, n.Value)
310 311 case *ast.ChanType:
312 a.apply(n, "Value", nil, n.Value)
313 314 // Statements
315 case *ast.BadStmt:
316 // nothing to do
317 318 case *ast.DeclStmt:
319 a.apply(n, "Decl", nil, n.Decl)
320 321 case *ast.EmptyStmt:
322 // nothing to do
323 324 case *ast.LabeledStmt:
325 a.apply(n, "Label", nil, n.Label)
326 a.apply(n, "Stmt", nil, n.Stmt)
327 328 case *ast.ExprStmt:
329 a.apply(n, "X", nil, n.X)
330 331 case *ast.SendStmt:
332 a.apply(n, "Chan", nil, n.Chan)
333 a.apply(n, "Value", nil, n.Value)
334 335 case *ast.IncDecStmt:
336 a.apply(n, "X", nil, n.X)
337 338 case *ast.AssignStmt:
339 a.applyList(n, "Lhs")
340 a.applyList(n, "Rhs")
341 342 case *ast.GoStmt:
343 a.apply(n, "Call", nil, n.Call)
344 345 case *ast.DeferStmt:
346 a.apply(n, "Call", nil, n.Call)
347 348 case *ast.ReturnStmt:
349 a.applyList(n, "Results")
350 351 case *ast.BranchStmt:
352 a.apply(n, "Label", nil, n.Label)
353 354 case *ast.BlockStmt:
355 a.applyList(n, "List")
356 357 case *ast.IfStmt:
358 a.apply(n, "Init", nil, n.Init)
359 a.apply(n, "Cond", nil, n.Cond)
360 a.apply(n, "Body", nil, n.Body)
361 a.apply(n, "Else", nil, n.Else)
362 363 case *ast.CaseClause:
364 a.applyList(n, "List")
365 a.applyList(n, "Body")
366 367 case *ast.SwitchStmt:
368 a.apply(n, "Init", nil, n.Init)
369 a.apply(n, "Tag", nil, n.Tag)
370 a.apply(n, "Body", nil, n.Body)
371 372 case *ast.TypeSwitchStmt:
373 a.apply(n, "Init", nil, n.Init)
374 a.apply(n, "Assign", nil, n.Assign)
375 a.apply(n, "Body", nil, n.Body)
376 377 case *ast.CommClause:
378 a.apply(n, "Comm", nil, n.Comm)
379 a.applyList(n, "Body")
380 381 case *ast.SelectStmt:
382 a.apply(n, "Body", nil, n.Body)
383 384 case *ast.ForStmt:
385 a.apply(n, "Init", nil, n.Init)
386 a.apply(n, "Cond", nil, n.Cond)
387 a.apply(n, "Post", nil, n.Post)
388 a.apply(n, "Body", nil, n.Body)
389 390 case *ast.RangeStmt:
391 a.apply(n, "Key", nil, n.Key)
392 a.apply(n, "Value", nil, n.Value)
393 a.apply(n, "X", nil, n.X)
394 a.apply(n, "Body", nil, n.Body)
395 396 // Declarations
397 case *ast.ImportSpec:
398 a.apply(n, "Doc", nil, n.Doc)
399 a.apply(n, "Name", nil, n.Name)
400 a.apply(n, "Path", nil, n.Path)
401 a.apply(n, "Comment", nil, n.Comment)
402 403 case *ast.ValueSpec:
404 a.apply(n, "Doc", nil, n.Doc)
405 a.applyList(n, "Names")
406 a.apply(n, "Type", nil, n.Type)
407 a.applyList(n, "Values")
408 a.apply(n, "Comment", nil, n.Comment)
409 410 case *ast.TypeSpec:
411 a.apply(n, "Doc", nil, n.Doc)
412 a.apply(n, "Name", nil, n.Name)
413 if tparams := n.TypeParams; tparams != nil {
414 a.apply(n, "TypeParams", nil, tparams)
415 }
416 a.apply(n, "Type", nil, n.Type)
417 a.apply(n, "Comment", nil, n.Comment)
418 419 case *ast.BadDecl:
420 // nothing to do
421 422 case *ast.GenDecl:
423 a.apply(n, "Doc", nil, n.Doc)
424 a.applyList(n, "Specs")
425 426 case *ast.FuncDecl:
427 a.apply(n, "Doc", nil, n.Doc)
428 a.apply(n, "Recv", nil, n.Recv)
429 a.apply(n, "Name", nil, n.Name)
430 a.apply(n, "Type", nil, n.Type)
431 a.apply(n, "Body", nil, n.Body)
432 433 // Files and packages
434 case *ast.File:
435 a.apply(n, "Doc", nil, n.Doc)
436 a.apply(n, "Name", nil, n.Name)
437 a.applyList(n, "Decls")
438 // Don't walk n.Comments; they have either been walked already if
439 // they are Doc comments, or they can be easily walked explicitly.
440 441 case *ast.Package:
442 // collect and sort names for reproducible behavior
443 var names []string
444 for name := range n.Files {
445 names = append(names, name)
446 }
447 sort.Strings(names)
448 for _, name := range names {
449 a.apply(n, name, nil, n.Files[name])
450 }
451 452 default:
453 panic(fmt.Sprintf("Apply: unexpected node type %T", n))
454 }
455 456 if a.post != nil && !a.post(&a.cursor) {
457 panic(abort)
458 }
459 460 a.cursor = saved
461 }
462 463 // An iterator controls iteration over a slice of nodes.
464 type iterator struct {
465 index, step int
466 }
467 468 func (a *application) applyList(parent ast.Node, name string) {
469 // avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
470 saved := a.iter
471 a.iter.index = 0
472 for {
473 // must reload parent.name each time, since cursor modifications might change it
474 v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
475 if a.iter.index >= v.Len() {
476 break
477 }
478 479 // element x may be nil in a bad AST - be cautious
480 var x ast.Node
481 if e := v.Index(a.iter.index); e.IsValid() {
482 x = e.Interface().(ast.Node)
483 }
484 485 a.iter.step = 1
486 a.apply(parent, name, &a.iter, x)
487 a.iter.index += a.iter.step
488 }
489 a.iter = saved
490 }
491