From 70572e295b1033a7b58716de630d002f542ce53b Mon Sep 17 00:00:00 2001 From: Trevor Suarez Date: Fri, 11 Aug 2023 20:14:33 -0600 Subject: [PATCH] Adding a feature to control the status code --- error.go | 27 +++++++++++++++++++++++++ error_test.go | 26 ++++++++++++++++++++++++ lieut.go | 13 ++++++++++++ lieut_test.go | 56 +++++++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 116 insertions(+), 6 deletions(-) create mode 100644 error.go create mode 100644 error_test.go diff --git a/error.go b/error.go new file mode 100644 index 0000000..7ea10da --- /dev/null +++ b/error.go @@ -0,0 +1,27 @@ +package lieut + +// StatusCodeError represents an error that reports an associated status code. +type StatusCodeError interface { + error + + // StatusCode returns the status code of the error, which can be used by an + // app's execution error to know which status code to return. + StatusCode() int +} + +type statusCodeError struct { + error + + statusCode int +} + +// ErrWithStatusCode takes an error and a status code and returns a type that +// satisfies StatusCodeError. +func ErrWithStatusCode(err error, statusCode int) StatusCodeError { + return &statusCodeError{error: err, statusCode: statusCode} +} + +// StatusCode returns the status code of the error. +func (e *statusCodeError) StatusCode() int { + return e.statusCode +} diff --git a/error_test.go b/error_test.go new file mode 100644 index 0000000..39157c8 --- /dev/null +++ b/error_test.go @@ -0,0 +1,26 @@ +package lieut + +import ( + "errors" + "testing" +) + +func TestErrWithStatusCode(t *testing.T) { + errMsg := "test error" + errCode := 107 + + err := errors.New(errMsg) + + statusCodeErr := ErrWithStatusCode(err, errCode) + if statusCodeErr == nil { + t.Fatal("ErrWithStatusCode returned nil") + } + + if gotMsg := statusCodeErr.Error(); gotMsg != errMsg { + t.Errorf("err.Error() returned %q, wanted %q", gotMsg, errMsg) + } + + if gotCode := statusCodeErr.StatusCode(); gotCode != errCode { + t.Errorf("err.Error() returned %v, wanted %v", gotCode, errCode) + } +} diff --git a/lieut.go b/lieut.go index a443ddf..d1f5328 100644 --- a/lieut.go +++ b/lieut.go @@ -194,6 +194,10 @@ func (a *MultiCommandApp) CommandNames() []string { // Run takes a context and arguments, runs the expected command, and returns an // exit code. +// +// If the init function or command Executor returns a StatusCodeError, then the +// returned exit code will match that of the value returned by +// StatusCodeError.StatusCode(). func (a *SingleCommandApp) Run(ctx context.Context, arguments []string) int { if len(arguments) == 0 { arguments = os.Args[1:] @@ -216,6 +220,10 @@ func (a *SingleCommandApp) Run(ctx context.Context, arguments []string) int { // Run takes a context and arguments, runs the expected command, and returns an // exit code. +// +// If the init function or command Executor returns a StatusCodeError, then the +// returned exit code will match that of the value returned by +// StatusCodeError.StatusCode(). func (a *MultiCommandApp) Run(ctx context.Context, arguments []string) int { if len(arguments) == 0 { arguments = os.Args[1:] @@ -373,6 +381,11 @@ func (a *app) printErr(err error, pad bool) int { fmt.Fprintf(a.errOut, msgFmt, err) + var statusErr StatusCodeError + if errors.As(err, &statusErr) { + return statusErr.StatusCode() + } + return 1 } diff --git a/lieut_test.go b/lieut_test.go index 6d229e2..39f5134 100644 --- a/lieut_test.go +++ b/lieut_test.go @@ -599,7 +599,7 @@ func TestSingleCommandApp_Run(t *testing.T) { exitCode := app.Run(ctx, args) if exitCode != wantedExitCode { - t.Errorf("app.Run gave %q, wanted %q", exitCode, wantedExitCode) + t.Errorf("app.Run gave %v, wanted %v", exitCode, wantedExitCode) } if !initRan { @@ -638,7 +638,7 @@ func TestSingleCommandApp_Run_EmptyArgsProvided(t *testing.T) { expectedArgs := os.Args[1:] if exitCode := app.Run(context.TODO(), nil); exitCode != 0 { - t.Errorf("app.Run gave non-zero exit code %q", exitCode) + t.Errorf("app.Run gave non-zero exit code %v", exitCode) } if capturedArgs[0] != expectedArgs[0] { @@ -716,6 +716,17 @@ test vTest (%s/%s) wantedOut: "", wantedErrOut: "Error: test init error\n", }, + "initialize returns status code error": { + init: func() error { + return ErrWithStatusCode(errors.New("test init error"), 101) + }, + + args: []string{"test"}, + + wantedExitCode: 101, + wantedOut: "", + wantedErrOut: "Error: test init error\n", + }, "execute returns error": { exec: func(ctx context.Context, arguments []string) error { return errors.New("test exec error") @@ -727,6 +738,17 @@ test vTest (%s/%s) wantedOut: "", wantedErrOut: "\nError: test exec error\n", }, + "execute returns status code error": { + exec: func(ctx context.Context, arguments []string) error { + return ErrWithStatusCode(errors.New("test exec error"), 217) + }, + + args: []string{"test"}, + + wantedExitCode: 217, + wantedOut: "", + wantedErrOut: "\nError: test exec error\n", + }, } { t.Run(testName, func(t *testing.T) { var out, errOut bytes.Buffer @@ -740,7 +762,7 @@ test vTest (%s/%s) exitCode := app.Run(context.TODO(), testData.args) if exitCode != testData.wantedExitCode { - t.Errorf("app.Run gave %q, wanted %q", exitCode, testData.wantedExitCode) + t.Errorf("app.Run gave %v, wanted %v", exitCode, testData.wantedExitCode) } if out.String() != testData.wantedOut { @@ -793,7 +815,7 @@ func TestMultiCommandApp_Run(t *testing.T) { exitCode := app.Run(ctx, args) if exitCode != wantedExitCode { - t.Errorf("app.Run gave %q, wanted %q", exitCode, wantedExitCode) + t.Errorf("app.Run gave %v, wanted %v", exitCode, wantedExitCode) } if !initRan { @@ -837,7 +859,7 @@ func TestMultiCommandApp_Run_EmptyArgsProvided(t *testing.T) { expectedArgs := os.Args[2:] if exitCode := app.Run(context.TODO(), nil); exitCode != 0 { - t.Errorf("app.Run gave non-zero exit code %q", exitCode) + t.Errorf("app.Run gave non-zero exit code %v", exitCode) } if capturedArgs[0] != expectedArgs[0] { @@ -972,6 +994,17 @@ test vTest (%s/%s) wantedOut: "", wantedErrOut: "Error: test init error\n", }, + "initialize returns status code error": { + init: func() error { + return ErrWithStatusCode(errors.New("test init error"), 101) + }, + + args: []string{testCommandInfo.Name}, + + wantedExitCode: 101, + wantedOut: "", + wantedErrOut: "Error: test init error\n", + }, "execute returns error": { exec: func(ctx context.Context, arguments []string) error { return errors.New("test exec error") @@ -983,6 +1016,17 @@ test vTest (%s/%s) wantedOut: "", wantedErrOut: "\nError: test exec error\n", }, + "execute returns status code error": { + exec: func(ctx context.Context, arguments []string) error { + return ErrWithStatusCode(errors.New("test exec error"), 217) + }, + + args: []string{testCommandInfo.Name}, + + wantedExitCode: 217, + wantedOut: "", + wantedErrOut: "\nError: test exec error\n", + }, "unknown command": { args: []string{"thiscommanddoesnotexist"}, @@ -1019,7 +1063,7 @@ test vTest (%s/%s) exitCode := app.Run(context.TODO(), testData.args) if exitCode != testData.wantedExitCode { - t.Errorf("app.Run gave %q, wanted %q", exitCode, testData.wantedExitCode) + t.Errorf("app.Run gave %v, wanted %v", exitCode, testData.wantedExitCode) } if out.String() != testData.wantedOut {