1 // Copyright 2022 The Go Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style
3 // license that can be found in the LICENSE file.
4 5 package ssa
6 7 import (
8 "go/types"
9 10 "golang.org/x/tools/go/types/typeutil"
11 "golang.org/x/tools/internal/aliases"
12 )
13 14 // subster defines a type substitution operation of a set of type parameters
15 // to type parameter free replacement types. Substitution is done within
16 // the context of a package-level function instantiation. *Named types
17 // declared in the function are unique to the instantiation.
18 //
19 // For example, given a parameterized function F
20 //
21 // func F[S, T any]() any {
22 // type X struct{ s S; next *X }
23 // var p *X
24 // return p
25 // }
26 //
27 // calling the instantiation F[string, int]() returns an interface
28 // value (*X[string,int], nil) where the underlying value of
29 // X[string,int] is a struct{s string; next *X[string,int]}.
30 //
31 // A nil *subster is a valid, empty substitution map. It always acts as
32 // the identity function. This allows for treating parameterized and
33 // non-parameterized functions identically while compiling to ssa.
34 //
35 // Not concurrency-safe.
36 //
37 // Note: Some may find it helpful to think through some of the most
38 // complex substitution cases using lambda calculus inspired notation.
39 // subst.typ() solves evaluating a type expression E
40 // within the body of a function Fn[m] with the type parameters m
41 // once we have applied the type arguments N.
42 // We can succinctly write this as a function application:
43 //
44 // ((λm. E) N)
45 //
46 // go/types does not provide this interface directly.
47 // So what subster provides is a type substitution operation
48 //
49 // E[m:=N]
50 type subster struct {
51 replacements map[*types.TypeParam]types.Type // values should contain no type params
52 cache map[types.Type]types.Type // cache of subst results
53 origin *types.Func // types.Objects declared within this origin function are unique within this context
54 ctxt *types.Context // speeds up repeated instantiations
55 uniqueness typeutil.Map // determines the uniqueness of the instantiations within the function
56 // TODO(taking): consider adding Pos
57 }
58 59 // Returns a subster that replaces tparams[i] with targs[i]. Uses ctxt as a cache.
60 // targs should not contain any types in tparams.
61 // fn is the generic function for which we are substituting.
62 func makeSubster(ctxt *types.Context, fn *types.Func, tparams *types.TypeParamList, targs []types.Type) *subster {
63 assert(tparams.Len() == len(targs), "makeSubster argument count must match")
64 65 subst := &subster{
66 replacements: make(map[*types.TypeParam]types.Type, tparams.Len()),
67 cache: make(map[types.Type]types.Type),
68 origin: fn.Origin(),
69 ctxt: ctxt,
70 }
71 for i := 0; i < tparams.Len(); i++ {
72 subst.replacements[tparams.At(i)] = targs[i]
73 }
74 return subst
75 }
76 77 // typ returns the type of t with the type parameter tparams[i] substituted
78 // for the type targs[i] where subst was created using tparams and targs.
79 func (subst *subster) typ(t types.Type) (res types.Type) {
80 if subst == nil {
81 return t // A nil subst is type preserving.
82 }
83 if r, ok := subst.cache[t]; ok {
84 return r
85 }
86 defer func() {
87 subst.cache[t] = res
88 }()
89 90 switch t := t.(type) {
91 case *types.TypeParam:
92 if r := subst.replacements[t]; r != nil {
93 return r
94 }
95 return t
96 97 case *types.Basic:
98 return t
99 100 case *types.Array:
101 if r := subst.typ(t.Elem()); r != t.Elem() {
102 return types.NewArray(r, t.Len())
103 }
104 return t
105 106 case *types.Slice:
107 if r := subst.typ(t.Elem()); r != t.Elem() {
108 return types.NewSlice(r)
109 }
110 return t
111 112 case *types.Pointer:
113 if r := subst.typ(t.Elem()); r != t.Elem() {
114 return types.NewPointer(r)
115 }
116 return t
117 118 case *types.Tuple:
119 return subst.tuple(t)
120 121 case *types.Struct:
122 return subst.struct_(t)
123 124 case *types.Map:
125 key := subst.typ(t.Key())
126 elem := subst.typ(t.Elem())
127 if key != t.Key() || elem != t.Elem() {
128 return types.NewMap(key, elem)
129 }
130 return t
131 132 case *types.Chan:
133 if elem := subst.typ(t.Elem()); elem != t.Elem() {
134 return types.NewChan(t.Dir(), elem)
135 }
136 return t
137 138 case *types.Signature:
139 return subst.signature(t)
140 141 case *types.Union:
142 return subst.union(t)
143 144 case *types.Interface:
145 return subst.interface_(t)
146 147 case *types.Alias:
148 return subst.alias(t)
149 150 case *types.Named:
151 return subst.named(t)
152 153 case *opaqueType:
154 return t // opaque types are never substituted
155 156 default:
157 panic("unreachable")
158 }
159 }
160 161 // types returns the result of {subst.typ(ts[i])}.
162 func (subst *subster) types(ts []types.Type) []types.Type {
163 res := make([]types.Type, len(ts))
164 for i := range ts {
165 res[i] = subst.typ(ts[i])
166 }
167 return res
168 }
169 170 func (subst *subster) tuple(t *types.Tuple) *types.Tuple {
171 if t != nil {
172 if vars := subst.varlist(t); vars != nil {
173 return types.NewTuple(vars...)
174 }
175 }
176 return t
177 }
178 179 type varlist interface {
180 At(i int) *types.Var
181 Len() int
182 }
183 184 // fieldlist is an adapter for structs for the varlist interface.
185 type fieldlist struct {
186 str *types.Struct
187 }
188 189 func (fl fieldlist) At(i int) *types.Var { return fl.str.Field(i) }
190 func (fl fieldlist) Len() int { return fl.str.NumFields() }
191 192 func (subst *subster) struct_(t *types.Struct) *types.Struct {
193 if t != nil {
194 if fields := subst.varlist(fieldlist{t}); fields != nil {
195 tags := make([]string, t.NumFields())
196 for i, n := 0, t.NumFields(); i < n; i++ {
197 tags[i] = t.Tag(i)
198 }
199 return types.NewStruct(fields, tags)
200 }
201 }
202 return t
203 }
204 205 // varlist returns subst(in[i]) or return nils if subst(v[i]) == v[i] for all i.
206 func (subst *subster) varlist(in varlist) []*types.Var {
207 var out []*types.Var // nil => no updates
208 for i, n := 0, in.Len(); i < n; i++ {
209 v := in.At(i)
210 w := subst.var_(v)
211 if v != w && out == nil {
212 out = make([]*types.Var, n)
213 for j := 0; j < i; j++ {
214 out[j] = in.At(j)
215 }
216 }
217 if out != nil {
218 out[i] = w
219 }
220 }
221 return out
222 }
223 224 func (subst *subster) var_(v *types.Var) *types.Var {
225 if v != nil {
226 if typ := subst.typ(v.Type()); typ != v.Type() {
227 if v.IsField() {
228 return types.NewField(v.Pos(), v.Pkg(), v.Name(), typ, v.Embedded())
229 }
230 return types.NewParam(v.Pos(), v.Pkg(), v.Name(), typ)
231 }
232 }
233 return v
234 }
235 236 func (subst *subster) union(u *types.Union) *types.Union {
237 var out []*types.Term // nil => no updates
238 239 for i, n := 0, u.Len(); i < n; i++ {
240 t := u.Term(i)
241 r := subst.typ(t.Type())
242 if r != t.Type() && out == nil {
243 out = make([]*types.Term, n)
244 for j := 0; j < i; j++ {
245 out[j] = u.Term(j)
246 }
247 }
248 if out != nil {
249 out[i] = types.NewTerm(t.Tilde(), r)
250 }
251 }
252 253 if out != nil {
254 return types.NewUnion(out)
255 }
256 return u
257 }
258 259 func (subst *subster) interface_(iface *types.Interface) *types.Interface {
260 if iface == nil {
261 return nil
262 }
263 264 // methods for the interface. Initially nil if there is no known change needed.
265 // Signatures for the method where recv is nil. NewInterfaceType fills in the receivers.
266 var methods []*types.Func
267 initMethods := func(n int) { // copy first n explicit methods
268 methods = make([]*types.Func, iface.NumExplicitMethods())
269 for i := range n {
270 f := iface.ExplicitMethod(i)
271 norecv := changeRecv(f.Type().(*types.Signature), nil)
272 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), norecv)
273 }
274 }
275 for i := 0; i < iface.NumExplicitMethods(); i++ {
276 f := iface.ExplicitMethod(i)
277 // On interfaces, we need to cycle break on anonymous interface types
278 // being in a cycle with their signatures being in cycles with their receivers
279 // that do not go through a Named.
280 norecv := changeRecv(f.Type().(*types.Signature), nil)
281 sig := subst.typ(norecv)
282 if sig != norecv && methods == nil {
283 initMethods(i)
284 }
285 if methods != nil {
286 methods[i] = types.NewFunc(f.Pos(), f.Pkg(), f.Name(), sig.(*types.Signature))
287 }
288 }
289 290 var embeds []types.Type
291 initEmbeds := func(n int) { // copy first n embedded types
292 embeds = make([]types.Type, iface.NumEmbeddeds())
293 for i := range n {
294 embeds[i] = iface.EmbeddedType(i)
295 }
296 }
297 for i := 0; i < iface.NumEmbeddeds(); i++ {
298 e := iface.EmbeddedType(i)
299 r := subst.typ(e)
300 if e != r && embeds == nil {
301 initEmbeds(i)
302 }
303 if embeds != nil {
304 embeds[i] = r
305 }
306 }
307 308 if methods == nil && embeds == nil {
309 return iface
310 }
311 if methods == nil {
312 initMethods(iface.NumExplicitMethods())
313 }
314 if embeds == nil {
315 initEmbeds(iface.NumEmbeddeds())
316 }
317 return types.NewInterfaceType(methods, embeds).Complete()
318 }
319 320 func (subst *subster) alias(t *types.Alias) types.Type {
321 // See subster.named. This follows the same strategy.
322 tparams := t.TypeParams()
323 targs := t.TypeArgs()
324 tname := t.Obj()
325 torigin := t.Origin()
326 327 if !declaredWithin(tname, subst.origin) {
328 // t is declared outside of the function origin. So t is a package level type alias.
329 if targs.Len() == 0 {
330 // No type arguments so no instantiation needed.
331 return t
332 }
333 334 // Instantiate with the substituted type arguments.
335 newTArgs := subst.typelist(targs)
336 return subst.instantiate(torigin, newTArgs)
337 }
338 339 if targs.Len() == 0 {
340 // t is declared within the function origin and has no type arguments.
341 //
342 // Example: This corresponds to A or B in F, but not A[int]:
343 //
344 // func F[T any]() {
345 // type A[S any] = struct{t T, s S}
346 // type B = T
347 // var x A[int]
348 // ...
349 // }
350 //
351 // This is somewhat different than *Named as *Alias cannot be created recursively.
352 353 // Copy and substitute type params.
354 var newTParams []*types.TypeParam
355 for cur := range tparams.TypeParams() {
356 cobj := cur.Obj()
357 cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
358 ntp := types.NewTypeParam(cname, nil)
359 subst.cache[cur] = ntp // See the comment "Note: Subtle" in subster.named.
360 newTParams = append(newTParams, ntp)
361 }
362 363 // Substitute rhs.
364 rhs := subst.typ(t.Rhs())
365 366 // Create the fresh alias.
367 obj := aliases.New(tname.Pos(), tname.Pkg(), tname.Name(), rhs, newTParams)
368 369 // Substitute into all of the constraints after they are created.
370 for i, ntp := range newTParams {
371 bound := tparams.At(i).Constraint()
372 ntp.SetConstraint(subst.typ(bound))
373 }
374 return obj.Type()
375 }
376 377 // t is declared within the function origin and has type arguments.
378 //
379 // Example: This corresponds to A[int] in F. Cases A and B are handled above.
380 // func F[T any]() {
381 // type A[S any] = struct{t T, s S}
382 // type B = T
383 // var x A[int]
384 // ...
385 // }
386 subOrigin := subst.typ(torigin)
387 subTArgs := subst.typelist(targs)
388 return subst.instantiate(subOrigin, subTArgs)
389 }
390 391 func (subst *subster) named(t *types.Named) types.Type {
392 // A Named type is a user defined type.
393 // Ignoring generics, Named types are canonical: they are identical if
394 // and only if they have the same defining symbol.
395 // Generics complicate things, both if the type definition itself is
396 // parameterized, and if the type is defined within the scope of a
397 // parameterized function. In this case, two named types are identical if
398 // and only if their identifying symbols are identical, and all type
399 // arguments bindings in scope of the named type definition (including the
400 // type parameters of the definition itself) are equivalent.
401 //
402 // Notably:
403 // 1. For type definition type T[P1 any] struct{}, T[A] and T[B] are identical
404 // only if A and B are identical.
405 // 2. Inside the generic func Fn[m any]() any { type T struct{}; return T{} },
406 // the result of Fn[A] and Fn[B] have identical type if and only if A and
407 // B are identical.
408 // 3. Both 1 and 2 could apply, such as in
409 // func F[m any]() any { type T[x any] struct{}; return T{} }
410 //
411 // A subster replaces type parameters within a function scope, and therefore must
412 // also replace free type parameters in the definitions of local types.
413 //
414 // Note: There are some detailed notes sprinkled throughout that borrow from
415 // lambda calculus notation. These contain some over simplifying math.
416 //
417 // LC: One way to think about subster is that it is a way of evaluating
418 // ((λm. E) N) as E[m:=N].
419 // Each Named type t has an object *TypeName within a scope S that binds an
420 // underlying type expression U. U can refer to symbols within S (+ S's ancestors).
421 // Let x = t.TypeParams() and A = t.TypeArgs().
422 // Each Named type t is then either:
423 // U where len(x) == 0 && len(A) == 0
424 // λx. U where len(x) != 0 && len(A) == 0
425 // ((λx. U) A) where len(x) == len(A)
426 // In each case, we will evaluate t[m:=N].
427 tparams := t.TypeParams() // x
428 targs := t.TypeArgs() // A
429 430 if !declaredWithin(t.Obj(), subst.origin) {
431 // t is declared outside of Fn[m].
432 //
433 // In this case, we can skip substituting t.Underlying().
434 // The underlying type cannot refer to the type parameters.
435 //
436 // LC: Let free(E) be the set of free type parameters in an expression E.
437 // Then whenever m ∉ free(E), then E = E[m:=N].
438 // t ∉ Scope(fn) so therefore m ∉ free(U) and m ∩ x = ∅.
439 if targs.Len() == 0 {
440 // t has no type arguments. So it does not need to be instantiated.
441 //
442 // This is the normal case in real Go code, where t is not parameterized,
443 // declared at some package scope, and m is a TypeParam from a parameterized
444 // function F[m] or method.
445 //
446 // LC: m ∉ free(A) lets us conclude m ∉ free(t). So t=t[m:=N].
447 return t
448 }
449 450 // t is declared outside of Fn[m] and has type arguments.
451 // The type arguments may contain type parameters m so
452 // substitute the type arguments, and instantiate the substituted
453 // type arguments.
454 //
455 // LC: Evaluate this as ((λx. U) A') where A' = A[m := N].
456 newTArgs := subst.typelist(targs)
457 return subst.instantiate(t.Origin(), newTArgs)
458 }
459 460 // t is declared within Fn[m].
461 462 if targs.Len() == 0 { // no type arguments?
463 assert(t == t.Origin(), "local parameterized type abstraction must be an origin type")
464 465 // t has no type arguments.
466 // The underlying type of t may contain the function's type parameters,
467 // replace these, and create a new type.
468 //
469 // Subtle: We short circuit substitution and use a newly created type in
470 // subst, i.e. cache[t]=fresh, to preemptively replace t with fresh
471 // in recursive types during traversal. This both breaks infinite cycles
472 // and allows for constructing types with the replacement applied in
473 // subst.typ(U).
474 //
475 // A new copy of the Named and Typename (and constraints) per function
476 // instantiation matches the semantics of Go, which treats all function
477 // instantiations F[N] as having distinct local types.
478 //
479 // LC: x.Len()=0 can be thought of as a special case of λx. U.
480 // LC: Evaluate (λx. U)[m:=N] as (λx'. U') where U'=U[x:=x',m:=N].
481 tname := t.Obj()
482 obj := types.NewTypeName(tname.Pos(), tname.Pkg(), tname.Name(), nil)
483 fresh := types.NewNamed(obj, nil, nil)
484 var newTParams []*types.TypeParam
485 for cur := range tparams.TypeParams() {
486 cobj := cur.Obj()
487 cname := types.NewTypeName(cobj.Pos(), cobj.Pkg(), cobj.Name(), nil)
488 ntp := types.NewTypeParam(cname, nil)
489 subst.cache[cur] = ntp
490 newTParams = append(newTParams, ntp)
491 }
492 fresh.SetTypeParams(newTParams)
493 subst.cache[t] = fresh
494 subst.cache[fresh] = fresh
495 fresh.SetUnderlying(subst.typ(t.Underlying()))
496 // Substitute into all of the constraints after they are created.
497 for i, ntp := range newTParams {
498 bound := tparams.At(i).Constraint()
499 ntp.SetConstraint(subst.typ(bound))
500 }
501 return fresh
502 }
503 504 // t is defined within Fn[m] and t has type arguments (an instantiation).
505 // We reduce this to the two cases above:
506 // (1) substitute the function's type parameters into t.Origin().
507 // (2) substitute t's type arguments A and instantiate the updated t.Origin() with these.
508 //
509 // LC: Evaluate ((λx. U) A)[m:=N] as (t' A') where t' = (λx. U)[m:=N] and A'=A [m:=N]
510 subOrigin := subst.typ(t.Origin())
511 subTArgs := subst.typelist(targs)
512 return subst.instantiate(subOrigin, subTArgs)
513 }
514 515 func (subst *subster) instantiate(orig types.Type, targs []types.Type) types.Type {
516 i, err := types.Instantiate(subst.ctxt, orig, targs, false)
517 assert(err == nil, "failed to Instantiate named (Named or Alias) type")
518 if c, _ := subst.uniqueness.At(i).(types.Type); c != nil {
519 return c.(types.Type)
520 }
521 subst.uniqueness.Set(i, i)
522 return i
523 }
524 525 func (subst *subster) typelist(l *types.TypeList) []types.Type {
526 res := make([]types.Type, l.Len())
527 for i := 0; i < l.Len(); i++ {
528 res[i] = subst.typ(l.At(i))
529 }
530 return res
531 }
532 533 func (subst *subster) signature(t *types.Signature) types.Type {
534 tparams := t.TypeParams()
535 536 // We are choosing not to support tparams.Len() > 0 until a need has been observed in practice.
537 //
538 // There are some known usages for types.Types coming from types.{Eval,CheckExpr}.
539 // To support tparams.Len() > 0, we just need to do the following [pseudocode]:
540 // targs := {subst.replacements[tparams[i]]]}; Instantiate(ctxt, t, targs, false)
541 542 assert(tparams.Len() == 0, "Substituting types.Signatures with generic functions are currently unsupported.")
543 544 // Either:
545 // (1)non-generic function.
546 // no type params to substitute
547 // (2)generic method and recv needs to be substituted.
548 549 // Receivers can be either:
550 // named
551 // pointer to named
552 // interface
553 // nil
554 // interface is the problematic case. We need to cycle break there!
555 recv := subst.var_(t.Recv())
556 params := subst.tuple(t.Params())
557 results := subst.tuple(t.Results())
558 if recv != t.Recv() || params != t.Params() || results != t.Results() {
559 return types.NewSignatureType(recv, nil, nil, params, results, t.Variadic())
560 }
561 return t
562 }
563