Skip to content

Commit

Permalink
Improving flag parse error handling
Browse files Browse the repository at this point in the history
This not only improves the handling of error messages during flag
parsing, but also GREATLY reduces the complexity of flag handling in the
framework, as we no longer have to worry about complex reflection to set
the `Usage` property of a FlagSet.
  • Loading branch information
Rican7 committed Mar 22, 2024
1 parent 9694bf3 commit 895e627
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 185 deletions.
111 changes: 13 additions & 98 deletions flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type Flags interface {
Parse(arguments []string) error
Args() []string
PrintDefaults()
Output() io.Writer
SetOutput(output io.Writer)
}

Expand Down Expand Up @@ -46,6 +47,17 @@ func createDefaultFlags(name string) *flag.FlagSet {
return flag.NewFlagSet(name, flag.ContinueOnError)
}

// Parse wraps the inner flag Parse method, making sure that the error output is
// discarded/silenced.
func (f *flagSet) Parse(arguments []string) error {
originalOut := f.Output()

f.SetOutput(io.Discard)
defer f.SetOutput(originalOut)

return f.Flags.Parse(arguments)
}

func (a *MultiCommandApp) isUniqueFlagSet(flags Flags) bool {
// NOTE: We have to use `reflect.DeepEqual`, because the interface values
// could be non-comparable and could panic at runtime.
Expand Down Expand Up @@ -98,20 +110,10 @@ func (a *app) setupFlagSet(flagSet *flagSet) {
}
}

// setUsage sets the `.Usage` field of the flags.
func (a *SingleCommandApp) setUsage(flagSet *flagSet) {
setUsage(flagSet, a.PrintHelp)
}

// setUsage sets the `.Usage` field of the flags for a given command.
func (a *MultiCommandApp) setUsage(flagSet *flagSet, commandName string) {
setUsage(flagSet, func() { a.PrintHelp(commandName) })
}

// printFlagDefaults wraps the writing of flag default values
func (a *app) printFlagDefaults(flags Flags) {
var buffer bytes.Buffer
originalOut := a.errOut
originalOut := flags.Output()

flags.SetOutput(&buffer)
flags.PrintDefaults()
Expand All @@ -127,90 +129,3 @@ func (a *app) printFlagDefaults(flags Flags) {
// Restore the original output
flags.SetOutput(originalOut)
}

func setUsage(flagSet *flagSet, usageFunc func()) {
switch flags := flagSet.Flags.(type) {
case *flag.FlagSet:
// If we're dealing with the standard library flags, just set the usage
// function natively
flags.Usage = usageFunc
default:
// Otherwise, we'll have to use reflection to work generically...
//
// TODO: Remove the use of reflection one we can use the type-system to
// reliably detect types with specific fields, and set them.
//
// We CAN try and enforce a specific type of the flag set itself, but
// then we'll only be able to be fully compatible with one flag package,
// and none of the popular forks (like github.com/spf13/pflag).
//
// Until then, we'll HAVE to use reflection... :(
//
// See: https://github.com/golang/go/issues/48522
usageReflect := reflectFlagsUsage(flagSet.Flags)
if usageReflect == nil || !usageReflect.CanSet() {
return
}

reflectFunc := reflect.ValueOf(usageFunc)
usageReflect.Set(reflectFunc)
}
}

func reflectFlagsUsage(flags Flags) *reflect.Value {
flagsReflect := reflect.ValueOf(flags)
flagsReflect = reflectElemUntil(flagsReflect, func(value reflect.Value) bool {
return value.Kind() == reflect.Struct
})

if flagsReflect.Kind() != reflect.Struct {
return nil
}

usageReflect := flagsReflect.FieldByName("Usage")
usageFuncType := reflect.TypeOf(flag.Usage)

if !usageReflect.IsValid() || !usageReflect.Type().AssignableTo(usageFuncType) {
if embedded := findEmbeddedFlagsStruct(flags); embedded != nil {
return reflectFlagsUsage(embedded)
}

return nil
}

return &usageReflect
}

func findEmbeddedFlagsStruct(flags Flags) Flags {
flagsReflect := reflect.ValueOf(flags)
flagsReflect = reflectElemUntil(flagsReflect, func(value reflect.Value) bool {
return value.Kind() == reflect.Struct
})

flagsType := reflect.TypeOf((*Flags)(nil)).Elem()
for i := 0; i < flagsReflect.NumField(); i++ {
field := flagsReflect.Field(i)

field = reflectElemUntil(field, func(value reflect.Value) bool {
canElem := value.Kind() == reflect.Pointer || value.Kind() == reflect.Interface
isStructPointer := canElem && value.Elem().Kind() == reflect.Struct

return isStructPointer &&
value.CanInterface() &&
value.Type().Implements(flagsType)
})

if field.IsValid() && field.CanInterface() && field.Type().Implements(flagsType) {
return field.Interface().(Flags)
}
}

return nil
}

func reflectElemUntil(value reflect.Value, until func(value reflect.Value) bool) reflect.Value {
for !until(value) && (value.Kind() == reflect.Pointer || value.Kind() == reflect.Interface) {
value = value.Elem()
}
return value
}
56 changes: 4 additions & 52 deletions flags_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,7 @@ package lieut

import (
"context"
"flag"
"io"
"os"
"reflect"
"testing"
)

Expand All @@ -22,6 +19,10 @@ func (b *bogusFlags) Args() []string {
func (b *bogusFlags) PrintDefaults() {
}

func (b *bogusFlags) Output() io.Writer {
return nil
}

func (b *bogusFlags) SetOutput(output io.Writer) {
}

Expand Down Expand Up @@ -73,52 +74,3 @@ func TestBogusFlags_WorkWithMultiCommandApps(t *testing.T) {
app.PrintUsageError("foo", nil)
app.Run(context.TODO(), nil)
}

func TestUsageIsSetCorrectlyForEmbeddedFlags(t *testing.T) {
customFlags := struct {
bogus *bogusFlags
Bogus *bogusFlags

*flag.FlagSet

Usage func(int, string, float64) // Make sure that it doesn't try and set this!
}{
FlagSet: flag.NewFlagSet("test", flag.ContinueOnError),

Usage: nil,
}

app := NewSingleCommandApp(AppInfo{}, nil, &customFlags, os.Stdout, os.Stderr)

if app == nil {
t.Fatal("NewSingleCommandApp returned nil")
}

flagsUsageFn := reflect.ValueOf(customFlags.FlagSet.Usage).Pointer()
want := reflect.ValueOf(app.PrintHelp).Pointer()

if flagsUsageFn != want {
t.Errorf("flags Usage wasn't set correctly, is %v", flagsUsageFn)
}
}

func TestUsageReflectionIsSafeForEmbeddedBogusFlags(t *testing.T) {
customFlags := struct {
*bogusFlags
Bogus *bogusFlags

Usage func(int, string, float64) // Make sure that it doesn't try and set this!
}{
Usage: nil,
}

app := NewSingleCommandApp(AppInfo{}, nil, &customFlags, os.Stdout, os.Stderr)

if app == nil {
t.Fatal("NewSingleCommandApp returned nil")
}

if customFlags.Usage != nil {
t.Error("flags Usage was set when it shouldn't have been")
}
}
18 changes: 13 additions & 5 deletions lieut.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ type MultiCommandApp struct {
}

// NewSingleCommandApp returns an initialized SingleCommandApp.
//
// The provided flags should have ContinueOnError ErrorHandling, or else flag
// parsing errors won't properly be displayed/handled.
func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer, errOut io.Writer) *SingleCommandApp {
if info.Name == "" {
info.Name = inferAppName()
Expand Down Expand Up @@ -112,14 +115,16 @@ func NewSingleCommandApp(info AppInfo, exec Executor, flags Flags, out io.Writer
}

app.setupFlagSet(app.flags)
app.setUsage(app.flags)

return app
}

// NewMultiCommandApp returns an initialized MultiCommandApp.
//
// The provided flags are global/shared among the app's commands.
//
// The provided flags should have ContinueOnError ErrorHandling, or else flag
// parsing errors won't properly be displayed/handled.
func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writer) *MultiCommandApp {
if info.Name == "" {
info.Name = inferAppName()
Expand Down Expand Up @@ -149,7 +154,6 @@ func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writ
}

app.setupFlagSet(app.flags)
app.setUsage(app.flags, "")

return app
}
Expand All @@ -158,6 +162,9 @@ func NewMultiCommandApp(info AppInfo, flags Flags, out io.Writer, errOut io.Writ
//
// It returns an error if the provided flags have already been used for another
// command (or for the globals).
//
// The provided flags should have ContinueOnError ErrorHandling, or else flag
// parsing errors won't properly be displayed/handled.
func (a *MultiCommandApp) SetCommand(info CommandInfo, exec Executor, flags Flags) error {
if info.Usage == "" {
info.Usage = DefaultCommandUsage
Expand All @@ -174,7 +181,6 @@ func (a *MultiCommandApp) SetCommand(info CommandInfo, exec Executor, flags Flag
flagSet := &flagSet{Flags: flags}

a.setupFlagSet(flagSet)
a.setUsage(flagSet, info.Name)

a.commands[info.Name] = command{info: info, Executor: exec, flags: flagSet}

Expand Down Expand Up @@ -204,7 +210,8 @@ func (a *SingleCommandApp) Run(ctx context.Context, arguments []string) int {
}

if err := a.flags.Parse(arguments); err != nil {
return 1
a.PrintUsageError(err)
return 2
}

if intercepted := a.intercept(a.flags); intercepted {
Expand Down Expand Up @@ -248,7 +255,8 @@ func (a *MultiCommandApp) Run(ctx context.Context, arguments []string) int {
}

if err := flags.Parse(arguments); err != nil {
return 1
a.PrintUsageError(commandName, err)
return 2
}

if intercepted := a.intercept(flags, commandName); intercepted {
Expand Down
40 changes: 10 additions & 30 deletions lieut_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -697,22 +697,14 @@ test vTest (%s/%s)
flags: flag.NewFlagSet("test", flag.ContinueOnError),
args: []string{"--non-existent-flag=val"},

wantedExitCode: 1,
wantedExitCode: 2,
wantedOut: "",
wantedErrOut: fmt.Sprintf(`flag provided but not defined: -non-existent-flag
Usage: test testing
A test
wantedErrOut: `Error: flag provided but not defined: -non-existent-flag
Options:
-help
Display the help message
-version
Display the application version
Usage: test testing
test vTest (%s/%s)
`, runtime.GOOS, runtime.GOARCH),
Run 'test --help' for usage.
`,
},
"initialize returns error": {
init: func() error {
Expand Down Expand Up @@ -993,26 +985,14 @@ test vTest (%s/%s)
flags: flag.NewFlagSet("test", flag.ContinueOnError),
args: []string{"--non-existent-flag=val"},

wantedExitCode: 1,
wantedExitCode: 2,
wantedOut: "",
wantedErrOut: fmt.Sprintf(`flag provided but not defined: -non-existent-flag
Usage: test testing
A test
Commands:
testcommand A test command
Options:
wantedErrOut: `Error: flag provided but not defined: -non-existent-flag
-help
Display the help message
-version
Display the application version
Usage: test testing
test vTest (%s/%s)
`, runtime.GOOS, runtime.GOARCH),
Run 'test --help' for usage.
`,
},
"initialize returns error": {
init: func() error {
Expand Down

0 comments on commit 895e627

Please sign in to comment.