diff --git a/bool.go b/bool.go index 70e2e0a6..1ebac16a 100644 --- a/bool.go +++ b/bool.go @@ -37,23 +37,23 @@ func (b *boolValue) IsBoolFlag() bool { return true } // BoolVar defines a bool flag with specified name, default value, and usage string. // The argument p points to a bool variable in which to store the value of the flag. func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) { - f.VarP(newBoolValue(value, p), name, "", usage) + f.OVarP(newBoolValue(value, p), name, "", usage, true) } // Like BoolVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage string) { - f.VarP(newBoolValue(value, p), name, shorthand, usage) + f.OVarP(newBoolValue(value, p), name, shorthand, usage, true) } // BoolVar defines a bool flag with specified name, default value, and usage string. // The argument p points to a bool variable in which to store the value of the flag. func BoolVar(p *bool, name string, value bool, usage string) { - CommandLine.VarP(newBoolValue(value, p), name, "", usage) + CommandLine.OVarP(newBoolValue(value, p), name, "", usage, true) } // Like BoolVar, but accepts a shorthand letter that can be used after a single dash. func BoolVarP(p *bool, name, shorthand string, value bool, usage string) { - CommandLine.VarP(newBoolValue(value, p), name, shorthand, usage) + CommandLine.OVarP(newBoolValue(value, p), name, shorthand, usage, true) } // Bool defines a bool flag with specified name, default value, and usage string. @@ -61,6 +61,7 @@ func BoolVarP(p *bool, name, shorthand string, value bool, usage string) { func (f *FlagSet) Bool(name string, value bool, usage string) *bool { p := new(bool) f.BoolVarP(p, name, "", value, usage) + f.MarkOptional(name) return p } @@ -68,16 +69,21 @@ func (f *FlagSet) Bool(name string, value bool, usage string) *bool { func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool { p := new(bool) f.BoolVarP(p, name, shorthand, value, usage) + f.MarkOptional(name) return p } // Bool defines a bool flag with specified name, default value, and usage string. // The return value is the address of a bool variable that stores the value of the flag. func Bool(name string, value bool, usage string) *bool { - return CommandLine.BoolP(name, "", value, usage) + b := CommandLine.BoolP(name, "", value, usage) + CommandLine.MarkOptional(name) + return b } // Like Bool, but accepts a shorthand letter that can be used after a single dash. func BoolP(name, shorthand string, value bool, usage string) *bool { - return CommandLine.BoolP(name, shorthand, value, usage) + b := CommandLine.BoolP(name, shorthand, value, usage) + CommandLine.MarkOptional(name) + return b } diff --git a/bool_test.go b/bool_test.go index a2e1c5dc..f42880c6 100644 --- a/bool_test.go +++ b/bool_test.go @@ -60,6 +60,7 @@ func setUpFlagSet(tristate *triStateValue) *FlagSet { f := NewFlagSet("test", ContinueOnError) *tristate = triStateFalse f.VarP(tristate, "tristate", "t", "tristate value (true, maybe or false)") + f.MarkOptional("tristate") return f } diff --git a/flag.go b/flag.go index 55594df4..182100cd 100644 --- a/flag.go +++ b/flag.go @@ -152,6 +152,7 @@ type Flag struct { Value Value // value as set DefValue string // default value (as text); for usage message Changed bool // If the user set the value (or if left to default) + Optional bool // If the flag argument is optional Deprecated string // If this flag is deprecated, this string is the new or now thing to use Annotations map[string][]string // used by cobra.Command bash autocomple code } @@ -250,6 +251,11 @@ func (f *FlagSet) Lookup(name string) *Flag { return f.lookup(f.normalizeFlagName(name)) } +// Changed checks by name if a given flag has been changed. +func (f *FlagSet) Changed(name string) bool { + return f.Lookup(name).Changed +} + // lookup returns the Flag structure of the named flag, returning nil if none exists. func (f *FlagSet) lookup(name NormalizedName) *Flag { return f.formal[name] @@ -265,6 +271,21 @@ func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error { return nil } +// Mark a flag argument optional in your program +func (f *FlagSet) MarkOptional(name string) error { + flag := f.Lookup(name) + if flag == nil { + return fmt.Errorf("flag %q does not exist", name) + } + flag.Optional = true + return nil +} + +// Changed checks by name if a given flag has been set. +func Changed(name string) bool { + return CommandLine.Changed(name) +} + // Lookup returns the Flag structure of the named command-line flag, // returning nil if none exists. func Lookup(name string) *Flag { @@ -409,6 +430,11 @@ func (f *FlagSet) Var(value Value, name string, usage string) { // Like Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { + f.OVarP(value, name, shorthand, usage, false) +} + +// Like VarP, but allows to mark the flag as allowing optional argument. +func (f *FlagSet) OVarP(value Value, name, shorthand, usage string, optional bool) { // Remember the default value as a string; it won't change. flag := &Flag{ Name: name, @@ -416,6 +442,7 @@ func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { Usage: usage, Value: value, DefValue: value.String(), + Optional: optional, } f.AddFlag(flag) } @@ -466,6 +493,11 @@ func VarP(value Value, name, shorthand, usage string) { CommandLine.VarP(value, name, shorthand, usage) } +// Like VarP, but allows to mark the flag as allowing optional argument. +func OVarP(value Value, name, shorthand, usage string, optional bool) { + CommandLine.OVarP(value, name, shorthand, usage, optional) +} + // failf prints to standard error a formatted error and usage message and // returns the error. func (f *FlagSet) failf(format string, a ...interface{}) error { @@ -527,11 +559,22 @@ func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error) return } if len(split) == 1 { - if bv, ok := flag.Value.(boolFlag); !ok || !bv.IsBoolFlag() { + if !flag.Optional { err = f.failf("flag needs an argument: %s", s) return } - f.setFlag(flag, "true", s) + if bv, ok := flag.Value.(boolFlag); ok && bv.IsBoolFlag() { + f.setFlag(flag, "true", s) + } else { + v := flag.DefValue + if len(args) == 1 { + v = args[0] + } + if e := f.setFlag(flag, v, s); e != nil { + err = e + return + } + } } else { if e := f.setFlag(flag, split[1], s); e != nil { err = e @@ -575,8 +618,11 @@ func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error) break } if len(args) == 0 { - err = f.failf("flag needs an argument: %q in -%s", c, shorthands) - return + if !flag.Optional { + err = f.failf("flag needs an argument: %q in -%s", c, shorthands) + return + } + args = append(args, flag.DefValue) } if e := f.setFlag(flag, args[0], s); e != nil { err = e diff --git a/flag_test.go b/flag_test.go index 7d114b22..0b872903 100644 --- a/flag_test.go +++ b/flag_test.go @@ -17,14 +17,15 @@ import ( ) var ( - test_bool = Bool("test_bool", false, "bool value") - test_int = Int("test_int", 0, "int value") - test_int64 = Int64("test_int64", 0, "int64 value") - test_uint = Uint("test_uint", 0, "uint value") - test_uint64 = Uint64("test_uint64", 0, "uint64 value") - test_string = String("test_string", "0", "string value") - test_float64 = Float64("test_float64", 0, "float64 value") - test_duration = Duration("test_duration", 0, "time.Duration value") + test_bool = Bool("test_bool", false, "bool value") + test_int = Int("test_int", 0, "int value") + test_int64 = Int64("test_int64", 0, "int64 value") + test_uint = Uint("test_uint", 0, "uint value") + test_uint64 = Uint64("test_uint64", 0, "uint64 value") + test_string = String("test_string", "0", "string value") + test_float64 = Float64("test_float64", 0, "float64 value") + test_duration = Duration("test_duration", 0, "time.Duration value") + test_optional_int = Int("test_optional_int", 0, "optional int value") ) func boolString(s string) string { @@ -55,7 +56,7 @@ func TestEverything(t *testing.T) { } } VisitAll(visitor) - if len(m) != 8 { + if len(m) != 9 { t.Error("VisitAll misses some flags") for k, v := range m { t.Log(k, *v) @@ -78,9 +79,10 @@ func TestEverything(t *testing.T) { Set("test_string", "1") Set("test_float64", "1") Set("test_duration", "1s") + Set("test_optional_int", "1") desired = "1" Visit(visitor) - if len(m) != 8 { + if len(m) != 9 { t.Error("Visit fails after set") for k, v := range m { t.Log(k, *v) @@ -119,6 +121,10 @@ func testParse(f *FlagSet, t *testing.T) { stringFlag := f.String("string", "0", "string value") float64Flag := f.Float64("float64", 0, "float64 value") durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value") + optionalIntNoValueFlag := f.Int("optional-int-no-value", 9, "int value") + optionalIntWithValueFlag := f.Int("optional-int-with-value", 9, "int value") + f.MarkOptional("optional-int-no-value") + f.MarkOptional("optional-int-with-value") extra := "one-extra-argument" args := []string{ "--bool", @@ -131,6 +137,8 @@ func testParse(f *FlagSet, t *testing.T) { "--string=hello", "--float64=2718e28", "--duration=2m", + "--optional-int-no-value", + "--optional-int-with-value=42", extra, } if err := f.Parse(args); err != nil { @@ -169,6 +177,12 @@ func testParse(f *FlagSet, t *testing.T) { if *durationFlag != 2*time.Minute { t.Error("duration flag should be 2m, is ", *durationFlag) } + if *optionalIntNoValueFlag != 9 { + t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) + } + if *optionalIntWithValueFlag != 42 { + t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag) + } if len(f.Args()) != 1 { t.Error("expected one argument, got", len(f.Args())) } else if f.Args()[0] != extra {