diff --git a/private/protocol/query/unmarshal_error.go b/private/protocol/query/unmarshal_error.go index 08609d92088..cb87b79f07d 100644 --- a/private/protocol/query/unmarshal_error.go +++ b/private/protocol/query/unmarshal_error.go @@ -24,10 +24,14 @@ func UnmarshalError(r *request.Request) { if err != nil && err != io.EOF { r.Error = awserr.New("SerializationError", "failed to decode query XML error response", err) } else { + reqID := resp.RequestID + if reqID == "" { + reqID = r.RequestID + } r.Error = awserr.NewRequestFailure( awserr.New(resp.Code, resp.Message, nil), r.HTTPResponse.StatusCode, - resp.RequestID, + reqID, ) } } diff --git a/private/protocol/rest/unmarshal.go b/private/protocol/rest/unmarshal.go index 27f47b02c71..46837f66c99 100644 --- a/private/protocol/rest/unmarshal.go +++ b/private/protocol/rest/unmarshal.go @@ -26,6 +26,10 @@ func Unmarshal(r *request.Request) { // UnmarshalMeta unmarshals the REST metadata of a response in a REST service func UnmarshalMeta(r *request.Request) { r.RequestID = r.HTTPResponse.Header.Get("X-Amzn-Requestid") + if r.RequestID == "" { + // Alternative version of request id in the header + r.RequestID = r.HTTPResponse.Header.Get("X-Amz-Request-Id") + } if r.DataFilled() { v := reflect.Indirect(reflect.ValueOf(r.Data)) unmarshalLocationElements(r, v) diff --git a/service/s3/unmarshal_error.go b/service/s3/unmarshal_error.go index 30470ac117f..5b877376f13 100644 --- a/service/s3/unmarshal_error.go +++ b/service/s3/unmarshal_error.go @@ -22,17 +22,23 @@ func unmarshalError(r *request.Request) { defer r.HTTPResponse.Body.Close() if r.HTTPResponse.StatusCode == http.StatusMovedPermanently { - r.Error = awserr.New("BucketRegionError", - fmt.Sprintf("incorrect region, the bucket is not in '%s' region", aws.StringValue(r.Config.Region)), nil) + r.Error = awserr.NewRequestFailure( + awserr.New("BucketRegionError", + fmt.Sprintf("incorrect region, the bucket is not in '%s' region", + aws.StringValue(r.Config.Region)), + nil), + r.HTTPResponse.StatusCode, + r.RequestID, + ) return } - if r.HTTPResponse.ContentLength == int64(0) { + if r.HTTPResponse.ContentLength <= 1 { // No body, use status code to generate an awserr.Error r.Error = awserr.NewRequestFailure( awserr.New(strings.Replace(r.HTTPResponse.Status, " ", "", -1), r.HTTPResponse.Status, nil), r.HTTPResponse.StatusCode, - "", + r.RequestID, ) return } @@ -45,7 +51,7 @@ func unmarshalError(r *request.Request) { r.Error = awserr.NewRequestFailure( awserr.New(resp.Code, resp.Message, nil), r.HTTPResponse.StatusCode, - "", + r.RequestID, ) } } diff --git a/service/s3/unmarshal_error_test.go b/service/s3/unmarshal_error_test.go index c4cce13c513..3a4b104d92d 100644 --- a/service/s3/unmarshal_error_test.go +++ b/service/s3/unmarshal_error_test.go @@ -1,9 +1,10 @@ package s3_test import ( - "bytes" + "fmt" "io/ioutil" "net/http" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -15,39 +16,93 @@ import ( "github.com/aws/aws-sdk-go/service/s3" ) -var s3StatusCodeErrorTests = []struct { - scode int - status string - body string - code string - message string -}{ - {301, "Moved Permanently", "", "BucketRegionError", "incorrect region, the bucket is not in 'mock-region' region"}, - {403, "Forbidden", "", "Forbidden", "Forbidden"}, - {400, "Bad Request", "", "BadRequest", "Bad Request"}, - {404, "Not Found", "", "NotFound", "Not Found"}, - {500, "Internal Error", "", "InternalError", "Internal Error"}, +type testErrorCase struct { + RespFn func() *http.Response + ReqID string + Code, Msg string } -func TestStatusCodeError(t *testing.T) { - for _, test := range s3StatusCodeErrorTests { +var testUnmarshalCases = []testErrorCase{ + { + RespFn: func() *http.Response { + return &http.Response{ + StatusCode: 301, + Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}}, + Body: ioutil.NopCloser(nil), + ContentLength: -1, + } + }, + ReqID: "abc123", + Code: "BucketRegionError", Msg: "incorrect region, the bucket is not in 'mock-region' region", + }, + { + RespFn: func() *http.Response { + return &http.Response{ + StatusCode: 403, + Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}}, + Body: ioutil.NopCloser(nil), + ContentLength: 0, + } + }, + ReqID: "abc123", + Code: "Forbidden", Msg: "Forbidden", + }, + { + RespFn: func() *http.Response { + return &http.Response{ + StatusCode: 400, + Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}}, + Body: ioutil.NopCloser(nil), + ContentLength: 0, + } + }, + ReqID: "abc123", + Code: "BadRequest", Msg: "Bad Request", + }, + { + RespFn: func() *http.Response { + return &http.Response{ + StatusCode: 404, + Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}}, + Body: ioutil.NopCloser(nil), + ContentLength: 0, + } + }, + ReqID: "abc123", + Code: "NotFound", Msg: "Not Found", + }, + { + RespFn: func() *http.Response { + body := `SomeExceptionException message` + return &http.Response{ + StatusCode: 500, + Header: http.Header{"X-Amz-Request-Id": []string{"abc123"}}, + Body: ioutil.NopCloser(strings.NewReader(body)), + ContentLength: int64(len(body)), + } + }, + ReqID: "abc123", + Code: "SomeException", Msg: "Exception message", + }, +} + +func TestUnmarshalError(t *testing.T) { + for _, c := range testUnmarshalCases { s := s3.New(unit.Session) s.Handlers.Send.Clear() s.Handlers.Send.PushBack(func(r *request.Request) { - body := ioutil.NopCloser(bytes.NewReader([]byte(test.body))) - r.HTTPResponse = &http.Response{ - ContentLength: int64(len(test.body)), - StatusCode: test.scode, - Status: test.status, - Body: body, - } + r.HTTPResponse = c.RespFn() + r.HTTPResponse.Status = http.StatusText(r.HTTPResponse.StatusCode) }) _, err := s.PutBucketAcl(&s3.PutBucketAclInput{ Bucket: aws.String("bucket"), ACL: aws.String("public-read"), }) + fmt.Printf("%#v\n", err) + assert.Error(t, err) - assert.Equal(t, test.code, err.(awserr.Error).Code()) - assert.Equal(t, test.message, err.(awserr.Error).Message()) + assert.Equal(t, c.Code, err.(awserr.Error).Code()) + assert.Equal(t, c.Msg, err.(awserr.Error).Message()) + assert.Equal(t, c.ReqID, err.(awserr.RequestFailure).RequestID()) } }