transform.go raw
1 package main
2
3 import (
4 "go/ast"
5 "go/token"
6 "strconv"
7 "strings"
8 )
9
10 type xform struct {
11 fset *token.FileSet
12 todos int
13 removed map[string]bool
14 renamed map[string]string
15 }
16
17 func transformAST(fset *token.FileSet, f *ast.File) int {
18 t := &xform{
19 fset: fset,
20 removed: map[string]bool{},
21 renamed: map[string]string{},
22 }
23 t.rewriteImports(f)
24 t.walkFile(f)
25 return t.todos
26 }
27
28 // --- imports ---
29
30 func (t *xform) rewriteImports(f *ast.File) {
31 for _, decl := range f.Decls {
32 gd, ok := decl.(*ast.GenDecl)
33 if !ok || gd.Tok != token.IMPORT {
34 continue
35 }
36 var kept []ast.Spec
37 for _, spec := range gd.Specs {
38 is := spec.(*ast.ImportSpec)
39 path, _ := strconv.Unquote(is.Path.Value)
40
41 switch {
42 case path == "reflect":
43 name := "reflect"
44 if is.Name != nil {
45 name = is.Name.Name
46 }
47 t.removed[name] = true
48 continue
49
50 case strings.HasPrefix(path, "golang.org/x/tools/"):
51 continue
52
53 case path == "strings":
54 name := "strings"
55 if is.Name != nil {
56 name = is.Name.Name
57 }
58 t.renamed[name] = "bytes"
59 is.Path.Value = `"mx/bytes"`
60 if is.Name != nil {
61 is.Name.Name = "bytes"
62 }
63
64 case path == "fmt", path == "unsafe":
65 // kept as-is, flagged in post-process
66
67 default:
68 if isStdlib(path) {
69 is.Path.Value = strconv.Quote("mx/" + path)
70 }
71 }
72 kept = append(kept, spec)
73 }
74 gd.Specs = kept
75 }
76 }
77
78 func isStdlib(path string) bool {
79 return !strings.Contains(path, ".") &&
80 !strings.HasPrefix(path, "./") &&
81 !strings.HasPrefix(path, "../")
82 }
83
84 // --- AST walk ---
85
86 func (t *xform) walkFile(f *ast.File) {
87 for _, decl := range f.Decls {
88 t.walkDecl(decl)
89 }
90 }
91
92 func (t *xform) walkDecl(decl ast.Decl) {
93 switch d := decl.(type) {
94 case *ast.FuncDecl:
95 if d.Recv != nil {
96 t.rewriteFieldList(d.Recv)
97 }
98 t.rewriteFieldList(d.Type.Params)
99 t.rewriteFieldList(d.Type.Results)
100 if d.Body != nil {
101 t.walkBlock(d.Body)
102 }
103 case *ast.GenDecl:
104 for _, spec := range d.Specs {
105 t.walkSpec(spec)
106 }
107 }
108 }
109
110 func (t *xform) walkSpec(spec ast.Spec) {
111 switch s := spec.(type) {
112 case *ast.TypeSpec:
113 t.rewriteTypeExpr(&s.Type)
114 case *ast.ValueSpec:
115 if s.Type != nil {
116 t.rewriteTypeExpr(&s.Type)
117 }
118 for i := range s.Values {
119 t.rewriteValueExpr(&s.Values[i])
120 }
121 }
122 }
123
124 func (t *xform) walkBlock(block *ast.BlockStmt) {
125 if block == nil {
126 return
127 }
128 for i := range block.List {
129 t.walkStmt(block.List[i])
130 }
131 }
132
133 func (t *xform) walkStmt(s ast.Stmt) {
134 if s == nil {
135 return
136 }
137 switch s := s.(type) {
138 case *ast.ExprStmt:
139 t.rewriteValueExpr(&s.X)
140
141 case *ast.AssignStmt:
142 for i := range s.Rhs {
143 if ta, ok := s.Rhs[i].(*ast.TypeAssertExpr); ok && ta.Type != nil {
144 t.rewriteValueExpr(&ta.X)
145 s.Rhs[i] = ta.X
146 if len(s.Lhs) == 2 && len(s.Rhs) == 1 {
147 s.Lhs = s.Lhs[:1]
148 }
149 t.todos++
150 } else {
151 t.rewriteValueExpr(&s.Rhs[i])
152 }
153 }
154 for i := range s.Lhs {
155 t.rewriteValueExpr(&s.Lhs[i])
156 }
157
158 case *ast.ReturnStmt:
159 for i := range s.Results {
160 t.rewriteValueExpr(&s.Results[i])
161 }
162
163 case *ast.IfStmt:
164 if s.Init != nil {
165 t.walkStmt(s.Init)
166 }
167 if s.Cond != nil {
168 t.rewriteValueExpr(&s.Cond)
169 }
170 t.walkBlock(s.Body)
171 if s.Else != nil {
172 t.walkStmt(s.Else)
173 }
174
175 case *ast.ForStmt:
176 if s.Init != nil {
177 t.walkStmt(s.Init)
178 }
179 if s.Cond != nil {
180 t.rewriteValueExpr(&s.Cond)
181 }
182 if s.Post != nil {
183 t.walkStmt(s.Post)
184 }
185 t.walkBlock(s.Body)
186
187 case *ast.RangeStmt:
188 if s.Key != nil {
189 t.rewriteValueExpr(&s.Key)
190 }
191 if s.Value != nil {
192 t.rewriteValueExpr(&s.Value)
193 }
194 t.rewriteValueExpr(&s.X)
195 t.walkBlock(s.Body)
196
197 case *ast.SwitchStmt:
198 if s.Init != nil {
199 t.walkStmt(s.Init)
200 }
201 if s.Tag != nil {
202 t.rewriteValueExpr(&s.Tag)
203 }
204 t.walkBlock(s.Body)
205
206 case *ast.TypeSwitchStmt:
207 if s.Init != nil {
208 t.walkStmt(s.Init)
209 }
210 if s.Assign != nil {
211 t.walkStmt(s.Assign)
212 }
213 t.walkBlock(s.Body)
214
215 case *ast.CaseClause:
216 for i := range s.List {
217 t.rewriteValueExpr(&s.List[i])
218 }
219 for i := range s.Body {
220 t.walkStmt(s.Body[i])
221 }
222
223 case *ast.SelectStmt:
224 t.walkBlock(s.Body)
225
226 case *ast.CommClause:
227 if s.Comm != nil {
228 t.walkStmt(s.Comm)
229 }
230 for i := range s.Body {
231 t.walkStmt(s.Body[i])
232 }
233
234 case *ast.BlockStmt:
235 t.walkBlock(s)
236
237 case *ast.DeclStmt:
238 t.walkDecl(s.Decl)
239
240 case *ast.DeferStmt:
241 t.rewriteCallExpr(s.Call)
242
243 case *ast.GoStmt:
244 t.rewriteCallExpr(s.Call)
245
246 case *ast.SendStmt:
247 t.rewriteValueExpr(&s.Chan)
248 t.rewriteValueExpr(&s.Value)
249
250 case *ast.IncDecStmt:
251 t.rewriteValueExpr(&s.X)
252
253 case *ast.LabeledStmt:
254 if s.Stmt != nil {
255 t.walkStmt(s.Stmt)
256 }
257 }
258 }
259
260 func (t *xform) rewriteCallExpr(call *ast.CallExpr) {
261 if call == nil {
262 return
263 }
264 if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
265 if id, ok := sel.X.(*ast.Ident); ok {
266 if n, ok := t.renamed[id.Name]; ok {
267 id.Name = n
268 }
269 }
270 }
271 if id, ok := call.Fun.(*ast.Ident); ok && id.Name == "string" {
272 call.Fun = byteSlice()
273 }
274 for i := range call.Args {
275 t.rewriteValueExpr(&call.Args[i])
276 }
277 }
278
279 // --- type expression rewriting ---
280
281 func (t *xform) rewriteTypeExpr(ep *ast.Expr) {
282 if ep == nil || *ep == nil {
283 return
284 }
285 switch x := (*ep).(type) {
286 case *ast.Ident:
287 switch x.Name {
288 case "string":
289 *ep = byteSlice()
290 case "any":
291 *ep = byteSlice()
292 t.todos++
293 }
294 case *ast.InterfaceType:
295 if x.Methods == nil || len(x.Methods.List) == 0 {
296 *ep = byteSlice()
297 t.todos++
298 } else {
299 t.rewriteFieldList(x.Methods)
300 }
301 case *ast.ArrayType:
302 t.rewriteTypeExpr(&x.Elt)
303 case *ast.MapType:
304 t.rewriteTypeExpr(&x.Key)
305 t.rewriteTypeExpr(&x.Value)
306 case *ast.StarExpr:
307 t.rewriteTypeExpr(&x.X)
308 case *ast.ChanType:
309 t.rewriteTypeExpr(&x.Value)
310 case *ast.FuncType:
311 t.rewriteFieldList(x.Params)
312 t.rewriteFieldList(x.Results)
313 case *ast.StructType:
314 t.rewriteFieldList(x.Fields)
315 case *ast.Ellipsis:
316 if x.Elt != nil {
317 t.rewriteTypeExpr(&x.Elt)
318 }
319 case *ast.SelectorExpr:
320 if id, ok := x.X.(*ast.Ident); ok {
321 if t.removed[id.Name] {
322 *ep = byteSlice()
323 t.todos++
324 } else if n, ok := t.renamed[id.Name]; ok {
325 id.Name = n
326 }
327 }
328 }
329 }
330
331 func (t *xform) rewriteFieldList(fl *ast.FieldList) {
332 if fl == nil {
333 return
334 }
335 for _, f := range fl.List {
336 t.rewriteTypeExpr(&f.Type)
337 }
338 }
339
340 func byteSlice() ast.Expr {
341 return &ast.ArrayType{Elt: &ast.Ident{Name: "byte"}}
342 }
343
344 // --- value expression rewriting ---
345
346 func (t *xform) rewriteValueExpr(ep *ast.Expr) {
347 if ep == nil || *ep == nil {
348 return
349 }
350 switch x := (*ep).(type) {
351 case *ast.BinaryExpr:
352 t.rewriteValueExpr(&x.X)
353 t.rewriteValueExpr(&x.Y)
354 if x.Op == token.ADD && (isStringLit(x.X) || isStringLit(x.Y)) {
355 x.Op = token.OR
356 }
357
358 case *ast.TypeAssertExpr:
359 if x.Type != nil {
360 t.rewriteValueExpr(&x.X)
361 *ep = x.X
362 t.todos++
363 } else {
364 // x.(type) in type switch - leave alone
365 t.rewriteValueExpr(&x.X)
366 }
367
368 case *ast.CallExpr:
369 if id, ok := x.Fun.(*ast.Ident); ok && id.Name == "string" {
370 x.Fun = byteSlice()
371 }
372 if sel, ok := x.Fun.(*ast.SelectorExpr); ok {
373 if id, ok := sel.X.(*ast.Ident); ok {
374 if n, ok := t.renamed[id.Name]; ok {
375 id.Name = n
376 }
377 }
378 }
379 for i := range x.Args {
380 t.rewriteValueExpr(&x.Args[i])
381 }
382
383 case *ast.SelectorExpr:
384 if id, ok := x.X.(*ast.Ident); ok {
385 if n, ok := t.renamed[id.Name]; ok {
386 id.Name = n
387 }
388 }
389
390 case *ast.UnaryExpr:
391 t.rewriteValueExpr(&x.X)
392
393 case *ast.ParenExpr:
394 t.rewriteValueExpr(&x.X)
395
396 case *ast.IndexExpr:
397 t.rewriteValueExpr(&x.X)
398 t.rewriteValueExpr(&x.Index)
399
400 case *ast.SliceExpr:
401 t.rewriteValueExpr(&x.X)
402 if x.Low != nil {
403 t.rewriteValueExpr(&x.Low)
404 }
405 if x.High != nil {
406 t.rewriteValueExpr(&x.High)
407 }
408 if x.Max != nil {
409 t.rewriteValueExpr(&x.Max)
410 }
411
412 case *ast.CompositeLit:
413 if x.Type != nil {
414 t.rewriteTypeExpr(&x.Type)
415 }
416 for i := range x.Elts {
417 t.rewriteValueExpr(&x.Elts[i])
418 }
419
420 case *ast.KeyValueExpr:
421 t.rewriteValueExpr(&x.Key)
422 t.rewriteValueExpr(&x.Value)
423
424 case *ast.FuncLit:
425 t.rewriteFieldList(x.Type.Params)
426 t.rewriteFieldList(x.Type.Results)
427 if x.Body != nil {
428 t.walkBlock(x.Body)
429 }
430
431 case *ast.StarExpr:
432 t.rewriteValueExpr(&x.X)
433
434 // types appearing in value position (make args, composite lit types, etc.)
435 case *ast.ChanType:
436 t.rewriteTypeExpr(&x.Value)
437 case *ast.ArrayType:
438 t.rewriteTypeExpr(&x.Elt)
439 case *ast.MapType:
440 t.rewriteTypeExpr(&x.Key)
441 t.rewriteTypeExpr(&x.Value)
442 case *ast.InterfaceType:
443 if x.Methods == nil || len(x.Methods.List) == 0 {
444 *ep = byteSlice()
445 t.todos++
446 }
447 case *ast.FuncType:
448 t.rewriteFieldList(x.Params)
449 t.rewriteFieldList(x.Results)
450 case *ast.StructType:
451 t.rewriteFieldList(x.Fields)
452 }
453 }
454
455 func isStringLit(e ast.Expr) bool {
456 lit, ok := e.(*ast.BasicLit)
457 return ok && lit.Kind == token.STRING
458 }
459