rewrite.go raw

   1  // Copyright 2017 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package astutil
   6  
   7  import (
   8  	"fmt"
   9  	"go/ast"
  10  	"reflect"
  11  	"sort"
  12  )
  13  
  14  // An ApplyFunc is invoked by Apply for each node n, even if n is nil,
  15  // before and/or after the node's children, using a Cursor describing
  16  // the current node and providing operations on it.
  17  //
  18  // The return value of ApplyFunc controls the syntax tree traversal.
  19  // See Apply for details.
  20  type ApplyFunc func(*Cursor) bool
  21  
  22  // Apply traverses a syntax tree recursively, starting with root,
  23  // and calling pre and post for each node as described below.
  24  // Apply returns the syntax tree, possibly modified.
  25  //
  26  // If pre is not nil, it is called for each node before the node's
  27  // children are traversed (pre-order). If pre returns false, no
  28  // children are traversed, and post is not called for that node.
  29  //
  30  // If post is not nil, and a prior call of pre didn't return false,
  31  // post is called for each node after its children are traversed
  32  // (post-order). If post returns false, traversal is terminated and
  33  // Apply returns immediately.
  34  //
  35  // Only fields that refer to AST nodes are considered children;
  36  // i.e., token.Pos, Scopes, Objects, and fields of basic types
  37  // (strings, etc.) are ignored.
  38  //
  39  // Children are traversed in the order in which they appear in the
  40  // respective node's struct definition. A package's files are
  41  // traversed in the filenames' alphabetical order.
  42  func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
  43  	parent := &struct{ ast.Node }{root}
  44  	defer func() {
  45  		if r := recover(); r != nil && r != abort {
  46  			panic(r)
  47  		}
  48  		result = parent.Node
  49  	}()
  50  	a := &application{pre: pre, post: post}
  51  	a.apply(parent, "Node", nil, root)
  52  	return
  53  }
  54  
  55  var abort = new(int) // singleton, to signal termination of Apply
  56  
  57  // A Cursor describes a node encountered during Apply.
  58  // Information about the node and its parent is available
  59  // from the Node, Parent, Name, and Index methods.
  60  //
  61  // If p is a variable of type and value of the current parent node
  62  // c.Parent(), and f is the field identifier with name c.Name(),
  63  // the following invariants hold:
  64  //
  65  //	p.f            == c.Node()  if c.Index() <  0
  66  //	p.f[c.Index()] == c.Node()  if c.Index() >= 0
  67  //
  68  // The methods Replace, Delete, InsertBefore, and InsertAfter
  69  // can be used to change the AST without disrupting Apply.
  70  //
  71  // This type is not to be confused with [inspector.Cursor] from
  72  // package [golang.org/x/tools/go/ast/inspector], which provides
  73  // stateless navigation of immutable syntax trees.
  74  type Cursor struct {
  75  	parent ast.Node
  76  	name   string
  77  	iter   *iterator // valid if non-nil
  78  	node   ast.Node
  79  }
  80  
  81  // Node returns the current Node.
  82  func (c *Cursor) Node() ast.Node { return c.node }
  83  
  84  // Parent returns the parent of the current Node.
  85  func (c *Cursor) Parent() ast.Node { return c.parent }
  86  
  87  // Name returns the name of the parent Node field that contains the current Node.
  88  // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
  89  // the filename for the current Node.
  90  func (c *Cursor) Name() string { return c.name }
  91  
  92  // Index reports the index >= 0 of the current Node in the slice of Nodes that
  93  // contains it, or a value < 0 if the current Node is not part of a slice.
  94  // The index of the current node changes if InsertBefore is called while
  95  // processing the current node.
  96  func (c *Cursor) Index() int {
  97  	if c.iter != nil {
  98  		return c.iter.index
  99  	}
 100  	return -1
 101  }
 102  
 103  // field returns the current node's parent field value.
 104  func (c *Cursor) field() reflect.Value {
 105  	return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
 106  }
 107  
 108  // Replace replaces the current Node with n.
 109  // The replacement node is not walked by Apply.
 110  func (c *Cursor) Replace(n ast.Node) {
 111  	if _, ok := c.node.(*ast.File); ok {
 112  		file, ok := n.(*ast.File)
 113  		if !ok {
 114  			panic("attempt to replace *ast.File with non-*ast.File")
 115  		}
 116  		c.parent.(*ast.Package).Files[c.name] = file
 117  		return
 118  	}
 119  
 120  	v := c.field()
 121  	if i := c.Index(); i >= 0 {
 122  		v = v.Index(i)
 123  	}
 124  	v.Set(reflect.ValueOf(n))
 125  }
 126  
 127  // Delete deletes the current Node from its containing slice.
 128  // If the current Node is not part of a slice, Delete panics.
 129  // As a special case, if the current node is a package file,
 130  // Delete removes it from the package's Files map.
 131  func (c *Cursor) Delete() {
 132  	if _, ok := c.node.(*ast.File); ok {
 133  		delete(c.parent.(*ast.Package).Files, c.name)
 134  		return
 135  	}
 136  
 137  	i := c.Index()
 138  	if i < 0 {
 139  		panic("Delete node not contained in slice")
 140  	}
 141  	v := c.field()
 142  	l := v.Len()
 143  	reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
 144  	v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
 145  	v.SetLen(l - 1)
 146  	c.iter.step--
 147  }
 148  
 149  // InsertAfter inserts n after the current Node in its containing slice.
 150  // If the current Node is not part of a slice, InsertAfter panics.
 151  // Apply does not walk n.
 152  func (c *Cursor) InsertAfter(n ast.Node) {
 153  	i := c.Index()
 154  	if i < 0 {
 155  		panic("InsertAfter node not contained in slice")
 156  	}
 157  	v := c.field()
 158  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
 159  	l := v.Len()
 160  	reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
 161  	v.Index(i + 1).Set(reflect.ValueOf(n))
 162  	c.iter.step++
 163  }
 164  
 165  // InsertBefore inserts n before the current Node in its containing slice.
 166  // If the current Node is not part of a slice, InsertBefore panics.
 167  // Apply will not walk n.
 168  func (c *Cursor) InsertBefore(n ast.Node) {
 169  	i := c.Index()
 170  	if i < 0 {
 171  		panic("InsertBefore node not contained in slice")
 172  	}
 173  	v := c.field()
 174  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
 175  	l := v.Len()
 176  	reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
 177  	v.Index(i).Set(reflect.ValueOf(n))
 178  	c.iter.index++
 179  }
 180  
 181  // application carries all the shared data so we can pass it around cheaply.
 182  type application struct {
 183  	pre, post ApplyFunc
 184  	cursor    Cursor
 185  	iter      iterator
 186  }
 187  
 188  func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
 189  	// convert typed nil into untyped nil
 190  	if v := reflect.ValueOf(n); v.Kind() == reflect.Pointer && v.IsNil() {
 191  		n = nil
 192  	}
 193  
 194  	// avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
 195  	saved := a.cursor
 196  	a.cursor.parent = parent
 197  	a.cursor.name = name
 198  	a.cursor.iter = iter
 199  	a.cursor.node = n
 200  
 201  	if a.pre != nil && !a.pre(&a.cursor) {
 202  		a.cursor = saved
 203  		return
 204  	}
 205  
 206  	// walk children
 207  	// (the order of the cases matches the order of the corresponding node types in go/ast)
 208  	switch n := n.(type) {
 209  	case nil:
 210  		// nothing to do
 211  
 212  	// Comments and fields
 213  	case *ast.Comment:
 214  		// nothing to do
 215  
 216  	case *ast.CommentGroup:
 217  		if n != nil {
 218  			a.applyList(n, "List")
 219  		}
 220  
 221  	case *ast.Field:
 222  		a.apply(n, "Doc", nil, n.Doc)
 223  		a.applyList(n, "Names")
 224  		a.apply(n, "Type", nil, n.Type)
 225  		a.apply(n, "Tag", nil, n.Tag)
 226  		a.apply(n, "Comment", nil, n.Comment)
 227  
 228  	case *ast.FieldList:
 229  		a.applyList(n, "List")
 230  
 231  	// Expressions
 232  	case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
 233  		// nothing to do
 234  
 235  	case *ast.Ellipsis:
 236  		a.apply(n, "Elt", nil, n.Elt)
 237  
 238  	case *ast.FuncLit:
 239  		a.apply(n, "Type", nil, n.Type)
 240  		a.apply(n, "Body", nil, n.Body)
 241  
 242  	case *ast.CompositeLit:
 243  		a.apply(n, "Type", nil, n.Type)
 244  		a.applyList(n, "Elts")
 245  
 246  	case *ast.ParenExpr:
 247  		a.apply(n, "X", nil, n.X)
 248  
 249  	case *ast.SelectorExpr:
 250  		a.apply(n, "X", nil, n.X)
 251  		a.apply(n, "Sel", nil, n.Sel)
 252  
 253  	case *ast.IndexExpr:
 254  		a.apply(n, "X", nil, n.X)
 255  		a.apply(n, "Index", nil, n.Index)
 256  
 257  	case *ast.IndexListExpr:
 258  		a.apply(n, "X", nil, n.X)
 259  		a.applyList(n, "Indices")
 260  
 261  	case *ast.SliceExpr:
 262  		a.apply(n, "X", nil, n.X)
 263  		a.apply(n, "Low", nil, n.Low)
 264  		a.apply(n, "High", nil, n.High)
 265  		a.apply(n, "Max", nil, n.Max)
 266  
 267  	case *ast.TypeAssertExpr:
 268  		a.apply(n, "X", nil, n.X)
 269  		a.apply(n, "Type", nil, n.Type)
 270  
 271  	case *ast.CallExpr:
 272  		a.apply(n, "Fun", nil, n.Fun)
 273  		a.applyList(n, "Args")
 274  
 275  	case *ast.StarExpr:
 276  		a.apply(n, "X", nil, n.X)
 277  
 278  	case *ast.UnaryExpr:
 279  		a.apply(n, "X", nil, n.X)
 280  
 281  	case *ast.BinaryExpr:
 282  		a.apply(n, "X", nil, n.X)
 283  		a.apply(n, "Y", nil, n.Y)
 284  
 285  	case *ast.KeyValueExpr:
 286  		a.apply(n, "Key", nil, n.Key)
 287  		a.apply(n, "Value", nil, n.Value)
 288  
 289  	// Types
 290  	case *ast.ArrayType:
 291  		a.apply(n, "Len", nil, n.Len)
 292  		a.apply(n, "Elt", nil, n.Elt)
 293  
 294  	case *ast.StructType:
 295  		a.apply(n, "Fields", nil, n.Fields)
 296  
 297  	case *ast.FuncType:
 298  		if tparams := n.TypeParams; tparams != nil {
 299  			a.apply(n, "TypeParams", nil, tparams)
 300  		}
 301  		a.apply(n, "Params", nil, n.Params)
 302  		a.apply(n, "Results", nil, n.Results)
 303  
 304  	case *ast.InterfaceType:
 305  		a.apply(n, "Methods", nil, n.Methods)
 306  
 307  	case *ast.MapType:
 308  		a.apply(n, "Key", nil, n.Key)
 309  		a.apply(n, "Value", nil, n.Value)
 310  
 311  	case *ast.ChanType:
 312  		a.apply(n, "Value", nil, n.Value)
 313  
 314  	// Statements
 315  	case *ast.BadStmt:
 316  		// nothing to do
 317  
 318  	case *ast.DeclStmt:
 319  		a.apply(n, "Decl", nil, n.Decl)
 320  
 321  	case *ast.EmptyStmt:
 322  		// nothing to do
 323  
 324  	case *ast.LabeledStmt:
 325  		a.apply(n, "Label", nil, n.Label)
 326  		a.apply(n, "Stmt", nil, n.Stmt)
 327  
 328  	case *ast.ExprStmt:
 329  		a.apply(n, "X", nil, n.X)
 330  
 331  	case *ast.SendStmt:
 332  		a.apply(n, "Chan", nil, n.Chan)
 333  		a.apply(n, "Value", nil, n.Value)
 334  
 335  	case *ast.IncDecStmt:
 336  		a.apply(n, "X", nil, n.X)
 337  
 338  	case *ast.AssignStmt:
 339  		a.applyList(n, "Lhs")
 340  		a.applyList(n, "Rhs")
 341  
 342  	case *ast.GoStmt:
 343  		a.apply(n, "Call", nil, n.Call)
 344  
 345  	case *ast.DeferStmt:
 346  		a.apply(n, "Call", nil, n.Call)
 347  
 348  	case *ast.ReturnStmt:
 349  		a.applyList(n, "Results")
 350  
 351  	case *ast.BranchStmt:
 352  		a.apply(n, "Label", nil, n.Label)
 353  
 354  	case *ast.BlockStmt:
 355  		a.applyList(n, "List")
 356  
 357  	case *ast.IfStmt:
 358  		a.apply(n, "Init", nil, n.Init)
 359  		a.apply(n, "Cond", nil, n.Cond)
 360  		a.apply(n, "Body", nil, n.Body)
 361  		a.apply(n, "Else", nil, n.Else)
 362  
 363  	case *ast.CaseClause:
 364  		a.applyList(n, "List")
 365  		a.applyList(n, "Body")
 366  
 367  	case *ast.SwitchStmt:
 368  		a.apply(n, "Init", nil, n.Init)
 369  		a.apply(n, "Tag", nil, n.Tag)
 370  		a.apply(n, "Body", nil, n.Body)
 371  
 372  	case *ast.TypeSwitchStmt:
 373  		a.apply(n, "Init", nil, n.Init)
 374  		a.apply(n, "Assign", nil, n.Assign)
 375  		a.apply(n, "Body", nil, n.Body)
 376  
 377  	case *ast.CommClause:
 378  		a.apply(n, "Comm", nil, n.Comm)
 379  		a.applyList(n, "Body")
 380  
 381  	case *ast.SelectStmt:
 382  		a.apply(n, "Body", nil, n.Body)
 383  
 384  	case *ast.ForStmt:
 385  		a.apply(n, "Init", nil, n.Init)
 386  		a.apply(n, "Cond", nil, n.Cond)
 387  		a.apply(n, "Post", nil, n.Post)
 388  		a.apply(n, "Body", nil, n.Body)
 389  
 390  	case *ast.RangeStmt:
 391  		a.apply(n, "Key", nil, n.Key)
 392  		a.apply(n, "Value", nil, n.Value)
 393  		a.apply(n, "X", nil, n.X)
 394  		a.apply(n, "Body", nil, n.Body)
 395  
 396  	// Declarations
 397  	case *ast.ImportSpec:
 398  		a.apply(n, "Doc", nil, n.Doc)
 399  		a.apply(n, "Name", nil, n.Name)
 400  		a.apply(n, "Path", nil, n.Path)
 401  		a.apply(n, "Comment", nil, n.Comment)
 402  
 403  	case *ast.ValueSpec:
 404  		a.apply(n, "Doc", nil, n.Doc)
 405  		a.applyList(n, "Names")
 406  		a.apply(n, "Type", nil, n.Type)
 407  		a.applyList(n, "Values")
 408  		a.apply(n, "Comment", nil, n.Comment)
 409  
 410  	case *ast.TypeSpec:
 411  		a.apply(n, "Doc", nil, n.Doc)
 412  		a.apply(n, "Name", nil, n.Name)
 413  		if tparams := n.TypeParams; tparams != nil {
 414  			a.apply(n, "TypeParams", nil, tparams)
 415  		}
 416  		a.apply(n, "Type", nil, n.Type)
 417  		a.apply(n, "Comment", nil, n.Comment)
 418  
 419  	case *ast.BadDecl:
 420  		// nothing to do
 421  
 422  	case *ast.GenDecl:
 423  		a.apply(n, "Doc", nil, n.Doc)
 424  		a.applyList(n, "Specs")
 425  
 426  	case *ast.FuncDecl:
 427  		a.apply(n, "Doc", nil, n.Doc)
 428  		a.apply(n, "Recv", nil, n.Recv)
 429  		a.apply(n, "Name", nil, n.Name)
 430  		a.apply(n, "Type", nil, n.Type)
 431  		a.apply(n, "Body", nil, n.Body)
 432  
 433  	// Files and packages
 434  	case *ast.File:
 435  		a.apply(n, "Doc", nil, n.Doc)
 436  		a.apply(n, "Name", nil, n.Name)
 437  		a.applyList(n, "Decls")
 438  		// Don't walk n.Comments; they have either been walked already if
 439  		// they are Doc comments, or they can be easily walked explicitly.
 440  
 441  	case *ast.Package:
 442  		// collect and sort names for reproducible behavior
 443  		var names []string
 444  		for name := range n.Files {
 445  			names = append(names, name)
 446  		}
 447  		sort.Strings(names)
 448  		for _, name := range names {
 449  			a.apply(n, name, nil, n.Files[name])
 450  		}
 451  
 452  	default:
 453  		panic(fmt.Sprintf("Apply: unexpected node type %T", n))
 454  	}
 455  
 456  	if a.post != nil && !a.post(&a.cursor) {
 457  		panic(abort)
 458  	}
 459  
 460  	a.cursor = saved
 461  }
 462  
 463  // An iterator controls iteration over a slice of nodes.
 464  type iterator struct {
 465  	index, step int
 466  }
 467  
 468  func (a *application) applyList(parent ast.Node, name string) {
 469  	// avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
 470  	saved := a.iter
 471  	a.iter.index = 0
 472  	for {
 473  		// must reload parent.name each time, since cursor modifications might change it
 474  		v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
 475  		if a.iter.index >= v.Len() {
 476  			break
 477  		}
 478  
 479  		// element x may be nil in a bad AST - be cautious
 480  		var x ast.Node
 481  		if e := v.Index(a.iter.index); e.IsValid() {
 482  			x = e.Interface().(ast.Node)
 483  		}
 484  
 485  		a.iter.step = 1
 486  		a.apply(parent, name, &a.iter, x)
 487  		a.iter.index += a.iter.step
 488  	}
 489  	a.iter = saved
 490  }
 491