flags.go raw

   1  /*
   2   * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
   3   * SPDX-License-Identifier: Apache-2.0
   4   */
   5  
   6  package z
   7  
   8  import (
   9  	"errors"
  10  	"fmt"
  11  	"log"
  12  	"os"
  13  	"os/user"
  14  	"path/filepath"
  15  	"sort"
  16  	"strconv"
  17  	"strings"
  18  	"time"
  19  )
  20  
  21  // SuperFlagHelp makes it really easy to generate command line `--help` output for a SuperFlag. For
  22  // example:
  23  //
  24  //	const flagDefaults = `enabled=true; path=some/path;`
  25  //
  26  //	var help string = z.NewSuperFlagHelp(flagDefaults).
  27  //		Flag("enabled", "Turns on <something>.").
  28  //		Flag("path", "The path to <something>.").
  29  //		Flag("another", "Not present in defaults, but still included.").
  30  //		String()
  31  //
  32  // The `help` string would then contain:
  33  //
  34  //	enabled=true; Turns on <something>.
  35  //	path=some/path; The path to <something>.
  36  //	another=; Not present in defaults, but still included.
  37  //
  38  // All flags are sorted alphabetically for consistent `--help` output. Flags with default values are
  39  // placed at the top, and everything else goes under.
  40  type SuperFlagHelp struct {
  41  	head     string
  42  	defaults *SuperFlag
  43  	flags    map[string]string
  44  }
  45  
  46  func NewSuperFlagHelp(defaults string) *SuperFlagHelp {
  47  	return &SuperFlagHelp{
  48  		defaults: NewSuperFlag(defaults),
  49  		flags:    make(map[string]string, 0),
  50  	}
  51  }
  52  
  53  func (h *SuperFlagHelp) Head(head string) *SuperFlagHelp {
  54  	h.head = head
  55  	return h
  56  }
  57  
  58  func (h *SuperFlagHelp) Flag(name, description string) *SuperFlagHelp {
  59  	h.flags[name] = description
  60  	return h
  61  }
  62  
  63  func (h *SuperFlagHelp) String() string {
  64  	defaultLines := make([]string, 0)
  65  	otherLines := make([]string, 0)
  66  	for name, help := range h.flags {
  67  		val, found := h.defaults.m[name]
  68  		line := fmt.Sprintf("    %s=%s; %s\n", name, val, help)
  69  		if found {
  70  			defaultLines = append(defaultLines, line)
  71  		} else {
  72  			otherLines = append(otherLines, line)
  73  		}
  74  	}
  75  	sort.Strings(defaultLines)
  76  	sort.Strings(otherLines)
  77  	dls := strings.Join(defaultLines, "")
  78  	ols := strings.Join(otherLines, "")
  79  	if len(h.defaults.m) == 0 && len(ols) == 0 {
  80  		// remove last newline
  81  		dls = dls[:len(dls)-1]
  82  	}
  83  	// remove last newline
  84  	if len(h.defaults.m) == 0 && len(ols) > 1 {
  85  		ols = ols[:len(ols)-1]
  86  	}
  87  	return h.head + "\n" + dls + ols
  88  }
  89  
  90  func parseFlag(flag string) (map[string]string, error) {
  91  	kvm := make(map[string]string)
  92  	for _, kv := range strings.Split(flag, ";") {
  93  		if strings.TrimSpace(kv) == "" {
  94  			continue
  95  		}
  96  		// For a non-empty separator, 0 < len(splits) ≤ 2.
  97  		splits := strings.SplitN(kv, "=", 2)
  98  		k := strings.TrimSpace(splits[0])
  99  		if len(splits) < 2 {
 100  			return nil, fmt.Errorf("superflag: missing value for '%s' in flag: %s", k, flag)
 101  		}
 102  		k = strings.ToLower(k)
 103  		k = strings.ReplaceAll(k, "_", "-")
 104  		kvm[k] = strings.TrimSpace(splits[1])
 105  	}
 106  	return kvm, nil
 107  }
 108  
 109  type SuperFlag struct {
 110  	m map[string]string
 111  }
 112  
 113  func NewSuperFlag(flag string) *SuperFlag {
 114  	sf, err := newSuperFlagImpl(flag)
 115  	if err != nil {
 116  		log.Fatal(err)
 117  	}
 118  	return sf
 119  }
 120  
 121  func newSuperFlagImpl(flag string) (*SuperFlag, error) {
 122  	m, err := parseFlag(flag)
 123  	if err != nil {
 124  		return nil, err
 125  	}
 126  	return &SuperFlag{m}, nil
 127  }
 128  
 129  func (sf *SuperFlag) String() string {
 130  	if sf == nil {
 131  		return ""
 132  	}
 133  	kvs := make([]string, 0, len(sf.m))
 134  	for k, v := range sf.m {
 135  		kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
 136  	}
 137  	return strings.Join(kvs, "; ")
 138  }
 139  
 140  func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag {
 141  	sf, err := sf.mergeAndCheckDefaultImpl(flag)
 142  	if err != nil {
 143  		log.Fatal(err)
 144  	}
 145  	return sf
 146  }
 147  
 148  func (sf *SuperFlag) mergeAndCheckDefaultImpl(flag string) (*SuperFlag, error) {
 149  	if sf == nil {
 150  		m, err := parseFlag(flag)
 151  		if err != nil {
 152  			return nil, err
 153  		}
 154  		return &SuperFlag{m}, nil
 155  	}
 156  
 157  	src, err := parseFlag(flag)
 158  	if err != nil {
 159  		return nil, err
 160  	}
 161  
 162  	numKeys := len(sf.m)
 163  	for k := range src {
 164  		if _, ok := sf.m[k]; ok {
 165  			numKeys--
 166  		}
 167  	}
 168  	if numKeys != 0 {
 169  		return nil, fmt.Errorf("superflag: found invalid options: %s.\nvalid options: %v", sf, flag)
 170  	}
 171  	for k, v := range src {
 172  		if _, ok := sf.m[k]; !ok {
 173  			sf.m[k] = v
 174  		}
 175  	}
 176  	return sf, nil
 177  }
 178  
 179  func (sf *SuperFlag) Has(opt string) bool {
 180  	val := sf.GetString(opt)
 181  	return val != ""
 182  }
 183  
 184  func (sf *SuperFlag) GetDuration(opt string) time.Duration {
 185  	val := sf.GetString(opt)
 186  	if val == "" {
 187  		return time.Duration(0)
 188  	}
 189  	if strings.Contains(val, "d") {
 190  		val = strings.Replace(val, "d", "", 1)
 191  		days, err := strconv.ParseInt(val, 0, 64)
 192  		if err != nil {
 193  			return time.Duration(0)
 194  		}
 195  		return time.Hour * 24 * time.Duration(days)
 196  	}
 197  	d, err := time.ParseDuration(val)
 198  	if err != nil {
 199  		return time.Duration(0)
 200  	}
 201  	return d
 202  }
 203  
 204  func (sf *SuperFlag) GetBool(opt string) bool {
 205  	val := sf.GetString(opt)
 206  	if val == "" {
 207  		return false
 208  	}
 209  	b, err := strconv.ParseBool(val)
 210  	if err != nil {
 211  		err = errors.Join(err,
 212  			fmt.Errorf("Unable to parse %s as bool for key: %s. Options: %s\n", val, opt, sf))
 213  		log.Fatalf("%+v", err)
 214  	}
 215  	return b
 216  }
 217  
 218  func (sf *SuperFlag) GetFloat64(opt string) float64 {
 219  	val := sf.GetString(opt)
 220  	if val == "" {
 221  		return 0
 222  	}
 223  	f, err := strconv.ParseFloat(val, 64)
 224  	if err != nil {
 225  		err = errors.Join(err,
 226  			fmt.Errorf("Unable to parse %s as float64 for key: %s. Options: %s\n", val, opt, sf))
 227  		log.Fatalf("%+v", err)
 228  	}
 229  	return f
 230  }
 231  
 232  func (sf *SuperFlag) GetInt64(opt string) int64 {
 233  	val := sf.GetString(opt)
 234  	if val == "" {
 235  		return 0
 236  	}
 237  	i, err := strconv.ParseInt(val, 0, 64)
 238  	if err != nil {
 239  		err = errors.Join(err,
 240  			fmt.Errorf("Unable to parse %s as int64 for key: %s. Options: %s\n", val, opt, sf))
 241  		log.Fatalf("%+v", err)
 242  	}
 243  	return i
 244  }
 245  
 246  func (sf *SuperFlag) GetUint64(opt string) uint64 {
 247  	val := sf.GetString(opt)
 248  	if val == "" {
 249  		return 0
 250  	}
 251  	u, err := strconv.ParseUint(val, 0, 64)
 252  	if err != nil {
 253  		err = errors.Join(err,
 254  			fmt.Errorf("Unable to parse %s as uint64 for key: %s. Options: %s\n", val, opt, sf))
 255  		log.Fatalf("%+v", err)
 256  	}
 257  	return u
 258  }
 259  
 260  func (sf *SuperFlag) GetUint32(opt string) uint32 {
 261  	val := sf.GetString(opt)
 262  	if val == "" {
 263  		return 0
 264  	}
 265  	u, err := strconv.ParseUint(val, 0, 32)
 266  	if err != nil {
 267  		err = errors.Join(err,
 268  			fmt.Errorf("Unable to parse %s as uint32 for key: %s. Options: %s\n", val, opt, sf))
 269  		log.Fatalf("%+v", err)
 270  	}
 271  	return uint32(u)
 272  }
 273  
 274  func (sf *SuperFlag) GetString(opt string) string {
 275  	if sf == nil {
 276  		return ""
 277  	}
 278  	return sf.m[opt]
 279  }
 280  
 281  func (sf *SuperFlag) GetPath(opt string) string {
 282  	p := sf.GetString(opt)
 283  	path, err := expandPath(p)
 284  	if err != nil {
 285  		log.Fatalf("Failed to get path: %+v", err)
 286  	}
 287  	return path
 288  }
 289  
 290  // expandPath expands the paths containing ~ to /home/user. It also computes the absolute path
 291  // from the relative paths. For example: ~/abc/../cef will be transformed to /home/user/cef.
 292  func expandPath(path string) (string, error) {
 293  	if len(path) == 0 {
 294  		return "", nil
 295  	}
 296  	if path[0] == '~' && (len(path) == 1 || os.IsPathSeparator(path[1])) {
 297  		usr, err := user.Current()
 298  		if err != nil {
 299  			return "", errors.Join(err, errors.New("Failed to get the home directory of the user"))
 300  		}
 301  		path = filepath.Join(usr.HomeDir, path[1:])
 302  	}
 303  
 304  	var err error
 305  	path, err = filepath.Abs(path)
 306  	if err != nil {
 307  		return "", errors.Join(err, errors.New("Failed to generate absolute path"))
 308  	}
 309  	return path, nil
 310  }
 311