Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: don't expect response to be json in endpointcreds provider #2381

Merged
merged 2 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/018d3cef4def4b019c5ac7c60555b7e3.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "018d3cef-4def-4b01-9c5a-c7c60555b7e3",
"type": "bugfix",
"description": "Don't expect error responses to have a JSON payload in the endpointcreds provider.",
"modules": [
"credentials"
]
}
23 changes: 19 additions & 4 deletions credentials/endpointcreds/internal/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,16 @@ func New(options Options, optFns ...func(*Options)) *Client {
}

if options.Retryer == nil {
options.Retryer = retry.NewStandard()
// Amazon-owned implementations of this endpoint are known to sometimes
// return plaintext responses (i.e. no Code) like normal, add a few
// additional status codes
options.Retryer = retry.NewStandard(func(o *retry.StandardOptions) {
o.Retryables = append(o.Retryables, retry.RetryableHTTPStatusCode{
Codes: map[int]struct{}{
http.StatusTooManyRequests: {},
},
})
})
}

for _, fn := range optFns {
Expand Down Expand Up @@ -122,9 +131,10 @@ type GetCredentialsOutput struct {

// EndpointError is an error returned from the endpoint service
type EndpointError struct {
Code string `json:"code"`
Message string `json:"message"`
Fault smithy.ErrorFault `json:"-"`
Code string `json:"code"`
Message string `json:"message"`
Fault smithy.ErrorFault `json:"-"`
statusCode int `json:"-"`
}

// Error is the error mesage string
Expand All @@ -146,3 +156,8 @@ func (e *EndpointError) ErrorMessage() string {
func (e *EndpointError) ErrorFault() smithy.ErrorFault {
return e.Fault
}

// HTTPStatusCode implements retry.HTTPStatusCode.
func (e *EndpointError) HTTPStatusCode() int {
return e.statusCode
}
53 changes: 38 additions & 15 deletions credentials/endpointcreds/internal/client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,23 @@ import (

func TestClient_GetCredentials(t *testing.T) {
cases := map[string]struct {
Token string
RelativeURI string
ResponseCode int
ResponseBody []byte
ExpectResult *GetCredentialsOutput
ExpectErr bool
ValidateRequest func(*testing.T, *http.Request)
ValidateError func(*testing.T, error) bool
Token string
RelativeURI string
ResponseCode int
ResponseBody []byte
ResponseContentType string
ExpectResult *GetCredentialsOutput
ExpectErr bool
ValidateRequest func(*testing.T, *http.Request)
ValidateError func(*testing.T, error) bool
}{
"success static": {
ResponseCode: 200,
ResponseBody: []byte(` {
"AccessKeyId" : "FooKey",
"SecretAccessKey" : "FooSecret"
}`),
ResponseContentType: "application/json",
ExpectResult: &GetCredentialsOutput{
AccessKeyID: "FooKey",
SecretAccessKey: "FooSecret",
Expand All @@ -45,6 +47,7 @@ func TestClient_GetCredentials(t *testing.T) {
"AccessKeyId" : "FooKey",
"SecretAccessKey" : "FooSecret"
}`),
ResponseContentType: "application/json",
ExpectResult: &GetCredentialsOutput{
AccessKeyID: "FooKey",
SecretAccessKey: "FooSecret",
Expand All @@ -59,6 +62,7 @@ func TestClient_GetCredentials(t *testing.T) {
"Token": "FooToken",
"Expiration": "2016-02-25T06:03:31Z"
}`),
ResponseContentType: "application/json",
ExpectResult: &GetCredentialsOutput{
AccessKeyID: "FooKey",
SecretAccessKey: "FooSecret",
Expand All @@ -76,6 +80,7 @@ func TestClient_GetCredentials(t *testing.T) {
"AccessKeyId" : "FooKey",
"SecretAccessKey" : "FooSecret"
}`),
ResponseContentType: "application/json",
ValidateRequest: func(t *testing.T, r *http.Request) {
t.Helper()
if e, a := "/path/to/thing", r.URL.Path; e != a {
Expand All @@ -96,7 +101,8 @@ func TestClient_GetCredentials(t *testing.T) {
"code": "Unauthorized",
"message": "not authorized for endpoint"
}`),
ExpectErr: true,
ResponseContentType: "application/json",
ExpectErr: true,
ValidateError: func(t *testing.T, err error) (ok bool) {
t.Helper()
var apiError smithy.APIError
Expand Down Expand Up @@ -126,7 +132,8 @@ func TestClient_GetCredentials(t *testing.T) {
"code": "InternalError",
"message": "an error occurred"
}`),
ExpectErr: true,
ResponseContentType: "application/json",
ExpectErr: true,
ValidateError: func(t *testing.T, err error) (ok bool) {
t.Helper()
var apiError smithy.APIError
Expand All @@ -151,13 +158,28 @@ func TestClient_GetCredentials(t *testing.T) {
},
},
"non-json error response": {
ResponseCode: 500,
ResponseBody: []byte(`<html><body>unexpected message format</body></html>`),
ExpectErr: true,
ResponseCode: 500,
ResponseBody: []byte(`<html><body>unexpected message format</body></html>`),
ResponseContentType: "text/html",
ExpectErr: true,
ValidateError: func(t *testing.T, err error) (ok bool) {
t.Helper()
if e, a := "failed to decode error message", err.Error(); !strings.Contains(a, e) {
t.Errorf("expect %v, got %v", e, a)
var apiError smithy.APIError
if errors.As(err, &apiError) {
if e, a := "", apiError.ErrorCode(); e != a {
t.Errorf("expect %v, got %v", e, a)
ok = false
}
if e, a := "<html><body>unexpected message format</body></html>", apiError.ErrorMessage(); e != a {
t.Errorf("expect %v, got %v", e, a)
ok = false
}
if e, a := smithy.FaultServer, apiError.ErrorFault(); e != a {
t.Errorf("expect %v, got %v", e, a)
ok = false
}
} else {
t.Errorf("expect %T error type, got %T: %v", apiError, err, err)
ok = false
}
return ok
Expand All @@ -177,6 +199,7 @@ func TestClient_GetCredentials(t *testing.T) {

actualReq.Body = ioutil.NopCloser(bytes.NewReader(buf.Bytes()))

w.Header().Set("Content-Type", tt.ResponseContentType)
w.WriteHeader(tt.ResponseCode)
w.Write(tt.ResponseBody)
}))
Expand Down
42 changes: 35 additions & 7 deletions credentials/endpointcreds/internal/client/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/url"

"github.com/aws/smithy-go"
Expand Down Expand Up @@ -104,17 +105,44 @@ func (d *deserializeOpGetCredential) HandleDeserialize(ctx context.Context, in s
}

func deserializeError(response *smithyhttp.Response) error {
var errShape *EndpointError
err := json.NewDecoder(response.Body).Decode(&errShape)
// we could be talking to anything, json isn't guaranteed
// see https://github.com/aws/aws-sdk-go-v2/issues/2316
if response.Header.Get("Content-Type") == "application/json" {
return deserializeJSONError(response)
}

msg, err := io.ReadAll(response.Body)
if err != nil {
return &smithy.DeserializationError{Err: fmt.Errorf("failed to decode error message, %w", err)}
return &smithy.DeserializationError{
Err: fmt.Errorf("read response, %w", err),
}
}

return &EndpointError{
// no sensible value for Code
Message: string(msg),
Fault: stof(response.StatusCode),
statusCode: response.StatusCode,
}
}

if response.StatusCode >= 500 {
errShape.Fault = smithy.FaultServer
} else {
errShape.Fault = smithy.FaultClient
func deserializeJSONError(response *smithyhttp.Response) error {
var errShape *EndpointError
if err := json.NewDecoder(response.Body).Decode(&errShape); err != nil {
return &smithy.DeserializationError{
Err: fmt.Errorf("failed to decode error message, %w", err),
}
}

errShape.Fault = stof(response.StatusCode)
errShape.statusCode = response.StatusCode
return errShape
}

// maps HTTP status code to smithy ErrorFault
func stof(code int) smithy.ErrorFault {
if code >= 500 {
return smithy.FaultServer
}
return smithy.FaultClient
}
73 changes: 73 additions & 0 deletions credentials/endpointcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
Expand Down Expand Up @@ -201,6 +202,9 @@ func TestFailedRetrieveCredentials(t *testing.T) {
"code": "Error",
"message": "Message"
}`))),
Header: http.Header{
"Content-Type": {"application/json"},
},
}, nil
})
})
Expand Down Expand Up @@ -238,3 +242,72 @@ func TestFailedRetrieveCredentials(t *testing.T) {
t.Errorf("expect empty creds not to be expired")
}
}

type mockClientN struct {
responses []*http.Response
index int
}

func (c *mockClientN) Do(r *http.Request) (*http.Response, error) {
resp := c.responses[c.index]
c.index++
return resp, nil
}

func TestRetryHTTPStatusCode(t *testing.T) {
expTime := time.Now().UTC().Add(1 * time.Hour).Format("2006-01-02T15:04:05Z")
credsResp := fmt.Sprintf(`{"AccessKeyID":"AKID","SecretAccessKey":"SECRET","Token":"TOKEN","Expiration":"%s"}`, expTime)

p := endpointcreds.New("http://127.0.0.1", func(o *endpointcreds.Options) {
o.HTTPClient = &mockClientN{
responses: []*http.Response{
{
StatusCode: 429,
Body: io.NopCloser(strings.NewReader("You have made too many requests.")),
Header: http.Header{
"Content-Type": {"text/plain"},
},
},
{
StatusCode: 500,
Body: io.NopCloser(strings.NewReader("Internal server error.")),
Header: http.Header{
"Content-Type": {"text/plain"},
},
},
{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(credsResp)),
Header: http.Header{
"Content-Type": {"application/json"},
},
},
},
}
})

creds, err := p.Retrieve(context.Background())
if err != nil {
t.Fatalf("expect no error, got %v", err)
}

if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if creds.Expired() {
t.Errorf("expect not expired")
}

sdk.NowTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
if !creds.Expired() {
t.Errorf("expect to be expired")
}
}