qf1003.go raw

   1  package qf1003
   2  
   3  import (
   4  	"fmt"
   5  	"go/ast"
   6  	"go/token"
   7  	"strings"
   8  
   9  	"honnef.co/go/tools/analysis/code"
  10  	"honnef.co/go/tools/analysis/edit"
  11  	"honnef.co/go/tools/analysis/lint"
  12  	"honnef.co/go/tools/analysis/report"
  13  	"honnef.co/go/tools/go/ast/astutil"
  14  
  15  	"golang.org/x/tools/go/analysis"
  16  	"golang.org/x/tools/go/analysis/passes/inspect"
  17  )
  18  
  19  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
  20  	Analyzer: &analysis.Analyzer{
  21  		Name:     "QF1003",
  22  		Run:      run,
  23  		Requires: []*analysis.Analyzer{inspect.Analyzer},
  24  	},
  25  	Doc: &lint.RawDocumentation{
  26  		Title: "Convert if/else-if chain to tagged switch",
  27  		Text: `
  28  A series of if/else-if checks comparing the same variable against
  29  values can be replaced with a tagged switch.`,
  30  		Before: `
  31  if x == 1 || x == 2 {
  32      ...
  33  } else if x == 3 {
  34      ...
  35  } else {
  36      ...
  37  }`,
  38  
  39  		After: `
  40  switch x {
  41  case 1, 2:
  42      ...
  43  case 3:
  44      ...
  45  default:
  46      ...
  47  }`,
  48  		Since:    "2021.1",
  49  		Severity: lint.SeverityInfo,
  50  	},
  51  })
  52  
  53  var Analyzer = SCAnalyzer.Analyzer
  54  
  55  func run(pass *analysis.Pass) (interface{}, error) {
  56  	fn := func(node ast.Node, stack []ast.Node) {
  57  		if _, ok := stack[len(stack)-2].(*ast.IfStmt); ok {
  58  			// this if statement is part of an if-else chain
  59  			return
  60  		}
  61  		ifstmt := node.(*ast.IfStmt)
  62  
  63  		m := map[ast.Expr][]*ast.BinaryExpr{}
  64  		for item := ifstmt; item != nil; {
  65  			if item.Init != nil {
  66  				return
  67  			}
  68  			if item.Body == nil {
  69  				return
  70  			}
  71  
  72  			skip := false
  73  			ast.Inspect(item.Body, func(node ast.Node) bool {
  74  				if branch, ok := node.(*ast.BranchStmt); ok && branch.Tok != token.GOTO {
  75  					skip = true
  76  					return false
  77  				}
  78  				return true
  79  			})
  80  			if skip {
  81  				return
  82  			}
  83  
  84  			var pairs []*ast.BinaryExpr
  85  			if !findSwitchPairs(pass, item.Cond, &pairs) {
  86  				return
  87  			}
  88  			m[item.Cond] = pairs
  89  			switch els := item.Else.(type) {
  90  			case *ast.IfStmt:
  91  				item = els
  92  			case *ast.BlockStmt, nil:
  93  				item = nil
  94  			default:
  95  				panic(fmt.Sprintf("unreachable: %T", els))
  96  			}
  97  		}
  98  
  99  		var x ast.Expr
 100  		for _, pair := range m {
 101  			if len(pair) == 0 {
 102  				continue
 103  			}
 104  			if x == nil {
 105  				x = pair[0].X
 106  			} else {
 107  				if !astutil.Equal(x, pair[0].X) {
 108  					return
 109  				}
 110  			}
 111  		}
 112  		if x == nil {
 113  			// shouldn't happen
 114  			return
 115  		}
 116  
 117  		// We require at least two 'if' to make this suggestion, to
 118  		// avoid clutter in the editor.
 119  		if len(m) < 2 {
 120  			return
 121  		}
 122  
 123  		// Note that we insert the switch statement as the first text edit instead of the last one so that gopls has an
 124  		// easier time converting it to an LSP-conforming edit.
 125  		//
 126  		// Specifically:
 127  		// > Text edits ranges must never overlap, that means no part of the original
 128  		// > document must be manipulated by more than one edit. However, it is
 129  		// > possible that multiple edits have the same start position: multiple
 130  		// > inserts, or any number of inserts followed by a single remove or replace
 131  		// > edit. If multiple inserts have the same position, the order in the array
 132  		// > defines the order in which the inserted strings appear in the resulting
 133  		// > text.
 134  		//
 135  		// See https://go.dev/issue/63930
 136  		//
 137  		// FIXME this edit forces the first case to begin in column 0 because we ignore indentation. try to fix that.
 138  		edits := []analysis.TextEdit{edit.ReplaceWithString(edit.Range{ifstmt.If, ifstmt.If}, fmt.Sprintf("switch %s {\n", report.Render(pass, x)))}
 139  		for item := ifstmt; item != nil; {
 140  			var end token.Pos
 141  			if item.Else != nil {
 142  				end = item.Else.Pos()
 143  			} else {
 144  				// delete up to but not including the closing brace.
 145  				end = item.Body.Rbrace
 146  			}
 147  
 148  			var conds []string
 149  			for _, cond := range m[item.Cond] {
 150  				y := cond.Y
 151  				if p, ok := y.(*ast.ParenExpr); ok {
 152  					y = p.X
 153  				}
 154  				conds = append(conds, report.Render(pass, y))
 155  			}
 156  			sconds := strings.Join(conds, ", ")
 157  			edits = append(edits,
 158  				edit.ReplaceWithString(edit.Range{item.If, item.Body.Lbrace + 1}, "case "+sconds+":"),
 159  				edit.Delete(edit.Range{item.Body.Rbrace, end}))
 160  
 161  			switch els := item.Else.(type) {
 162  			case *ast.IfStmt:
 163  				item = els
 164  			case *ast.BlockStmt:
 165  				edits = append(edits, edit.ReplaceWithString(edit.Range{els.Lbrace, els.Lbrace + 1}, "default:"))
 166  				item = nil
 167  			case nil:
 168  				item = nil
 169  			default:
 170  				panic(fmt.Sprintf("unreachable: %T", els))
 171  			}
 172  		}
 173  		report.Report(pass, ifstmt, fmt.Sprintf("could use tagged switch on %s", report.Render(pass, x)),
 174  			report.Fixes(edit.Fix("Replace with tagged switch", edits...)),
 175  			report.ShortRange())
 176  	}
 177  	code.PreorderStack(pass, fn, (*ast.IfStmt)(nil))
 178  	return nil, nil
 179  }
 180  
 181  func findSwitchPairs(pass *analysis.Pass, expr ast.Expr, pairs *[]*ast.BinaryExpr) bool {
 182  	binexpr, ok := astutil.Unparen(expr).(*ast.BinaryExpr)
 183  	if !ok {
 184  		return false
 185  	}
 186  	switch binexpr.Op {
 187  	case token.EQL:
 188  		if code.MayHaveSideEffects(pass, binexpr.X, nil) || code.MayHaveSideEffects(pass, binexpr.Y, nil) {
 189  			return false
 190  		}
 191  		// syntactic identity should suffice. we do not allow side
 192  		// effects in the case clauses, so there should be no way for
 193  		// values to change.
 194  		if len(*pairs) > 0 && !astutil.Equal(binexpr.X, (*pairs)[0].X) {
 195  			return false
 196  		}
 197  		*pairs = append(*pairs, binexpr)
 198  		return true
 199  	case token.LOR:
 200  		return findSwitchPairs(pass, binexpr.X, pairs) && findSwitchPairs(pass, binexpr.Y, pairs)
 201  	default:
 202  		return false
 203  	}
 204  }
 205