From c2533c6f582626648e4879af07a645bb98123dca Mon Sep 17 00:00:00 2001 From: Aaron Todd Date: Mon, 13 Mar 2023 17:01:50 -0400 Subject: [PATCH] feat: add flag to disable IMDSv1 fallback (#2048) --- .../a1050420a5f7409fbc4d1d82c882d168.json | 8 ++ feature/ec2/imds/api_client.go | 10 +++ feature/ec2/imds/request_middleware_test.go | 78 ++++++++++++++++-- feature/ec2/imds/token_provider.go | 82 ++++++++++++------- 4 files changed, 142 insertions(+), 36 deletions(-) create mode 100644 .changelog/a1050420a5f7409fbc4d1d82c882d168.json diff --git a/.changelog/a1050420a5f7409fbc4d1d82c882d168.json b/.changelog/a1050420a5f7409fbc4d1d82c882d168.json new file mode 100644 index 00000000000..4996783e782 --- /dev/null +++ b/.changelog/a1050420a5f7409fbc4d1d82c882d168.json @@ -0,0 +1,8 @@ +{ + "id": "a1050420-a5f7-409f-bc4d-1d82c882d168", + "type": "feature", + "description": "Add flag to disable IMDSv1 fallback", + "modules": [ + "feature/ec2/imds" + ] +} \ No newline at end of file diff --git a/feature/ec2/imds/api_client.go b/feature/ec2/imds/api_client.go index f97730bd931..e55edd992e2 100644 --- a/feature/ec2/imds/api_client.go +++ b/feature/ec2/imds/api_client.go @@ -174,6 +174,16 @@ type Options struct { // The logger writer interface to write logging messages to. Logger logging.Logger + // Configure IMDSv1 fallback behavior. By default, the client will attempt + // to fall back to IMDSv1 as needed for backwards compatibility. When set to [aws.FalseTernary] + // the client will return any errors encountered from attempting to fetch a token + // instead of silently using the insecure data flow of IMDSv1. + // + // See [configuring IMDS] for more information. + // + // [configuring IMDS]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html + EnableFallback aws.Ternary + // provides the caching of API tokens used for operation calls. If unset, // the API token will not be retrieved for the operation. tokenProvider *tokenProvider diff --git a/feature/ec2/imds/request_middleware_test.go b/feature/ec2/imds/request_middleware_test.go index 0e751b0ce2f..53fd8b6ed73 100644 --- a/feature/ec2/imds/request_middleware_test.go +++ b/feature/ec2/imds/request_middleware_test.go @@ -5,6 +5,7 @@ import ( "context" "encoding/hex" "fmt" + "github.com/aws/aws-sdk-go-v2/aws" "io" "io/ioutil" "net/http" @@ -330,11 +331,12 @@ func (h *successAPIResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Req func TestRequestGetToken(t *testing.T) { cases := map[string]struct { - GetHandler func(*testing.T) http.Handler - APICallCount int - ExpectTrace []string - ExpectContent []byte - ExpectErr string + GetHandler func(*testing.T) http.Handler + APICallCount int + ExpectTrace []string + ExpectContent []byte + ExpectErr string + EnableFallback aws.Ternary }{ "secure": { ExpectTrace: []string{ @@ -496,8 +498,69 @@ func TestRequestGetToken(t *testing.T) { }), )) }, + ExpectErr: "failed to get API token", + }, + + // retryable token error with fallback enabled (default) + "token failure fallback enabled": { + ExpectTrace: []string{ + getTokenPath, + getTokenPath, + getTokenPath, + "/latest/foo", + }, + APICallCount: 1, + GetHandler: func(t *testing.T) http.Handler { + return newTestServeMux(t, + newInsecureAPIHandler(t, + 500, + &successAPIResponseHandler{t: t, + path: "/latest/foo", + method: "GET", + body: []byte("hello"), + }, + )) + }, ExpectContent: []byte("hello"), - ExpectErr: "EC2 IMDS failed", + }, + // retryable token error with fallback disabled + "token failure fallback disabled": { + ExpectTrace: []string{ + getTokenPath, + getTokenPath, + getTokenPath, + }, + APICallCount: 1, + GetHandler: func(t *testing.T) http.Handler { + return newTestServeMux(t, + newInsecureAPIHandler(t, + 500, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Errorf("expected no call to API handler") + http.Error(w, "", 400) + }), + )) + }, + ExpectErr: "failed to get API token", + EnableFallback: aws.BoolTernary(false), + }, + "insecure 403 fallback disabled": { + ExpectTrace: []string{ + getTokenPath, + }, + APICallCount: 1, + GetHandler: func(t *testing.T) http.Handler { + return newTestServeMux(t, + newInsecureAPIHandler(t, + 403, + http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + t.Errorf("expected no call to API handler") + http.Error(w, "", 400) + }), + )) + }, + ExpectErr: "failed to get API token", + EnableFallback: aws.BoolTernary(false), }, } @@ -515,7 +578,8 @@ func TestRequestGetToken(t *testing.T) { defer server.Close() client := New(Options{ - Endpoint: server.URL, + Endpoint: server.URL, + EnableFallback: c.EnableFallback, }) ctx := context.Background() diff --git a/feature/ec2/imds/token_provider.go b/feature/ec2/imds/token_provider.go index 275fade488a..5703c6e16ad 100644 --- a/feature/ec2/imds/token_provider.go +++ b/feature/ec2/imds/token_provider.go @@ -4,12 +4,14 @@ import ( "context" "errors" "fmt" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/smithy-go" + "github.com/aws/smithy-go/logging" "net/http" "sync" "sync/atomic" "time" - smithy "github.com/aws/smithy-go" "github.com/aws/smithy-go/middleware" smithyhttp "github.com/aws/smithy-go/transport/http" ) @@ -68,7 +70,7 @@ func (t *tokenProvider) HandleFinalize( ) ( out middleware.FinalizeOutput, metadata middleware.Metadata, err error, ) { - if !t.enabled() { + if t.fallbackEnabled() && !t.enabled() { // short-circuits to insecure data flow if token provider is disabled. return next.HandleFinalize(ctx, input) } @@ -115,23 +117,15 @@ func (t *tokenProvider) HandleDeserialize( } if resp.StatusCode == http.StatusUnauthorized { // unauthorized - err = &retryableError{Err: err} t.enable() + err = &retryableError{Err: err, isRetryable: true} } return out, metadata, err } -type retryableError struct { - Err error -} - -func (*retryableError) RetryableError() bool { return true } - -func (e *retryableError) Error() string { return e.Err.Error() } - func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) { - if !t.enabled() { + if t.fallbackEnabled() && !t.enabled() { return nil, &bypassTokenRetrievalError{ Err: fmt.Errorf("cannot get API token, provider disabled"), } @@ -147,7 +141,7 @@ func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) tok, err = t.updateToken(ctx) if err != nil { - return nil, fmt.Errorf("cannot get API token, %w", err) + return nil, err } return tok, nil @@ -167,17 +161,19 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) { TokenTTL: t.tokenTTL, }) if err != nil { - // change the disabled flag on token provider to true, when error is request timeout error. var statusErr interface{ HTTPStatusCode() int } if errors.As(err, &statusErr) { switch statusErr.HTTPStatusCode() { - - // Disable get token if failed because of 403, 404, or 405 + // Disable future get token if failed because of 403, 404, or 405 case http.StatusForbidden, http.StatusNotFound, http.StatusMethodNotAllowed: - t.disable() + if t.fallbackEnabled() { + logger := middleware.GetLogger(ctx) + logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err) + t.disable() + } // 400 errors are terminal, and need to be upstreamed case http.StatusBadRequest: @@ -192,8 +188,17 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) { atomic.StoreUint32(&t.disabled, 1) } - // Token couldn't be retrieved, but bypass this, and allow the - // request to continue. + if !t.fallbackEnabled() { + // NOTE: getToken() is an implementation detail of some outer operation + // (e.g. GetMetadata). It has its own retries that have already been exhausted. + // Mark the underlying error as a terminal error. + err = &retryableError{Err: err, isRetryable: false} + return nil, err + } + + // Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request + // and allow the request to proceed. Future requests _may_ re-attempt fetching a + // token if not disabled. return nil, &bypassTokenRetrievalError{Err: err} } @@ -206,21 +211,21 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) { return tok, nil } -type bypassTokenRetrievalError struct { - Err error -} - -func (e *bypassTokenRetrievalError) Error() string { - return fmt.Sprintf("bypass token retrieval, %v", e.Err) -} - -func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err } - // enabled returns if the token provider is current enabled or not. func (t *tokenProvider) enabled() bool { return atomic.LoadUint32(&t.disabled) == 0 } +// fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise +func (t *tokenProvider) fallbackEnabled() bool { + switch t.client.options.EnableFallback { + case aws.FalseTernary: + return false + default: + return true + } +} + // disable disables the token provider and it will no longer attempt to inject // the token, nor request updates. func (t *tokenProvider) disable() { @@ -235,3 +240,22 @@ func (t *tokenProvider) enable() { t.tokenMux.Unlock() atomic.StoreUint32(&t.disabled, 0) } + +type bypassTokenRetrievalError struct { + Err error +} + +func (e *bypassTokenRetrievalError) Error() string { + return fmt.Sprintf("bypass token retrieval, %v", e.Err) +} + +func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err } + +type retryableError struct { + Err error + isRetryable bool +} + +func (e *retryableError) RetryableError() bool { return e.isRetryable } + +func (e *retryableError) Error() string { return e.Err.Error() }