main.go raw

   1  // Copyright (c) 2020-2022 Uber Technologies, Inc.
   2  //
   3  // Permission is hereby granted, free of charge, to any person obtaining a copy
   4  // of this software and associated documentation files (the "Software"), to deal
   5  // in the Software without restriction, including without limitation the rights
   6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
   7  // copies of the Software, and to permit persons to whom the Software is
   8  // furnished to do so, subject to the following conditions:
   9  //
  10  // The above copyright notice and this permission notice shall be included in
  11  // all copies or substantial portions of the Software.
  12  //
  13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  19  // THE SOFTWARE.
  20  
  21  // gen-atomicwrapper generates wrapper types around other atomic types.
  22  //
  23  // It supports plugging in functions which convert the value inside the atomic
  24  // type to the user-facing value. For example,
  25  //
  26  // Given, atomic.Value and the functions,
  27  //
  28  //	func packString(string) enveloper{}
  29  //	func unpackString(enveloper{}) string
  30  //
  31  // We can run the following command:
  32  //
  33  //	gen-atomicwrapper -name String -wrapped Value \
  34  //	  -type string -pack fromString -unpack tostring
  35  //
  36  // This wil generate approximately,
  37  //
  38  //	type String struct{ v Value }
  39  //
  40  //	func (s *String) Load() string {
  41  //	  return unpackString(v.Load())
  42  //	}
  43  //
  44  //	func (s *String) Store(s string) {
  45  //	  return s.v.Store(packString(s))
  46  //	}
  47  //
  48  // The packing/unpacking logic allows the stored value to be different from
  49  // the user-facing value.
  50  package main
  51  
  52  import (
  53  	"bytes"
  54  	"embed"
  55  	"errors"
  56  	"flag"
  57  	"fmt"
  58  	"go/format"
  59  	"io"
  60  	"log"
  61  	"os"
  62  	"sort"
  63  	"strings"
  64  	"text/template"
  65  	"time"
  66  )
  67  
  68  func main() {
  69  	log.SetFlags(0)
  70  	if err := run(os.Args[1:]); err != nil {
  71  		log.Fatalf("%+v", err)
  72  	}
  73  }
  74  
  75  type stringList []string
  76  
  77  func (sl *stringList) String() string {
  78  	return strings.Join(*sl, ",")
  79  }
  80  
  81  func (sl *stringList) Set(s string) error {
  82  	for _, i := range strings.Split(s, ",") {
  83  		*sl = append(*sl, strings.TrimSpace(i))
  84  	}
  85  	return nil
  86  }
  87  
  88  func run(args []string) error {
  89  	var opts struct {
  90  		Name    string
  91  		Wrapped string
  92  		Type    string
  93  
  94  		Imports      stringList
  95  		Pack, Unpack string
  96  
  97  		CAS            bool
  98  		CompareAndSwap bool
  99  		Swap           bool
 100  		JSON           bool
 101  
 102  		File   string
 103  		ToYear int
 104  	}
 105  
 106  	opts.ToYear = time.Now().Year()
 107  
 108  	fl := flag.NewFlagSet("gen-atomicwrapper", flag.ContinueOnError)
 109  
 110  	// Required flags
 111  	fl.StringVar(&opts.Name, "name", "",
 112  		"name of the generated type (e.g. Duration)")
 113  	fl.StringVar(&opts.Wrapped, "wrapped", "",
 114  		"name of the wrapped atomic (e.g. Int64)")
 115  	fl.StringVar(&opts.Type, "type", "",
 116  		"name of the type exposed by the atomic (e.g. time.Duration)")
 117  
 118  	// Optional flags
 119  	fl.Var(&opts.Imports, "imports",
 120  		"comma separated list of imports to add")
 121  	fl.StringVar(&opts.Pack, "pack", "",
 122  		"function to transform values with before storage")
 123  	fl.StringVar(&opts.Unpack, "unpack", "",
 124  		"function to reverse packing on loading")
 125  	fl.StringVar(&opts.File, "file", "",
 126  		"output file path (default: stdout)")
 127  
 128  	// Switches for individual methods. Underlying atomics must support
 129  	// these.
 130  	fl.BoolVar(&opts.CAS, "cas", false,
 131  		"generate a deprecated `CAS(old, new) bool` method; requires -pack")
 132  	fl.BoolVar(&opts.CompareAndSwap, "compareandswap", false,
 133  		"generate a `CompareAndSwap(old, new) bool` method; requires -pack")
 134  	fl.BoolVar(&opts.Swap, "swap", false,
 135  		"generate a `Swap(new) old` method; requires -pack and -unpack")
 136  	fl.BoolVar(&opts.JSON, "json", false,
 137  		"generate `Marshal/UnmarshJSON` methods")
 138  
 139  	if err := fl.Parse(args); err != nil {
 140  		return err
 141  	}
 142  
 143  	if len(opts.Name) == 0 ||
 144  		len(opts.Wrapped) == 0 ||
 145  		len(opts.Type) == 0 ||
 146  		len(opts.Pack) == 0 ||
 147  		len(opts.Unpack) == 0 {
 148  		return errors.New("flags -name, -wrapped, -pack, -unpack and -type are required")
 149  	}
 150  
 151  	if opts.CAS {
 152  		opts.CompareAndSwap = true
 153  	}
 154  
 155  	var w io.Writer = os.Stdout
 156  	if file := opts.File; len(file) > 0 {
 157  		f, err := os.Create(file)
 158  		if err != nil {
 159  			return fmt.Errorf("create %q: %v", file, err)
 160  		}
 161  		defer f.Close()
 162  
 163  		w = f
 164  	}
 165  
 166  	// Import encoding/json if needed.
 167  	if opts.JSON {
 168  		found := false
 169  		for _, imp := range opts.Imports {
 170  			if imp == "encoding/json" {
 171  				found = true
 172  				break
 173  			}
 174  		}
 175  
 176  		if !found {
 177  			opts.Imports = append(opts.Imports, "encoding/json")
 178  		}
 179  	}
 180  
 181  	sort.Strings(opts.Imports)
 182  
 183  	var buff bytes.Buffer
 184  	if err := _tmpl.ExecuteTemplate(&buff, "wrapper.tmpl", opts); err != nil {
 185  		return fmt.Errorf("render template: %v", err)
 186  	}
 187  
 188  	bs, err := format.Source(buff.Bytes())
 189  	if err != nil {
 190  		return fmt.Errorf("reformat source: %v", err)
 191  	}
 192  
 193  	io.WriteString(w, "// @generated Code generated by gen-atomicwrapper.\n\n")
 194  	_, err = w.Write(bs)
 195  	return err
 196  }
 197  
 198  var (
 199  	//go:embed *.tmpl
 200  	_tmplFS embed.FS
 201  
 202  	_tmpl = template.Must(template.New("atomicwrapper").ParseFS(_tmplFS, "*.tmpl"))
 203  )
 204