diff --git a/lib/cloud/aws/errors.go b/lib/cloud/aws/errors.go index 1d5afbfe8e909..19e90b7b8c43c 100644 --- a/lib/cloud/aws/errors.go +++ b/lib/cloud/aws/errors.go @@ -19,11 +19,13 @@ package aws import ( "errors" "net/http" + "strings" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" iamTypes "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/iam" + "github.com/aws/aws-sdk-go/service/redshiftserverless" "github.com/gravitational/trace" ) @@ -47,6 +49,12 @@ func convertRequestFailureErrorFromStatusCode(statusCode int, requestErr error) return trace.AlreadyExists(requestErr.Error()) case http.StatusNotFound: return trace.NotFound(requestErr.Error()) + case http.StatusBadRequest: + // Some services like memorydb, redshiftserverless may return 400 with + // "AccessDeniedException" instead of 403. + if strings.Contains(requestErr.Error(), redshiftserverless.ErrCodeAccessDeniedException) { + return trace.AccessDenied(requestErr.Error()) + } } return requestErr // Return unmodified. diff --git a/lib/cloud/aws/errors_test.go b/lib/cloud/aws/errors_test.go index a2631857ad898..f2b484314eb46 100644 --- a/lib/cloud/aws/errors_test.go +++ b/lib/cloud/aws/errors_test.go @@ -17,17 +17,75 @@ limitations under the License. package aws import ( + "errors" "net/http" "testing" awshttp "github.com/aws/aws-sdk-go-v2/aws/transport/http" iamTypes "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" smithyhttp "github.com/aws/smithy-go/transport/http" "github.com/gravitational/trace" "github.com/stretchr/testify/require" ) +func TestConvertRequestFailureError(t *testing.T) { + t.Parallel() + + fakeRequestID := "11111111-2222-3333-3333-333333333334" + + tests := []struct { + name string + inputError error + wantUnmodified bool + wantIsError func(error) bool + }{ + { + name: "StatusForbidden", + inputError: awserr.NewRequestFailure(awserr.New("code", "message", nil), http.StatusForbidden, fakeRequestID), + wantIsError: trace.IsAccessDenied, + }, + { + name: "StatusConflict", + inputError: awserr.NewRequestFailure(awserr.New("code", "message", nil), http.StatusConflict, fakeRequestID), + wantIsError: trace.IsAlreadyExists, + }, + { + name: "StatusNotFound", + inputError: awserr.NewRequestFailure(awserr.New("code", "message", nil), http.StatusNotFound, fakeRequestID), + wantIsError: trace.IsNotFound, + }, + { + name: "StatusBadRequest", + inputError: awserr.NewRequestFailure(awserr.New("code", "message", nil), http.StatusBadRequest, fakeRequestID), + wantUnmodified: true, + }, + { + name: "StatusBadRequest with AccessDeniedException", + inputError: awserr.NewRequestFailure(awserr.New("AccessDeniedException", "message", nil), http.StatusBadRequest, fakeRequestID), + wantIsError: trace.IsAccessDenied, + }, + { + name: "not AWS error", + inputError: errors.New("not-aws-error"), + wantUnmodified: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := ConvertRequestFailureError(test.inputError) + + if test.wantUnmodified { + require.Equal(t, test.inputError, err) + } else { + require.True(t, test.wantIsError(err)) + } + }) + } +} + func TestConvertIAMv2Error(t *testing.T) { for _, tt := range []struct { name string