subst.go raw

   1  // Copyright 2022 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package ssa
   6  
   7  import (
   8  	"go/types"
   9  
  10  	"golang.org/x/tools/go/types/typeutil"
  11  	"golang.org/x/tools/internal/aliases"
  12  )
  13  
  14  // subster defines a type substitution operation of a set of type parameters
  15  // to type parameter free replacement types. Substitution is done within
  16  // the context of a package-level function instantiation. *Named types
  17  // declared in the function are unique to the instantiation.
  18  //
  19  // For example, given a parameterized function F
  20  //
  21  //	  func F[S, T any]() any {
  22  //	    type X struct{ s S; next *X }
  23  //		var p *X
  24  //	    return p
  25  //	  }
  26  //
  27  // calling the instantiation F[string, int]() returns an interface
  28  // value (*X[string,int], nil) where the underlying value of
  29  // X[string,int] is a struct{s string; next *X[string,int]}.
  30  //
  31  // A nil *subster is a valid, empty substitution map. It always acts as
  32  // the identity function. This allows for treating parameterized and
  33  // non-parameterized functions identically while compiling to ssa.
  34  //
  35  // Not concurrency-safe.
  36  //
  37  // Note: Some may find it helpful to think through some of the most
  38  // complex substitution cases using lambda calculus inspired notation.
  39  // subst.typ() solves evaluating a type expression E
  40  // within the body of a function Fn[m] with the type parameters m
  41  // once we have applied the type arguments N.
  42  // We can succinctly write this as a function application:
  43  //
  44  //	((λm. E) N)
  45  //
  46  // go/types does not provide this interface directly.
  47  // So what subster provides is a type substitution operation
  48  //
  49  //	E[m:=N]
  50  type subster struct {
  51  	replacements map[*types.TypeParam]types.Type // values should contain no type params
  52  	cache        map[types.Type]types.Type       // cache of subst results
  53  	origin       *types.Func                     // types.Objects declared within this origin function are unique within this context
  54  	ctxt         *types.Context                  // speeds up repeated instantiations
  55  	uniqueness   typeutil.Map                    // determines the uniqueness of the instantiations within the function
  56  	// TODO(taking): consider adding Pos
  57  }
  58  
  59  // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
  60  // targs should not contain any types in tparams.
  61  // fn is the generic function for which we are substituting.
  62  func makeSubster(ctxt *types.Context, fn *types.Func, tparams *types.TypeParamList, targs []types.Type) *subster {
  63  	assert(tparams.Len() == len(targs), "makeSubster argument count must match")
  64  
  65  	subst := &subster{
  66  		replacements: make(map[*types.TypeParam]types.Type, tparams.Len()),
  67  		cache:        make(map[types.Type]types.Type),
  68  		origin:       fn.Origin(),
  69  		ctxt:         ctxt,
  70  	}
  71  	for i := 0; i < tparams.Len(); i++ {
  72  		subst.replacements[tparams.At(i)] = targs[i]
  73  	}
  74  	return subst
  75  }
  76  
  77  // typ returns the type of t with the type parameter tparams[i] substituted
  78  // for the type targs[i] where subst was created using tparams and targs.
  79  func (subst *subster) typ(t types.Type) (res types.Type) {
  80  	if subst == nil {
  81  		return t // A nil subst is type preserving.
  82  	}
  83  	if r, ok := subst.cache[t]; ok {
  84  		return r
  85  	}
  86  	defer func() {
  87  		subst.cache[t] = res
  88  	}()
  89  
  90  	switch t := t.(type) {
  91  	case *types.TypeParam:
  92  		if r := subst.replacements[t]; r != nil {
  93  			return r
  94  		}
  95  		return t
  96  
  97  	case *types.Basic:
  98  		return t
  99  
 100  	case *types.Array:
 101  		if r := subst.typ(t.Elem()); r != t.Elem() {
 102  			return types.NewArray(r, t.Len())
 103  		}
 104  		return t
 105  
 106  	case *types.Slice:
 107  		if r := subst.typ(t.Elem()); r != t.Elem() {
 108  			return types.NewSlice(r)
 109  		}
 110  		return t
 111  
 112  	case *types.Pointer:
 113  		if r := subst.typ(t.Elem()); r != t.Elem() {
 114  			return types.NewPointer(r)
 115  		}
 116  		return t
 117  
 118  	case *types.Tuple:
 119  		return subst.tuple(t)
 120  
 121  	case *types.Struct:
 122  		return subst.struct_(t)
 123  
 124  	case *types.Map:
 125  		key := subst.typ(t.Key())
 126  		elem := subst.typ(t.Elem())
 127  		if key != t.Key() || elem != t.Elem() {
 128  			return types.NewMap(key, elem)
 129  		}
 130  		return t
 131  
 132  	case *types.Chan:
 133  		if elem := subst.typ(t.Elem()); elem != t.Elem() {
 134  			return types.NewChan(t.Dir(), elem)
 135  		}
 136  		return t
 137  
 138  	case *types.Signature:
 139  		return subst.signature(t)
 140  
 141  	case *types.Union:
 142  		return subst.union(t)
 143  
 144  	case *types.Interface:
 145  		return subst.interface_(t)
 146  
 147  	case *types.Alias:
 148  		return subst.alias(t)
 149  
 150  	case *types.Named:
 151  		return subst.named(t)
 152  
 153  	case *opaqueType:
 154  		return t // opaque types are never substituted
 155  
 156  	default:
 157  		panic("unreachable")
 158  	}
 159  }
 160  
 161  // types returns the result of {subst.typ(ts[i])}.
 162  func (subst *subster) types(ts []types.Type) []types.Type {
 163  	res := make([]types.Type, len(ts))
 164  	for i := range ts {
 165  		res[i] = subst.typ(ts[i])
 166  	}
 167  	return res
 168  }
 169  
 170  func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
 171  	if t != nil {
 172  		if vars := subst.varlist(t); vars != nil {
 173  			return types.NewTuple(vars...)
 174  		}
 175  	}
 176  	return t
 177  }
 178  
 179  type varlist interface {
 180  	At(i int) *types.Var
 181  	Len() int
 182  }
 183  
 184  // fieldlist is an adapter for structs for the varlist interface.
 185  type fieldlist struct {
 186  	str *types.Struct
 187  }
 188  
 189  func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
 190  func (fl fieldlist) Len() int            { return fl.str.NumFields() }
 191  
 192  func (subst *subster) struct_(t *types.Struct) *types.Struct {
 193  	if t != nil {
 194  		if fields := subst.varlist(fieldlist{t}); fields != nil {
 195  			tags := make([]string, t.NumFields())
 196  			for i, n := 0, t.NumFields(); i < n; i++ {
 197  				tags[i] = t.Tag(i)
 198  			}
 199  			return types.NewStruct(fields, tags)
 200  		}
 201  	}
 202  	return t
 203  }
 204  
 205  // varlist returns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
 206  func (subst *subster) varlist(in varlist) []*types.Var {
 207  	var out []*types.Var // nil => no updates
 208  	for i, n := 0, in.Len(); i < n; i++ {
 209  		v := in.At(i)
 210  		w := subst.var_(v)
 211  		if v != w && out == nil {
 212  			out = make([]*types.Var, n)
 213  			for j := 0; j < i; j++ {
 214  				out[j] = in.At(j)
 215  			}
 216  		}
 217  		if out != nil {
 218  			out[i] = w
 219  		}
 220  	}
 221  	return out
 222  }
 223  
 224  func (subst *subster) var_(v *types.Var) *types.Var {
 225  	if v != nil {
 226  		if typ := subst.typ(v.Type()); typ != v.Type() {
 227  			if v.IsField() {
 228  				return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
 229  			}
 230  			return types.NewParam(v.Pos(), v.Pkg(), v.Name(), typ)
 231  		}
 232  	}
 233  	return v
 234  }
 235  
 236  func (subst *subster) union(u *types.Union) *types.Union {
 237  	var out []*types.Term // nil => no updates
 238  
 239  	for i, n := 0, u.Len(); i < n; i++ {
 240  		t := u.Term(i)
 241  		r := subst.typ(t.Type())
 242  		if r != t.Type() && out == nil {
 243  			out = make([]*types.Term, n)
 244  			for j := 0; j < i; j++ {
 245  				out[j] = u.Term(j)
 246  			}
 247  		}
 248  		if out != nil {
 249  			out[i] = types.NewTerm(t.Tilde(), r)
 250  		}
 251  	}
 252  
 253  	if out != nil {
 254  		return types.NewUnion(out)
 255  	}
 256  	return u
 257  }
 258  
 259  func (subst *subster) interface_(iface *types.Interface) *types.Interface {
 260  	if iface == nil {
 261  		return nil
 262  	}
 263  
 264  	// methods for the interface. Initially nil if there is no known change needed.
 265  	// Signatures for the method where recv is nil. NewInterfaceType fills in the receivers.
 266  	var methods []*types.Func
 267  	initMethods := func(n int) { // copy first n explicit methods
 268  		methods = make([]*types.Func, iface.NumExplicitMethods())
 269  		for i := range n {
 270  			f := iface.ExplicitMethod(i)
 271  			norecv := changeRecv(f.Type().(*types.Signature), nil)
 272  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
 273  		}
 274  	}
 275  	for i := 0; i < iface.NumExplicitMethods(); i++ {
 276  		f := iface.ExplicitMethod(i)
 277  		// On interfaces, we need to cycle break on anonymous interface types
 278  		// being in a cycle with their signatures being in cycles with their receivers
 279  		// that do not go through a Named.
 280  		norecv := changeRecv(f.Type().(*types.Signature), nil)
 281  		sig := subst.typ(norecv)
 282  		if sig != norecv && methods == nil {
 283  			initMethods(i)
 284  		}
 285  		if methods != nil {
 286  			methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
 287  		}
 288  	}
 289  
 290  	var embeds []types.Type
 291  	initEmbeds := func(n int) { // copy first n embedded types
 292  		embeds = make([]types.Type, iface.NumEmbeddeds())
 293  		for i := range n {
 294  			embeds[i] = iface.EmbeddedType(i)
 295  		}
 296  	}
 297  	for i := 0; i < iface.NumEmbeddeds(); i++ {
 298  		e := iface.EmbeddedType(i)
 299  		r := subst.typ(e)
 300  		if e != r && embeds == nil {
 301  			initEmbeds(i)
 302  		}
 303  		if embeds != nil {
 304  			embeds[i] = r
 305  		}
 306  	}
 307  
 308  	if methods == nil && embeds == nil {
 309  		return iface
 310  	}
 311  	if methods == nil {
 312  		initMethods(iface.NumExplicitMethods())
 313  	}
 314  	if embeds == nil {
 315  		initEmbeds(iface.NumEmbeddeds())
 316  	}
 317  	return types.NewInterfaceType(methods, embeds).Complete()
 318  }
 319  
 320  func (subst *subster) alias(t *types.Alias) types.Type {
 321  	// See subster.named. This follows the same strategy.
 322  	tparams := t.TypeParams()
 323  	targs := t.TypeArgs()
 324  	tname := t.Obj()
 325  	torigin := t.Origin()
 326  
 327  	if !declaredWithin(tname, subst.origin) {
 328  		// t is declared outside of the function origin. So t is a package level type alias.
 329  		if targs.Len() == 0 {
 330  			// No type arguments so no instantiation needed.
 331  			return t
 332  		}
 333  
 334  		// Instantiate with the substituted type arguments.
 335  		newTArgs := subst.typelist(targs)
 336  		return subst.instantiate(torigin, newTArgs)
 337  	}
 338  
 339  	if targs.Len() == 0 {
 340  		// t is declared within the function origin and has no type arguments.
 341  		//
 342  		// Example: This corresponds to A or B in F, but not A[int]:
 343  		//
 344  		//     func F[T any]() {
 345  		//       type A[S any] = struct{t T, s S}
 346  		//       type B = T
 347  		//       var x A[int]
 348  		//       ...
 349  		//     }
 350  		//
 351  		// This is somewhat different than *Named as *Alias cannot be created recursively.
 352  
 353  		// Copy and substitute type params.
 354  		var newTParams []*types.TypeParam
 355  		for cur := range tparams.TypeParams() {
 356  			cobj := cur.Obj()
 357  			cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
 358  			ntp := types.NewTypeParam(cname, nil)
 359  			subst.cache[cur] = ntp // See the comment "Note: Subtle" in subster.named.
 360  			newTParams = append(newTParams, ntp)
 361  		}
 362  
 363  		// Substitute rhs.
 364  		rhs := subst.typ(t.Rhs())
 365  
 366  		// Create the fresh alias.
 367  		obj := aliases.New(tname.Pos(), tname.Pkg(), tname.Name(), rhs, newTParams)
 368  
 369  		// Substitute into all of the constraints after they are created.
 370  		for i, ntp := range newTParams {
 371  			bound := tparams.At(i).Constraint()
 372  			ntp.SetConstraint(subst.typ(bound))
 373  		}
 374  		return obj.Type()
 375  	}
 376  
 377  	// t is declared within the function origin and has type arguments.
 378  	//
 379  	// Example: This corresponds to A[int] in F. Cases A and B are handled above.
 380  	//     func F[T any]() {
 381  	//       type A[S any] = struct{t T, s S}
 382  	//       type B = T
 383  	//       var x A[int]
 384  	//       ...
 385  	//     }
 386  	subOrigin := subst.typ(torigin)
 387  	subTArgs := subst.typelist(targs)
 388  	return subst.instantiate(subOrigin, subTArgs)
 389  }
 390  
 391  func (subst *subster) named(t *types.Named) types.Type {
 392  	// A Named type is a user defined type.
 393  	// Ignoring generics, Named types are canonical: they are identical if
 394  	// and only if they have the same defining symbol.
 395  	// Generics complicate things, both if the type definition itself is
 396  	// parameterized, and if the type is defined within the scope of a
 397  	// parameterized function. In this case, two named types are identical if
 398  	// and only if their identifying symbols are identical, and all type
 399  	// arguments bindings in scope of the named type definition (including the
 400  	// type parameters of the definition itself) are equivalent.
 401  	//
 402  	// Notably:
 403  	// 1. For type definition type T[P1 any] struct{}, T[A] and T[B] are identical
 404  	//    only if A and B are identical.
 405  	// 2. Inside the generic func Fn[m any]() any { type T struct{}; return T{} },
 406  	//    the result of Fn[A] and Fn[B] have identical type if and only if A and
 407  	//    B are identical.
 408  	// 3. Both 1 and 2 could apply, such as in
 409  	//    func F[m any]() any { type T[x any] struct{}; return T{} }
 410  	//
 411  	// A subster replaces type parameters within a function scope, and therefore must
 412  	// also replace free type parameters in the definitions of local types.
 413  	//
 414  	// Note: There are some detailed notes sprinkled throughout that borrow from
 415  	// lambda calculus notation. These contain some over simplifying math.
 416  	//
 417  	// LC: One way to think about subster is that it is  a way of evaluating
 418  	//   ((λm. E) N) as E[m:=N].
 419  	// Each Named type t has an object *TypeName within a scope S that binds an
 420  	// underlying type expression U. U can refer to symbols within S (+ S's ancestors).
 421  	// Let x = t.TypeParams() and A = t.TypeArgs().
 422  	// Each Named type t is then either:
 423  	//   U              where len(x) == 0 && len(A) == 0
 424  	//   λx. U          where len(x) != 0 && len(A) == 0
 425  	//   ((λx. U) A)    where len(x) == len(A)
 426  	// In each case, we will evaluate t[m:=N].
 427  	tparams := t.TypeParams() // x
 428  	targs := t.TypeArgs()     // A
 429  
 430  	if !declaredWithin(t.Obj(), subst.origin) {
 431  		// t is declared outside of Fn[m].
 432  		//
 433  		// In this case, we can skip substituting t.Underlying().
 434  		// The underlying type cannot refer to the type parameters.
 435  		//
 436  		// LC: Let free(E) be the set of free type parameters in an expression E.
 437  		// Then whenever m ∉ free(E), then E = E[m:=N].
 438  		// t ∉ Scope(fn) so therefore m ∉ free(U) and m ∩ x = ∅.
 439  		if targs.Len() == 0 {
 440  			// t has no type arguments. So it does not need to be instantiated.
 441  			//
 442  			// This is the normal case in real Go code, where t is not parameterized,
 443  			// declared at some package scope, and m is a TypeParam from a parameterized
 444  			// function F[m] or method.
 445  			//
 446  			// LC: m ∉ free(A) lets us conclude m ∉ free(t). So t=t[m:=N].
 447  			return t
 448  		}
 449  
 450  		// t is declared outside of Fn[m] and has type arguments.
 451  		// The type arguments may contain type parameters m so
 452  		// substitute the type arguments, and instantiate the substituted
 453  		// type arguments.
 454  		//
 455  		// LC: Evaluate this as ((λx. U) A') where A' = A[m := N].
 456  		newTArgs := subst.typelist(targs)
 457  		return subst.instantiate(t.Origin(), newTArgs)
 458  	}
 459  
 460  	// t is declared within Fn[m].
 461  
 462  	if targs.Len() == 0 { // no type arguments?
 463  		assert(t == t.Origin(), "local parameterized type abstraction must be an origin type")
 464  
 465  		// t has no type arguments.
 466  		// The underlying type of t may contain the function's type parameters,
 467  		// replace these, and create a new type.
 468  		//
 469  		// Subtle: We short circuit substitution and use a newly created type in
 470  		// subst, i.e. cache[t]=fresh, to preemptively replace t with fresh
 471  		// in recursive types during traversal. This both breaks infinite cycles
 472  		// and allows for constructing types with the replacement applied in
 473  		// subst.typ(U).
 474  		//
 475  		// A new copy of the Named and Typename (and constraints) per function
 476  		// instantiation matches the semantics of Go, which treats all function
 477  		// instantiations F[N] as having distinct local types.
 478  		//
 479  		// LC: x.Len()=0 can be thought of as a special case of λx. U.
 480  		// LC: Evaluate (λx. U)[m:=N] as (λx'. U') where U'=U[x:=x',m:=N].
 481  		tname := t.Obj()
 482  		obj := types.NewTypeName(tname.Pos(), tname.Pkg(), tname.Name(), nil)
 483  		fresh := types.NewNamed(obj, nil, nil)
 484  		var newTParams []*types.TypeParam
 485  		for cur := range tparams.TypeParams() {
 486  			cobj := cur.Obj()
 487  			cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
 488  			ntp := types.NewTypeParam(cname, nil)
 489  			subst.cache[cur] = ntp
 490  			newTParams = append(newTParams, ntp)
 491  		}
 492  		fresh.SetTypeParams(newTParams)
 493  		subst.cache[t] = fresh
 494  		subst.cache[fresh] = fresh
 495  		fresh.SetUnderlying(subst.typ(t.Underlying()))
 496  		// Substitute into all of the constraints after they are created.
 497  		for i, ntp := range newTParams {
 498  			bound := tparams.At(i).Constraint()
 499  			ntp.SetConstraint(subst.typ(bound))
 500  		}
 501  		return fresh
 502  	}
 503  
 504  	// t is defined within Fn[m] and t has type arguments (an instantiation).
 505  	// We reduce this to the two cases above:
 506  	// (1) substitute the function's type parameters into t.Origin().
 507  	// (2) substitute t's type arguments A and instantiate the updated t.Origin() with these.
 508  	//
 509  	// LC: Evaluate ((λx. U) A)[m:=N] as (t' A') where t' = (λx. U)[m:=N] and A'=A [m:=N]
 510  	subOrigin := subst.typ(t.Origin())
 511  	subTArgs := subst.typelist(targs)
 512  	return subst.instantiate(subOrigin, subTArgs)
 513  }
 514  
 515  func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type {
 516  	i, err := types.Instantiate(subst.ctxt, orig, targs, false)
 517  	assert(err == nil, "failed to Instantiate named (Named or Alias) type")
 518  	if c, _ := subst.uniqueness.At(i).(types.Type); c != nil {
 519  		return c.(types.Type)
 520  	}
 521  	subst.uniqueness.Set(i, i)
 522  	return i
 523  }
 524  
 525  func (subst *subster) typelist(l *types.TypeList) []types.Type {
 526  	res := make([]types.Type, l.Len())
 527  	for i := 0; i < l.Len(); i++ {
 528  		res[i] = subst.typ(l.At(i))
 529  	}
 530  	return res
 531  }
 532  
 533  func (subst *subster) signature(t *types.Signature) types.Type {
 534  	tparams := t.TypeParams()
 535  
 536  	// We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
 537  	//
 538  	// There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
 539  	// To support tparams.Len() > 0, we just need to do the following [pseudocode]:
 540  	//   targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
 541  
 542  	assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
 543  
 544  	// Either:
 545  	// (1)non-generic function.
 546  	//    no type params to substitute
 547  	// (2)generic method and recv needs to be substituted.
 548  
 549  	// Receivers can be either:
 550  	// named
 551  	// pointer to named
 552  	// interface
 553  	// nil
 554  	// interface is the problematic case. We need to cycle break there!
 555  	recv := subst.var_(t.Recv())
 556  	params := subst.tuple(t.Params())
 557  	results := subst.tuple(t.Results())
 558  	if recv != t.Recv() || params != t.Params() || results != t.Results() {
 559  		return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
 560  	}
 561  	return t
 562  }
 563