qf1002.go raw

   1  package qf1002
   2  
   3  import (
   4  	"fmt"
   5  	"go/ast"
   6  	"go/token"
   7  	"strings"
   8  
   9  	"honnef.co/go/tools/analysis/code"
  10  	"honnef.co/go/tools/analysis/edit"
  11  	"honnef.co/go/tools/analysis/lint"
  12  	"honnef.co/go/tools/analysis/report"
  13  	"honnef.co/go/tools/go/ast/astutil"
  14  
  15  	"golang.org/x/tools/go/analysis"
  16  	"golang.org/x/tools/go/analysis/passes/inspect"
  17  )
  18  
  19  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
  20  	Analyzer: &analysis.Analyzer{
  21  		Name:     "QF1002",
  22  		Run:      run,
  23  		Requires: []*analysis.Analyzer{inspect.Analyzer},
  24  	},
  25  	Doc: &lint.RawDocumentation{
  26  		Title: "Convert untagged switch to tagged switch",
  27  		Text: `
  28  An untagged switch that compares a single variable against a series of
  29  values can be replaced with a tagged switch.`,
  30  		Before: `
  31  switch {
  32  case x == 1 || x == 2, x == 3:
  33      ...
  34  case x == 4:
  35      ...
  36  default:
  37      ...
  38  }`,
  39  
  40  		After: `
  41  switch x {
  42  case 1, 2, 3:
  43      ...
  44  case 4:
  45      ...
  46  default:
  47      ...
  48  }`,
  49  		Since:    "2021.1",
  50  		Severity: lint.SeverityHint,
  51  	},
  52  })
  53  
  54  var Analyzer = SCAnalyzer.Analyzer
  55  
  56  func run(pass *analysis.Pass) (interface{}, error) {
  57  	fn := func(node ast.Node) {
  58  		swtch := node.(*ast.SwitchStmt)
  59  		if swtch.Tag != nil || len(swtch.Body.List) == 0 {
  60  			return
  61  		}
  62  
  63  		pairs := make([][]*ast.BinaryExpr, len(swtch.Body.List))
  64  		for i, stmt := range swtch.Body.List {
  65  			stmt := stmt.(*ast.CaseClause)
  66  			for _, cond := range stmt.List {
  67  				if !findSwitchPairs(pass, cond, &pairs[i]) {
  68  					return
  69  				}
  70  			}
  71  		}
  72  
  73  		var x ast.Expr
  74  		for _, pair := range pairs {
  75  			if len(pair) == 0 {
  76  				continue
  77  			}
  78  			if x == nil {
  79  				x = pair[0].X
  80  			} else {
  81  				if !astutil.Equal(x, pair[0].X) {
  82  					return
  83  				}
  84  			}
  85  		}
  86  		if x == nil {
  87  			// the switch only has a default case
  88  			if len(pairs) > 1 {
  89  				panic("found more than one case clause with no pairs")
  90  			}
  91  			return
  92  		}
  93  
  94  		edits := make([]analysis.TextEdit, 0, len(swtch.Body.List)+1)
  95  		for i, stmt := range swtch.Body.List {
  96  			stmt := stmt.(*ast.CaseClause)
  97  			if stmt.List == nil {
  98  				continue
  99  			}
 100  
 101  			var values []string
 102  			for _, binexpr := range pairs[i] {
 103  				y := binexpr.Y
 104  				if p, ok := y.(*ast.ParenExpr); ok {
 105  					y = p.X
 106  				}
 107  				values = append(values, report.Render(pass, y))
 108  			}
 109  
 110  			edits = append(edits, edit.ReplaceWithString(edit.Range{stmt.List[0].Pos(), stmt.Colon}, strings.Join(values, ", ")))
 111  		}
 112  		pos := swtch.Body.Lbrace
 113  		edits = append(edits, edit.ReplaceWithString(edit.Range{pos, pos}, " "+report.Render(pass, x)))
 114  		report.Report(pass, swtch, fmt.Sprintf("could use tagged switch on %s", report.Render(pass, x)),
 115  			report.Fixes(edit.Fix("Replace with tagged switch", edits...)))
 116  	}
 117  
 118  	code.Preorder(pass, fn, (*ast.SwitchStmt)(nil))
 119  	return nil, nil
 120  }
 121  
 122  func findSwitchPairs(pass *analysis.Pass, expr ast.Expr, pairs *[]*ast.BinaryExpr) bool {
 123  	binexpr, ok := astutil.Unparen(expr).(*ast.BinaryExpr)
 124  	if !ok {
 125  		return false
 126  	}
 127  	switch binexpr.Op {
 128  	case token.EQL:
 129  		if code.MayHaveSideEffects(pass, binexpr.X, nil) || code.MayHaveSideEffects(pass, binexpr.Y, nil) {
 130  			return false
 131  		}
 132  		// syntactic identity should suffice. we do not allow side
 133  		// effects in the case clauses, so there should be no way for
 134  		// values to change.
 135  		if len(*pairs) > 0 && !astutil.Equal(binexpr.X, (*pairs)[0].X) {
 136  			return false
 137  		}
 138  		*pairs = append(*pairs, binexpr)
 139  		return true
 140  	case token.LOR:
 141  		return findSwitchPairs(pass, binexpr.X, pairs) && findSwitchPairs(pass, binexpr.Y, pairs)
 142  	default:
 143  		return false
 144  	}
 145  }
 146