Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor ACS refresh credentials message handling #3830

Merged
merged 1 commit into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 9 additions & 18 deletions agent/acs/handler/acs_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import (
rolecredentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/doctor"
"github.com/aws/amazon-ecs-agent/ecs-agent/eventstream"
"github.com/aws/amazon-ecs-agent/ecs-agent/metrics"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/retry"
"github.com/aws/amazon-ecs-agent/ecs-agent/utils/ttime"
"github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"
Expand Down Expand Up @@ -78,8 +79,8 @@ const (
acsProtocolVersion = 2
// numOfHandlersSendingAcks is the number of handlers that send acks back to ACS and that are not saved across
// sessions. We use this to send pending acks, before agent initiates a disconnect to ACS.
// they are: refreshCredentialsHandler, taskManifestHandler, and payloadHandler
numOfHandlersSendingAcks = 3
// they are: taskManifestHandler, and payloadHandler
numOfHandlersSendingAcks = 2
amogh09 marked this conversation as resolved.
Show resolved Hide resolved
)

// Session defines an interface for handler's long-lived connection with ACS.
Expand Down Expand Up @@ -250,13 +251,7 @@ func (acsSession *session) startSessionOnce() error {
func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
cfg := acsSession.agentConfig

refreshCredsHandler := newRefreshCredentialsHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN,
client, acsSession.credentialsManager, acsSession.taskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()
defer refreshCredsHandler.stop()

client.AddRequestHandler(refreshCredsHandler.handlerFunc())
credsMetadataSetter := &credentialsMetadataSetter{taskEngine: acsSession.taskEngine}

eniHandler := &eniHandler{
state: acsSession.state,
Expand All @@ -265,6 +260,8 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {

manifestMessageIDAccessor := &manifestMessageIDAccessor{}

metricsFactory := metrics.NewNopEntryFactory()

// Add TaskManifestHandler
taskManifestHandler := newTaskManifestHandler(acsSession.ctx, cfg.Cluster, acsSession.containerInstanceARN,
client, acsSession.dataClient, acsSession.taskEngine, acsSession.latestSeqNumTaskManifest,
Expand All @@ -286,7 +283,6 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
acsSession.containerInstanceARN,
client,
acsSession.dataClient,
refreshCredsHandler,
acsSession.credentialsManager,
acsSession.taskHandler, acsSession.latestSeqNumTaskManifest)
// Clear the acks channel on return because acks of messageids don't have any value across sessions
Expand All @@ -300,6 +296,8 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
return client.MakeRequest(response)
}
responders := []wsclient.RequestResponder{
acssession.NewRefreshCredentialsResponder(acsSession.credentialsManager, credsMetadataSetter, metricsFactory,
responseSender),
acssession.NewAttachTaskENIResponder(eniHandler, responseSender),
acssession.NewAttachInstanceENIResponder(eniHandler, responseSender),
acssession.NewHeartbeatResponder(acsSession.doctor, responseSender),
Expand All @@ -320,7 +318,7 @@ func (acsSession *session) startACSSession(client wsclient.ClientServer) error {
// Start a connection timer; agent will send pending acks and close its ACS websocket connection
// after this timer expires
connectionTimer := newConnectionTimer(client, acsSession.connectionTime, acsSession.connectionJitter,
&refreshCredsHandler, &taskManifestHandler, &payloadHandler)
&taskManifestHandler, &payloadHandler)
defer connectionTimer.Stop()

// Start a heartbeat timer for closing the connection
Expand Down Expand Up @@ -416,7 +414,6 @@ func newConnectionTimer(
client wsclient.ClientServer,
connectionTime time.Duration,
connectionJitter time.Duration,
refreshCredsHandler *refreshCredentialsHandler,
taskManifestHandler *taskManifestHandler,
payloadHandler *payloadRequestHandler,
) ttime.Timer {
Expand All @@ -427,12 +424,6 @@ func newConnectionTimer(
wg := sync.WaitGroup{}
wg.Add(numOfHandlersSendingAcks)

// send pending creds refresh acks to ACS
go func() {
refreshCredsHandler.sendPendingAcks()
wg.Done()
}()

// send pending task manifest acks and task stop verification acks to ACS
go func() {
taskManifestHandler.sendPendingTaskManifestMessageAck()
Expand Down
4 changes: 2 additions & 2 deletions agent/acs/handler/acs_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1162,8 +1162,6 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) {
// Ensure that credentials manager interface methods are invoked in the
// correct order, with expected arguments
gomock.InOrder(
// Return a task from the engine for GetTaskByArn
taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true),
// The last invocation of SetCredentials is to update
// credentials when a refresh message is received by the handler
credentialsManager.EXPECT().SetTaskCredentials(gomock.Any()).Do(func(creds *rolecredentials.TaskIAMRoleCredentials) {
Expand All @@ -1185,6 +1183,8 @@ func TestStartSessionHandlesRefreshCredentialsMessages(t *testing.T) {
t.Errorf("Mismatch between expected and credentials expected: %v, added: %v", expectedCreds, updatedCredentials)
}
}).Return(nil),
// Return a task from the engine for GetTaskByArn
taskEngine.EXPECT().GetTaskByArn("t1").Return(taskFromEngine, true),
)
serverIn <- sampleRefreshCredentialsMessage

Expand Down
8 changes: 4 additions & 4 deletions agent/acs/handler/attach_eni_handler_common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func testENIAckTimeout(t *testing.T, attachmentType string) {
expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis)
eniAttachment := &apieni.ENIAttachment{
AttachmentInfo: attachmentinfo.AttachmentInfo{
TaskARN: taskArn,
TaskARN: testconst.TaskARN,
AttachmentARN: attachmentArn,
ExpiresAt: expiresAt,
AttachStatusSent: false,
Expand Down Expand Up @@ -103,7 +103,7 @@ func testENIAckWithinTimeout(t *testing.T, attachmentType string) {
expiresAt := time.Now().Add(time.Millisecond * testconst.WaitTimeoutMillis)
eniAttachment := &apieni.ENIAttachment{
AttachmentInfo: attachmentinfo.AttachmentInfo{
TaskARN: taskArn,
TaskARN: testconst.TaskARN,
AttachmentARN: attachmentArn,
ExpiresAt: expiresAt,
AttachStatusSent: false,
Expand All @@ -130,7 +130,7 @@ func testENIAckWithinTimeout(t *testing.T, attachmentType string) {

// TestHandleENIAttachmentTaskENI tests handling a new task eni
func TestHandleENIAttachmentTaskENI(t *testing.T) {
testHandleENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, taskArn)
testHandleENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, testconst.TaskARN)
}

// TestHandleENIAttachmentInstanceENI tests handling a new instance eni
Expand Down Expand Up @@ -178,7 +178,7 @@ func testHandleENIAttachment(t *testing.T, attachmentType, taskArn string) {

// TestHandleExpiredENIAttachmentTaskENI tests handling an expired task eni
func TestHandleExpiredENIAttachmentTaskENI(t *testing.T) {
testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, taskArn)
testHandleExpiredENIAttachment(t, apieni.ENIAttachmentTypeTaskENI, testconst.TaskARN)
}

// TestHandleExpiredENIAttachmentInstanceENI tests handling an expired instance eni
Expand Down
15 changes: 11 additions & 4 deletions agent/acs/handler/payload_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ type payloadRequestHandler struct {
cluster string
containerInstanceArn string
acsClient wsclient.ClientServer
refreshHandler refreshCredentialsHandler
credentialsManager credentials.Manager
latestSeqNumberTaskManifest *int64
}
Expand All @@ -65,7 +64,6 @@ func newPayloadRequestHandler(
containerInstanceArn string,
acsClient wsclient.ClientServer,
dataClient data.Client,
refreshHandler refreshCredentialsHandler,
credentialsManager credentials.Manager,
taskHandler *eventhandler.TaskHandler, seqNumTaskManifest *int64) payloadRequestHandler {
// Create a cancelable context from the parent context
Expand All @@ -82,7 +80,6 @@ func newPayloadRequestHandler(
cluster: cluster,
containerInstanceArn: containerInstanceArn,
acsClient: acsClient,
refreshHandler: refreshHandler,
credentialsManager: credentialsManager,
latestSeqNumberTaskManifest: seqNumTaskManifest,
}
Expand Down Expand Up @@ -187,14 +184,24 @@ func (payloadHandler *payloadRequestHandler) handleSingleMessage(payload *ecsacs
go func() {
// Throw the ack in async; it doesn't really matter all that much and this is blocking handling more tasks.
for _, credentialsAck := range credentialsAcks {
payloadHandler.refreshHandler.ackMessage(credentialsAck)
payloadHandler.makeCredentialsAckRequest(credentialsAck)
}
payloadHandler.ackRequest <- *payload.MessageId
}()

return nil
}

// makeCredentialsAckRequest sends an IAMRoleCredentialsAckRequest to the backend
func (payloadHandler *payloadRequestHandler) makeCredentialsAckRequest(ack *ecsacs.IAMRoleCredentialsAckRequest) {
seelog.Debugf("ACKing credentials associated with ACS payload message: %s", ack.String())
err := payloadHandler.acsClient.MakeRequest(ack)
if err != nil {
seelog.Warnf("Error ACKing credentials with credentialsID '%s' associated with ACS payload message, error: %v",
aws.StringValue(ack.CredentialsId), err)
}
}

// addPayloadTasks does validation on each task and, for all valid ones, adds
// it to the task engine. It returns a bool indicating if it could add every
// task to the taskEngine and a slice of credential ack requests
Expand Down
23 changes: 4 additions & 19 deletions agent/acs/handler/payload_handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ func setup(t *testing.T) *testHelper {
testconst.ContainerInstanceARN,
mockWsClient,
data.NewNoopClient(),
refreshCredentialsHandler{},
credentialsManager,
taskHandler, &latestSeqNumberTaskManifest)

Expand Down Expand Up @@ -307,11 +306,6 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) {
}),
)

refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()
tester.payloadHandler.refreshHandler = refreshCredsHandler

go tester.payloadHandler.start()

taskArn := "t1"
Expand All @@ -338,8 +332,8 @@ func TestHandlePayloadMessageCredentialsAckedWhenTaskAdded(t *testing.T) {
},
},
MessageId: aws.String(payloadMessageId),
ClusterArn: aws.String(cluster),
ContainerInstanceArn: aws.String(containerInstance),
ClusterArn: aws.String(testconst.ClusterName),
ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN),
}
err := tester.payloadHandler.handleSingleMessage(payloadMessage)
assert.NoError(t, err, "error handling payload message")
Expand Down Expand Up @@ -496,11 +490,6 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
}),
)

refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()
tester.payloadHandler.refreshHandler = refreshCredsHandler

go tester.payloadHandler.start()

firstTaskArn := "t1"
Expand Down Expand Up @@ -546,8 +535,8 @@ func TestPayloadBufferHandlerWithCredentials(t *testing.T) {
},
},
MessageId: aws.String(payloadMessageId),
ClusterArn: aws.String(cluster),
ContainerInstanceArn: aws.String(containerInstance),
ClusterArn: aws.String(testconst.ClusterName),
ContainerInstanceArn: aws.String(testconst.ContainerInstanceARN),
}

// Wait till we get an ack
Expand Down Expand Up @@ -618,11 +607,7 @@ func TestAddPayloadTaskAddsExecutionRoles(t *testing.T) {
tester.cancel()
}),
)
refreshCredsHandler := newRefreshCredentialsHandler(tester.ctx, testconst.ClusterName, testconst.ContainerInstanceARN, tester.mockWsClient, tester.credentialsManager, tester.mockTaskEngine)
defer refreshCredsHandler.clearAcks()
refreshCredsHandler.start()

tester.payloadHandler.refreshHandler = refreshCredsHandler
go tester.payloadHandler.start()
taskArn := "t1"
credentialsExpiration := "expiration"
Expand Down
Loading