Skip to content

Commit

Permalink
NewInstanceCredentialsCache returns *aws.CredentialsCache (aws#4431)
Browse files Browse the repository at this point in the history
  • Loading branch information
tinnywang authored Nov 18, 2024
1 parent c24cdae commit ee3bb78
Show file tree
Hide file tree
Showing 9 changed files with 55 additions and 49 deletions.
12 changes: 5 additions & 7 deletions agent/app/agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ type ecsAgent struct {
dockerClient dockerapi.DockerClient
containerInstanceARN string
credentialProvider *aws_credentials.Credentials
credentialsCache awsv2.CredentialsProvider
credentialsCache *awsv2.CredentialsCache
stateManagerFactory factory.StateManager
saveableOptionFactory factory.SaveableOption
pauseLoader loader.Loader
Expand Down Expand Up @@ -234,12 +234,10 @@ func newAgent(blackholeEC2Metadata bool, acceptInsecureCert *bool) (agent, error
metadataManager = containermetadata.NewManager(dockerClient, cfg)
}

credentialsCache := awsv2.NewCredentialsCache(
providers.NewInstanceCredentialsCache(
cfg.External.Enabled(),
providers.NewRotatingSharedCredentialsProviderV2(),
nil,
),
credentialsCache := providers.NewInstanceCredentialsCache(
cfg.External.Enabled(),
providers.NewRotatingSharedCredentialsProviderV2(),
nil,
)
initialSeqNumber := int64(-1)
return &ecsAgent{
Expand Down
24 changes: 12 additions & 12 deletions agent/app/agent_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func TestDoStartRegisterContainerInstanceErrorTerminal(t *testing.T) {
ctx: ctx,
cfg: &cfg,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
dockerClient: dockerClient,
mobyPlugins: mockMobyPlugins,
ec2MetadataClient: mockEC2Metadata,
Expand Down Expand Up @@ -332,7 +332,7 @@ func TestDoStartRegisterContainerInstanceErrorNonTerminal(t *testing.T) {
cfg: &cfg,
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
ec2MetadataClient: mockEC2Metadata,
terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
Expand Down Expand Up @@ -515,7 +515,7 @@ func testDoStartHappyPathWithConditions(t *testing.T, blackholed bool, warmPools
dockerClient: dockerClient,
dataClient: dataClient,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
metadataManager: containermetadata,
terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
Expand Down Expand Up @@ -1015,7 +1015,7 @@ func TestReregisterContainerInstanceHappyPath(t *testing.T) {
cfg: &cfg,
dockerClient: mockDockerClient,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
ec2MetadataClient: mockEC2Metadata,
serviceconnectManager: mockServiceConnectManager,
Expand Down Expand Up @@ -1075,7 +1075,7 @@ func TestReregisterContainerInstanceInstanceTypeChanged(t *testing.T) {
cfg: &cfg,
dockerClient: mockDockerClient,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
ec2MetadataClient: mockEC2Metadata,
mobyPlugins: mockMobyPlugins,
serviceconnectManager: mockServiceConnectManager,
Expand Down Expand Up @@ -1135,7 +1135,7 @@ func TestReregisterContainerInstanceAttributeError(t *testing.T) {
ec2MetadataClient: mockEC2Metadata,
dockerClient: mockDockerClient,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
serviceconnectManager: mockServiceConnectManager,
daemonManagers: mockDaemonManagers,
Expand Down Expand Up @@ -1194,7 +1194,7 @@ func TestReregisterContainerInstanceNonTerminalError(t *testing.T) {
dockerClient: mockDockerClient,
ec2MetadataClient: mockEC2Metadata,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
serviceconnectManager: mockServiceConnectManager,
daemonManagers: mockDaemonManagers,
Expand Down Expand Up @@ -1254,7 +1254,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetHappyPath(t *t
dockerClient: mockDockerClient,
ec2MetadataClient: mockEC2Metadata,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
serviceconnectManager: mockServiceConnectManager,
daemonManagers: mockDaemonManagers,
Expand Down Expand Up @@ -1312,7 +1312,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetCanRetryError(
dockerClient: mockDockerClient,
ec2MetadataClient: mockEC2Metadata,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
serviceconnectManager: mockServiceConnectManager,
daemonManagers: mockDaemonManagers,
Expand Down Expand Up @@ -1370,7 +1370,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetCannotRetryErr
ec2MetadataClient: mockEC2Metadata,
dockerClient: mockDockerClient,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
serviceconnectManager: mockServiceConnectManager,
daemonManagers: mockDaemonManagers,
Expand Down Expand Up @@ -1427,7 +1427,7 @@ func TestRegisterContainerInstanceWhenContainerInstanceARNIsNotSetAttributeError
ec2MetadataClient: mockEC2Metadata,
dockerClient: mockDockerClient,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
mobyPlugins: mockMobyPlugins,
serviceconnectManager: mockServiceConnectManager,
daemonManagers: mockDaemonManagers,
Expand Down Expand Up @@ -1486,7 +1486,7 @@ func TestRegisterContainerInstanceInvalidParameterTerminalError(t *testing.T) {
ec2MetadataClient: mockEC2Metadata,
cfg: &cfg,
pauseLoader: mockPauseLoader,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
dockerClient: dockerClient,
mobyPlugins: mockMobyPlugins,
terminationHandler: func(taskEngineState dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
Expand Down
12 changes: 6 additions & 6 deletions agent/app/agent_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func TestDoStartTaskENIHappyPath(t *testing.T) {
agent := &ecsAgent{
ctx: ctx,
cfg: &cfg,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
dataClient: data.NewNoopClient(),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
Expand Down Expand Up @@ -510,7 +510,7 @@ func TestDoStartCgroupInitHappyPath(t *testing.T) {
agent := &ecsAgent{
ctx: ctx,
cfg: &cfg,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
pauseLoader: mockPauseLoader,
dockerClient: dockerClient,
terminationHandler: func(state dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
Expand Down Expand Up @@ -579,7 +579,7 @@ func TestDoStartCgroupInitErrorPath(t *testing.T) {
agent := &ecsAgent{
ctx: ctx,
cfg: &cfg,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
terminationHandler: func(state dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
Expand Down Expand Up @@ -689,7 +689,7 @@ func TestDoStartGPUManagerHappyPath(t *testing.T) {
agent := &ecsAgent{
ctx: ctx,
cfg: &cfg,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
terminationHandler: func(state dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
Expand Down Expand Up @@ -751,7 +751,7 @@ func TestDoStartGPUManagerInitError(t *testing.T) {
agent := &ecsAgent{
ctx: ctx,
cfg: &cfg,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
terminationHandler: func(state dockerstate.TaskEngineState, dataClient data.Client, taskEngine engine.TaskEngine, cancel context.CancelFunc) {
Expand Down Expand Up @@ -799,7 +799,7 @@ func TestDoStartTaskENIPauseError(t *testing.T) {
agent := &ecsAgent{
ctx: ctx,
cfg: &cfg,
credentialsCache: mockCredentialsProvider,
credentialsCache: awsv2.NewCredentialsCache(mockCredentialsProvider),
dockerClient: dockerClient,
pauseLoader: mockPauseLoader,
cniClient: cniClient,
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ import (
"github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds"
)

type InstanceCredentialsCache struct {
type InstanceCredentialsProvider struct {
providers []aws.CredentialsProvider
}

func (p *InstanceCredentialsCache) Retrieve(ctx context.Context) (aws.Credentials, error) {
func (p *InstanceCredentialsProvider) Retrieve(ctx context.Context) (aws.Credentials, error) {
var errs []error
for _, provider := range p.providers {
creds, err := provider.Retrieve(ctx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,18 @@ func NewInstanceCredentialsCache(
isExternal bool,
rotatingSharedCreds aws.CredentialsProvider,
imdsClient ec2rolecreds.GetMetadataAPIClient,
) *InstanceCredentialsCache {
) *aws.CredentialsCache {
// If imdsClient is nil, the SDK will default to the EC2 IMDS client.
// Pass a non-nil imdsClient to stub it out in tests.
options := func(o *ec2rolecreds.Options) {
o.Client = imdsClient
}
return &InstanceCredentialsCache{
providers: []aws.CredentialsProvider{
defaultCreds(options),
rotatingSharedCreds,
return aws.NewCredentialsCache(
&InstanceCredentialsProvider{
providers: []aws.CredentialsProvider{
defaultCreds(options),
rotatingSharedCreds,
},
},
}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewInstanceCredentialsCache(
isExternal bool,
rotatingSharedCreds aws.CredentialsProvider,
imdsClient ec2rolecreds.GetMetadataAPIClient,
) *InstanceCredentialsCache {
) *aws.CredentialsCache {
var providers []aws.CredentialsProvider

// If imdsClient is nil, the SDK will default to the EC2 IMDS client.
Expand All @@ -73,9 +73,11 @@ func NewInstanceCredentialsCache(
}
}

return &InstanceCredentialsCache{
providers: providers,
}
return aws.NewCredentialsCache(
&InstanceCredentialsProvider{
providers: providers,
},
)
}

var envCreds = aws.CredentialsProviderFunc(func(ctx context.Context) (aws.Credentials, error) {
Expand Down

0 comments on commit ee3bb78

Please sign in to comment.