s1008.go raw

   1  package s1008
   2  
   3  import (
   4  	"fmt"
   5  	"go/ast"
   6  	"go/constant"
   7  	"go/token"
   8  	"strings"
   9  
  10  	"honnef.co/go/tools/analysis/code"
  11  	"honnef.co/go/tools/analysis/facts/generated"
  12  	"honnef.co/go/tools/analysis/lint"
  13  	"honnef.co/go/tools/analysis/report"
  14  	"honnef.co/go/tools/pattern"
  15  
  16  	"golang.org/x/tools/go/analysis"
  17  	"golang.org/x/tools/go/analysis/passes/inspect"
  18  )
  19  
  20  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
  21  	Analyzer: &analysis.Analyzer{
  22  		Name:     "S1008",
  23  		Run:      run,
  24  		Requires: []*analysis.Analyzer{inspect.Analyzer, generated.Analyzer},
  25  	},
  26  	Doc: &lint.RawDocumentation{
  27  		Title: `Simplify returning boolean expression`,
  28  		Before: `
  29  if <expr> {
  30      return true
  31  }
  32  return false`,
  33  		After:   `return <expr>`,
  34  		Since:   "2017.1",
  35  		MergeIf: lint.MergeIfAny,
  36  	},
  37  })
  38  
  39  var Analyzer = SCAnalyzer.Analyzer
  40  
  41  var (
  42  	checkIfReturnQIf  = pattern.MustParse(`(IfStmt nil cond [(ReturnStmt [ret@(Builtin (Or "true" "false"))])] nil)`)
  43  	checkIfReturnQRet = pattern.MustParse(`(ReturnStmt [ret@(Builtin (Or "true" "false"))])`)
  44  )
  45  
  46  func run(pass *analysis.Pass) (any, error) {
  47  	var cm ast.CommentMap
  48  	fn := func(node ast.Node) {
  49  		if f, ok := node.(*ast.File); ok {
  50  			cm = ast.NewCommentMap(pass.Fset, f, f.Comments)
  51  			return
  52  		}
  53  
  54  		block := node.(*ast.BlockStmt)
  55  		l := len(block.List)
  56  		if l < 2 {
  57  			return
  58  		}
  59  		n1, n2 := block.List[l-2], block.List[l-1]
  60  
  61  		if len(block.List) >= 3 {
  62  			if _, ok := block.List[l-3].(*ast.IfStmt); ok {
  63  				// Do not flag a series of if statements
  64  				return
  65  			}
  66  		}
  67  		m1, ok := code.Match(pass, checkIfReturnQIf, n1)
  68  		if !ok {
  69  			return
  70  		}
  71  		m2, ok := code.Match(pass, checkIfReturnQRet, n2)
  72  		if !ok {
  73  			return
  74  		}
  75  
  76  		if op, ok := m1.State["cond"].(*ast.BinaryExpr); ok {
  77  			switch op.Op {
  78  			case token.EQL, token.LSS, token.GTR, token.NEQ, token.LEQ, token.GEQ:
  79  			default:
  80  				return
  81  			}
  82  		}
  83  
  84  		ret1 := m1.State["ret"].(*ast.Ident)
  85  		ret2 := m2.State["ret"].(*ast.Ident)
  86  		if ret1.Name == ret2.Name {
  87  			// we want the function to return true and false, not the
  88  			// same value both times.
  89  			return
  90  		}
  91  
  92  		hasComments := func(n ast.Node) bool {
  93  			cmf := cm.Filter(n)
  94  			for _, groups := range cmf {
  95  				for _, group := range groups {
  96  					for _, cmt := range group.List {
  97  						if strings.HasPrefix(cmt.Text, "//@ diag") {
  98  							// Staticcheck test cases use comments to mark
  99  							// expected diagnostics. Ignore these comments so we
 100  							// can test this check.
 101  							continue
 102  						}
 103  						return true
 104  					}
 105  				}
 106  			}
 107  			return false
 108  		}
 109  
 110  		// Don't flag if either branch is commented
 111  		if hasComments(n1) || hasComments(n2) {
 112  			return
 113  		}
 114  
 115  		cond := m1.State["cond"].(ast.Expr)
 116  		origCond := cond
 117  		if ret1.Name == "false" {
 118  			cond = negate(pass, cond)
 119  		}
 120  		report.Report(pass, n1,
 121  			fmt.Sprintf("should use 'return %s' instead of 'if %s { return %s }; return %s'",
 122  				report.Render(pass, cond),
 123  				report.Render(pass, origCond), report.Render(pass, ret1), report.Render(pass, ret2)),
 124  			report.FilterGenerated())
 125  	}
 126  	code.Preorder(pass, fn, (*ast.File)(nil), (*ast.BlockStmt)(nil))
 127  	return nil, nil
 128  }
 129  
 130  func negate(pass *analysis.Pass, expr ast.Expr) ast.Expr {
 131  	switch expr := expr.(type) {
 132  	case *ast.BinaryExpr:
 133  		out := *expr
 134  		switch expr.Op {
 135  		case token.EQL:
 136  			out.Op = token.NEQ
 137  		case token.LSS:
 138  			out.Op = token.GEQ
 139  		case token.GTR:
 140  			// Some builtins never return negative ints; "len(x) <= 0" should be "len(x) == 0".
 141  			if call, ok := expr.X.(*ast.CallExpr); ok &&
 142  				code.IsCallToAny(pass, call, "len", "cap", "copy") &&
 143  				code.IsIntegerLiteral(pass, expr.Y, constant.MakeInt64(0)) {
 144  				out.Op = token.EQL
 145  			} else {
 146  				out.Op = token.LEQ
 147  			}
 148  		case token.NEQ:
 149  			out.Op = token.EQL
 150  		case token.LEQ:
 151  			out.Op = token.GTR
 152  		case token.GEQ:
 153  			out.Op = token.LSS
 154  		}
 155  		return &out
 156  	case *ast.Ident, *ast.CallExpr, *ast.IndexExpr, *ast.StarExpr:
 157  		return &ast.UnaryExpr{
 158  			Op: token.NOT,
 159  			X:  expr,
 160  		}
 161  	case *ast.UnaryExpr:
 162  		if expr.Op == token.NOT {
 163  			return expr.X
 164  		}
 165  		return &ast.UnaryExpr{
 166  			Op: token.NOT,
 167  			X:  expr,
 168  		}
 169  	default:
 170  		return &ast.UnaryExpr{
 171  			Op: token.NOT,
 172  			X: &ast.ParenExpr{
 173  				X: expr,
 174  			},
 175  		}
 176  	}
 177  }
 178