diff --git a/agent/acs/session/task_stop_verification_ack_responder_integ_test.go b/agent/acs/session/task_stop_verification_ack_responder_integ_test.go new file mode 100644 index 00000000000..408c87aeedb --- /dev/null +++ b/agent/acs/session/task_stop_verification_ack_responder_integ_test.go @@ -0,0 +1,227 @@ +//go:build integration +// +build integration + +package session_test + +import ( + "context" + "fmt" + "testing" + + "github.com/aws/amazon-ecs-agent/agent/acs/session" + "github.com/aws/amazon-ecs-agent/agent/api/container" + apitask "github.com/aws/amazon-ecs-agent/agent/api/task" + "github.com/aws/amazon-ecs-agent/agent/data" + "github.com/aws/amazon-ecs-agent/agent/engine" + "github.com/aws/amazon-ecs-agent/agent/taskresource" + "github.com/aws/amazon-ecs-agent/agent/taskresource/envFiles" + resourcestatus "github.com/aws/amazon-ecs-agent/agent/taskresource/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs" + acssession "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session" + "github.com/aws/amazon-ecs-agent/ecs-agent/acs/session/testconst" + apicontainerstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/container/status" + apitaskstatus "github.com/aws/amazon-ecs-agent/ecs-agent/api/task/status" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/aws-sdk-go/aws" + "github.com/stretchr/testify/require" +) + +// Tests that a task, its containers, and its resources are all stopped when a task stop verification ACK message is received. +func TestTaskStopVerificationACKResponder_StopsTaskContainersAndResources(t *testing.T) { + taskEngine, done, dockerClient, _ := engine.SetupIntegTestTaskEngine(engine.DefaultTestConfigIntegTest(), nil, t) + defer done() + + task := engine.CreateTestTask("test_task") + createEnvironmentFileResources(task, 3) + createLongRunningContainers(task, 3) + go taskEngine.AddTask(task) + + for i := 0; i < len(task.Containers); i++ { + engine.VerifyContainerManifestPulledStateChange(t, taskEngine) + } + engine.VerifyTaskManifestPulledStateChange(t, taskEngine) + for i := 0; i < len(task.Containers); i++ { + engine.VerifyContainerRunningStateChange(t, taskEngine) + } + engine.VerifyTaskRunningStateChange(t, taskEngine) + + manifestMessageIDAccessor := session.NewManifestMessageIDAccessor() + require.NoError(t, manifestMessageIDAccessor.SetMessageID("manifest_message_id")) + + taskStopper := session.NewTaskStopper(taskEngine, data.NewNoopClient()) + responder := acssession.NewTaskStopVerificationACKResponder(taskStopper, manifestMessageIDAccessor, metrics.NewNopEntryFactory()) + + handler := responder.HandlerFunc().(func(*ecsacs.TaskStopVerificationAck)) + handler(&ecsacs.TaskStopVerificationAck{ + GeneratedAt: aws.Int64(testconst.DummyInt), + MessageId: aws.String(manifestMessageIDAccessor.GetMessageID()), + StopTasks: []*ecsacs.TaskIdentifier{{TaskArn: aws.String(task.Arn)}}, + }) + + // Wait for all state changes before verifying container, resource, and task statuses. + for i := 0; i < len(task.Containers); i++ { + engine.VerifyContainerStoppedStateChange(t, taskEngine) + } + engine.VerifyTaskStoppedStateChange(t, taskEngine) + + // Verify that all the task's containers have stopped. + for _, container := range task.Containers { + status, _ := dockerClient.DescribeContainer(context.Background(), container.RuntimeID) + require.Equal(t, apicontainerstatus.ContainerStopped, status) + } + // Verify that all the tasks's resources have been removed. + for _, resource := range task.GetResources() { + require.Equal(t, resourcestatus.ResourceRemoved, resource.GetKnownStatus()) + } + // Verify that the task has stopped. + require.Equal(t, apitaskstatus.TaskStopped, task.GetKnownStatus()) +} + +// Tests that only the tasks specified in the task stop verification ACK message are stopped. +func TestTaskStopVerificationACKResponder_StopsSpecificTasks(t *testing.T) { + taskEngine, done, dockerClient, _ := engine.SetupIntegTestTaskEngine(engine.DefaultTestConfigIntegTest(), nil, t) + defer done() + + var tasks []*apitask.Task + for i := 0; i < 3; i++ { + task := engine.CreateTestTask(fmt.Sprintf("test_task_%d", i)) + createLongRunningContainers(task, 1) + go taskEngine.AddTask(task) + + engine.VerifyContainerManifestPulledStateChange(t, taskEngine) + engine.VerifyTaskManifestPulledStateChange(t, taskEngine) + engine.VerifyContainerRunningStateChange(t, taskEngine) + engine.VerifyTaskRunningStateChange(t, taskEngine) + tasks = append(tasks, task) + } + + manifestMessageIDAccessor := session.NewManifestMessageIDAccessor() + require.NoError(t, manifestMessageIDAccessor.SetMessageID("manifest_message_id")) + + taskStopper := session.NewTaskStopper(taskEngine, data.NewNoopClient()) + responder := acssession.NewTaskStopVerificationACKResponder(taskStopper, manifestMessageIDAccessor, metrics.NewNopEntryFactory()) + + // Stop the last 2 tasks. + handler := responder.HandlerFunc().(func(*ecsacs.TaskStopVerificationAck)) + handler(&ecsacs.TaskStopVerificationAck{ + GeneratedAt: aws.Int64(testconst.DummyInt), + MessageId: aws.String(manifestMessageIDAccessor.GetMessageID()), + StopTasks: []*ecsacs.TaskIdentifier{ + {TaskArn: aws.String(tasks[1].Arn)}, + {TaskArn: aws.String(tasks[2].Arn)}, + }, + }) + + // Wait for all state changes before verifying container and task statuses. + for i := 0; i < 2; i++ { + engine.VerifyContainerStoppedStateChange(t, taskEngine) + engine.VerifyTaskStoppedStateChange(t, taskEngine) + } + + // Verify that the last 2 tasks and their containers have stopped. + for _, task := range tasks[1:] { + status, _ := dockerClient.DescribeContainer(context.Background(), task.Containers[0].RuntimeID) + require.Equal(t, apicontainerstatus.ContainerStopped, status) + require.Equal(t, apitaskstatus.TaskStopped, task.GetKnownStatus()) + } + + // Verify that the first task and its container are still running. + status, _ := dockerClient.DescribeContainer(context.Background(), tasks[0].Containers[0].RuntimeID) + require.Equal(t, apicontainerstatus.ContainerRunning, status) + require.Equal(t, apitaskstatus.TaskRunning, tasks[0].GetKnownStatus()) +} + +// Tests simple test cases, such as the happy path for 1 task with 1 container and edge cases where no tasks are stopped. +func TestTaskStopVerificationACKResponder(t *testing.T) { + testCases := []struct { + description string + messageID string + taskArn string + stopTaskArn string + shouldStop bool + }{ + { + description: "stops a task", + messageID: "manifest_message_id", + taskArn: "test_task", + stopTaskArn: "test_task", + shouldStop: true, + }, + { + description: "task not found", + messageID: "manifest_message_id", + taskArn: "test_task", + stopTaskArn: "not_found_task", + }, + { + description: "invalid message id", + taskArn: "test_task", + stopTaskArn: "test_task", + }, + } + + for _, tc := range testCases { + t.Run(tc.description, func(t *testing.T) { + taskEngine, done, dockerClient, _ := engine.SetupIntegTestTaskEngine(engine.DefaultTestConfigIntegTest(), nil, t) + defer done() + + task := engine.CreateTestTask(tc.taskArn) + createLongRunningContainers(task, 1) + go taskEngine.AddTask(task) + + engine.VerifyContainerManifestPulledStateChange(t, taskEngine) + engine.VerifyTaskManifestPulledStateChange(t, taskEngine) + engine.VerifyContainerRunningStateChange(t, taskEngine) + engine.VerifyTaskRunningStateChange(t, taskEngine) + + manifestMessageIDAccessor := session.NewManifestMessageIDAccessor() + manifestMessageIDAccessor.SetMessageID(tc.messageID) + + taskStopper := session.NewTaskStopper(taskEngine, data.NewNoopClient()) + responder := acssession.NewTaskStopVerificationACKResponder(taskStopper, manifestMessageIDAccessor, metrics.NewNopEntryFactory()) + + handler := responder.HandlerFunc().(func(*ecsacs.TaskStopVerificationAck)) + handler(&ecsacs.TaskStopVerificationAck{ + GeneratedAt: aws.Int64(testconst.DummyInt), + MessageId: aws.String(manifestMessageIDAccessor.GetMessageID()), + StopTasks: []*ecsacs.TaskIdentifier{ + {TaskArn: aws.String(tc.stopTaskArn)}, + }, + }) + + if tc.shouldStop { + engine.VerifyContainerStoppedStateChange(t, taskEngine) + engine.VerifyTaskStoppedStateChange(t, taskEngine) + + status, _ := dockerClient.DescribeContainer(context.Background(), task.Containers[0].RuntimeID) + require.Equal(t, apicontainerstatus.ContainerStopped, status) + require.Equal(t, apitaskstatus.TaskStopped, task.GetKnownStatus()) + } else { + status, _ := dockerClient.DescribeContainer(context.Background(), task.Containers[0].RuntimeID) + require.Equal(t, apicontainerstatus.ContainerRunning, status) + require.Equal(t, apitaskstatus.TaskRunning, task.GetKnownStatus()) + } + }) + } +} + +func createEnvironmentFileResources(task *apitask.Task, n int) { + task.ResourcesMapUnsafe = make(map[string][]taskresource.TaskResource) + for i := 0; i < n; i++ { + envFile := &envFiles.EnvironmentFileResource{} + // Set known status to ResourceCreated to avoid downloading files from S3. + envFile.SetKnownStatus(resourcestatus.ResourceCreated) + task.AddResource(envFiles.ResourceName, envFile) + } +} + +func createLongRunningContainers(task *apitask.Task, n int) { + var containers []*container.Container + for i := 0; i < n; i++ { + container := engine.CreateTestContainer() + container.Command = engine.GetLongRunningCommand() + container.Name = fmt.Sprintf("%s-%d", container.Name, i) + containers = append(containers, container) + } + task.Containers = containers +} diff --git a/agent/engine/common_integ_testutil.go b/agent/engine/common_integ_testutil.go index 02a604e3632..86a6b57394f 100644 --- a/agent/engine/common_integ_testutil.go +++ b/agent/engine/common_integ_testutil.go @@ -52,6 +52,11 @@ var ( sdkClientFactory sdkclientfactory.Factory ) +const ( + taskSteadyStatePollInterval = 100 * time.Millisecond + taskSteadyStatePollIntervalJitter = 10 * time.Millisecond +) + func init() { sdkClientFactory = sdkclientfactory.NewFactory(context.TODO(), dockerEndpoint) } @@ -69,7 +74,7 @@ func CreateTestTask(arn string) *apitask.Task { Family: "family", Version: "1", DesiredStatusUnsafe: apitaskstatus.TaskRunning, - Containers: []*apicontainer.Container{createTestContainer()}, + Containers: []*apicontainer.Container{CreateTestContainer()}, } } @@ -119,6 +124,10 @@ func setupGMSALinux(cfg *config.Config, state dockerstate.TaskEngineState, t *te taskEngine := NewDockerTaskEngine(cfg, dockerClient, credentialsManager, eventstream.NewEventStream("ENGINEINTEGTEST", context.Background()), imageManager, &hostResourceManager, state, metadataManager, resourceFields, execcmd.NewManager(), engineserviceconnect.NewManager(), daemonManagers) + // Set the steady state poll interval to a low value so that tasks transition from their current state to their + // desired state faster. This prevents tests from appearing to hang while waiting for state change events. + taskEngine.taskSteadyStatePollInterval = taskSteadyStatePollInterval + taskEngine.taskSteadyStatePollIntervalJitter = taskSteadyStatePollIntervalJitter taskEngine.MustInit(context.TODO()) return taskEngine, func() { taskEngine.Shutdown() @@ -159,7 +168,7 @@ func VerifyTaskManifestPulledStateChange(t *testing.T, taskEngine TaskEngine) { func VerifyContainerRunningStateChange(t *testing.T, taskEngine TaskEngine) { stateChangeEvents := taskEngine.StateChangeEvents() event := <-stateChangeEvents - assert.Equal(t, event.(api.ContainerStateChange).Status, apicontainerstatus.ContainerRunning, + assert.Equal(t, apicontainerstatus.ContainerRunning, event.(api.ContainerStateChange).Status, "Expected container to be RUNNING") } @@ -173,7 +182,7 @@ func VerifyTaskRunningStateChange(t *testing.T, taskEngine TaskEngine) { func verifyContainerRunningStateChangeWithRuntimeID(t *testing.T, taskEngine TaskEngine) { stateChangeEvents := taskEngine.StateChangeEvents() event := <-stateChangeEvents - assert.Equal(t, event.(api.ContainerStateChange).Status, apicontainerstatus.ContainerRunning, + assert.Equal(t, apicontainerstatus.ContainerRunning, event.(api.ContainerStateChange).Status, "Expected container to be RUNNING") assert.NotEqual(t, "", event.(api.ContainerStateChange).RuntimeID, "Expected container runtimeID should not empty") @@ -196,8 +205,9 @@ func verifyExecAgentStateChange(t *testing.T, taskEngine TaskEngine, func VerifyContainerStoppedStateChange(t *testing.T, taskEngine TaskEngine) { stateChangeEvents := taskEngine.StateChangeEvents() event := <-stateChangeEvents + sc := event.(api.ContainerStateChange) assert.Equal(t, event.(api.ContainerStateChange).Status, apicontainerstatus.ContainerStopped, - "Expected container to be STOPPED") + "Expected container %s from task %s to be STOPPED", sc.RuntimeID, sc.TaskArn) } func verifyContainerStoppedStateChangeWithReason(t *testing.T, taskEngine TaskEngine, reason string) { @@ -259,6 +269,10 @@ func SetupIntegTestTaskEngine(cfg *config.Config, state dockerstate.TaskEngineSt taskEngine := NewDockerTaskEngine(cfg, dockerClient, credentialsManager, eventstream.NewEventStream("ENGINEINTEGTEST", context.Background()), imageManager, &hostResourceManager, state, metadataManager, nil, execcmd.NewManager(), engineserviceconnect.NewManager(), daemonManagers) + // Set the steady state poll interval to a low value so that tasks transition from their current state to their + // desired state faster. This prevents tests from appearing to hang while waiting for state change events. + taskEngine.taskSteadyStatePollInterval = taskSteadyStatePollInterval + taskEngine.taskSteadyStatePollIntervalJitter = taskSteadyStatePollIntervalJitter taskEngine.MustInit(context.TODO()) return taskEngine, func() { taskEngine.Shutdown() diff --git a/agent/engine/common_unix_integ_testutil.go b/agent/engine/common_unix_integ_testutil.go index 0b83295f7a6..e97e5af8456 100644 --- a/agent/engine/common_unix_integ_testutil.go +++ b/agent/engine/common_unix_integ_testutil.go @@ -28,13 +28,13 @@ const ( dockerEndpoint = "unix:///var/run/docker.sock" ) -func createTestContainer() *apicontainer.Container { +func CreateTestContainer() *apicontainer.Container { return createTestContainerWithImageAndName(testRegistryImage, "netcat") } -// getLongRunningCommand returns the command that keeps the container running for the container +// GetLongRunningCommand returns the command that keeps the container running for the container // that uses the default integ test image (amazon/amazon-ecs-netkitten for unix) -func getLongRunningCommand() []string { +func GetLongRunningCommand() []string { return []string{"-loop=true"} } diff --git a/agent/engine/engine_integ_test.go b/agent/engine/engine_integ_test.go index cd640d4c058..cbc95956090 100644 --- a/agent/engine/engine_integ_test.go +++ b/agent/engine/engine_integ_test.go @@ -139,7 +139,7 @@ func TestDockerStateToContainerState(t *testing.T) { // let the container keep running to prevent the edge case where it's already stopped when we check whether // it's running - container.Command = getLongRunningCommand() + container.Command = GetLongRunningCommand() client, err := sdkClient.NewClientWithOpts(sdkClient.WithHost(endpoint), sdkClient.WithVersion(sdkclientfactory.GetDefaultVersion().String())) require.NoError(t, err, "Creating go docker client failed") diff --git a/agent/engine/engine_sudo_linux_integ_test.go b/agent/engine/engine_sudo_linux_integ_test.go index e9f5d9add47..0128ec29f82 100644 --- a/agent/engine/engine_sudo_linux_integ_test.go +++ b/agent/engine/engine_sudo_linux_integ_test.go @@ -850,7 +850,7 @@ func TestGMSATaskFile(t *testing.T) { defer os.RemoveAll(testCredSpecFilePath) - testContainer := createTestContainer() + testContainer := CreateTestContainer() testContainer.Name = "testGMSATaskFile" hostConfig := "{\"SecurityOpt\": [\"credentialspec:file:///tmp/test-gmsa.json\"]}" @@ -865,7 +865,7 @@ func TestGMSATaskFile(t *testing.T) { } testTask.Containers[0].TransitionDependenciesMap = make(map[apicontainerstatus.ContainerStatus]apicontainer.TransitionDependencySet) testTask.ResourcesMapUnsafe = make(map[string][]taskresource.TaskResource) - testTask.Containers[0].Command = getLongRunningCommand() + testTask.Containers[0].Command = GetLongRunningCommand() go taskEngine.AddTask(testTask) @@ -944,7 +944,7 @@ func TestGMSADomainlessTaskFile(t *testing.T) { defer os.RemoveAll(testCredSpecFilePath) - testContainer := createTestContainer() + testContainer := CreateTestContainer() testContainer.Name = "testGMSADomainlessTaskFile" testContainer.CredentialSpecs = []string{"credentialspecdomainless:file:///tmp/test-gmsa.json"} @@ -958,7 +958,7 @@ func TestGMSADomainlessTaskFile(t *testing.T) { } testTask.Containers[0].TransitionDependenciesMap = make(map[apicontainerstatus.ContainerStatus]apicontainer.TransitionDependencySet) testTask.ResourcesMapUnsafe = make(map[string][]taskresource.TaskResource) - testTask.Containers[0].Command = getLongRunningCommand() + testTask.Containers[0].Command = GetLongRunningCommand() go taskEngine.AddTask(testTask) @@ -995,7 +995,7 @@ func TestGMSATaskFileS3Err(t *testing.T) { stateChangeEvents := taskEngine.StateChangeEvents() - testContainer := createTestContainer() + testContainer := CreateTestContainer() testContainer.Name = "testGMSATaskFile" hostConfig := "{\"SecurityOpt\": [\"credentialspec:arn:aws:::s3:testbucket/test-gmsa.json\"]}" @@ -1010,7 +1010,7 @@ func TestGMSATaskFileS3Err(t *testing.T) { } testTask.Containers[0].TransitionDependenciesMap = make(map[apicontainerstatus.ContainerStatus]apicontainer.TransitionDependencySet) testTask.ResourcesMapUnsafe = make(map[string][]taskresource.TaskResource) - testTask.Containers[0].Command = getLongRunningCommand() + testTask.Containers[0].Command = GetLongRunningCommand() go taskEngine.AddTask(testTask) @@ -1035,7 +1035,7 @@ func TestGMSATaskFileSSMErr(t *testing.T) { stateChangeEvents := taskEngine.StateChangeEvents() - testContainer := createTestContainer() + testContainer := CreateTestContainer() testContainer.Name = "testGMSATaskFile" hostConfig := "{\"SecurityOpt\": [\"credentialspec:aws:arn:ssm:us-west-2:123456789012:document/test-gmsa.json\"]}" @@ -1050,7 +1050,7 @@ func TestGMSATaskFileSSMErr(t *testing.T) { } testTask.Containers[0].TransitionDependenciesMap = make(map[apicontainerstatus.ContainerStatus]apicontainer.TransitionDependencySet) testTask.ResourcesMapUnsafe = make(map[string][]taskresource.TaskResource) - testTask.Containers[0].Command = getLongRunningCommand() + testTask.Containers[0].Command = GetLongRunningCommand() go taskEngine.AddTask(testTask) @@ -1108,7 +1108,7 @@ func TestGMSANotRunningErr(t *testing.T) { err = ioutil.WriteFile(testCredSpecFilePath, testCredSpecData, 0755) require.NoError(t, err) - testContainer := createTestContainer() + testContainer := CreateTestContainer() testContainer.Name = "testGMSATaskFile" hostConfig := "{\"SecurityOpt\": [\"credentialspec:file:///tmp/test-gmsa.json\"]}" @@ -1123,7 +1123,7 @@ func TestGMSANotRunningErr(t *testing.T) { } testTask.Containers[0].TransitionDependenciesMap = make(map[apicontainerstatus.ContainerStatus]apicontainer.TransitionDependencySet) testTask.ResourcesMapUnsafe = make(map[string][]taskresource.TaskResource) - testTask.Containers[0].Command = getLongRunningCommand() + testTask.Containers[0].Command = GetLongRunningCommand() go taskEngine.AddTask(testTask) diff --git a/agent/engine/engine_unix_integ_test.go b/agent/engine/engine_unix_integ_test.go index 960ccf3b621..760e399617b 100644 --- a/agent/engine/engine_unix_integ_test.go +++ b/agent/engine/engine_unix_integ_test.go @@ -94,7 +94,7 @@ func createTestHealthCheckTask(arn string) *apitask.Task { Family: "family", Version: "1", DesiredStatusUnsafe: apitaskstatus.TaskRunning, - Containers: []*apicontainer.Container{createTestContainer()}, + Containers: []*apicontainer.Container{CreateTestContainer()}, } testTask.Containers[0].Image = testBusyboxImage testTask.Containers[0].Name = "test-health-check" @@ -392,7 +392,7 @@ func TestMultiplePortForwards(t *testing.T) { testTask.Containers[0].Command = []string{fmt.Sprintf("-l=%d", port1), "-serve", serverContent + "1"} testTask.Containers[0].Ports = []apicontainer.PortBinding{{ContainerPort: port1, HostPort: port1}} testTask.Containers[0].Essential = false - testTask.Containers = append(testTask.Containers, createTestContainer()) + testTask.Containers = append(testTask.Containers, CreateTestContainer()) testTask.Containers[1].Name = "nc2" testTask.Containers[1].Command = []string{fmt.Sprintf("-l=%d", port1), "-serve", serverContent + "2"} testTask.Containers[1].Ports = []apicontainer.PortBinding{{ContainerPort: port1, HostPort: port2}} @@ -580,7 +580,7 @@ func TestLinking(t *testing.T) { testArn := "TestLinking" testTask := CreateTestTask(testArn) - testTask.Containers = append(testTask.Containers, createTestContainer()) + testTask.Containers = append(testTask.Containers, CreateTestContainer()) testTask.Containers[0].Command = []string{"-l=80", "-serve", "hello linker"} testTask.Containers[0].Name = "linkee" port := getUnassignedPort() @@ -638,7 +638,7 @@ func TestVolumesFromRO(t *testing.T) { testTask := CreateTestTask("testVolumeROContainer") testTask.Containers[0].Image = testVolumeImage for i := 0; i < 3; i++ { - cont := createTestContainer() + cont := CreateTestContainer() cont.Name = "test" + strconv.Itoa(i) cont.Image = testVolumeImage cont.Essential = i > 0 diff --git a/agent/engine/engine_windows_integ_test.go b/agent/engine/engine_windows_integ_test.go index 1472a6e5582..d69446e2398 100644 --- a/agent/engine/engine_windows_integ_test.go +++ b/agent/engine/engine_windows_integ_test.go @@ -560,6 +560,10 @@ func setupGMSA(cfg *config.Config, state dockerstate.TaskEngineState, t *testing taskEngine := NewDockerTaskEngine(cfg, dockerClient, credentialsManager, eventstream.NewEventStream("ENGINEINTEGTEST", context.Background()), imageManager, &hostResourceManager, state, metadataManager, resourceFields, execcmd.NewManager(), engineserviceconnect.NewManager()) + // Set the steady state poll interval to a low value so that tasks transition from their current state to their + // desired state faster. This prevents tests from appearing to hang while waiting for state change events. + taskEngine.taskSteadyStatePollInterval = taskSteadyStatePollInterval + taskEngine.taskSteadyStatePollIntervalJitter = taskSteadyStatePollIntervalJitter taskEngine.MustInit(context.TODO()) return taskEngine, func() { taskEngine.Shutdown()