callcheck.go raw

   1  // Package callcheck provides a framework for validating arguments in function calls.
   2  package callcheck
   3  
   4  import (
   5  	"fmt"
   6  	"go/ast"
   7  	"go/constant"
   8  	"go/types"
   9  
  10  	"golang.org/x/tools/go/analysis"
  11  	"honnef.co/go/tools/analysis/report"
  12  	"honnef.co/go/tools/go/ir"
  13  	"honnef.co/go/tools/go/ir/irutil"
  14  	"honnef.co/go/tools/go/types/typeutil"
  15  	"honnef.co/go/tools/internal/passes/buildir"
  16  )
  17  
  18  type Call struct {
  19  	Pass  *analysis.Pass
  20  	Instr ir.CallInstruction
  21  	Args  []*Argument
  22  
  23  	Parent *ir.Function
  24  
  25  	invalids []string
  26  }
  27  
  28  func (c *Call) Invalid(msg string) {
  29  	c.invalids = append(c.invalids, msg)
  30  }
  31  
  32  type Argument struct {
  33  	Value    Value
  34  	invalids []string
  35  }
  36  
  37  type Value struct {
  38  	Value ir.Value
  39  }
  40  
  41  func (arg *Argument) Invalid(msg string) {
  42  	arg.invalids = append(arg.invalids, msg)
  43  }
  44  
  45  type Check func(call *Call)
  46  
  47  func Analyzer(rules map[string]Check) func(pass *analysis.Pass) (interface{}, error) {
  48  	return func(pass *analysis.Pass) (interface{}, error) {
  49  		return checkCalls(pass, rules)
  50  	}
  51  }
  52  
  53  func checkCalls(pass *analysis.Pass, rules map[string]Check) (interface{}, error) {
  54  	cb := func(caller *ir.Function, site ir.CallInstruction, callee *ir.Function) {
  55  		obj, ok := callee.Object().(*types.Func)
  56  		if !ok {
  57  			return
  58  		}
  59  
  60  		r, ok := rules[typeutil.FuncName(obj)]
  61  		if !ok {
  62  			return
  63  		}
  64  		var args []*Argument
  65  		irargs := site.Common().Args
  66  		if callee.Signature.Recv() != nil {
  67  			irargs = irargs[1:]
  68  		}
  69  		for _, arg := range irargs {
  70  			if iarg, ok := arg.(*ir.MakeInterface); ok {
  71  				arg = iarg.X
  72  			}
  73  			args = append(args, &Argument{Value: Value{arg}})
  74  		}
  75  		call := &Call{
  76  			Pass:   pass,
  77  			Instr:  site,
  78  			Args:   args,
  79  			Parent: site.Parent(),
  80  		}
  81  		r(call)
  82  
  83  		var astcall *ast.CallExpr
  84  		switch source := site.Source().(type) {
  85  		case *ast.CallExpr:
  86  			astcall = source
  87  		case *ast.DeferStmt:
  88  			astcall = source.Call
  89  		case *ast.GoStmt:
  90  			astcall = source.Call
  91  		case nil:
  92  			// TODO(dh): I am not sure this can actually happen. If it
  93  			// can't, we should remove this case, and also stop
  94  			// checking for astcall == nil in the code that follows.
  95  		default:
  96  			panic(fmt.Sprintf("unhandled case %T", source))
  97  		}
  98  
  99  		for idx, arg := range call.Args {
 100  			for _, e := range arg.invalids {
 101  				if astcall != nil {
 102  					if idx < len(astcall.Args) {
 103  						report.Report(pass, astcall.Args[idx], e)
 104  					} else {
 105  						// this is an instance of fn1(fn2()) where fn2
 106  						// returns multiple values. Report the error
 107  						// at the next-best position that we have, the
 108  						// first argument. An example of a check that
 109  						// triggers this is checkEncodingBinaryRules.
 110  						report.Report(pass, astcall.Args[0], e)
 111  					}
 112  				} else {
 113  					report.Report(pass, site, e)
 114  				}
 115  			}
 116  		}
 117  		for _, e := range call.invalids {
 118  			report.Report(pass, call.Instr, e)
 119  		}
 120  	}
 121  	for _, fn := range pass.ResultOf[buildir.Analyzer].(*buildir.IR).SrcFuncs {
 122  		eachCall(fn, cb)
 123  	}
 124  	return nil, nil
 125  }
 126  
 127  func eachCall(fn *ir.Function, cb func(caller *ir.Function, site ir.CallInstruction, callee *ir.Function)) {
 128  	for _, b := range fn.Blocks {
 129  		for _, instr := range b.Instrs {
 130  			if site, ok := instr.(ir.CallInstruction); ok {
 131  				if g := site.Common().StaticCallee(); g != nil {
 132  					cb(fn, site, g)
 133  				}
 134  			}
 135  		}
 136  	}
 137  }
 138  
 139  func ExtractConstExpectKind(v Value, kind constant.Kind) *ir.Const {
 140  	k := extractConst(v.Value)
 141  	if k == nil || k.Value == nil || k.Value.Kind() != kind {
 142  		return nil
 143  	}
 144  	return k
 145  }
 146  
 147  func ExtractConst(v Value) *ir.Const {
 148  	return extractConst(v.Value)
 149  }
 150  
 151  func extractConst(v ir.Value) *ir.Const {
 152  	v = irutil.Flatten(v)
 153  	switch v := v.(type) {
 154  	case *ir.Const:
 155  		return v
 156  	case *ir.MakeInterface:
 157  		return extractConst(v.X)
 158  	default:
 159  		return nil
 160  	}
 161  }
 162