sa4031.go raw

   1  package sa4031
   2  
   3  import (
   4  	"fmt"
   5  	"go/ast"
   6  	"go/token"
   7  	"go/types"
   8  	"sort"
   9  
  10  	"honnef.co/go/tools/analysis/code"
  11  	"honnef.co/go/tools/analysis/lint"
  12  	"honnef.co/go/tools/analysis/report"
  13  	"honnef.co/go/tools/go/ir"
  14  	"honnef.co/go/tools/internal/passes/buildir"
  15  	"honnef.co/go/tools/pattern"
  16  	"honnef.co/go/tools/staticcheck/sa4022"
  17  
  18  	"golang.org/x/tools/go/analysis"
  19  	"golang.org/x/tools/go/analysis/passes/inspect"
  20  )
  21  
  22  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
  23  	Analyzer: &analysis.Analyzer{
  24  		Name:     "SA4031",
  25  		Run:      run,
  26  		Requires: []*analysis.Analyzer{buildir.Analyzer, inspect.Analyzer},
  27  	},
  28  	Doc: &lint.RawDocumentation{
  29  		Title:    `Checking never-nil value against nil`,
  30  		Since:    "2022.1",
  31  		Severity: lint.SeverityWarning,
  32  		MergeIf:  lint.MergeIfAny,
  33  	},
  34  })
  35  
  36  var Analyzer = SCAnalyzer.Analyzer
  37  
  38  var allocationNilCheckQ = pattern.MustParse(`(IfStmt _ cond@(BinaryExpr lhs op@(Or "==" "!=") (Builtin "nil")) _ _)`)
  39  
  40  func run(pass *analysis.Pass) (interface{}, error) {
  41  	irpkg := pass.ResultOf[buildir.Analyzer].(*buildir.IR).Pkg
  42  
  43  	var path []ast.Node
  44  	fn := func(node ast.Node, stack []ast.Node) {
  45  		m, ok := code.Match(pass, allocationNilCheckQ, node)
  46  		if !ok {
  47  			return
  48  		}
  49  		cond := m.State["cond"].(ast.Node)
  50  		if _, ok := code.Match(pass, sa4022.CheckAddressIsNilQ, cond); ok {
  51  			// Don't duplicate diagnostics reported by SA4022
  52  			return
  53  		}
  54  		lhs := m.State["lhs"].(ast.Expr)
  55  		path = path[:0]
  56  		for i := len(stack) - 1; i >= 0; i-- {
  57  			path = append(path, stack[i])
  58  		}
  59  		irfn := ir.EnclosingFunction(irpkg, path)
  60  		if irfn == nil {
  61  			// For example for functions named "_", because we don't generate IR for them.
  62  			return
  63  		}
  64  		v, isAddr := irfn.ValueForExpr(lhs)
  65  		if isAddr {
  66  			return
  67  		}
  68  
  69  		seen := map[ir.Value]struct{}{}
  70  		var values []ir.Value
  71  		var neverNil func(v ir.Value, track bool) bool
  72  		neverNil = func(v ir.Value, track bool) bool {
  73  			if _, ok := seen[v]; ok {
  74  				return true
  75  			}
  76  			seen[v] = struct{}{}
  77  			switch v := v.(type) {
  78  			case *ir.MakeClosure, *ir.Function:
  79  				if track {
  80  					values = append(values, v)
  81  				}
  82  				return true
  83  			case *ir.MakeChan, *ir.MakeMap, *ir.MakeSlice, *ir.Alloc:
  84  				if track {
  85  					values = append(values, v)
  86  				}
  87  				return true
  88  			case *ir.Slice:
  89  				if track {
  90  					values = append(values, v)
  91  				}
  92  				return neverNil(v.X, false)
  93  			case *ir.FieldAddr:
  94  				if track {
  95  					values = append(values, v)
  96  				}
  97  				return neverNil(v.X, false)
  98  			case *ir.Sigma:
  99  				return neverNil(v.X, true)
 100  			case *ir.Phi:
 101  				for _, e := range v.Edges {
 102  					if !neverNil(e, true) {
 103  						return false
 104  					}
 105  				}
 106  				return true
 107  			default:
 108  				return false
 109  			}
 110  		}
 111  
 112  		if !neverNil(v, true) {
 113  			return
 114  		}
 115  
 116  		var qualifier string
 117  		if op := m.State["op"].(token.Token); op == token.EQL {
 118  			qualifier = "never"
 119  		} else {
 120  			qualifier = "always"
 121  		}
 122  		fallback := fmt.Sprintf("this nil check is %s true", qualifier)
 123  
 124  		sort.Slice(values, func(i, j int) bool { return values[i].Pos() < values[j].Pos() })
 125  
 126  		if ident, ok := m.State["lhs"].(*ast.Ident); ok {
 127  			if _, ok := pass.TypesInfo.ObjectOf(ident).(*types.Var); ok {
 128  				var opts []report.Option
 129  				if v.Parent() == irfn {
 130  					if len(values) == 1 {
 131  						opts = append(opts, report.Related(values[0], fmt.Sprintf("this is the value of %s", ident.Name)))
 132  					} else {
 133  						for _, vv := range values {
 134  							opts = append(opts, report.Related(vv, fmt.Sprintf("this is one of the value of %s", ident.Name)))
 135  						}
 136  					}
 137  				}
 138  
 139  				switch v.(type) {
 140  				case *ir.MakeClosure, *ir.Function:
 141  					report.Report(pass, cond, "the checked variable contains a function and is never nil; did you mean to call it?", opts...)
 142  				default:
 143  					report.Report(pass, cond, fallback, opts...)
 144  				}
 145  			} else {
 146  				if _, ok := v.(*ir.Function); ok {
 147  					report.Report(pass, cond, "functions are never nil; did you mean to call it?")
 148  				} else {
 149  					report.Report(pass, cond, fallback)
 150  				}
 151  			}
 152  		} else {
 153  			if _, ok := v.(*ir.Function); ok {
 154  				report.Report(pass, cond, "functions are never nil; did you mean to call it?")
 155  			} else {
 156  				report.Report(pass, cond, fallback)
 157  			}
 158  		}
 159  	}
 160  	code.PreorderStack(pass, fn, (*ast.IfStmt)(nil))
 161  	return nil, nil
 162  }
 163