Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add flag to disable IMDSv1 fallback #2048

Merged
merged 8 commits into from
Mar 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .changelog/a1050420a5f7409fbc4d1d82c882d168.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"id": "a1050420-a5f7-409f-bc4d-1d82c882d168",
"type": "feature",
"description": "Add flag to disable IMDSv1 fallback",
"modules": [
"feature/ec2/imds"
]
}
10 changes: 10 additions & 0 deletions feature/ec2/imds/api_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,16 @@ type Options struct {
// The logger writer interface to write logging messages to.
Logger logging.Logger

// Configure IMDSv1 fallback behavior. By default, the client will attempt
// to fall back to IMDSv1 as needed for backwards compatibility. When set to [aws.FalseTernary]
// the client will return any errors encountered from attempting to fetch a token
// instead of silently using the insecure data flow of IMDSv1.
//
// See [configuring IMDS] for more information.
//
// [configuring IMDS]: https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html
EnableFallback aws.Ternary

// provides the caching of API tokens used for operation calls. If unset,
// the API token will not be retrieved for the operation.
tokenProvider *tokenProvider
Expand Down
78 changes: 71 additions & 7 deletions feature/ec2/imds/request_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"
"encoding/hex"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -330,11 +331,12 @@ func (h *successAPIResponseHandler) ServeHTTP(w http.ResponseWriter, r *http.Req

func TestRequestGetToken(t *testing.T) {
cases := map[string]struct {
GetHandler func(*testing.T) http.Handler
APICallCount int
ExpectTrace []string
ExpectContent []byte
ExpectErr string
GetHandler func(*testing.T) http.Handler
APICallCount int
ExpectTrace []string
ExpectContent []byte
ExpectErr string
EnableFallback aws.Ternary
}{
"secure": {
ExpectTrace: []string{
Expand Down Expand Up @@ -496,8 +498,69 @@ func TestRequestGetToken(t *testing.T) {
}),
))
},
ExpectErr: "failed to get API token",
},

// retryable token error with fallback enabled (default)
"token failure fallback enabled": {
ExpectTrace: []string{
getTokenPath,
getTokenPath,
getTokenPath,
"/latest/foo",
},
APICallCount: 1,
GetHandler: func(t *testing.T) http.Handler {
return newTestServeMux(t,
newInsecureAPIHandler(t,
500,
&successAPIResponseHandler{t: t,
path: "/latest/foo",
method: "GET",
body: []byte("hello"),
},
))
},
ExpectContent: []byte("hello"),
ExpectErr: "EC2 IMDS failed",
},
// retryable token error with fallback disabled
"token failure fallback disabled": {
ExpectTrace: []string{
getTokenPath,
getTokenPath,
getTokenPath,
},
APICallCount: 1,
GetHandler: func(t *testing.T) http.Handler {
return newTestServeMux(t,
newInsecureAPIHandler(t,
500,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Errorf("expected no call to API handler")
http.Error(w, "", 400)
}),
))
},
ExpectErr: "failed to get API token",
EnableFallback: aws.BoolTernary(false),
},
"insecure 403 fallback disabled": {
ExpectTrace: []string{
getTokenPath,
},
APICallCount: 1,
GetHandler: func(t *testing.T) http.Handler {
return newTestServeMux(t,
newInsecureAPIHandler(t,
403,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Errorf("expected no call to API handler")
http.Error(w, "", 400)
}),
))
},
ExpectErr: "failed to get API token",
EnableFallback: aws.BoolTernary(false),
},
}

Expand All @@ -515,7 +578,8 @@ func TestRequestGetToken(t *testing.T) {
defer server.Close()

client := New(Options{
Endpoint: server.URL,
Endpoint: server.URL,
EnableFallback: c.EnableFallback,
})

ctx := context.Background()
Expand Down
82 changes: 53 additions & 29 deletions feature/ec2/imds/token_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@ import (
"context"
"errors"
"fmt"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/smithy-go"
"github.com/aws/smithy-go/logging"
"net/http"
"sync"
"sync/atomic"
"time"

smithy "github.com/aws/smithy-go"
"github.com/aws/smithy-go/middleware"
smithyhttp "github.com/aws/smithy-go/transport/http"
)
Expand Down Expand Up @@ -68,7 +70,7 @@ func (t *tokenProvider) HandleFinalize(
) (
out middleware.FinalizeOutput, metadata middleware.Metadata, err error,
) {
if !t.enabled() {
if t.fallbackEnabled() && !t.enabled() {
// short-circuits to insecure data flow if token provider is disabled.
return next.HandleFinalize(ctx, input)
}
Expand Down Expand Up @@ -115,23 +117,15 @@ func (t *tokenProvider) HandleDeserialize(
}

if resp.StatusCode == http.StatusUnauthorized { // unauthorized
err = &retryableError{Err: err}
t.enable()
err = &retryableError{Err: err, isRetryable: true}
}

return out, metadata, err
}

type retryableError struct {
Err error
}

func (*retryableError) RetryableError() bool { return true }

func (e *retryableError) Error() string { return e.Err.Error() }

func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error) {
if !t.enabled() {
if t.fallbackEnabled() && !t.enabled() {
return nil, &bypassTokenRetrievalError{
Err: fmt.Errorf("cannot get API token, provider disabled"),
}
Expand All @@ -147,7 +141,7 @@ func (t *tokenProvider) getToken(ctx context.Context) (tok *apiToken, err error)

tok, err = t.updateToken(ctx)
if err != nil {
return nil, fmt.Errorf("cannot get API token, %w", err)
return nil, err
}

return tok, nil
Expand All @@ -167,17 +161,19 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
TokenTTL: t.tokenTTL,
})
if err != nil {
// change the disabled flag on token provider to true, when error is request timeout error.
var statusErr interface{ HTTPStatusCode() int }
if errors.As(err, &statusErr) {
switch statusErr.HTTPStatusCode() {

// Disable get token if failed because of 403, 404, or 405
// Disable future get token if failed because of 403, 404, or 405
case http.StatusForbidden,
http.StatusNotFound,
http.StatusMethodNotAllowed:

t.disable()
if t.fallbackEnabled() {
logger := middleware.GetLogger(ctx)
logger.Logf(logging.Warn, "falling back to IMDSv1: %v", err)
t.disable()
}

// 400 errors are terminal, and need to be upstreamed
case http.StatusBadRequest:
Expand All @@ -192,8 +188,17 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
atomic.StoreUint32(&t.disabled, 1)
}

// Token couldn't be retrieved, but bypass this, and allow the
// request to continue.
if !t.fallbackEnabled() {
// NOTE: getToken() is an implementation detail of some outer operation
// (e.g. GetMetadata). It has its own retries that have already been exhausted.
// Mark the underlying error as a terminal error.
err = &retryableError{Err: err, isRetryable: false}
return nil, err
}

// Token couldn't be retrieved, fallback to IMDSv1 insecure flow for this request
// and allow the request to proceed. Future requests _may_ re-attempt fetching a
// token if not disabled.
return nil, &bypassTokenRetrievalError{Err: err}
}

Expand All @@ -206,21 +211,21 @@ func (t *tokenProvider) updateToken(ctx context.Context) (*apiToken, error) {
return tok, nil
}

type bypassTokenRetrievalError struct {
Err error
}

func (e *bypassTokenRetrievalError) Error() string {
return fmt.Sprintf("bypass token retrieval, %v", e.Err)
}

func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }

// enabled returns if the token provider is current enabled or not.
func (t *tokenProvider) enabled() bool {
return atomic.LoadUint32(&t.disabled) == 0
}

// fallbackEnabled returns false if EnableFallback is [aws.FalseTernary], true otherwise
func (t *tokenProvider) fallbackEnabled() bool {
switch t.client.options.EnableFallback {
case aws.FalseTernary:
return false
default:
return true
}
}

// disable disables the token provider and it will no longer attempt to inject
// the token, nor request updates.
func (t *tokenProvider) disable() {
Expand All @@ -235,3 +240,22 @@ func (t *tokenProvider) enable() {
t.tokenMux.Unlock()
atomic.StoreUint32(&t.disabled, 0)
}

type bypassTokenRetrievalError struct {
Err error
}

func (e *bypassTokenRetrievalError) Error() string {
return fmt.Sprintf("bypass token retrieval, %v", e.Err)
}

func (e *bypassTokenRetrievalError) Unwrap() error { return e.Err }

type retryableError struct {
Err error
isRetryable bool
}

func (e *retryableError) RetryableError() bool { return e.isRetryable }

func (e *retryableError) Error() string { return e.Err.Error() }