enclosing.go raw

   1  // Copyright 2013 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package astutil
   6  
   7  // This file defines utilities for working with source positions.
   8  
   9  import (
  10  	"fmt"
  11  	"go/ast"
  12  	"go/token"
  13  	"sort"
  14  )
  15  
  16  // PathEnclosingInterval returns the node that encloses the source
  17  // interval [start, end), and all its ancestors up to the AST root.
  18  //
  19  // The definition of "enclosing" used by this function considers
  20  // additional whitespace abutting a node to be enclosed by it.
  21  // In this example:
  22  //
  23  //	z := x + y // add them
  24  //	     <-A->
  25  //	    <----B----->
  26  //
  27  // the ast.BinaryExpr(+) node is considered to enclose interval B
  28  // even though its [Pos()..End()) is actually only interval A.
  29  // This behaviour makes user interfaces more tolerant of imperfect
  30  // input.
  31  //
  32  // This function treats tokens as nodes, though they are not included
  33  // in the result. e.g. PathEnclosingInterval("+") returns the
  34  // enclosing ast.BinaryExpr("x + y").
  35  //
  36  // If start==end, the 1-char interval following start is used instead.
  37  //
  38  // The 'exact' result is true if the interval contains only path[0]
  39  // and perhaps some adjacent whitespace.  It is false if the interval
  40  // overlaps multiple children of path[0], or if it contains only
  41  // interior whitespace of path[0].
  42  // In this example:
  43  //
  44  //	z := x + y // add them
  45  //	  <--C-->     <---E-->
  46  //	    ^
  47  //	    D
  48  //
  49  // intervals C, D and E are inexact.  C is contained by the
  50  // z-assignment statement, because it spans three of its children (:=,
  51  // x, +).  So too is the 1-char interval D, because it contains only
  52  // interior whitespace of the assignment.  E is considered interior
  53  // whitespace of the BlockStmt containing the assignment.
  54  //
  55  // The resulting path is never empty; it always contains at least the
  56  // 'root' *ast.File.  Ideally PathEnclosingInterval would reject
  57  // intervals that lie wholly or partially outside the range of the
  58  // file, but unfortunately ast.File records only the token.Pos of
  59  // the 'package' keyword, but not of the start of the file itself.
  60  func PathEnclosingInterval(root *ast.File, start, end token.Pos) (path []ast.Node, exact bool) {
  61  	// fmt.Printf("EnclosingInterval %d %d\n", start, end) // debugging
  62  
  63  	// Precondition: node.[Pos..End) and adjoining whitespace contain [start, end).
  64  	var visit func(node ast.Node) bool
  65  	visit = func(node ast.Node) bool {
  66  		path = append(path, node)
  67  
  68  		nodePos := node.Pos()
  69  		nodeEnd := node.End()
  70  
  71  		// fmt.Printf("visit(%T, %d, %d)\n", node, nodePos, nodeEnd) // debugging
  72  
  73  		// Intersect [start, end) with interval of node.
  74  		if start < nodePos {
  75  			start = nodePos
  76  		}
  77  		if end > nodeEnd {
  78  			end = nodeEnd
  79  		}
  80  
  81  		// Find sole child that contains [start, end).
  82  		children := childrenOf(node)
  83  		l := len(children)
  84  		for i, child := range children {
  85  			// [childPos, childEnd) is unaugmented interval of child.
  86  			childPos := child.Pos()
  87  			childEnd := child.End()
  88  
  89  			// [augPos, augEnd) is whitespace-augmented interval of child.
  90  			augPos := childPos
  91  			augEnd := childEnd
  92  			if i > 0 {
  93  				augPos = children[i-1].End() // start of preceding whitespace
  94  			}
  95  			if i < l-1 {
  96  				nextChildPos := children[i+1].Pos()
  97  				// Does [start, end) lie between child and next child?
  98  				if start >= augEnd && end <= nextChildPos {
  99  					return false // inexact match
 100  				}
 101  				augEnd = nextChildPos // end of following whitespace
 102  			}
 103  
 104  			// fmt.Printf("\tchild %d: [%d..%d)\tcontains interval [%d..%d)?\n",
 105  			// 	i, augPos, augEnd, start, end) // debugging
 106  
 107  			// Does augmented child strictly contain [start, end)?
 108  			if augPos <= start && end <= augEnd {
 109  				if is[tokenNode](child) {
 110  					return true
 111  				}
 112  
 113  				// childrenOf elides the FuncType node beneath FuncDecl.
 114  				// Add it back here for TypeParams, Params, Results,
 115  				// all FieldLists). But we don't add it back for the "func" token
 116  				// even though it is the tree at FuncDecl.Type.Func.
 117  				if decl, ok := node.(*ast.FuncDecl); ok {
 118  					if fields, ok := child.(*ast.FieldList); ok && fields != decl.Recv {
 119  						path = append(path, decl.Type)
 120  					}
 121  				}
 122  
 123  				return visit(child)
 124  			}
 125  
 126  			// Does [start, end) overlap multiple children?
 127  			// i.e. left-augmented child contains start
 128  			// but LR-augmented child does not contain end.
 129  			if start < childEnd && end > augEnd {
 130  				break
 131  			}
 132  		}
 133  
 134  		// No single child contained [start, end),
 135  		// so node is the result.  Is it exact?
 136  
 137  		// (It's tempting to put this condition before the
 138  		// child loop, but it gives the wrong result in the
 139  		// case where a node (e.g. ExprStmt) and its sole
 140  		// child have equal intervals.)
 141  		if start == nodePos && end == nodeEnd {
 142  			return true // exact match
 143  		}
 144  
 145  		return false // inexact: overlaps multiple children
 146  	}
 147  
 148  	// Ensure [start,end) is nondecreasing.
 149  	if start > end {
 150  		start, end = end, start
 151  	}
 152  
 153  	if start < root.End() && end > root.Pos() {
 154  		if start == end {
 155  			end = start + 1 // empty interval => interval of size 1
 156  		}
 157  		exact = visit(root)
 158  
 159  		// Reverse the path:
 160  		for i, l := 0, len(path); i < l/2; i++ {
 161  			path[i], path[l-1-i] = path[l-1-i], path[i]
 162  		}
 163  	} else {
 164  		// Selection lies within whitespace preceding the
 165  		// first (or following the last) declaration in the file.
 166  		// The result nonetheless always includes the ast.File.
 167  		path = append(path, root)
 168  	}
 169  
 170  	return
 171  }
 172  
 173  // tokenNode is a dummy implementation of ast.Node for a single token.
 174  // They are used transiently by PathEnclosingInterval but never escape
 175  // this package.
 176  type tokenNode struct {
 177  	pos token.Pos
 178  	end token.Pos
 179  }
 180  
 181  func (n tokenNode) Pos() token.Pos {
 182  	return n.pos
 183  }
 184  
 185  func (n tokenNode) End() token.Pos {
 186  	return n.end
 187  }
 188  
 189  func tok(pos token.Pos, len int) ast.Node {
 190  	return tokenNode{pos, pos + token.Pos(len)}
 191  }
 192  
 193  // childrenOf returns the direct non-nil children of ast.Node n.
 194  // It may include fake ast.Node implementations for bare tokens.
 195  // it is not safe to call (e.g.) ast.Walk on such nodes.
 196  func childrenOf(n ast.Node) []ast.Node {
 197  	var children []ast.Node
 198  
 199  	// First add nodes for all true subtrees.
 200  	ast.Inspect(n, func(node ast.Node) bool {
 201  		if node == n { // push n
 202  			return true // recur
 203  		}
 204  		if node != nil { // push child
 205  			children = append(children, node)
 206  		}
 207  		return false // no recursion
 208  	})
 209  
 210  	// TODO(adonovan): be more careful about missing (!Pos.Valid)
 211  	// tokens in trees produced from invalid input.
 212  
 213  	// Then add fake Nodes for bare tokens.
 214  	switch n := n.(type) {
 215  	case *ast.ArrayType:
 216  		children = append(children,
 217  			tok(n.Lbrack, len("[")),
 218  			tok(n.Elt.End(), len("]")))
 219  
 220  	case *ast.AssignStmt:
 221  		children = append(children,
 222  			tok(n.TokPos, len(n.Tok.String())))
 223  
 224  	case *ast.BasicLit:
 225  		children = append(children,
 226  			tok(n.ValuePos, len(n.Value)))
 227  
 228  	case *ast.BinaryExpr:
 229  		children = append(children, tok(n.OpPos, len(n.Op.String())))
 230  
 231  	case *ast.BlockStmt:
 232  		if n.Lbrace.IsValid() {
 233  			children = append(children, tok(n.Lbrace, len("{")))
 234  		}
 235  		if n.Rbrace.IsValid() {
 236  			children = append(children, tok(n.Rbrace, len("}")))
 237  		}
 238  
 239  	case *ast.BranchStmt:
 240  		children = append(children,
 241  			tok(n.TokPos, len(n.Tok.String())))
 242  
 243  	case *ast.CallExpr:
 244  		children = append(children,
 245  			tok(n.Lparen, len("(")),
 246  			tok(n.Rparen, len(")")))
 247  		if n.Ellipsis != 0 {
 248  			children = append(children, tok(n.Ellipsis, len("...")))
 249  		}
 250  
 251  	case *ast.CaseClause:
 252  		if n.List == nil {
 253  			children = append(children,
 254  				tok(n.Case, len("default")))
 255  		} else {
 256  			children = append(children,
 257  				tok(n.Case, len("case")))
 258  		}
 259  		children = append(children, tok(n.Colon, len(":")))
 260  
 261  	case *ast.ChanType:
 262  		switch n.Dir {
 263  		case ast.RECV:
 264  			children = append(children, tok(n.Begin, len("<-chan")))
 265  		case ast.SEND:
 266  			children = append(children, tok(n.Begin, len("chan<-")))
 267  		case ast.RECV | ast.SEND:
 268  			children = append(children, tok(n.Begin, len("chan")))
 269  		}
 270  
 271  	case *ast.CommClause:
 272  		if n.Comm == nil {
 273  			children = append(children,
 274  				tok(n.Case, len("default")))
 275  		} else {
 276  			children = append(children,
 277  				tok(n.Case, len("case")))
 278  		}
 279  		children = append(children, tok(n.Colon, len(":")))
 280  
 281  	case *ast.Comment:
 282  		// nop
 283  
 284  	case *ast.CommentGroup:
 285  		// nop
 286  
 287  	case *ast.CompositeLit:
 288  		children = append(children,
 289  			tok(n.Lbrace, len("{")),
 290  			tok(n.Rbrace, len("{")))
 291  
 292  	case *ast.DeclStmt:
 293  		// nop
 294  
 295  	case *ast.DeferStmt:
 296  		children = append(children,
 297  			tok(n.Defer, len("defer")))
 298  
 299  	case *ast.Ellipsis:
 300  		children = append(children,
 301  			tok(n.Ellipsis, len("...")))
 302  
 303  	case *ast.EmptyStmt:
 304  		// nop
 305  
 306  	case *ast.ExprStmt:
 307  		// nop
 308  
 309  	case *ast.Field:
 310  		// TODO(adonovan): Field.{Doc,Comment,Tag}?
 311  
 312  	case *ast.FieldList:
 313  		if n.Opening.IsValid() {
 314  			children = append(children, tok(n.Opening, len("(")))
 315  		}
 316  		if n.Closing.IsValid() {
 317  			children = append(children, tok(n.Closing, len(")")))
 318  		}
 319  
 320  	case *ast.File:
 321  		// TODO test: Doc
 322  		children = append(children,
 323  			tok(n.Package, len("package")))
 324  
 325  	case *ast.ForStmt:
 326  		children = append(children,
 327  			tok(n.For, len("for")))
 328  
 329  	case *ast.FuncDecl:
 330  		// TODO(adonovan): FuncDecl.Comment?
 331  
 332  		// Uniquely, FuncDecl breaks the invariant that
 333  		// preorder traversal yields tokens in lexical order:
 334  		// in fact, FuncDecl.Recv precedes FuncDecl.Type.Func.
 335  		//
 336  		// As a workaround, we inline the case for FuncType
 337  		// here and order things correctly.
 338  		// We also need to insert the elided FuncType just
 339  		// before the 'visit' recursion.
 340  		//
 341  		children = nil // discard ast.Walk(FuncDecl) info subtrees
 342  		children = append(children, tok(n.Type.Func, len("func")))
 343  		if n.Recv != nil {
 344  			children = append(children, n.Recv)
 345  		}
 346  		children = append(children, n.Name)
 347  		if tparams := n.Type.TypeParams; tparams != nil {
 348  			children = append(children, tparams)
 349  		}
 350  		if n.Type.Params != nil {
 351  			children = append(children, n.Type.Params)
 352  		}
 353  		if n.Type.Results != nil {
 354  			children = append(children, n.Type.Results)
 355  		}
 356  		if n.Body != nil {
 357  			children = append(children, n.Body)
 358  		}
 359  
 360  	case *ast.FuncLit:
 361  		// nop
 362  
 363  	case *ast.FuncType:
 364  		if n.Func != 0 {
 365  			children = append(children,
 366  				tok(n.Func, len("func")))
 367  		}
 368  
 369  	case *ast.GenDecl:
 370  		children = append(children,
 371  			tok(n.TokPos, len(n.Tok.String())))
 372  		if n.Lparen != 0 {
 373  			children = append(children,
 374  				tok(n.Lparen, len("(")),
 375  				tok(n.Rparen, len(")")))
 376  		}
 377  
 378  	case *ast.GoStmt:
 379  		children = append(children,
 380  			tok(n.Go, len("go")))
 381  
 382  	case *ast.Ident:
 383  		children = append(children,
 384  			tok(n.NamePos, len(n.Name)))
 385  
 386  	case *ast.IfStmt:
 387  		children = append(children,
 388  			tok(n.If, len("if")))
 389  
 390  	case *ast.ImportSpec:
 391  		// TODO(adonovan): ImportSpec.{Doc,EndPos}?
 392  
 393  	case *ast.IncDecStmt:
 394  		children = append(children,
 395  			tok(n.TokPos, len(n.Tok.String())))
 396  
 397  	case *ast.IndexExpr:
 398  		children = append(children,
 399  			tok(n.Lbrack, len("[")),
 400  			tok(n.Rbrack, len("]")))
 401  
 402  	case *ast.IndexListExpr:
 403  		children = append(children,
 404  			tok(n.Lbrack, len("[")),
 405  			tok(n.Rbrack, len("]")))
 406  
 407  	case *ast.InterfaceType:
 408  		children = append(children,
 409  			tok(n.Interface, len("interface")))
 410  
 411  	case *ast.KeyValueExpr:
 412  		children = append(children,
 413  			tok(n.Colon, len(":")))
 414  
 415  	case *ast.LabeledStmt:
 416  		children = append(children,
 417  			tok(n.Colon, len(":")))
 418  
 419  	case *ast.MapType:
 420  		children = append(children,
 421  			tok(n.Map, len("map")))
 422  
 423  	case *ast.ParenExpr:
 424  		children = append(children,
 425  			tok(n.Lparen, len("(")),
 426  			tok(n.Rparen, len(")")))
 427  
 428  	case *ast.RangeStmt:
 429  		children = append(children,
 430  			tok(n.For, len("for")),
 431  			tok(n.TokPos, len(n.Tok.String())))
 432  
 433  	case *ast.ReturnStmt:
 434  		children = append(children,
 435  			tok(n.Return, len("return")))
 436  
 437  	case *ast.SelectStmt:
 438  		children = append(children,
 439  			tok(n.Select, len("select")))
 440  
 441  	case *ast.SelectorExpr:
 442  		// nop
 443  
 444  	case *ast.SendStmt:
 445  		children = append(children,
 446  			tok(n.Arrow, len("<-")))
 447  
 448  	case *ast.SliceExpr:
 449  		children = append(children,
 450  			tok(n.Lbrack, len("[")),
 451  			tok(n.Rbrack, len("]")))
 452  
 453  	case *ast.StarExpr:
 454  		children = append(children, tok(n.Star, len("*")))
 455  
 456  	case *ast.StructType:
 457  		children = append(children, tok(n.Struct, len("struct")))
 458  
 459  	case *ast.SwitchStmt:
 460  		children = append(children, tok(n.Switch, len("switch")))
 461  
 462  	case *ast.TypeAssertExpr:
 463  		children = append(children,
 464  			tok(n.Lparen-1, len(".")),
 465  			tok(n.Lparen, len("(")),
 466  			tok(n.Rparen, len(")")))
 467  
 468  	case *ast.TypeSpec:
 469  		// TODO(adonovan): TypeSpec.{Doc,Comment}?
 470  
 471  	case *ast.TypeSwitchStmt:
 472  		children = append(children, tok(n.Switch, len("switch")))
 473  
 474  	case *ast.UnaryExpr:
 475  		children = append(children, tok(n.OpPos, len(n.Op.String())))
 476  
 477  	case *ast.ValueSpec:
 478  		// TODO(adonovan): ValueSpec.{Doc,Comment}?
 479  
 480  	case *ast.BadDecl, *ast.BadExpr, *ast.BadStmt:
 481  		// nop
 482  	}
 483  
 484  	// TODO(adonovan): opt: merge the logic of ast.Inspect() into
 485  	// the switch above so we can make interleaved callbacks for
 486  	// both Nodes and Tokens in the right order and avoid the need
 487  	// to sort.
 488  	sort.Sort(byPos(children))
 489  
 490  	return children
 491  }
 492  
 493  type byPos []ast.Node
 494  
 495  func (sl byPos) Len() int {
 496  	return len(sl)
 497  }
 498  func (sl byPos) Less(i, j int) bool {
 499  	return sl[i].Pos() < sl[j].Pos()
 500  }
 501  func (sl byPos) Swap(i, j int) {
 502  	sl[i], sl[j] = sl[j], sl[i]
 503  }
 504  
 505  // NodeDescription returns a description of the concrete type of n suitable
 506  // for a user interface.
 507  //
 508  // TODO(adonovan): in some cases (e.g. Field, FieldList, Ident,
 509  // StarExpr) we could be much more specific given the path to the AST
 510  // root.  Perhaps we should do that.
 511  func NodeDescription(n ast.Node) string {
 512  	switch n := n.(type) {
 513  	case *ast.ArrayType:
 514  		return "array type"
 515  	case *ast.AssignStmt:
 516  		return "assignment"
 517  	case *ast.BadDecl:
 518  		return "bad declaration"
 519  	case *ast.BadExpr:
 520  		return "bad expression"
 521  	case *ast.BadStmt:
 522  		return "bad statement"
 523  	case *ast.BasicLit:
 524  		return "basic literal"
 525  	case *ast.BinaryExpr:
 526  		return fmt.Sprintf("binary %s operation", n.Op)
 527  	case *ast.BlockStmt:
 528  		return "block"
 529  	case *ast.BranchStmt:
 530  		switch n.Tok {
 531  		case token.BREAK:
 532  			return "break statement"
 533  		case token.CONTINUE:
 534  			return "continue statement"
 535  		case token.GOTO:
 536  			return "goto statement"
 537  		case token.FALLTHROUGH:
 538  			return "fall-through statement"
 539  		}
 540  	case *ast.CallExpr:
 541  		if len(n.Args) == 1 && !n.Ellipsis.IsValid() {
 542  			return "function call (or conversion)"
 543  		}
 544  		return "function call"
 545  	case *ast.CaseClause:
 546  		return "case clause"
 547  	case *ast.ChanType:
 548  		return "channel type"
 549  	case *ast.CommClause:
 550  		return "communication clause"
 551  	case *ast.Comment:
 552  		return "comment"
 553  	case *ast.CommentGroup:
 554  		return "comment group"
 555  	case *ast.CompositeLit:
 556  		return "composite literal"
 557  	case *ast.DeclStmt:
 558  		return NodeDescription(n.Decl) + " statement"
 559  	case *ast.DeferStmt:
 560  		return "defer statement"
 561  	case *ast.Ellipsis:
 562  		return "ellipsis"
 563  	case *ast.EmptyStmt:
 564  		return "empty statement"
 565  	case *ast.ExprStmt:
 566  		return "expression statement"
 567  	case *ast.Field:
 568  		// Can be any of these:
 569  		// struct {x, y int}  -- struct field(s)
 570  		// struct {T}         -- anon struct field
 571  		// interface {I}      -- interface embedding
 572  		// interface {f()}    -- interface method
 573  		// func (A) func(B) C -- receiver, param(s), result(s)
 574  		return "field/method/parameter"
 575  	case *ast.FieldList:
 576  		return "field/method/parameter list"
 577  	case *ast.File:
 578  		return "source file"
 579  	case *ast.ForStmt:
 580  		return "for loop"
 581  	case *ast.FuncDecl:
 582  		return "function declaration"
 583  	case *ast.FuncLit:
 584  		return "function literal"
 585  	case *ast.FuncType:
 586  		return "function type"
 587  	case *ast.GenDecl:
 588  		switch n.Tok {
 589  		case token.IMPORT:
 590  			return "import declaration"
 591  		case token.CONST:
 592  			return "constant declaration"
 593  		case token.TYPE:
 594  			return "type declaration"
 595  		case token.VAR:
 596  			return "variable declaration"
 597  		}
 598  	case *ast.GoStmt:
 599  		return "go statement"
 600  	case *ast.Ident:
 601  		return "identifier"
 602  	case *ast.IfStmt:
 603  		return "if statement"
 604  	case *ast.ImportSpec:
 605  		return "import specification"
 606  	case *ast.IncDecStmt:
 607  		if n.Tok == token.INC {
 608  			return "increment statement"
 609  		}
 610  		return "decrement statement"
 611  	case *ast.IndexExpr:
 612  		return "index expression"
 613  	case *ast.IndexListExpr:
 614  		return "index list expression"
 615  	case *ast.InterfaceType:
 616  		return "interface type"
 617  	case *ast.KeyValueExpr:
 618  		return "key/value association"
 619  	case *ast.LabeledStmt:
 620  		return "statement label"
 621  	case *ast.MapType:
 622  		return "map type"
 623  	case *ast.Package:
 624  		return "package"
 625  	case *ast.ParenExpr:
 626  		return "parenthesized " + NodeDescription(n.X)
 627  	case *ast.RangeStmt:
 628  		return "range loop"
 629  	case *ast.ReturnStmt:
 630  		return "return statement"
 631  	case *ast.SelectStmt:
 632  		return "select statement"
 633  	case *ast.SelectorExpr:
 634  		return "selector"
 635  	case *ast.SendStmt:
 636  		return "channel send"
 637  	case *ast.SliceExpr:
 638  		return "slice expression"
 639  	case *ast.StarExpr:
 640  		return "*-operation" // load/store expr or pointer type
 641  	case *ast.StructType:
 642  		return "struct type"
 643  	case *ast.SwitchStmt:
 644  		return "switch statement"
 645  	case *ast.TypeAssertExpr:
 646  		return "type assertion"
 647  	case *ast.TypeSpec:
 648  		return "type specification"
 649  	case *ast.TypeSwitchStmt:
 650  		return "type switch"
 651  	case *ast.UnaryExpr:
 652  		return fmt.Sprintf("unary %s operation", n.Op)
 653  	case *ast.ValueSpec:
 654  		return "value specification"
 655  
 656  	}
 657  	panic(fmt.Sprintf("unexpected node type: %T", n))
 658  }
 659  
 660  func is[T any](x any) bool {
 661  	_, ok := x.(T)
 662  	return ok
 663  }
 664