diff --git a/flags.go b/flags.go index 1161e19..f9c7d06 100644 --- a/flags.go +++ b/flags.go @@ -15,6 +15,7 @@ type Flags interface { Parse(arguments []string) error Args() []string PrintDefaults() + Output() io.Writer SetOutput(output io.Writer) } @@ -46,6 +47,17 @@ func createDefaultFlags(name string) *flag.FlagSet { return flag.NewFlagSet(name, flag.ContinueOnError) } +// Parse wraps the inner flag Parse method, making sure that the error output is +// discarded/silenced. +func (f *flagSet) Parse(arguments []string) error { + originalOut := f.Output() + + f.SetOutput(io.Discard) + defer f.SetOutput(originalOut) + + return f.Flags.Parse(arguments) +} + func (a *MultiCommandApp) isUniqueFlagSet(flags Flags) bool { // NOTE: We have to use `reflect.DeepEqual`, because the interface values // could be non-comparable and could panic at runtime. @@ -98,20 +110,10 @@ func (a *app) setupFlagSet(flagSet *flagSet) { } } -// setUsage sets the `.Usage` field of the flags. -func (a *SingleCommandApp) setUsage(flagSet *flagSet) { - setUsage(flagSet, a.PrintHelp) -} - -// setUsage sets the `.Usage` field of the flags for a given command. -func (a *MultiCommandApp) setUsage(flagSet *flagSet, commandName string) { - setUsage(flagSet, func() { a.PrintHelp(commandName) }) -} - // printFlagDefaults wraps the writing of flag default values func (a *app) printFlagDefaults(flags Flags) { var buffer bytes.Buffer - originalOut := a.errOut + originalOut := flags.Output() flags.SetOutput(&buffer) flags.PrintDefaults() @@ -127,90 +129,3 @@ func (a *app) printFlagDefaults(flags Flags) { // Restore the original output flags.SetOutput(originalOut) } - -func setUsage(flagSet *flagSet, usageFunc func()) { - switch flags := flagSet.Flags.(type) { - case *flag.FlagSet: - // If we're dealing with the standard library flags, just set the usage - // function natively - flags.Usage = usageFunc - default: - // Otherwise, we'll have to use reflection to work generically... - // - // TODO: Remove the use of reflection one we can use the type-system to - // reliably detect types with specific fields, and set them. - // - // We CAN try and enforce a specific type of the flag set itself, but - // then we'll only be able to be fully compatible with one flag package, - // and none of the popular forks (like github.com/spf13/pflag). - // - // Until then, we'll HAVE to use reflection... :( - // - // See: https://github.com/golang/go/issues/48522 - usageReflect := reflectFlagsUsage(flagSet.Flags) - if usageReflect == nil || !usageReflect.CanSet() { - return - } - - reflectFunc := reflect.ValueOf(usageFunc) - usageReflect.Set(reflectFunc) - } -} - -func reflectFlagsUsage(flags Flags) *reflect.Value { - flagsReflect := reflect.ValueOf(flags) - flagsReflect = reflectElemUntil(flagsReflect, func(value reflect.Value) bool { - return value.Kind() == reflect.Struct - }) - - if flagsReflect.Kind() != reflect.Struct { - return nil - } - - usageReflect := flagsReflect.FieldByName("Usage") - usageFuncType := reflect.TypeOf(flag.Usage) - - if !usageReflect.IsValid() || !usageReflect.Type().AssignableTo(usageFuncType) { - if embedded := findEmbeddedFlagsStruct(flags); embedded != nil { - return reflectFlagsUsage(embedded) - } - - return nil - } - - return &usageReflect -} - -func findEmbeddedFlagsStruct(flags Flags) Flags { - flagsReflect := reflect.ValueOf(flags) - flagsReflect = reflectElemUntil(flagsReflect, func(value reflect.Value) bool { - return value.Kind() == reflect.Struct - }) - - flagsType := reflect.TypeOf((*Flags)(nil)).Elem() - for i := 0; i < flagsReflect.NumField(); i++ { - field := flagsReflect.Field(i) - - field = reflectElemUntil(field, func(value reflect.Value) bool { - canElem := value.Kind() == reflect.Pointer || value.Kind() == reflect.Interface - isStructPointer := canElem && value.Elem().Kind() == reflect.Struct - - return isStructPointer && - value.CanInterface() && - value.Type().Implements(flagsType) - }) - - if field.IsValid() && field.CanInterface() && field.Type().Implements(flagsType) { - return field.Interface().(Flags) - } - } - - return nil -} - -func reflectElemUntil(value reflect.Value, until func(value reflect.Value) bool) reflect.Value { - for !until(value) && (value.Kind() == reflect.Pointer || value.Kind() == reflect.Interface) { - value = value.Elem() - } - return value -} diff --git a/flags_test.go b/flags_test.go index b505353..bb45b48 100644 --- a/flags_test.go +++ b/flags_test.go @@ -2,10 +2,7 @@ package lieut import ( "context" - "flag" "io" - "os" - "reflect" "testing" ) @@ -22,6 +19,10 @@ func (b *bogusFlags) Args() []string { func (b *bogusFlags) PrintDefaults() { } +func (b *bogusFlags) Output() io.Writer { + return nil +} + func (b *bogusFlags) SetOutput(output io.Writer) { } @@ -73,52 +74,3 @@ func TestBogusFlags_WorkWithMultiCommandApps(t *testing.T) { app.PrintUsageError("foo", nil) app.Run(context.TODO(), nil) } - -func TestUsageIsSetCorrectlyForEmbeddedFlags(t *testing.T) { - customFlags := struct { - bogus *bogusFlags - Bogus *bogusFlags - - *flag.FlagSet - - Usage func(int, string, float64) // Make sure that it doesn't try and set this! - }{ - FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), - - Usage: nil, - } - - app := NewSingleCommandApp(AppInfo{}, nil, &customFlags, os.Stdout, os.Stderr) - - if app == nil { - t.Fatal("NewSingleCommandApp returned nil") - } - - flagsUsageFn := reflect.ValueOf(customFlags.FlagSet.Usage).Pointer() - want := reflect.ValueOf(app.PrintHelp).Pointer() - - if flagsUsageFn != want { - t.Errorf("flags Usage wasn't set correctly, is %v", flagsUsageFn) - } -} - -func TestUsageReflectionIsSafeForEmbeddedBogusFlags(t *testing.T) { - customFlags := struct { - *bogusFlags - Bogus *bogusFlags - - Usage func(int, string, float64) // Make sure that it doesn't try and set this! - }{ - Usage: nil, - } - - app := NewSingleCommandApp(AppInfo{}, nil, &customFlags, os.Stdout, os.Stderr) - - if app == nil { - t.Fatal("NewSingleCommandApp returned nil") - } - - if customFlags.Usage != nil { - t.Error("flags Usage was set when it shouldn't have been") - } -} diff --git a/lieut.go b/lieut.go index a09c016..d1d7ff8 100644 --- a/lieut.go +++ b/lieut.go @@ -83,6 +83,9 @@ type MultiCommandApp struct { } // NewSingleCommandApp returns an initialized SingleCommandApp. +// +// The provided flags should have ContinueOnError ErrorHandling, or else flag +// parsing errors won't properly be displayed/handled. func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer, errOut io.Writer) *SingleCommandApp { if info.Name == "" { info.Name = inferAppName() @@ -112,7 +115,6 @@ func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer } app.setupFlagSet(app.flags) - app.setUsage(app.flags) return app } @@ -120,6 +122,9 @@ func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer // NewMultiCommandApp returns an initialized MultiCommandApp. // // The provided flags are global/shared among the app's commands. +// +// The provided flags should have ContinueOnError ErrorHandling, or else flag +// parsing errors won't properly be displayed/handled. func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writer) *MultiCommandApp { if info.Name == "" { info.Name = inferAppName() @@ -149,7 +154,6 @@ func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writ } app.setupFlagSet(app.flags) - app.setUsage(app.flags, "") return app } @@ -158,6 +162,9 @@ func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writ // // It returns an error if the provided flags have already been used for another // command (or for the globals). +// +// The provided flags should have ContinueOnError ErrorHandling, or else flag +// parsing errors won't properly be displayed/handled. func (a *MultiCommandApp) SetCommand(info CommandInfo, exec Executor, flags Flags) error { if info.Usage == "" { info.Usage = DefaultCommandUsage @@ -174,7 +181,6 @@ func (a *MultiCommandApp) SetCommand(info CommandInfo, exec Executor, flags Flag flagSet := &flagSet{Flags: flags} a.setupFlagSet(flagSet) - a.setUsage(flagSet, info.Name) a.commands[info.Name] = command{info: info, Executor: exec, flags: flagSet} @@ -204,7 +210,8 @@ func (a *SingleCommandApp) Run(ctx context.Context, arguments []string) int { } if err := a.flags.Parse(arguments); err != nil { - return 1 + a.PrintUsageError(err) + return 2 } if intercepted := a.intercept(a.flags); intercepted { @@ -248,7 +255,8 @@ func (a *MultiCommandApp) Run(ctx context.Context, arguments []string) int { } if err := flags.Parse(arguments); err != nil { - return 1 + a.PrintUsageError(commandName, err) + return 2 } if intercepted := a.intercept(flags, commandName); intercepted { diff --git a/lieut_test.go b/lieut_test.go index 348caea..6404705 100644 --- a/lieut_test.go +++ b/lieut_test.go @@ -697,22 +697,14 @@ test vTest (%s/%s) flags: flag.NewFlagSet("test", flag.ContinueOnError), args: []string{"--non-existent-flag=val"}, - wantedExitCode: 1, + wantedExitCode: 2, wantedOut: "", - wantedErrOut: fmt.Sprintf(`flag provided but not defined: -non-existent-flag -Usage: test testing - -A test + wantedErrOut: `Error: flag provided but not defined: -non-existent-flag -Options: - - -help - Display the help message - -version - Display the application version +Usage: test testing -test vTest (%s/%s) -`, runtime.GOOS, runtime.GOARCH), +Run 'test --help' for usage. +`, }, "initialize returns error": { init: func() error { @@ -993,26 +985,14 @@ test vTest (%s/%s) flags: flag.NewFlagSet("test", flag.ContinueOnError), args: []string{"--non-existent-flag=val"}, - wantedExitCode: 1, + wantedExitCode: 2, wantedOut: "", - wantedErrOut: fmt.Sprintf(`flag provided but not defined: -non-existent-flag -Usage: test testing - -A test - -Commands: - - testcommand A test command - -Options: + wantedErrOut: `Error: flag provided but not defined: -non-existent-flag - -help - Display the help message - -version - Display the application version +Usage: test testing -test vTest (%s/%s) -`, runtime.GOOS, runtime.GOARCH), +Run 'test --help' for usage. +`, }, "initialize returns error": { init: func() error {