transform.go raw

   1  package main
   2  
   3  import (
   4  	"go/ast"
   5  	"go/token"
   6  	"strconv"
   7  	"strings"
   8  )
   9  
  10  type xform struct {
  11  	fset    *token.FileSet
  12  	todos   int
  13  	removed map[string]bool
  14  	renamed map[string]string
  15  }
  16  
  17  func transformAST(fset *token.FileSet, f *ast.File) int {
  18  	t := &xform{
  19  		fset:    fset,
  20  		removed: map[string]bool{},
  21  		renamed: map[string]string{},
  22  	}
  23  	t.rewriteImports(f)
  24  	t.walkFile(f)
  25  	return t.todos
  26  }
  27  
  28  // --- imports ---
  29  
  30  func (t *xform) rewriteImports(f *ast.File) {
  31  	for _, decl := range f.Decls {
  32  		gd, ok := decl.(*ast.GenDecl)
  33  		if !ok || gd.Tok != token.IMPORT {
  34  			continue
  35  		}
  36  		var kept []ast.Spec
  37  		for _, spec := range gd.Specs {
  38  			is := spec.(*ast.ImportSpec)
  39  			path, _ := strconv.Unquote(is.Path.Value)
  40  
  41  			switch {
  42  			case path == "reflect":
  43  				name := "reflect"
  44  				if is.Name != nil {
  45  					name = is.Name.Name
  46  				}
  47  				t.removed[name] = true
  48  				continue
  49  
  50  			case strings.HasPrefix(path, "golang.org/x/tools/"):
  51  				continue
  52  
  53  			case path == "strings":
  54  				name := "strings"
  55  				if is.Name != nil {
  56  					name = is.Name.Name
  57  				}
  58  				t.renamed[name] = "bytes"
  59  				is.Path.Value = `"mx/bytes"`
  60  				if is.Name != nil {
  61  					is.Name.Name = "bytes"
  62  				}
  63  
  64  			case path == "fmt", path == "unsafe":
  65  				// kept as-is, flagged in post-process
  66  
  67  			default:
  68  				if isStdlib(path) {
  69  					is.Path.Value = strconv.Quote("mx/" + path)
  70  				}
  71  			}
  72  			kept = append(kept, spec)
  73  		}
  74  		gd.Specs = kept
  75  	}
  76  }
  77  
  78  func isStdlib(path string) bool {
  79  	return !strings.Contains(path, ".") &&
  80  		!strings.HasPrefix(path, "./") &&
  81  		!strings.HasPrefix(path, "../")
  82  }
  83  
  84  // --- AST walk ---
  85  
  86  func (t *xform) walkFile(f *ast.File) {
  87  	for _, decl := range f.Decls {
  88  		t.walkDecl(decl)
  89  	}
  90  }
  91  
  92  func (t *xform) walkDecl(decl ast.Decl) {
  93  	switch d := decl.(type) {
  94  	case *ast.FuncDecl:
  95  		if d.Recv != nil {
  96  			t.rewriteFieldList(d.Recv)
  97  		}
  98  		t.rewriteFieldList(d.Type.Params)
  99  		t.rewriteFieldList(d.Type.Results)
 100  		if d.Body != nil {
 101  			t.walkBlock(d.Body)
 102  		}
 103  	case *ast.GenDecl:
 104  		for _, spec := range d.Specs {
 105  			t.walkSpec(spec)
 106  		}
 107  	}
 108  }
 109  
 110  func (t *xform) walkSpec(spec ast.Spec) {
 111  	switch s := spec.(type) {
 112  	case *ast.TypeSpec:
 113  		t.rewriteTypeExpr(&s.Type)
 114  	case *ast.ValueSpec:
 115  		if s.Type != nil {
 116  			t.rewriteTypeExpr(&s.Type)
 117  		}
 118  		for i := range s.Values {
 119  			t.rewriteValueExpr(&s.Values[i])
 120  		}
 121  	}
 122  }
 123  
 124  func (t *xform) walkBlock(block *ast.BlockStmt) {
 125  	if block == nil {
 126  		return
 127  	}
 128  	for i := range block.List {
 129  		t.walkStmt(block.List[i])
 130  	}
 131  }
 132  
 133  func (t *xform) walkStmt(s ast.Stmt) {
 134  	if s == nil {
 135  		return
 136  	}
 137  	switch s := s.(type) {
 138  	case *ast.ExprStmt:
 139  		t.rewriteValueExpr(&s.X)
 140  
 141  	case *ast.AssignStmt:
 142  		for i := range s.Rhs {
 143  			if ta, ok := s.Rhs[i].(*ast.TypeAssertExpr); ok && ta.Type != nil {
 144  				t.rewriteValueExpr(&ta.X)
 145  				s.Rhs[i] = ta.X
 146  				if len(s.Lhs) == 2 && len(s.Rhs) == 1 {
 147  					s.Lhs = s.Lhs[:1]
 148  				}
 149  				t.todos++
 150  			} else {
 151  				t.rewriteValueExpr(&s.Rhs[i])
 152  			}
 153  		}
 154  		for i := range s.Lhs {
 155  			t.rewriteValueExpr(&s.Lhs[i])
 156  		}
 157  
 158  	case *ast.ReturnStmt:
 159  		for i := range s.Results {
 160  			t.rewriteValueExpr(&s.Results[i])
 161  		}
 162  
 163  	case *ast.IfStmt:
 164  		if s.Init != nil {
 165  			t.walkStmt(s.Init)
 166  		}
 167  		if s.Cond != nil {
 168  			t.rewriteValueExpr(&s.Cond)
 169  		}
 170  		t.walkBlock(s.Body)
 171  		if s.Else != nil {
 172  			t.walkStmt(s.Else)
 173  		}
 174  
 175  	case *ast.ForStmt:
 176  		if s.Init != nil {
 177  			t.walkStmt(s.Init)
 178  		}
 179  		if s.Cond != nil {
 180  			t.rewriteValueExpr(&s.Cond)
 181  		}
 182  		if s.Post != nil {
 183  			t.walkStmt(s.Post)
 184  		}
 185  		t.walkBlock(s.Body)
 186  
 187  	case *ast.RangeStmt:
 188  		if s.Key != nil {
 189  			t.rewriteValueExpr(&s.Key)
 190  		}
 191  		if s.Value != nil {
 192  			t.rewriteValueExpr(&s.Value)
 193  		}
 194  		t.rewriteValueExpr(&s.X)
 195  		t.walkBlock(s.Body)
 196  
 197  	case *ast.SwitchStmt:
 198  		if s.Init != nil {
 199  			t.walkStmt(s.Init)
 200  		}
 201  		if s.Tag != nil {
 202  			t.rewriteValueExpr(&s.Tag)
 203  		}
 204  		t.walkBlock(s.Body)
 205  
 206  	case *ast.TypeSwitchStmt:
 207  		if s.Init != nil {
 208  			t.walkStmt(s.Init)
 209  		}
 210  		if s.Assign != nil {
 211  			t.walkStmt(s.Assign)
 212  		}
 213  		t.walkBlock(s.Body)
 214  
 215  	case *ast.CaseClause:
 216  		for i := range s.List {
 217  			t.rewriteValueExpr(&s.List[i])
 218  		}
 219  		for i := range s.Body {
 220  			t.walkStmt(s.Body[i])
 221  		}
 222  
 223  	case *ast.SelectStmt:
 224  		t.walkBlock(s.Body)
 225  
 226  	case *ast.CommClause:
 227  		if s.Comm != nil {
 228  			t.walkStmt(s.Comm)
 229  		}
 230  		for i := range s.Body {
 231  			t.walkStmt(s.Body[i])
 232  		}
 233  
 234  	case *ast.BlockStmt:
 235  		t.walkBlock(s)
 236  
 237  	case *ast.DeclStmt:
 238  		t.walkDecl(s.Decl)
 239  
 240  	case *ast.DeferStmt:
 241  		t.rewriteCallExpr(s.Call)
 242  
 243  	case *ast.GoStmt:
 244  		t.rewriteCallExpr(s.Call)
 245  
 246  	case *ast.SendStmt:
 247  		t.rewriteValueExpr(&s.Chan)
 248  		t.rewriteValueExpr(&s.Value)
 249  
 250  	case *ast.IncDecStmt:
 251  		t.rewriteValueExpr(&s.X)
 252  
 253  	case *ast.LabeledStmt:
 254  		if s.Stmt != nil {
 255  			t.walkStmt(s.Stmt)
 256  		}
 257  	}
 258  }
 259  
 260  func (t *xform) rewriteCallExpr(call *ast.CallExpr) {
 261  	if call == nil {
 262  		return
 263  	}
 264  	if sel, ok := call.Fun.(*ast.SelectorExpr); ok {
 265  		if id, ok := sel.X.(*ast.Ident); ok {
 266  			if n, ok := t.renamed[id.Name]; ok {
 267  				id.Name = n
 268  			}
 269  		}
 270  	}
 271  	if id, ok := call.Fun.(*ast.Ident); ok && id.Name == "string" {
 272  		call.Fun = byteSlice()
 273  	}
 274  	for i := range call.Args {
 275  		t.rewriteValueExpr(&call.Args[i])
 276  	}
 277  }
 278  
 279  // --- type expression rewriting ---
 280  
 281  func (t *xform) rewriteTypeExpr(ep *ast.Expr) {
 282  	if ep == nil || *ep == nil {
 283  		return
 284  	}
 285  	switch x := (*ep).(type) {
 286  	case *ast.Ident:
 287  		switch x.Name {
 288  		case "string":
 289  			*ep = byteSlice()
 290  		case "any":
 291  			*ep = byteSlice()
 292  			t.todos++
 293  		}
 294  	case *ast.InterfaceType:
 295  		if x.Methods == nil || len(x.Methods.List) == 0 {
 296  			*ep = byteSlice()
 297  			t.todos++
 298  		} else {
 299  			t.rewriteFieldList(x.Methods)
 300  		}
 301  	case *ast.ArrayType:
 302  		t.rewriteTypeExpr(&x.Elt)
 303  	case *ast.MapType:
 304  		t.rewriteTypeExpr(&x.Key)
 305  		t.rewriteTypeExpr(&x.Value)
 306  	case *ast.StarExpr:
 307  		t.rewriteTypeExpr(&x.X)
 308  	case *ast.ChanType:
 309  		t.rewriteTypeExpr(&x.Value)
 310  	case *ast.FuncType:
 311  		t.rewriteFieldList(x.Params)
 312  		t.rewriteFieldList(x.Results)
 313  	case *ast.StructType:
 314  		t.rewriteFieldList(x.Fields)
 315  	case *ast.Ellipsis:
 316  		if x.Elt != nil {
 317  			t.rewriteTypeExpr(&x.Elt)
 318  		}
 319  	case *ast.SelectorExpr:
 320  		if id, ok := x.X.(*ast.Ident); ok {
 321  			if t.removed[id.Name] {
 322  				*ep = byteSlice()
 323  				t.todos++
 324  			} else if n, ok := t.renamed[id.Name]; ok {
 325  				id.Name = n
 326  			}
 327  		}
 328  	}
 329  }
 330  
 331  func (t *xform) rewriteFieldList(fl *ast.FieldList) {
 332  	if fl == nil {
 333  		return
 334  	}
 335  	for _, f := range fl.List {
 336  		t.rewriteTypeExpr(&f.Type)
 337  	}
 338  }
 339  
 340  func byteSlice() ast.Expr {
 341  	return &ast.ArrayType{Elt: &ast.Ident{Name: "byte"}}
 342  }
 343  
 344  // --- value expression rewriting ---
 345  
 346  func (t *xform) rewriteValueExpr(ep *ast.Expr) {
 347  	if ep == nil || *ep == nil {
 348  		return
 349  	}
 350  	switch x := (*ep).(type) {
 351  	case *ast.BinaryExpr:
 352  		t.rewriteValueExpr(&x.X)
 353  		t.rewriteValueExpr(&x.Y)
 354  		if x.Op == token.ADD && (isStringLit(x.X) || isStringLit(x.Y)) {
 355  			x.Op = token.OR
 356  		}
 357  
 358  	case *ast.TypeAssertExpr:
 359  		if x.Type != nil {
 360  			t.rewriteValueExpr(&x.X)
 361  			*ep = x.X
 362  			t.todos++
 363  		} else {
 364  			// x.(type) in type switch - leave alone
 365  			t.rewriteValueExpr(&x.X)
 366  		}
 367  
 368  	case *ast.CallExpr:
 369  		if id, ok := x.Fun.(*ast.Ident); ok && id.Name == "string" {
 370  			x.Fun = byteSlice()
 371  		}
 372  		if sel, ok := x.Fun.(*ast.SelectorExpr); ok {
 373  			if id, ok := sel.X.(*ast.Ident); ok {
 374  				if n, ok := t.renamed[id.Name]; ok {
 375  					id.Name = n
 376  				}
 377  			}
 378  		}
 379  		for i := range x.Args {
 380  			t.rewriteValueExpr(&x.Args[i])
 381  		}
 382  
 383  	case *ast.SelectorExpr:
 384  		if id, ok := x.X.(*ast.Ident); ok {
 385  			if n, ok := t.renamed[id.Name]; ok {
 386  				id.Name = n
 387  			}
 388  		}
 389  
 390  	case *ast.UnaryExpr:
 391  		t.rewriteValueExpr(&x.X)
 392  
 393  	case *ast.ParenExpr:
 394  		t.rewriteValueExpr(&x.X)
 395  
 396  	case *ast.IndexExpr:
 397  		t.rewriteValueExpr(&x.X)
 398  		t.rewriteValueExpr(&x.Index)
 399  
 400  	case *ast.SliceExpr:
 401  		t.rewriteValueExpr(&x.X)
 402  		if x.Low != nil {
 403  			t.rewriteValueExpr(&x.Low)
 404  		}
 405  		if x.High != nil {
 406  			t.rewriteValueExpr(&x.High)
 407  		}
 408  		if x.Max != nil {
 409  			t.rewriteValueExpr(&x.Max)
 410  		}
 411  
 412  	case *ast.CompositeLit:
 413  		if x.Type != nil {
 414  			t.rewriteTypeExpr(&x.Type)
 415  		}
 416  		for i := range x.Elts {
 417  			t.rewriteValueExpr(&x.Elts[i])
 418  		}
 419  
 420  	case *ast.KeyValueExpr:
 421  		t.rewriteValueExpr(&x.Key)
 422  		t.rewriteValueExpr(&x.Value)
 423  
 424  	case *ast.FuncLit:
 425  		t.rewriteFieldList(x.Type.Params)
 426  		t.rewriteFieldList(x.Type.Results)
 427  		if x.Body != nil {
 428  			t.walkBlock(x.Body)
 429  		}
 430  
 431  	case *ast.StarExpr:
 432  		t.rewriteValueExpr(&x.X)
 433  
 434  	// types appearing in value position (make args, composite lit types, etc.)
 435  	case *ast.ChanType:
 436  		t.rewriteTypeExpr(&x.Value)
 437  	case *ast.ArrayType:
 438  		t.rewriteTypeExpr(&x.Elt)
 439  	case *ast.MapType:
 440  		t.rewriteTypeExpr(&x.Key)
 441  		t.rewriteTypeExpr(&x.Value)
 442  	case *ast.InterfaceType:
 443  		if x.Methods == nil || len(x.Methods.List) == 0 {
 444  			*ep = byteSlice()
 445  			t.todos++
 446  		}
 447  	case *ast.FuncType:
 448  		t.rewriteFieldList(x.Params)
 449  		t.rewriteFieldList(x.Results)
 450  	case *ast.StructType:
 451  		t.rewriteFieldList(x.Fields)
 452  	}
 453  }
 454  
 455  func isStringLit(e ast.Expr) bool {
 456  	lit, ok := e.(*ast.BasicLit)
 457  	return ok && lit.Kind == token.STRING
 458  }
 459