Skip to content

Commit

Permalink
feat: add flag to disable IMDSv1 fallback (#2048)
Browse files Browse the repository at this point in the history
  • Loading branch information
aajtodd authored Mar 13, 2023
1 parent fe59fdb commit c2533c6
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 36 deletions.
8 changes: 8 additions & 0 deletions .changelog/a1050420a5f7409fbc4d1d82c882d168.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "a1050420-a5f7-409f-bc4d-1d82c882d168",
"type": "feature",
"description": "Add flag to disable IMDSv1 fallback",
"modules": [
"feature/ec2/imds"
]
}
10 changes: 10 additions & 0 deletions feature/ec2/imds/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ type Options struct {
// The logger writer interface to write logging messages to.
Logger logging.Logger

// Configure IMDSv1 fallback behavior. By default, the client will attempt
// to fall back to IMDSv1 as needed for backwards compatibility. When set to [aws.FalseTernary]
// the client will return any errors encountered from attempting to fetch a token
// instead of silently using the insecure data flow of IMDSv1.
//
// See [configuring IMDS] for more information.
//
// [configuring IMDS]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
EnableFallback aws.Ternary

// provides the caching of API tokens used for operation calls. If unset,
// the API token will not be retrieved for the operation.
tokenProvider *tokenProvider
Expand Down
78 changes: 71 additions & 7 deletions feature/ec2/imds/request_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/hex"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -330,11 +331,12 @@ func (h *successAPIResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Req

func TestRequestGetToken(t *testing.T) {
cases := map[string]struct {
GetHandler func(*testing.T) http.Handler
APICallCount int
ExpectTrace []string
ExpectContent []byte
ExpectErr string
GetHandler func(*testing.T) http.Handler
APICallCount int
ExpectTrace []string
ExpectContent []byte
ExpectErr string
EnableFallback aws.Ternary
}{
"secure": {
ExpectTrace: []string{
Expand Down Expand Up @@ -496,8 +498,69 @@ func TestRequestGetToken(t *testing.T) {
}),
))
},
ExpectErr: "failed to get API token",
},

// retryable token error with fallback enabled (default)
"token failure fallback enabled": {
ExpectTrace: []string{
getTokenPath,
getTokenPath,
getTokenPath,
"/latest/foo",
},
APICallCount: 1,
GetHandler: func(t *testing.T) http.Handler {
return newTestServeMux(t,
newInsecureAPIHandler(t,
500,
&successAPIResponseHandler{t: t,
path: "/latest/foo",
method: "GET",
body: []byte("hello"),
},
))
},
ExpectContent: []byte("hello"),
ExpectErr: "EC2 IMDS failed",
},
// retryable token error with fallback disabled
"token failure fallback disabled": {
ExpectTrace: []string{
getTokenPath,
getTokenPath,
getTokenPath,
},
APICallCount: 1,
GetHandler: func(t *testing.T) http.Handler {
return newTestServeMux(t,
newInsecureAPIHandler(t,
500,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Errorf("expected no call to API handler")
http.Error(w, "", 400)
}),
))
},
ExpectErr: "failed to get API token",
EnableFallback: aws.BoolTernary(false),
},
"insecure 403 fallback disabled": {
ExpectTrace: []string{
getTokenPath,
},
APICallCount: 1,
GetHandler: func(t *testing.T) http.Handler {
return newTestServeMux(t,
newInsecureAPIHandler(t,
403,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Errorf("expected no call to API handler")
http.Error(w, "", 400)
}),
))
},
ExpectErr: "failed to get API token",
EnableFallback: aws.BoolTernary(false),
},
}

Expand All @@ -515,7 +578,8 @@ func TestRequestGetToken(t *testing.T) {
defer server.Close()

client := New(Options{
Endpoint: server.URL,
Endpoint: server.URL,
EnableFallback: c.EnableFallback,
})

ctx := context.Background()
Expand Down
82 changes: 53 additions & 29 deletions feature/ec2/imds/token_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/logging"
"net/http"
"sync"
"sync/atomic"
"time"

smithy "github.com/aws/smithy-go"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
Expand Down Expand Up @@ -68,7 +70,7 @@ func (t *tokenProvider) HandleFinalize(
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
if !t.enabled() {
if t.fallbackEnabled() && !t.enabled() {
// short-circuits to insecure data flow if token provider is disabled.
return next.HandleFinalize(ctx, input)
}
Expand Down Expand Up @@ -115,23 +117,15 @@ func (t *tokenProvider) HandleDeserialize(
}

if resp.StatusCode == http.StatusUnauthorized { // unauthorized
err = &retryableError{Err: err}
t.enable()
err = &retryableError{Err: err, isRetryable: true}
}

return out, metadata, err
}

type retryableError struct {
Err error
}

func (*retryableError) RetryableError() bool { return true }

func (e *retryableError) Error() string { return e.Err.Error() }

func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
if !t.enabled() {
if t.fallbackEnabled() && !t.enabled() {
return nil, &bypassTokenRetrievalError{
Err: fmt.Errorf("cannot get API token, provider disabled"),
}
Expand All @@ -147,7 +141,7 @@ func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error)

tok, err = t.updateToken(ctx)
if err != nil {
return nil, fmt.Errorf("cannot get API token, %w", err)
return nil, err
}

return tok, nil
Expand All @@ -167,17 +161,19 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
TokenTTL: t.tokenTTL,
})
if err != nil {
// change the disabled flag on token provider to true, when error is request timeout error.
var statusErr interface{ HTTPStatusCode() int }
if errors.As(err, &statusErr) {
switch statusErr.HTTPStatusCode() {

// Disable get token if failed because of 403, 404, or 405
// Disable future get token if failed because of 403, 404, or 405
case http.StatusForbidden,
http.StatusNotFound,
http.StatusMethodNotAllowed:

t.disable()
if t.fallbackEnabled() {
logger := middleware.GetLogger(ctx)
logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
t.disable()
}

// 400 errors are terminal, and need to be upstreamed
case http.StatusBadRequest:
Expand All @@ -192,8 +188,17 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
atomic.StoreUint32(&t.disabled, 1)
}

// Token couldn't be retrieved, but bypass this, and allow the
// request to continue.
if !t.fallbackEnabled() {
// NOTE: getToken() is an implementation detail of some outer operation
// (e.g. GetMetadata). It has its own retries that have already been exhausted.
// Mark the underlying error as a terminal error.
err = &retryableError{Err: err, isRetryable: false}
return nil, err
}

// Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
// and allow the request to proceed. Future requests _may_ re-attempt fetching a
// token if not disabled.
return nil, &bypassTokenRetrievalError{Err: err}
}

Expand All @@ -206,21 +211,21 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
return tok, nil
}

type bypassTokenRetrievalError struct {
Err error
}

func (e *bypassTokenRetrievalError) Error() string {
return fmt.Sprintf("bypass token retrieval, %v", e.Err)
}

func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }

// enabled returns if the token provider is current enabled or not.
func (t *tokenProvider) enabled() bool {
return atomic.LoadUint32(&t.disabled) == 0
}

// fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
func (t *tokenProvider) fallbackEnabled() bool {
switch t.client.options.EnableFallback {
case aws.FalseTernary:
return false
default:
return true
}
}

// disable disables the token provider and it will no longer attempt to inject
// the token, nor request updates.
func (t *tokenProvider) disable() {
Expand All @@ -235,3 +240,22 @@ func (t *tokenProvider) enable() {
t.tokenMux.Unlock()
atomic.StoreUint32(&t.disabled, 0)
}

type bypassTokenRetrievalError struct {
Err error
}

func (e *bypassTokenRetrievalError) Error() string {
return fmt.Sprintf("bypass token retrieval, %v", e.Err)
}

func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }

type retryableError struct {
Err error
isRetryable bool
}

func (e *retryableError) RetryableError() bool { return e.isRetryable }

func (e *retryableError) Error() string { return e.Err.Error() }

0 comments on commit c2533c6

Please sign in to comment.