purity.go raw

   1  package purity
   2  
   3  // TODO(dh): we should split this into two facts, one tracking actual purity, and one tracking side-effects. A function
   4  // that returns a heap allocation isn't pure, but it may be free of side effects.
   5  
   6  import (
   7  	"go/types"
   8  	"reflect"
   9  
  10  	"honnef.co/go/tools/go/ir"
  11  	"honnef.co/go/tools/go/ir/irutil"
  12  	"honnef.co/go/tools/internal/passes/buildir"
  13  
  14  	"golang.org/x/tools/go/analysis"
  15  )
  16  
  17  type IsPure struct{}
  18  
  19  func (*IsPure) AFact()           {}
  20  func (d *IsPure) String() string { return "is pure" }
  21  
  22  type Result map[*types.Func]*IsPure
  23  
  24  var Analyzer = &analysis.Analyzer{
  25  	Name:       "fact_purity",
  26  	Doc:        "Mark pure functions",
  27  	Run:        purity,
  28  	Requires:   []*analysis.Analyzer{buildir.Analyzer},
  29  	FactTypes:  []analysis.Fact{(*IsPure)(nil)},
  30  	ResultType: reflect.TypeOf(Result{}),
  31  }
  32  
  33  var pureStdlib = map[string]struct{}{
  34  	"errors.New":                      {},
  35  	"fmt.Errorf":                      {},
  36  	"fmt.Sprintf":                     {},
  37  	"fmt.Sprint":                      {},
  38  	"sort.Reverse":                    {},
  39  	"strings.Map":                     {},
  40  	"strings.Repeat":                  {},
  41  	"strings.Replace":                 {},
  42  	"strings.Title":                   {},
  43  	"strings.ToLower":                 {},
  44  	"strings.ToLowerSpecial":          {},
  45  	"strings.ToTitle":                 {},
  46  	"strings.ToTitleSpecial":          {},
  47  	"strings.ToUpper":                 {},
  48  	"strings.ToUpperSpecial":          {},
  49  	"strings.Trim":                    {},
  50  	"strings.TrimFunc":                {},
  51  	"strings.TrimLeft":                {},
  52  	"strings.TrimLeftFunc":            {},
  53  	"strings.TrimPrefix":              {},
  54  	"strings.TrimRight":               {},
  55  	"strings.TrimRightFunc":           {},
  56  	"strings.TrimSpace":               {},
  57  	"strings.TrimSuffix":              {},
  58  	"(*net/http.Request).WithContext": {},
  59  	"time.Now":                        {},
  60  	"time.Parse":                      {},
  61  	"time.ParseInLocation":            {},
  62  	"time.Unix":                       {},
  63  	"time.UnixMicro":                  {},
  64  	"time.UnixMilli":                  {},
  65  	"(time.Time).Add":                 {},
  66  	"(time.Time).AddDate":             {},
  67  	"(time.Time).After":               {},
  68  	"(time.Time).Before":              {},
  69  	"(time.Time).Clock":               {},
  70  	"(time.Time).Compare":             {},
  71  	"(time.Time).Date":                {},
  72  	"(time.Time).Day":                 {},
  73  	"(time.Time).Equal":               {},
  74  	"(time.Time).Format":              {},
  75  	"(time.Time).GoString":            {},
  76  	"(time.Time).GobEncode":           {},
  77  	"(time.Time).Hour":                {},
  78  	"(time.Time).ISOWeek":             {},
  79  	"(time.Time).In":                  {},
  80  	"(time.Time).IsDST":               {},
  81  	"(time.Time).IsZero":              {},
  82  	"(time.Time).Local":               {},
  83  	"(time.Time).Location":            {},
  84  	"(time.Time).MarshalBinary":       {},
  85  	"(time.Time).MarshalJSON":         {},
  86  	"(time.Time).MarshalText":         {},
  87  	"(time.Time).Minute":              {},
  88  	"(time.Time).Month":               {},
  89  	"(time.Time).Nanosecond":          {},
  90  	"(time.Time).Round":               {},
  91  	"(time.Time).Second":              {},
  92  	"(time.Time).String":              {},
  93  	"(time.Time).Sub":                 {},
  94  	"(time.Time).Truncate":            {},
  95  	"(time.Time).UTC":                 {},
  96  	"(time.Time).Unix":                {},
  97  	"(time.Time).UnixMicro":           {},
  98  	"(time.Time).UnixMilli":           {},
  99  	"(time.Time).UnixNano":            {},
 100  	"(time.Time).Weekday":             {},
 101  	"(time.Time).Year":                {},
 102  	"(time.Time).YearDay":             {},
 103  	"(time.Time).Zone":                {},
 104  	"(time.Time).ZoneBounds":          {},
 105  }
 106  
 107  func purity(pass *analysis.Pass) (interface{}, error) {
 108  	seen := map[*ir.Function]struct{}{}
 109  	irpkg := pass.ResultOf[buildir.Analyzer].(*buildir.IR).Pkg
 110  	var check func(fn *ir.Function) (ret bool)
 111  	check = func(fn *ir.Function) (ret bool) {
 112  		if fn.Object() == nil {
 113  			// TODO(dh): support closures
 114  			return false
 115  		}
 116  		if pass.ImportObjectFact(fn.Object(), new(IsPure)) {
 117  			return true
 118  		}
 119  		if fn.Pkg != irpkg {
 120  			// Function is in another package but wasn't marked as
 121  			// pure, ergo it isn't pure
 122  			return false
 123  		}
 124  		// Break recursion
 125  		if _, ok := seen[fn]; ok {
 126  			return false
 127  		}
 128  
 129  		seen[fn] = struct{}{}
 130  		defer func() {
 131  			if ret {
 132  				pass.ExportObjectFact(fn.Object(), &IsPure{})
 133  			}
 134  		}()
 135  
 136  		if irutil.IsStub(fn) {
 137  			return false
 138  		}
 139  
 140  		if _, ok := pureStdlib[fn.Object().(*types.Func).FullName()]; ok {
 141  			return true
 142  		}
 143  
 144  		if fn.Signature.Results().Len() == 0 {
 145  			// A function with no return values is empty or is doing some
 146  			// work we cannot see (for example because of build tags);
 147  			// don't consider it pure.
 148  			return false
 149  		}
 150  
 151  		var isBasic func(typ types.Type) bool
 152  		isBasic = func(typ types.Type) bool {
 153  			switch u := typ.Underlying().(type) {
 154  			case *types.Basic:
 155  				return true
 156  			case *types.Struct:
 157  				for i := 0; i < u.NumFields(); i++ {
 158  					if !isBasic(u.Field(i).Type()) {
 159  						return false
 160  					}
 161  				}
 162  				return true
 163  			default:
 164  				return false
 165  			}
 166  		}
 167  
 168  		for _, param := range fn.Params {
 169  			// TODO(dh): this may not be strictly correct. pure code can, to an extent, operate on non-basic types.
 170  			if !isBasic(param.Type()) {
 171  				return false
 172  			}
 173  		}
 174  
 175  		// Don't consider external functions pure.
 176  		if fn.Blocks == nil {
 177  			return false
 178  		}
 179  		checkCall := func(common *ir.CallCommon) bool {
 180  			if common.IsInvoke() {
 181  				return false
 182  			}
 183  			builtin, ok := common.Value.(*ir.Builtin)
 184  			if !ok {
 185  				if common.StaticCallee() != fn {
 186  					if common.StaticCallee() == nil {
 187  						return false
 188  					}
 189  					if !check(common.StaticCallee()) {
 190  						return false
 191  					}
 192  				}
 193  			} else {
 194  				switch builtin.Name() {
 195  				case "len", "cap":
 196  				default:
 197  					return false
 198  				}
 199  			}
 200  			return true
 201  		}
 202  
 203  		var isStackAddr func(ir.Value) bool
 204  		isStackAddr = func(v ir.Value) bool {
 205  			switch v := v.(type) {
 206  			case *ir.Alloc:
 207  				return !v.Heap
 208  			case *ir.FieldAddr:
 209  				return isStackAddr(v.X)
 210  			default:
 211  				return false
 212  			}
 213  		}
 214  		for _, b := range fn.Blocks {
 215  			for _, ins := range b.Instrs {
 216  				switch ins := ins.(type) {
 217  				case *ir.Call:
 218  					if !checkCall(ins.Common()) {
 219  						return false
 220  					}
 221  				case *ir.Defer:
 222  					if !checkCall(&ins.Call) {
 223  						return false
 224  					}
 225  				case *ir.Select:
 226  					return false
 227  				case *ir.Send:
 228  					return false
 229  				case *ir.Go:
 230  					return false
 231  				case *ir.Panic:
 232  					return false
 233  				case *ir.Store:
 234  					if !isStackAddr(ins.Addr) {
 235  						return false
 236  					}
 237  				case *ir.FieldAddr:
 238  					if !isStackAddr(ins.X) {
 239  						return false
 240  					}
 241  				case *ir.Alloc:
 242  					// TODO(dh): make use of proper escape analysis
 243  					if ins.Heap {
 244  						return false
 245  					}
 246  				case *ir.Load:
 247  					if !isStackAddr(ins.X) {
 248  						return false
 249  					}
 250  				}
 251  			}
 252  		}
 253  		return true
 254  	}
 255  	for _, fn := range pass.ResultOf[buildir.Analyzer].(*buildir.IR).SrcFuncs {
 256  		check(fn)
 257  	}
 258  
 259  	out := Result{}
 260  	for _, fact := range pass.AllObjectFacts() {
 261  		out[fact.Object.(*types.Func)] = fact.Fact.(*IsPure)
 262  	}
 263  	return out, nil
 264  }
 265