golangflag.go raw

   1  // Copyright 2009 The Go Authors. All rights reserved.
   2  // Use of this source code is governed by a BSD-style
   3  // license that can be found in the LICENSE file.
   4  
   5  package pflag
   6  
   7  import (
   8  	goflag "flag"
   9  	"reflect"
  10  	"strings"
  11  )
  12  
  13  // go test flags prefixes
  14  func isGotestFlag(flag string) bool {
  15  	return strings.HasPrefix(flag, "-test.")
  16  }
  17  
  18  func isGotestShorthandFlag(flag string) bool {
  19  	return strings.HasPrefix(flag, "test.")
  20  }
  21  
  22  // flagValueWrapper implements pflag.Value around a flag.Value.  The main
  23  // difference here is the addition of the Type method that returns a string
  24  // name of the type.  As this is generally unknown, we approximate that with
  25  // reflection.
  26  type flagValueWrapper struct {
  27  	inner    goflag.Value
  28  	flagType string
  29  }
  30  
  31  // We are just copying the boolFlag interface out of goflag as that is what
  32  // they use to decide if a flag should get "true" when no arg is given.
  33  type goBoolFlag interface {
  34  	goflag.Value
  35  	IsBoolFlag() bool
  36  }
  37  
  38  func wrapFlagValue(v goflag.Value) Value {
  39  	// If the flag.Value happens to also be a pflag.Value, just use it directly.
  40  	if pv, ok := v.(Value); ok {
  41  		return pv
  42  	}
  43  
  44  	pv := &flagValueWrapper{
  45  		inner: v,
  46  	}
  47  
  48  	t := reflect.TypeOf(v)
  49  	if t.Kind() == reflect.Interface || t.Kind() == reflect.Ptr {
  50  		t = t.Elem()
  51  	}
  52  
  53  	pv.flagType = strings.TrimSuffix(t.Name(), "Value")
  54  	return pv
  55  }
  56  
  57  func (v *flagValueWrapper) String() string {
  58  	return v.inner.String()
  59  }
  60  
  61  func (v *flagValueWrapper) Set(s string) error {
  62  	return v.inner.Set(s)
  63  }
  64  
  65  func (v *flagValueWrapper) Type() string {
  66  	return v.flagType
  67  }
  68  
  69  // PFlagFromGoFlag will return a *pflag.Flag given a *flag.Flag
  70  // If the *flag.Flag.Name was a single character (ex: `v`) it will be accessiblei
  71  // with both `-v` and `--v` in flags. If the golang flag was more than a single
  72  // character (ex: `verbose`) it will only be accessible via `--verbose`
  73  func PFlagFromGoFlag(goflag *goflag.Flag) *Flag {
  74  	// Remember the default value as a string; it won't change.
  75  	flag := &Flag{
  76  		Name:  goflag.Name,
  77  		Usage: goflag.Usage,
  78  		Value: wrapFlagValue(goflag.Value),
  79  		// Looks like golang flags don't set DefValue correctly  :-(
  80  		//DefValue: goflag.DefValue,
  81  		DefValue: goflag.Value.String(),
  82  	}
  83  	// Ex: if the golang flag was -v, allow both -v and --v to work
  84  	if len(flag.Name) == 1 {
  85  		flag.Shorthand = flag.Name
  86  	}
  87  	if fv, ok := goflag.Value.(goBoolFlag); ok && fv.IsBoolFlag() {
  88  		flag.NoOptDefVal = "true"
  89  	}
  90  	return flag
  91  }
  92  
  93  // AddGoFlag will add the given *flag.Flag to the pflag.FlagSet
  94  func (f *FlagSet) AddGoFlag(goflag *goflag.Flag) {
  95  	if f.Lookup(goflag.Name) != nil {
  96  		return
  97  	}
  98  	newflag := PFlagFromGoFlag(goflag)
  99  	f.AddFlag(newflag)
 100  }
 101  
 102  // AddGoFlagSet will add the given *flag.FlagSet to the pflag.FlagSet
 103  func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) {
 104  	if newSet == nil {
 105  		return
 106  	}
 107  	newSet.VisitAll(func(goflag *goflag.Flag) {
 108  		f.AddGoFlag(goflag)
 109  	})
 110  	if f.addedGoFlagSets == nil {
 111  		f.addedGoFlagSets = make([]*goflag.FlagSet, 0)
 112  	}
 113  	f.addedGoFlagSets = append(f.addedGoFlagSets, newSet)
 114  }
 115  
 116  // ParseSkippedFlags explicitly Parses go test flags (i.e. the one starting with '-test.') with goflag.Parse(),
 117  // since by default those are skipped by pflag.Parse().
 118  // Typical usage example: `ParseGoTestFlags(os.Args[1:], goflag.CommandLine)`
 119  func ParseSkippedFlags(osArgs []string, goFlagSet *goflag.FlagSet) error {
 120  	var skippedFlags []string
 121  	for _, f := range osArgs {
 122  		if isGotestFlag(f) {
 123  			skippedFlags = append(skippedFlags, f)
 124  		}
 125  	}
 126  	return goFlagSet.Parse(skippedFlags)
 127  }
 128