sa5001.go raw

   1  package sa5001
   2  
   3  import (
   4  	"fmt"
   5  	"go/ast"
   6  	"go/types"
   7  
   8  	"honnef.co/go/tools/analysis/code"
   9  	"honnef.co/go/tools/analysis/lint"
  10  	"honnef.co/go/tools/analysis/report"
  11  
  12  	"golang.org/x/tools/go/analysis"
  13  	"golang.org/x/tools/go/analysis/passes/inspect"
  14  )
  15  
  16  var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
  17  	Analyzer: &analysis.Analyzer{
  18  		Name:     "SA5001",
  19  		Run:      run,
  20  		Requires: []*analysis.Analyzer{inspect.Analyzer},
  21  	},
  22  	Doc: &lint.RawDocumentation{
  23  		Title:    `Deferring \'Close\' before checking for a possible error`,
  24  		Since:    "2017.1",
  25  		Severity: lint.SeverityWarning,
  26  		MergeIf:  lint.MergeIfAny,
  27  	},
  28  })
  29  
  30  var Analyzer = SCAnalyzer.Analyzer
  31  
  32  func run(pass *analysis.Pass) (interface{}, error) {
  33  	fn := func(node ast.Node) {
  34  		block := node.(*ast.BlockStmt)
  35  		if len(block.List) < 2 {
  36  			return
  37  		}
  38  		for i, stmt := range block.List {
  39  			if i == len(block.List)-1 {
  40  				break
  41  			}
  42  			assign, ok := stmt.(*ast.AssignStmt)
  43  			if !ok {
  44  				continue
  45  			}
  46  			if len(assign.Rhs) != 1 {
  47  				continue
  48  			}
  49  			if len(assign.Lhs) < 2 {
  50  				continue
  51  			}
  52  			if lhs, ok := assign.Lhs[len(assign.Lhs)-1].(*ast.Ident); ok && lhs.Name == "_" {
  53  				continue
  54  			}
  55  			call, ok := assign.Rhs[0].(*ast.CallExpr)
  56  			if !ok {
  57  				continue
  58  			}
  59  			sig, ok := pass.TypesInfo.TypeOf(call.Fun).(*types.Signature)
  60  			if !ok {
  61  				continue
  62  			}
  63  			if sig.Results().Len() < 2 {
  64  				continue
  65  			}
  66  			last := sig.Results().At(sig.Results().Len() - 1)
  67  			// FIXME(dh): check that it's error from universe, not
  68  			// another type of the same name
  69  			if last.Type().String() != "error" {
  70  				continue
  71  			}
  72  			lhs, ok := assign.Lhs[0].(*ast.Ident)
  73  			if !ok {
  74  				continue
  75  			}
  76  			def, ok := block.List[i+1].(*ast.DeferStmt)
  77  			if !ok {
  78  				continue
  79  			}
  80  			sel, ok := def.Call.Fun.(*ast.SelectorExpr)
  81  			if !ok {
  82  				continue
  83  			}
  84  			ident, ok := selectorX(sel).(*ast.Ident)
  85  			if !ok {
  86  				continue
  87  			}
  88  			if pass.TypesInfo.ObjectOf(ident) != pass.TypesInfo.ObjectOf(lhs) {
  89  				continue
  90  			}
  91  			if sel.Sel.Name != "Close" {
  92  				continue
  93  			}
  94  			report.Report(pass, def, fmt.Sprintf("should check error returned from %s() before deferring %s",
  95  				report.Render(pass, call.Fun), report.Render(pass, def.Call)))
  96  		}
  97  	}
  98  	code.Preorder(pass, fn, (*ast.BlockStmt)(nil))
  99  	return nil, nil
 100  }
 101  
 102  func selectorX(sel *ast.SelectorExpr) ast.Node {
 103  	switch x := sel.X.(type) {
 104  	case *ast.SelectorExpr:
 105  		return selectorX(x)
 106  	default:
 107  		return x
 108  	}
 109  }
 110