flags.go raw
1 /*
2 * SPDX-FileCopyrightText: © Hypermode Inc. <hello@hypermode.com>
3 * SPDX-License-Identifier: Apache-2.0
4 */
5
6 package z
7
8 import (
9 "errors"
10 "fmt"
11 "log"
12 "os"
13 "os/user"
14 "path/filepath"
15 "sort"
16 "strconv"
17 "strings"
18 "time"
19 )
20
21 // SuperFlagHelp makes it really easy to generate command line `--help` output for a SuperFlag. For
22 // example:
23 //
24 // const flagDefaults = `enabled=true; path=some/path;`
25 //
26 // var help string = z.NewSuperFlagHelp(flagDefaults).
27 // Flag("enabled", "Turns on <something>.").
28 // Flag("path", "The path to <something>.").
29 // Flag("another", "Not present in defaults, but still included.").
30 // String()
31 //
32 // The `help` string would then contain:
33 //
34 // enabled=true; Turns on <something>.
35 // path=some/path; The path to <something>.
36 // another=; Not present in defaults, but still included.
37 //
38 // All flags are sorted alphabetically for consistent `--help` output. Flags with default values are
39 // placed at the top, and everything else goes under.
40 type SuperFlagHelp struct {
41 head string
42 defaults *SuperFlag
43 flags map[string]string
44 }
45
46 func NewSuperFlagHelp(defaults string) *SuperFlagHelp {
47 return &SuperFlagHelp{
48 defaults: NewSuperFlag(defaults),
49 flags: make(map[string]string, 0),
50 }
51 }
52
53 func (h *SuperFlagHelp) Head(head string) *SuperFlagHelp {
54 h.head = head
55 return h
56 }
57
58 func (h *SuperFlagHelp) Flag(name, description string) *SuperFlagHelp {
59 h.flags[name] = description
60 return h
61 }
62
63 func (h *SuperFlagHelp) String() string {
64 defaultLines := make([]string, 0)
65 otherLines := make([]string, 0)
66 for name, help := range h.flags {
67 val, found := h.defaults.m[name]
68 line := fmt.Sprintf(" %s=%s; %s\n", name, val, help)
69 if found {
70 defaultLines = append(defaultLines, line)
71 } else {
72 otherLines = append(otherLines, line)
73 }
74 }
75 sort.Strings(defaultLines)
76 sort.Strings(otherLines)
77 dls := strings.Join(defaultLines, "")
78 ols := strings.Join(otherLines, "")
79 if len(h.defaults.m) == 0 && len(ols) == 0 {
80 // remove last newline
81 dls = dls[:len(dls)-1]
82 }
83 // remove last newline
84 if len(h.defaults.m) == 0 && len(ols) > 1 {
85 ols = ols[:len(ols)-1]
86 }
87 return h.head + "\n" + dls + ols
88 }
89
90 func parseFlag(flag string) (map[string]string, error) {
91 kvm := make(map[string]string)
92 for _, kv := range strings.Split(flag, ";") {
93 if strings.TrimSpace(kv) == "" {
94 continue
95 }
96 // For a non-empty separator, 0 < len(splits) ≤ 2.
97 splits := strings.SplitN(kv, "=", 2)
98 k := strings.TrimSpace(splits[0])
99 if len(splits) < 2 {
100 return nil, fmt.Errorf("superflag: missing value for '%s' in flag: %s", k, flag)
101 }
102 k = strings.ToLower(k)
103 k = strings.ReplaceAll(k, "_", "-")
104 kvm[k] = strings.TrimSpace(splits[1])
105 }
106 return kvm, nil
107 }
108
109 type SuperFlag struct {
110 m map[string]string
111 }
112
113 func NewSuperFlag(flag string) *SuperFlag {
114 sf, err := newSuperFlagImpl(flag)
115 if err != nil {
116 log.Fatal(err)
117 }
118 return sf
119 }
120
121 func newSuperFlagImpl(flag string) (*SuperFlag, error) {
122 m, err := parseFlag(flag)
123 if err != nil {
124 return nil, err
125 }
126 return &SuperFlag{m}, nil
127 }
128
129 func (sf *SuperFlag) String() string {
130 if sf == nil {
131 return ""
132 }
133 kvs := make([]string, 0, len(sf.m))
134 for k, v := range sf.m {
135 kvs = append(kvs, fmt.Sprintf("%s=%s", k, v))
136 }
137 return strings.Join(kvs, "; ")
138 }
139
140 func (sf *SuperFlag) MergeAndCheckDefault(flag string) *SuperFlag {
141 sf, err := sf.mergeAndCheckDefaultImpl(flag)
142 if err != nil {
143 log.Fatal(err)
144 }
145 return sf
146 }
147
148 func (sf *SuperFlag) mergeAndCheckDefaultImpl(flag string) (*SuperFlag, error) {
149 if sf == nil {
150 m, err := parseFlag(flag)
151 if err != nil {
152 return nil, err
153 }
154 return &SuperFlag{m}, nil
155 }
156
157 src, err := parseFlag(flag)
158 if err != nil {
159 return nil, err
160 }
161
162 numKeys := len(sf.m)
163 for k := range src {
164 if _, ok := sf.m[k]; ok {
165 numKeys--
166 }
167 }
168 if numKeys != 0 {
169 return nil, fmt.Errorf("superflag: found invalid options: %s.\nvalid options: %v", sf, flag)
170 }
171 for k, v := range src {
172 if _, ok := sf.m[k]; !ok {
173 sf.m[k] = v
174 }
175 }
176 return sf, nil
177 }
178
179 func (sf *SuperFlag) Has(opt string) bool {
180 val := sf.GetString(opt)
181 return val != ""
182 }
183
184 func (sf *SuperFlag) GetDuration(opt string) time.Duration {
185 val := sf.GetString(opt)
186 if val == "" {
187 return time.Duration(0)
188 }
189 if strings.Contains(val, "d") {
190 val = strings.Replace(val, "d", "", 1)
191 days, err := strconv.ParseInt(val, 0, 64)
192 if err != nil {
193 return time.Duration(0)
194 }
195 return time.Hour * 24 * time.Duration(days)
196 }
197 d, err := time.ParseDuration(val)
198 if err != nil {
199 return time.Duration(0)
200 }
201 return d
202 }
203
204 func (sf *SuperFlag) GetBool(opt string) bool {
205 val := sf.GetString(opt)
206 if val == "" {
207 return false
208 }
209 b, err := strconv.ParseBool(val)
210 if err != nil {
211 err = errors.Join(err,
212 fmt.Errorf("Unable to parse %s as bool for key: %s. Options: %s\n", val, opt, sf))
213 log.Fatalf("%+v", err)
214 }
215 return b
216 }
217
218 func (sf *SuperFlag) GetFloat64(opt string) float64 {
219 val := sf.GetString(opt)
220 if val == "" {
221 return 0
222 }
223 f, err := strconv.ParseFloat(val, 64)
224 if err != nil {
225 err = errors.Join(err,
226 fmt.Errorf("Unable to parse %s as float64 for key: %s. Options: %s\n", val, opt, sf))
227 log.Fatalf("%+v", err)
228 }
229 return f
230 }
231
232 func (sf *SuperFlag) GetInt64(opt string) int64 {
233 val := sf.GetString(opt)
234 if val == "" {
235 return 0
236 }
237 i, err := strconv.ParseInt(val, 0, 64)
238 if err != nil {
239 err = errors.Join(err,
240 fmt.Errorf("Unable to parse %s as int64 for key: %s. Options: %s\n", val, opt, sf))
241 log.Fatalf("%+v", err)
242 }
243 return i
244 }
245
246 func (sf *SuperFlag) GetUint64(opt string) uint64 {
247 val := sf.GetString(opt)
248 if val == "" {
249 return 0
250 }
251 u, err := strconv.ParseUint(val, 0, 64)
252 if err != nil {
253 err = errors.Join(err,
254 fmt.Errorf("Unable to parse %s as uint64 for key: %s. Options: %s\n", val, opt, sf))
255 log.Fatalf("%+v", err)
256 }
257 return u
258 }
259
260 func (sf *SuperFlag) GetUint32(opt string) uint32 {
261 val := sf.GetString(opt)
262 if val == "" {
263 return 0
264 }
265 u, err := strconv.ParseUint(val, 0, 32)
266 if err != nil {
267 err = errors.Join(err,
268 fmt.Errorf("Unable to parse %s as uint32 for key: %s. Options: %s\n", val, opt, sf))
269 log.Fatalf("%+v", err)
270 }
271 return uint32(u)
272 }
273
274 func (sf *SuperFlag) GetString(opt string) string {
275 if sf == nil {
276 return ""
277 }
278 return sf.m[opt]
279 }
280
281 func (sf *SuperFlag) GetPath(opt string) string {
282 p := sf.GetString(opt)
283 path, err := expandPath(p)
284 if err != nil {
285 log.Fatalf("Failed to get path: %+v", err)
286 }
287 return path
288 }
289
290 // expandPath expands the paths containing ~ to /home/user. It also computes the absolute path
291 // from the relative paths. For example: ~/abc/../cef will be transformed to /home/user/cef.
292 func expandPath(path string) (string, error) {
293 if len(path) == 0 {
294 return "", nil
295 }
296 if path[0] == '~' && (len(path) == 1 || os.IsPathSeparator(path[1])) {
297 usr, err := user.Current()
298 if err != nil {
299 return "", errors.Join(err, errors.New("Failed to get the home directory of the user"))
300 }
301 path = filepath.Join(usr.HomeDir, path[1:])
302 }
303
304 var err error
305 path, err = filepath.Abs(path)
306 if err != nil {
307 return "", errors.Join(err, errors.New("Failed to generate absolute path"))
308 }
309 return path, nil
310 }
311