package main import ( "go/ast" "go/token" "strconv" "strings" ) type xform struct { fset *token.FileSet todos int removed map[string]bool renamed map[string]string } func transformAST(fset *token.FileSet, f *ast.File) int { t := &xform{ fset: fset, removed: map[string]bool{}, renamed: map[string]string{}, } t.rewriteImports(f) t.walkFile(f) return t.todos } // --- imports --- func (t *xform) rewriteImports(f *ast.File) { for _, decl := range f.Decls { gd, ok := decl.(*ast.GenDecl) if !ok || gd.Tok != token.IMPORT { continue } var kept []ast.Spec for _, spec := range gd.Specs { is := spec.(*ast.ImportSpec) path, _ := strconv.Unquote(is.Path.Value) switch { case path == "reflect": name := "reflect" if is.Name != nil { name = is.Name.Name } t.removed[name] = true continue case strings.HasPrefix(path, "golang.org/x/tools/"): continue case path == "strings": name := "strings" if is.Name != nil { name = is.Name.Name } t.renamed[name] = "bytes" is.Path.Value = `"mx/bytes"` if is.Name != nil { is.Name.Name = "bytes" } case path == "fmt", path == "unsafe": // kept as-is, flagged in post-process default: if isStdlib(path) { is.Path.Value = strconv.Quote("mx/" + path) } } kept = append(kept, spec) } gd.Specs = kept } } func isStdlib(path string) bool { return !strings.Contains(path, ".") && !strings.HasPrefix(path, "./") && !strings.HasPrefix(path, "../") } // --- AST walk --- func (t *xform) walkFile(f *ast.File) { for _, decl := range f.Decls { t.walkDecl(decl) } } func (t *xform) walkDecl(decl ast.Decl) { switch d := decl.(type) { case *ast.FuncDecl: if d.Recv != nil { t.rewriteFieldList(d.Recv) } t.rewriteFieldList(d.Type.Params) t.rewriteFieldList(d.Type.Results) if d.Body != nil { t.walkBlock(d.Body) } case *ast.GenDecl: for _, spec := range d.Specs { t.walkSpec(spec) } } } func (t *xform) walkSpec(spec ast.Spec) { switch s := spec.(type) { case *ast.TypeSpec: t.rewriteTypeExpr(&s.Type) case *ast.ValueSpec: if s.Type != nil { t.rewriteTypeExpr(&s.Type) } for i := range s.Values { t.rewriteValueExpr(&s.Values[i]) } } } func (t *xform) walkBlock(block *ast.BlockStmt) { if block == nil { return } for i := range block.List { t.walkStmt(block.List[i]) } } func (t *xform) walkStmt(s ast.Stmt) { if s == nil { return } switch s := s.(type) { case *ast.ExprStmt: t.rewriteValueExpr(&s.X) case *ast.AssignStmt: for i := range s.Rhs { if ta, ok := s.Rhs[i].(*ast.TypeAssertExpr); ok && ta.Type != nil { t.rewriteValueExpr(&ta.X) s.Rhs[i] = ta.X if len(s.Lhs) == 2 && len(s.Rhs) == 1 { s.Lhs = s.Lhs[:1] } t.todos++ } else { t.rewriteValueExpr(&s.Rhs[i]) } } for i := range s.Lhs { t.rewriteValueExpr(&s.Lhs[i]) } case *ast.ReturnStmt: for i := range s.Results { t.rewriteValueExpr(&s.Results[i]) } case *ast.IfStmt: if s.Init != nil { t.walkStmt(s.Init) } if s.Cond != nil { t.rewriteValueExpr(&s.Cond) } t.walkBlock(s.Body) if s.Else != nil { t.walkStmt(s.Else) } case *ast.ForStmt: if s.Init != nil { t.walkStmt(s.Init) } if s.Cond != nil { t.rewriteValueExpr(&s.Cond) } if s.Post != nil { t.walkStmt(s.Post) } t.walkBlock(s.Body) case *ast.RangeStmt: if s.Key != nil { t.rewriteValueExpr(&s.Key) } if s.Value != nil { t.rewriteValueExpr(&s.Value) } t.rewriteValueExpr(&s.X) t.walkBlock(s.Body) case *ast.SwitchStmt: if s.Init != nil { t.walkStmt(s.Init) } if s.Tag != nil { t.rewriteValueExpr(&s.Tag) } t.walkBlock(s.Body) case *ast.TypeSwitchStmt: if s.Init != nil { t.walkStmt(s.Init) } if s.Assign != nil { t.walkStmt(s.Assign) } t.walkBlock(s.Body) case *ast.CaseClause: for i := range s.List { t.rewriteValueExpr(&s.List[i]) } for i := range s.Body { t.walkStmt(s.Body[i]) } case *ast.SelectStmt: t.walkBlock(s.Body) case *ast.CommClause: if s.Comm != nil { t.walkStmt(s.Comm) } for i := range s.Body { t.walkStmt(s.Body[i]) } case *ast.BlockStmt: t.walkBlock(s) case *ast.DeclStmt: t.walkDecl(s.Decl) case *ast.DeferStmt: t.rewriteCallExpr(s.Call) case *ast.GoStmt: t.rewriteCallExpr(s.Call) case *ast.SendStmt: t.rewriteValueExpr(&s.Chan) t.rewriteValueExpr(&s.Value) case *ast.IncDecStmt: t.rewriteValueExpr(&s.X) case *ast.LabeledStmt: if s.Stmt != nil { t.walkStmt(s.Stmt) } } } func (t *xform) rewriteCallExpr(call *ast.CallExpr) { if call == nil { return } if sel, ok := call.Fun.(*ast.SelectorExpr); ok { if id, ok := sel.X.(*ast.Ident); ok { if n, ok := t.renamed[id.Name]; ok { id.Name = n } } } if id, ok := call.Fun.(*ast.Ident); ok && id.Name == "string" { call.Fun = byteSlice() } for i := range call.Args { t.rewriteValueExpr(&call.Args[i]) } } // --- type expression rewriting --- func (t *xform) rewriteTypeExpr(ep *ast.Expr) { if ep == nil || *ep == nil { return } switch x := (*ep).(type) { case *ast.Ident: switch x.Name { case "string": *ep = byteSlice() case "any": *ep = byteSlice() t.todos++ } case *ast.InterfaceType: if x.Methods == nil || len(x.Methods.List) == 0 { *ep = byteSlice() t.todos++ } else { t.rewriteFieldList(x.Methods) } case *ast.ArrayType: t.rewriteTypeExpr(&x.Elt) case *ast.MapType: t.rewriteTypeExpr(&x.Key) t.rewriteTypeExpr(&x.Value) case *ast.StarExpr: t.rewriteTypeExpr(&x.X) case *ast.ChanType: t.rewriteTypeExpr(&x.Value) case *ast.FuncType: t.rewriteFieldList(x.Params) t.rewriteFieldList(x.Results) case *ast.StructType: t.rewriteFieldList(x.Fields) case *ast.Ellipsis: if x.Elt != nil { t.rewriteTypeExpr(&x.Elt) } case *ast.SelectorExpr: if id, ok := x.X.(*ast.Ident); ok { if t.removed[id.Name] { *ep = byteSlice() t.todos++ } else if n, ok := t.renamed[id.Name]; ok { id.Name = n } } } } func (t *xform) rewriteFieldList(fl *ast.FieldList) { if fl == nil { return } for _, f := range fl.List { t.rewriteTypeExpr(&f.Type) } } func byteSlice() ast.Expr { return &ast.ArrayType{Elt: &ast.Ident{Name: "byte"}} } // --- value expression rewriting --- func (t *xform) rewriteValueExpr(ep *ast.Expr) { if ep == nil || *ep == nil { return } switch x := (*ep).(type) { case *ast.BinaryExpr: t.rewriteValueExpr(&x.X) t.rewriteValueExpr(&x.Y) if x.Op == token.ADD && (isStringLit(x.X) || isStringLit(x.Y)) { x.Op = token.OR } case *ast.TypeAssertExpr: if x.Type != nil { t.rewriteValueExpr(&x.X) *ep = x.X t.todos++ } else { // x.(type) in type switch - leave alone t.rewriteValueExpr(&x.X) } case *ast.CallExpr: if id, ok := x.Fun.(*ast.Ident); ok && id.Name == "string" { x.Fun = byteSlice() } if sel, ok := x.Fun.(*ast.SelectorExpr); ok { if id, ok := sel.X.(*ast.Ident); ok { if n, ok := t.renamed[id.Name]; ok { id.Name = n } } } for i := range x.Args { t.rewriteValueExpr(&x.Args[i]) } case *ast.SelectorExpr: if id, ok := x.X.(*ast.Ident); ok { if n, ok := t.renamed[id.Name]; ok { id.Name = n } } case *ast.UnaryExpr: t.rewriteValueExpr(&x.X) case *ast.ParenExpr: t.rewriteValueExpr(&x.X) case *ast.IndexExpr: t.rewriteValueExpr(&x.X) t.rewriteValueExpr(&x.Index) case *ast.SliceExpr: t.rewriteValueExpr(&x.X) if x.Low != nil { t.rewriteValueExpr(&x.Low) } if x.High != nil { t.rewriteValueExpr(&x.High) } if x.Max != nil { t.rewriteValueExpr(&x.Max) } case *ast.CompositeLit: if x.Type != nil { t.rewriteTypeExpr(&x.Type) } for i := range x.Elts { t.rewriteValueExpr(&x.Elts[i]) } case *ast.KeyValueExpr: t.rewriteValueExpr(&x.Key) t.rewriteValueExpr(&x.Value) case *ast.FuncLit: t.rewriteFieldList(x.Type.Params) t.rewriteFieldList(x.Type.Results) if x.Body != nil { t.walkBlock(x.Body) } case *ast.StarExpr: t.rewriteValueExpr(&x.X) // types appearing in value position (make args, composite lit types, etc.) case *ast.ChanType: t.rewriteTypeExpr(&x.Value) case *ast.ArrayType: t.rewriteTypeExpr(&x.Elt) case *ast.MapType: t.rewriteTypeExpr(&x.Key) t.rewriteTypeExpr(&x.Value) case *ast.InterfaceType: if x.Methods == nil || len(x.Methods.List) == 0 { *ep = byteSlice() t.todos++ } case *ast.FuncType: t.rewriteFieldList(x.Params) t.rewriteFieldList(x.Results) case *ast.StructType: t.rewriteFieldList(x.Fields) } } func isStringLit(e ast.Expr) bool { lit, ok := e.(*ast.BasicLit) return ok && lit.Kind == token.STRING }