diff --git a/agent/acs/session/task_manifest_responder_test.go b/agent/acs/session/task_manifest_responder_test.go index e3f1023d08d..86f22c5afcb 100644 --- a/agent/acs/session/task_manifest_responder_test.go +++ b/agent/acs/session/task_manifest_responder_test.go @@ -34,11 +34,14 @@ import ( ) const ( - initialSeqNum = 11 - nextSeqNum = 12 - taskARN1 = "arn1" - taskARN2 = "arn2" - taskARN3 = "arn3" + initialSeqNum = 11 + nextSeqNum = 12 + taskARN1 = "arn1" + taskARN2 = "arn2" + taskARN3 = "arn3" + containerName1 = "name1" + containerName2 = "name2" + containerName3 = "name3" ) var expectedTaskManifestAck = &ecsacs.AckRequest{ diff --git a/agent/acs/session/task_stop_verification_ack_responder.go b/agent/acs/session/task_stop_verification_ack_responder.go index b9132239896..96420bc7a4e 100644 --- a/agent/acs/session/task_stop_verification_ack_responder.go +++ b/agent/acs/session/task_stop_verification_ack_responder.go @@ -14,6 +14,7 @@ package session import ( + "github.com/aws/amazon-ecs-agent/agent/data" "github.com/aws/amazon-ecs-agent/agent/engine" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" @@ -23,12 +24,14 @@ import ( // taskStopper implements the TaskStopper interface defined in ecs-agent module. type taskStopper struct { taskEngine engine.TaskEngine + dataClient data.Client } // NewTaskStopper creates a new taskStopper. -func NewTaskStopper(taskEngine engine.TaskEngine) *taskStopper { +func NewTaskStopper(taskEngine engine.TaskEngine, dataClient data.Client) *taskStopper { return &taskStopper{ taskEngine: taskEngine, + dataClient: dataClient, } } @@ -39,7 +42,13 @@ func (ts *taskStopper) StopTask(taskARN string) { loggerfield.TaskARN: task.Arn, }) task.SetDesiredStatus(apitaskstatus.TaskStopped) - ts.taskEngine.AddTask(task) + task.UpdateDesiredStatus() + if err := ts.dataClient.SaveTask(task); err != nil { + logger.Error("Failed to save data for task", logger.Fields{ + loggerfield.TaskARN: task.Arn, + loggerfield.Error: err, + }) + } } else { logger.Debug("Task from task stop verification ACK not found on the instance", logger.Fields{ loggerfield.TaskARN: taskARN, diff --git a/agent/acs/session/task_stop_verification_ack_responder_test.go b/agent/acs/session/task_stop_verification_ack_responder_test.go index cf12c03c866..83127ab2fe7 100644 --- a/agent/acs/session/task_stop_verification_ack_responder_test.go +++ b/agent/acs/session/task_stop_verification_ack_responder_test.go @@ -18,11 +18,14 @@ package session import ( "testing" + "github.com/aws/amazon-ecs-agent/agent/api/container" "github.com/aws/amazon-ecs-agent/agent/api/task" + "github.com/aws/amazon-ecs-agent/agent/data" mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" + apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" "github.com/aws/amazon-ecs-agent/ecs-agent/wsclient" @@ -44,7 +47,7 @@ func setupTaskStopVerificationAckTest(t *testing.T) *taskStopVerificationAckTest manifestMessageIDAccessor := NewManifestMessageIDAccessor() manifestMessageIDAccessor.SetMessageID(testconst.MessageID) taskStopVerificationAckResponder := acssession.NewTaskStopVerificationACKResponder( - NewTaskStopper(taskEngine), + NewTaskStopper(taskEngine, data.NewNoopClient()), manifestMessageIDAccessor, metrics.NewNopEntryFactory()) @@ -58,9 +61,27 @@ func setupTaskStopVerificationAckTest(t *testing.T) *taskStopVerificationAckTest // defaultTasksOnInstance returns a baseline map of tasks that simulates/tracks the tasks on an instance. func defaultTasksOnInstance() map[string]*task.Task { return map[string]*task.Task{ - taskARN1: {Arn: taskARN1, DesiredStatusUnsafe: apitaskstatus.TaskRunning}, - taskARN2: {Arn: taskARN2, DesiredStatusUnsafe: apitaskstatus.TaskRunning}, - taskARN3: {Arn: taskARN3, DesiredStatusUnsafe: apitaskstatus.TaskRunning}, + taskARN1: {Arn: taskARN1, DesiredStatusUnsafe: apitaskstatus.TaskRunning, + Containers: []*container.Container{ + { + Name: containerName1, + DesiredStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + }}, + taskARN2: {Arn: taskARN2, DesiredStatusUnsafe: apitaskstatus.TaskRunning, + Containers: []*container.Container{ + { + Name: containerName2, + DesiredStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + }}, + taskARN3: {Arn: taskARN3, DesiredStatusUnsafe: apitaskstatus.TaskRunning, + Containers: []*container.Container{ + { + Name: containerName3, + DesiredStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + }}, } } @@ -99,21 +120,24 @@ func TestTaskStopVerificationAckResponderStopsMultipleTasks(t *testing.T) { tester.taskEngine.EXPECT().GetTaskByArn(taskARN2).Return(tasksOnInstance[taskARN2], true) tester.taskEngine.EXPECT().GetTaskByArn(taskARN3).Return(tasksOnInstance[taskARN3], true) - tester.taskEngine.EXPECT().AddTask(tasksOnInstance[taskARN2]).Do(func(task *task.Task) { - task.SetDesiredStatus(apitaskstatus.TaskStopped) - }) - tester.taskEngine.EXPECT().AddTask(tasksOnInstance[taskARN3]).Do(func(task *task.Task) { - task.SetDesiredStatus(apitaskstatus.TaskStopped) - }) handleTaskStopVerificationAck := tester.taskStopVerificationAckResponder.HandlerFunc().(func(message *ecsacs.TaskStopVerificationAck)) handleTaskStopVerificationAck(taskStopVerificationAck) - // Only task2 and task3 should be stopped. + // Only task2 and task3 and their containers should be stopped. assert.Equal(t, apitaskstatus.TaskRunning, tasksOnInstance[taskARN1].GetDesiredStatus()) + container1, ok := tasksOnInstance[taskARN1].ContainerByName(containerName1) + assert.True(t, ok) + assert.Equal(t, apicontainerstatus.ContainerRunning, container1.GetDesiredStatus()) assert.Equal(t, apitaskstatus.TaskStopped, tasksOnInstance[taskARN2].GetDesiredStatus()) + container2, ok := tasksOnInstance[taskARN2].ContainerByName(containerName2) + assert.True(t, ok) + assert.Equal(t, apicontainerstatus.ContainerStopped, container2.GetDesiredStatus()) assert.Equal(t, apitaskstatus.TaskStopped, tasksOnInstance[taskARN3].GetDesiredStatus()) + container3, ok := tasksOnInstance[taskARN3].ContainerByName(containerName3) + assert.True(t, ok) + assert.Equal(t, apicontainerstatus.ContainerStopped, container3.GetDesiredStatus()) } @@ -149,22 +173,15 @@ func TestTaskStopVerificationAckResponderStopsAllTasks(t *testing.T) { tester.taskEngine.EXPECT().GetTaskByArn(taskARN1).Return(tasksOnInstance[taskARN1], true) tester.taskEngine.EXPECT().GetTaskByArn(taskARN2).Return(tasksOnInstance[taskARN2], true) tester.taskEngine.EXPECT().GetTaskByArn(taskARN3).Return(tasksOnInstance[taskARN3], true) - tester.taskEngine.EXPECT().AddTask(tasksOnInstance[taskARN1]).Do(func(task *task.Task) { - task.SetDesiredStatus(apitaskstatus.TaskStopped) - }) - tester.taskEngine.EXPECT().AddTask(tasksOnInstance[taskARN2]).Do(func(task *task.Task) { - task.SetDesiredStatus(apitaskstatus.TaskStopped) - }) - tester.taskEngine.EXPECT().AddTask(tasksOnInstance[taskARN3]).Do(func(task *task.Task) { - task.SetDesiredStatus(apitaskstatus.TaskStopped) - }) handleTaskStopVerificationAck := tester.taskStopVerificationAckResponder.HandlerFunc().(func(message *ecsacs.TaskStopVerificationAck)) handleTaskStopVerificationAck(taskStopVerificationAck) - // All tasks on instance should be stopped. + // All tasks and containers on instance should be stopped. for _, task := range tasksOnInstance { assert.Equal(t, apitaskstatus.TaskStopped, task.GetDesiredStatus()) + assert.Equal(t, 1, len(task.Containers)) + assert.Equal(t, apicontainerstatus.ContainerStopped, task.Containers[0].GetDesiredStatus()) } } diff --git a/agent/app/agent.go b/agent/app/agent.go index 5cfa2b52bd6..9bc4087278f 100644 --- a/agent/app/agent.go +++ b/agent/app/agent.go @@ -1044,7 +1044,7 @@ func (agent *ecsAgent) startACSSession( manifestMessageIDAccessor := agentacs.NewManifestMessageIDAccessor() sequenceNumberAccessor := agentacs.NewSequenceNumberAccessor(agent.latestSeqNumberTaskManifest, agent.dataClient) taskComparer := agentacs.NewTaskComparer(taskEngine) - taskStopper := agentacs.NewTaskStopper(taskEngine) + taskStopper := agentacs.NewTaskStopper(taskEngine, agent.dataClient) acsSession := session.NewSession(agent.containerInstanceARN, agent.cfg.Cluster,