From 895e62704c6c71c419a012781d28a35b593fa482 Mon Sep 17 00:00:00 2001 From: Trevor Suarez Date: Fri, 22 Mar 2024 04:33:25 -0600 Subject: [PATCH] Improving flag parse error handling This not only improves the handling of error messages during flag parsing, but also GREATLY reduces the complexity of flag handling in the framework, as we no longer have to worry about complex reflection to set the `Usage` property of a FlagSet. --- flags.go | 111 ++++++-------------------------------------------- flags_test.go | 56 ++----------------------- lieut.go | 18 +++++--- lieut_test.go | 40 +++++------------- 4 files changed, 40 insertions(+), 185 deletions(-) diff --git a/flags.go b/flags.go index 1161e19..f9c7d06 100644 --- a/flags.go +++ b/flags.go @@ -15,6 +15,7 @@ type Flags interface { Parse(arguments []string) error Args() []string PrintDefaults() + Output() io.Writer SetOutput(output io.Writer) } @@ -46,6 +47,17 @@ func createDefaultFlags(name string) *flag.FlagSet { return flag.NewFlagSet(name, flag.ContinueOnError) } +// Parse wraps the inner flag Parse method, making sure that the error output is +// discarded/silenced. +func (f *flagSet) Parse(arguments []string) error { + originalOut := f.Output() + + f.SetOutput(io.Discard) + defer f.SetOutput(originalOut) + + return f.Flags.Parse(arguments) +} + func (a *MultiCommandApp) isUniqueFlagSet(flags Flags) bool { // NOTE: We have to use `reflect.DeepEqual`, because the interface values // could be non-comparable and could panic at runtime. @@ -98,20 +110,10 @@ func (a *app) setupFlagSet(flagSet *flagSet) { } } -// setUsage sets the `.Usage` field of the flags. -func (a *SingleCommandApp) setUsage(flagSet *flagSet) { - setUsage(flagSet, a.PrintHelp) -} - -// setUsage sets the `.Usage` field of the flags for a given command. -func (a *MultiCommandApp) setUsage(flagSet *flagSet, commandName string) { - setUsage(flagSet, func() { a.PrintHelp(commandName) }) -} - // printFlagDefaults wraps the writing of flag default values func (a *app) printFlagDefaults(flags Flags) { var buffer bytes.Buffer - originalOut := a.errOut + originalOut := flags.Output() flags.SetOutput(&buffer) flags.PrintDefaults() @@ -127,90 +129,3 @@ func (a *app) printFlagDefaults(flags Flags) { // Restore the original output flags.SetOutput(originalOut) } - -func setUsage(flagSet *flagSet, usageFunc func()) { - switch flags := flagSet.Flags.(type) { - case *flag.FlagSet: - // If we're dealing with the standard library flags, just set the usage - // function natively - flags.Usage = usageFunc - default: - // Otherwise, we'll have to use reflection to work generically... - // - // TODO: Remove the use of reflection one we can use the type-system to - // reliably detect types with specific fields, and set them. - // - // We CAN try and enforce a specific type of the flag set itself, but - // then we'll only be able to be fully compatible with one flag package, - // and none of the popular forks (like github.com/spf13/pflag). - // - // Until then, we'll HAVE to use reflection... :( - // - // See: https://github.com/golang/go/issues/48522 - usageReflect := reflectFlagsUsage(flagSet.Flags) - if usageReflect == nil || !usageReflect.CanSet() { - return - } - - reflectFunc := reflect.ValueOf(usageFunc) - usageReflect.Set(reflectFunc) - } -} - -func reflectFlagsUsage(flags Flags) *reflect.Value { - flagsReflect := reflect.ValueOf(flags) - flagsReflect = reflectElemUntil(flagsReflect, func(value reflect.Value) bool { - return value.Kind() == reflect.Struct - }) - - if flagsReflect.Kind() != reflect.Struct { - return nil - } - - usageReflect := flagsReflect.FieldByName("Usage") - usageFuncType := reflect.TypeOf(flag.Usage) - - if !usageReflect.IsValid() || !usageReflect.Type().AssignableTo(usageFuncType) { - if embedded := findEmbeddedFlagsStruct(flags); embedded != nil { - return reflectFlagsUsage(embedded) - } - - return nil - } - - return &usageReflect -} - -func findEmbeddedFlagsStruct(flags Flags) Flags { - flagsReflect := reflect.ValueOf(flags) - flagsReflect = reflectElemUntil(flagsReflect, func(value reflect.Value) bool { - return value.Kind() == reflect.Struct - }) - - flagsType := reflect.TypeOf((*Flags)(nil)).Elem() - for i := 0; i < flagsReflect.NumField(); i++ { - field := flagsReflect.Field(i) - - field = reflectElemUntil(field, func(value reflect.Value) bool { - canElem := value.Kind() == reflect.Pointer || value.Kind() == reflect.Interface - isStructPointer := canElem && value.Elem().Kind() == reflect.Struct - - return isStructPointer && - value.CanInterface() && - value.Type().Implements(flagsType) - }) - - if field.IsValid() && field.CanInterface() && field.Type().Implements(flagsType) { - return field.Interface().(Flags) - } - } - - return nil -} - -func reflectElemUntil(value reflect.Value, until func(value reflect.Value) bool) reflect.Value { - for !until(value) && (value.Kind() == reflect.Pointer || value.Kind() == reflect.Interface) { - value = value.Elem() - } - return value -} diff --git a/flags_test.go b/flags_test.go index b505353..bb45b48 100644 --- a/flags_test.go +++ b/flags_test.go @@ -2,10 +2,7 @@ package lieut import ( "context" - "flag" "io" - "os" - "reflect" "testing" ) @@ -22,6 +19,10 @@ func (b *bogusFlags) Args() []string { func (b *bogusFlags) PrintDefaults() { } +func (b *bogusFlags) Output() io.Writer { + return nil +} + func (b *bogusFlags) SetOutput(output io.Writer) { } @@ -73,52 +74,3 @@ func TestBogusFlags_WorkWithMultiCommandApps(t *testing.T) { app.PrintUsageError("foo", nil) app.Run(context.TODO(), nil) } - -func TestUsageIsSetCorrectlyForEmbeddedFlags(t *testing.T) { - customFlags := struct { - bogus *bogusFlags - Bogus *bogusFlags - - *flag.FlagSet - - Usage func(int, string, float64) // Make sure that it doesn't try and set this! - }{ - FlagSet: flag.NewFlagSet("test", flag.ContinueOnError), - - Usage: nil, - } - - app := NewSingleCommandApp(AppInfo{}, nil, &customFlags, os.Stdout, os.Stderr) - - if app == nil { - t.Fatal("NewSingleCommandApp returned nil") - } - - flagsUsageFn := reflect.ValueOf(customFlags.FlagSet.Usage).Pointer() - want := reflect.ValueOf(app.PrintHelp).Pointer() - - if flagsUsageFn != want { - t.Errorf("flags Usage wasn't set correctly, is %v", flagsUsageFn) - } -} - -func TestUsageReflectionIsSafeForEmbeddedBogusFlags(t *testing.T) { - customFlags := struct { - *bogusFlags - Bogus *bogusFlags - - Usage func(int, string, float64) // Make sure that it doesn't try and set this! - }{ - Usage: nil, - } - - app := NewSingleCommandApp(AppInfo{}, nil, &customFlags, os.Stdout, os.Stderr) - - if app == nil { - t.Fatal("NewSingleCommandApp returned nil") - } - - if customFlags.Usage != nil { - t.Error("flags Usage was set when it shouldn't have been") - } -} diff --git a/lieut.go b/lieut.go index a09c016..d1d7ff8 100644 --- a/lieut.go +++ b/lieut.go @@ -83,6 +83,9 @@ type MultiCommandApp struct { } // NewSingleCommandApp returns an initialized SingleCommandApp. +// +// The provided flags should have ContinueOnError ErrorHandling, or else flag +// parsing errors won't properly be displayed/handled. func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer, errOut io.Writer) *SingleCommandApp { if info.Name == "" { info.Name = inferAppName() @@ -112,7 +115,6 @@ func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer } app.setupFlagSet(app.flags) - app.setUsage(app.flags) return app } @@ -120,6 +122,9 @@ func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer // NewMultiCommandApp returns an initialized MultiCommandApp. // // The provided flags are global/shared among the app's commands. +// +// The provided flags should have ContinueOnError ErrorHandling, or else flag +// parsing errors won't properly be displayed/handled. func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writer) *MultiCommandApp { if info.Name == "" { info.Name = inferAppName() @@ -149,7 +154,6 @@ func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writ } app.setupFlagSet(app.flags) - app.setUsage(app.flags, "") return app } @@ -158,6 +162,9 @@ func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writ // // It returns an error if the provided flags have already been used for another // command (or for the globals). +// +// The provided flags should have ContinueOnError ErrorHandling, or else flag +// parsing errors won't properly be displayed/handled. func (a *MultiCommandApp) SetCommand(info CommandInfo, exec Executor, flags Flags) error { if info.Usage == "" { info.Usage = DefaultCommandUsage @@ -174,7 +181,6 @@ func (a *MultiCommandApp) SetCommand(info CommandInfo, exec Executor, flags Flag flagSet := &flagSet{Flags: flags} a.setupFlagSet(flagSet) - a.setUsage(flagSet, info.Name) a.commands[info.Name] = command{info: info, Executor: exec, flags: flagSet} @@ -204,7 +210,8 @@ func (a *SingleCommandApp) Run(ctx context.Context, arguments []string) int { } if err := a.flags.Parse(arguments); err != nil { - return 1 + a.PrintUsageError(err) + return 2 } if intercepted := a.intercept(a.flags); intercepted { @@ -248,7 +255,8 @@ func (a *MultiCommandApp) Run(ctx context.Context, arguments []string) int { } if err := flags.Parse(arguments); err != nil { - return 1 + a.PrintUsageError(commandName, err) + return 2 } if intercepted := a.intercept(flags, commandName); intercepted { diff --git a/lieut_test.go b/lieut_test.go index 348caea..6404705 100644 --- a/lieut_test.go +++ b/lieut_test.go @@ -697,22 +697,14 @@ test vTest (%s/%s) flags: flag.NewFlagSet("test", flag.ContinueOnError), args: []string{"--non-existent-flag=val"}, - wantedExitCode: 1, + wantedExitCode: 2, wantedOut: "", - wantedErrOut: fmt.Sprintf(`flag provided but not defined: -non-existent-flag -Usage: test testing - -A test + wantedErrOut: `Error: flag provided but not defined: -non-existent-flag -Options: - - -help - Display the help message - -version - Display the application version +Usage: test testing -test vTest (%s/%s) -`, runtime.GOOS, runtime.GOARCH), +Run 'test --help' for usage. +`, }, "initialize returns error": { init: func() error { @@ -993,26 +985,14 @@ test vTest (%s/%s) flags: flag.NewFlagSet("test", flag.ContinueOnError), args: []string{"--non-existent-flag=val"}, - wantedExitCode: 1, + wantedExitCode: 2, wantedOut: "", - wantedErrOut: fmt.Sprintf(`flag provided but not defined: -non-existent-flag -Usage: test testing - -A test - -Commands: - - testcommand A test command - -Options: + wantedErrOut: `Error: flag provided but not defined: -non-existent-flag - -help - Display the help message - -version - Display the application version +Usage: test testing -test vTest (%s/%s) -`, runtime.GOOS, runtime.GOARCH), +Run 'test --help' for usage. +`, }, "initialize returns error": { init: func() error {