Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix credential rotation issue for ECS-A Windows #3184

Merged
merged 1 commit into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion agent/app/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ func newAgent(blackholeEC2Metadata bool, acceptInsecureCert *bool) (agent, error
// We instantiate our own credentialProvider for use in acs/tcs. This tries
// to mimic roughly the way it's instantiated by the SDK for a default
// session.
credentialProvider: instancecreds.GetCredentials(),
credentialProvider: instancecreds.GetCredentials(cfg.External.Enabled()),
stateManagerFactory: factory.NewStateManager(),
saveableOptionFactory: factory.NewSaveableOption(),
pauseLoader: pause.New(),
Expand Down
32 changes: 0 additions & 32 deletions agent/credentials/instancecreds/instancecreds.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,42 +16,10 @@ package instancecreds
import (
"sync"

"github.com/aws/amazon-ecs-agent/agent/credentials/providers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/cihub/seelog"
)

var (
credentialChain *credentials.Credentials
mu sync.Mutex
)

// GetCredentials returns the instance credentials chain. This is the default chain
// credentials plus the "rotating shared credentials provider", so credentials will
// be checked in this order:
// 1. Env vars (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY).
// 2. Shared credentials file (https://docs.aws.amazon.com/ses/latest/DeveloperGuide/create-shared-credentials-file.html) (file at ~/.aws/credentials containing access key id and secret access key).
// 3. EC2 role credentials. This is an IAM role that the user specifies when they launch their EC2 container instance (ie ecsInstanceRole (https://docs.aws.amazon.com/AmazonECS/latest/developerguide/instance_IAM_role.html)).
// 4. Rotating shared credentials file located at /rotatingcreds/credentials
func GetCredentials() *credentials.Credentials {
mu.Lock()
if credentialChain == nil {
credProviders := defaults.CredProviders(defaults.Config(), defaults.Handlers())
credProviders = append(credProviders, providers.NewRotatingSharedCredentialsProvider())
credentialChain = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: false,
Providers: credProviders,
})
}
mu.Unlock()

// credentials.Credentials is concurrency-safe, so lock not needed here
v, err := credentialChain.Get()
if err != nil {
seelog.Errorf("Error getting ECS instance credentials from default chain: %s", err)
} else {
seelog.Infof("Successfully got ECS instance credentials from provider: %s", v.ProviderName)
}
return credentialChain
}
52 changes: 52 additions & 0 deletions agent/credentials/instancecreds/instancecreds_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//go:build linux

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package instancecreds

import (
"github.com/aws/amazon-ecs-agent/agent/credentials/providers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/cihub/seelog"
)

// GetCredentials returns the instance credentials chain. This is the default chain
// credentials plus the "rotating shared credentials provider", so credentials will
// be checked in this order:
// 1. Env vars (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY).
// 2. Shared credentials file (https://docs.aws.amazon.com/ses/latest/DeveloperGuide/create-shared-credentials-file.html) (file at ~/.aws/credentials containing access key id and secret access key).
// 3. EC2 role credentials. This is an IAM role that the user specifies when they launch their EC2 container instance (ie ecsInstanceRole (https://docs.aws.amazon.com/AmazonECS/latest/developerguide/instance_IAM_role.html)).
// 4. Rotating shared credentials file located at /rotatingcreds/credentials
func GetCredentials(isExternal bool) *credentials.Credentials {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For Linux mark this as _ since its unused right?

mu.Lock()
if credentialChain == nil {
credProviders := defaults.CredProviders(defaults.Config(), defaults.Handlers())
credProviders = append(credProviders, providers.NewRotatingSharedCredentialsProvider())
credentialChain = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: false,
Providers: credProviders,
})
}
mu.Unlock()

// credentials.Credentials is concurrency-safe, so lock not needed here
v, err := credentialChain.Get()
if err != nil {
seelog.Errorf("Error getting ECS instance credentials from default chain: %s", err)
} else {
seelog.Infof("Successfully got ECS instance credentials from provider: %s", v.ProviderName)
}
return credentialChain
}
9 changes: 4 additions & 5 deletions agent/credentials/instancecreds/instancecreds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,10 @@ import (

func TestGetCredentials(t *testing.T) {
credentialChain = nil
credsA := GetCredentials()
credsA := GetCredentials(false)
require.NotNil(t, credsA)
credsB := GetCredentials()
credsB := GetCredentials(true)
vsiddharth marked this conversation as resolved.
Show resolved Hide resolved
require.NotNil(t, credsB)
require.Equal(t, credsA, credsB)
}

// test that env vars override all other provider types
Expand All @@ -44,7 +43,7 @@ func TestGetCredentials_EnvVars(t *testing.T) {
defer os.Setenv("AWS_ACCESS_KEY_ID", origAKID)
defer os.Setenv("AWS_SECRET_ACCESS_KEY", origSecret)

creds := GetCredentials()
creds := GetCredentials(false)
require.NotNil(t, creds)
v, err := creds.Get()
require.NoError(t, err)
Expand Down Expand Up @@ -81,7 +80,7 @@ aws_secret_access_key = TESTFILESECRET
// reset before exiting
defer os.Setenv("AWS_SHARED_CREDENTIALS_FILE", origEnv)

creds := GetCredentials()
creds := GetCredentials(false)
require.NotNil(t, creds)
v, err := creds.Get()
require.NoError(t, err)
Expand Down
31 changes: 31 additions & 0 deletions agent/credentials/instancecreds/instancecreds_unsupported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//go:build !linux && !windows

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package instancecreds

import (
"github.com/aws/aws-sdk-go/aws/credentials"
)

// GetCredentials returns the instance credentials chain. This is the default chain
// credentials plus the "rotating shared credentials provider", so credentials will
// be checked in this order:
// 1. Env vars (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY).
// 2. Shared credentials file (https://docs.aws.amazon.com/ses/latest/DeveloperGuide/create-shared-credentials-file.html) (file at ~/.aws/credentials containing access key id and secret access key).
// 3. EC2 role credentials. This is an IAM role that the user specifies when they launch their EC2 container instance (ie ecsInstanceRole (https://docs.aws.amazon.com/AmazonECS/latest/developerguide/instance_IAM_role.html)).
// 4. Rotating shared credentials file located at /rotatingcreds/credentials
func GetCredentials(isExternal bool) *credentials.Credentials {
return nil
}
65 changes: 65 additions & 0 deletions agent/credentials/instancecreds/instancecreds_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//go:build windows

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package instancecreds

import (
"github.com/aws/amazon-ecs-agent/agent/credentials/providers"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/defaults"
"github.com/cihub/seelog"
)

// GetCredentials returns the instance credentials chain. This is the default chain
// credentials plus the "rotating shared credentials provider", so credentials will
// be checked in this order:
// 1. Env vars (AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY).
// 2. Shared credentials file (https://docs.aws.amazon.com/ses/latest/DeveloperGuide/create-shared-credentials-file.html) (file at ~/.aws/credentials containing access key id and secret access key).
// 3. EC2 role credentials. This is an IAM role that the user specifies when they launch their EC2 container instance (ie ecsInstanceRole (https://docs.aws.amazon.com/AmazonECS/latest/developerguide/instance_IAM_role.html)).
// 4. Rotating shared credentials file located at /rotatingcreds/credentials
//
// The default credential chain provided by the SDK includes:
// * EnvProvider
// * SharedCredentialsProvider
// * RemoteCredProvider (EC2RoleProvider)
//
// In the case of ECS-A on Windows, the `SharedCredentialsProvider` takes
// precedence over the `RotatingSharedCredentialsProvider` and this results
// in the credentials not being refreshed. To mitigate this issue, we will
// reorder the credential chain and ensure that `RotatingSharedCredentialsProvider`
// takes precedence over the `SharedCredentialsProvider` for ECS-A.
func GetCredentials(isExternal bool) *credentials.Credentials {
mu.Lock()
credProviders := defaults.CredProviders(defaults.Config(), defaults.Handlers())
if isExternal {
credProviders = append(credProviders[:1], append([]credentials.Provider{providers.NewRotatingSharedCredentialsProvider()}, credProviders[1:]...)...)
vsiddharth marked this conversation as resolved.
Show resolved Hide resolved
} else {
credProviders = append(credProviders, providers.NewRotatingSharedCredentialsProvider())
}
credentialChain = credentials.NewCredentials(&credentials.ChainProvider{
VerboseErrors: false,
Providers: credProviders,
})
mu.Unlock()

// credentials.Credentials is concurrency-safe, so lock not needed here
v, err := credentialChain.Get()
if err != nil {
seelog.Errorf("Error getting ECS instance credentials from default chain: %s", err)
} else {
seelog.Infof("Successfully got ECS instance credentials from provider: %s", v.ProviderName)
}
return credentialChain
}
22 changes: 22 additions & 0 deletions agent/credentials/providers/credentials_filename_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//go:build linux
vsiddharth marked this conversation as resolved.
Show resolved Hide resolved

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package providers

const (
// defaultRotatingCredentialsFilename is the default location of the credentials file
// for RotatingSharedCredentialsProvider.
defaultRotatingCredentialsFilename = "/rotatingcreds/credentials"
)
22 changes: 22 additions & 0 deletions agent/credentials/providers/credentials_filename_unsupported.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//go:build !linux && !windows

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package providers

const (
// defaultRotatingCredentialsFilename is the default location of the credentials file
// for RotatingSharedCredentialsProvider.
defaultRotatingCredentialsFilename = "/unsupported/rotatingcreds/credentials"
)
20 changes: 20 additions & 0 deletions agent/credentials/providers/credentials_filename_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
//go:build windows
vsiddharth marked this conversation as resolved.
Show resolved Hide resolved

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package providers

// defaultRotatingCredentialsFilename is the default location of the credentials file
// for RotatingSharedCredentialsProvider.
const defaultRotatingCredentialsFilename = "C:\\Windows\\System32\\config\\systemprofile\\.aws\\credentials"
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ import (
const (
// defaultRotationInterval is how frequently to expire and re-retrieve the credentials from file.
defaultRotationInterval = time.Minute
// defaultFilename is the default location of the credentials file within the container.
defaultFilename = "/rotatingcreds/credentials"
// RotatingSharedCredentialsProviderName is the name of this provider
RotatingSharedCredentialsProviderName = "RotatingSharedCredentialsProvider"
)
Expand All @@ -46,7 +44,7 @@ func NewRotatingSharedCredentialsProvider() *RotatingSharedCredentialsProvider {
return &RotatingSharedCredentialsProvider{
RotationInterval: defaultRotationInterval,
sharedCredentialsProvider: &credentials.SharedCredentialsProvider{
Filename: defaultFilename,
Filename: defaultRotatingCredentialsFilename,
Profile: "default",
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func TestNewRotatingSharedCredentialsProvider(t *testing.T) {
p := NewRotatingSharedCredentialsProvider()
require.Equal(t, time.Minute, p.RotationInterval)
require.Equal(t, "default", p.sharedCredentialsProvider.Profile)
require.Equal(t, "/rotatingcreds/credentials", p.sharedCredentialsProvider.Filename)
require.Equal(t, defaultRotatingCredentialsFilename, p.sharedCredentialsProvider.Filename)
}

func TestRotatingSharedCredentialsProvider_RetrieveFail_BadPath(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion agent/ec2/ec2_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ type ClientImpl struct {
func NewClientImpl(awsRegion string) Client {
ec2Config := aws.NewConfig().WithMaxRetries(clientRetriesNum)
ec2Config.Region = aws.String(awsRegion)
ec2Config.Credentials = instancecreds.GetCredentials()
ec2Config.Credentials = instancecreds.GetCredentials(false)
client := ec2sdk.New(session.New(), ec2Config)
return &ClientImpl{
client: client,
Expand Down
2 changes: 1 addition & 1 deletion agent/ec2/ec2_metadata_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ type ec2MetadataClientImpl struct {
func NewEC2MetadataClient(client HttpClient) EC2MetadataClient {
if client == nil {
config := aws.NewConfig().WithMaxRetries(metadataRetries)
config.Credentials = instancecreds.GetCredentials()
config.Credentials = instancecreds.GetCredentials(false)
return &ec2MetadataClientImpl{
client: ec2metadata.New(session.New(), config),
}
Expand Down
2 changes: 1 addition & 1 deletion agent/ecr/factory.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func getClientConfig(httpClient *http.Client, authData *apicontainer.ECRAuthData
authData.GetPullCredentials().SessionToken)
cfg = cfg.WithCredentials(creds)
} else {
cfg = cfg.WithCredentials(instancecreds.GetCredentials())
cfg = cfg.WithCredentials(instancecreds.GetCredentials(false))
}

return cfg, nil
Expand Down