From 7a2b7b1c9b7105b5ed890bfb60b6176d66b63193 Mon Sep 17 00:00:00 2001 From: Michael Ye Date: Tue, 17 Oct 2023 23:40:49 +0000 Subject: [PATCH] Fix loading CSI driver container from state if it exists --- agent/api/container/container.go | 11 +++ agent/api/container/container_test.go | 25 +++++- agent/api/container/containertype.go | 1 + agent/api/task/task.go | 20 +++++ agent/api/task/task_test.go | 76 +++++++++++++++++++ agent/engine/data.go | 5 ++ agent/engine/data_test.go | 51 +++++++++++++ .../ecs-agent/manageddaemon/managed_daemon.go | 6 ++ ecs-agent/manageddaemon/managed_daemon.go | 6 ++ 9 files changed, 200 insertions(+), 1 deletion(-) diff --git a/agent/api/container/container.go b/agent/api/container/container.go index 8c7af30d4ec..2b4ecc7e364 100644 --- a/agent/api/container/container.go +++ b/agent/api/container/container.go @@ -28,6 +28,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + md "github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" @@ -1508,3 +1509,13 @@ func (c *Container) GetContainerPortRangeMap() map[string]string { defer c.lock.RUnlock() return c.ContainerPortRangeMap } + +func (c *Container) IsManagedDaemonContainer() (string, bool) { + c.lock.RLock() + defer c.lock.RUnlock() + containerImage := strings.Split(c.Image, ":")[0] + if md.ManagedDaemonImageNames[containerImage] && c.Type == ContainerManagedDaemon { + return containerImage, true + } + return "", false +} diff --git a/agent/api/container/container_test.go b/agent/api/container/container_test.go index db34d75ed41..c8ad001a27a 100644 --- a/agent/api/container/container_test.go +++ b/agent/api/container/container_test.go @@ -130,13 +130,36 @@ func TestIsInternal(t *testing.T) { } for _, tc := range testCases { - t.Run(fmt.Sprintf("IsInternal shoukd return %t for %s", tc.internal, tc.container.String()), + t.Run(fmt.Sprintf("IsInternal should return %t for %s", tc.internal, tc.container.String()), func(t *testing.T) { assert.Equal(t, tc.internal, tc.container.IsInternal()) }) } } +func TestIsManagedDaemonContainer(t *testing.T) { + testCases := []struct { + container *Container + internal bool + isManagedDaemon bool + }{ + {&Container{}, false, false}, + {&Container{Type: ContainerNormal, Image: "someImage"}, false, false}, + {&Container{Type: ContainerManagedDaemon, Image: "ebs-csi-driver:latest"}, true, true}, + {&Container{Type: ContainerNormal, Image: "ebs-csi-driver:latest"}, false, false}, + {&Container{Type: ContainerManagedDaemon, Image: "someImage"}, true, false}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("IsManagedDaemonContainer should return %t for %s", tc.isManagedDaemon, tc.container.String()), + func(t *testing.T) { + assert.Equal(t, tc.internal, tc.container.IsInternal()) + _, ok := tc.container.IsManagedDaemonContainer() + assert.Equal(t, tc.isManagedDaemon, ok) + }) + } +} + // TestSetupExecutionRoleFlag tests whether or not the container appropriately // sets the flag for using execution roles func TestSetupExecutionRoleFlag(t *testing.T) { diff --git a/agent/api/container/containertype.go b/agent/api/container/containertype.go index 740f504229e..165692e35f8 100644 --- a/agent/api/container/containertype.go +++ b/agent/api/container/containertype.go @@ -53,6 +53,7 @@ var stringToContainerType = map[string]ContainerType{ "EMPTY_HOST_VOLUME": ContainerEmptyHostVolume, "CNI_PAUSE": ContainerCNIPause, "NAMESPACE_PAUSE": ContainerNamespacePause, + "MANAGED_DAEMON": ContainerManagedDaemon, } // String converts the container type enum to a string diff --git a/agent/api/task/task.go b/agent/api/task/task.go index 689f1c20282..44c8fa62873 100644 --- a/agent/api/task/task.go +++ b/agent/api/task/task.go @@ -3723,3 +3723,23 @@ func (task *Task) HasActiveContainers() bool { } return false } + +func (task *Task) IsManagedDaemonTask() (string, bool) { + task.lock.RLock() + defer task.lock.RUnlock() + + if !task.IsInternal { + return "", false + } + + for _, c := range task.Containers { + containerStatus := c.GetKnownStatus() + if containerStatus.IsRunning() && c.IsInternal() { + md, ok := c.IsManagedDaemonContainer() + if ok { + return md, ok + } + } + } + return "", false +} diff --git a/agent/api/task/task_test.go b/agent/api/task/task_test.go index 7bddfb78e31..48cd14c1d09 100644 --- a/agent/api/task/task_test.go +++ b/agent/api/task/task_test.go @@ -5234,3 +5234,79 @@ func TestToHostResources(t *testing.T) { assert.Equal(t, len(tc.expectedResources["PORTS_UDP"].StringSetValue), len(calcResources["PORTS_UDP"].StringSetValue), "Error converting task UDP port resources") } } + +func TestIsManagedDaemonTask(t *testing.T) { + + testTask1 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerManagedDaemon, + Image: "ebs-csi-driver:latest", + KnownStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + }, + IsInternal: true, + } + + testTask2 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerNormal, + Image: "someImage", + KnownStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + { + Type: apicontainer.ContainerNormal, + Image: "ebs-csi-driver:latest", + KnownStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + { + Type: apicontainer.ContainerNormal, + Image: "someImage:latest", + KnownStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + }, + IsInternal: false, + } + + testTask3 := &Task{ + Containers: []*apicontainer.Container{ + { + Type: apicontainer.ContainerManagedDaemon, + Image: "someImage:latest", + KnownStatusUnsafe: apicontainerstatus.ContainerRunning, + }, + }, + IsInternal: true, + } + + testCases := []struct { + task *Task + internal bool + isManagedDaemon bool + }{ + { + task: testTask1, + internal: true, + isManagedDaemon: true, + }, + { + task: testTask2, + internal: false, + isManagedDaemon: false, + }, + { + task: testTask3, + internal: true, + isManagedDaemon: false, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("IsManagedDaemonTask should return %t for %s", tc.isManagedDaemon, tc.task.String()), + func(t *testing.T) { + _, ok := tc.task.IsManagedDaemonTask() + assert.Equal(t, tc.isManagedDaemon, ok) + }) + } +} diff --git a/agent/engine/data.go b/agent/engine/data.go index eae131792e6..af288963a49 100644 --- a/agent/engine/data.go +++ b/agent/engine/data.go @@ -52,6 +52,11 @@ func (engine *DockerTaskEngine) loadTasks() error { for _, task := range tasks { engine.state.AddTask(task) + md, ok := task.IsManagedDaemonTask() + if ok { + engine.SetDaemonTask(md, task) + } + // Populate ip <-> task mapping if task has a local ip. This mapping is needed for serving v2 task metadata. if ip := task.GetLocalIPAddress(); ip != "" { engine.state.AddTaskIPAddress(ip, task.Arn) diff --git a/agent/engine/data_test.go b/agent/engine/data_test.go index c58329eeb07..f0e1a54847a 100644 --- a/agent/engine/data_test.go +++ b/agent/engine/data_test.go @@ -51,6 +51,13 @@ var ( TaskARNUnsafe: testTaskARN, KnownStatusUnsafe: apicontainerstatus.ContainerPulled, } + testManagedDaemonContainer = &apicontainer.Container{ + Name: "ecs-managed-" + testContainerName, + Image: "ebs-csi-driver", + TaskARNUnsafe: testTaskARN, + Type: apicontainer.ContainerManagedDaemon, + KnownStatusUnsafe: apicontainerstatus.ContainerRunning, + } testDockerContainer = &apicontainer.DockerContainer{ DockerID: testDockerID, Container: testContainer, @@ -59,6 +66,10 @@ var ( DockerID: testDockerID, Container: testPulledContainer, } + testManagedDaemonDockerContainer = &apicontainer.DockerContainer{ + DockerID: testDockerID, + Container: testManagedDaemonContainer, + } testTask = &apitask.Task{ Arn: testTaskARN, Containers: []*apicontainer.Container{testContainer}, @@ -69,6 +80,12 @@ var ( Containers: []*apicontainer.Container{testContainer, testPulledContainer}, LocalIPAddressUnsafe: testTaskIP, } + testTaskWithManagedDaemonContainer = &apitask.Task{ + Arn: testTaskARN, + Containers: []*apicontainer.Container{testManagedDaemonContainer}, + LocalIPAddressUnsafe: testTaskIP, + IsInternal: true, + } testImageState = &image.ImageState{ Image: testImage, PullSucceeded: false, @@ -135,6 +152,40 @@ func TestLoadState(t *testing.T) { assert.Equal(t, testTaskARN, arn) } +func TestLoadStateWithManagedDaemon(t *testing.T) { + dataClient := newTestDataClient(t) + + engine := &DockerTaskEngine{ + state: dockerstate.NewTaskEngineState(), + dataClient: dataClient, + daemonTasks: make(map[string]*apitask.Task), + } + + require.NoError(t, dataClient.SaveTask(testTaskWithManagedDaemonContainer)) + require.NoError(t, dataClient.SaveDockerContainer(testManagedDaemonDockerContainer)) + require.NoError(t, dataClient.SaveENIAttachment(testENIAttachment)) + require.NoError(t, dataClient.SaveImageState(testImageState)) + + require.NoError(t, engine.LoadState()) + task, ok := engine.state.TaskByArn(testTaskARN) + assert.True(t, ok) + assert.Equal(t, apicontainerstatus.ContainerRunning, task.Containers[0].GetKnownStatus()) + _, ok = engine.state.ContainerByID(testDockerID) + assert.True(t, ok) + assert.Len(t, engine.state.AllImageStates(), 1) + assert.Len(t, engine.state.AllENIAttachments(), 1) + + // Check ip <-> task arn mapping is loaded in state. + ip, ok := engine.state.GetIPAddressByTaskARN(testTaskARN) + require.True(t, ok) + assert.Equal(t, testTaskIP, ip) + arn, ok := engine.state.GetTaskByIPAddress(testTaskIP) + require.True(t, ok) + assert.Equal(t, testTaskARN, arn) + + assert.NotNil(t, engine.GetDaemonTask("ebs-csi-driver")) +} + func TestSaveState(t *testing.T) { dataClient := newTestDataClient(t) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go index 553d4608e0a..7074224c758 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon/managed_daemon.go @@ -33,6 +33,12 @@ const ( defaultApplicationLogMount = "applicationLogMount" ) +var ( + ManagedDaemonImageNames = map[string]bool{ + "ebs-csi-driver": true, + } +) + type ManagedDaemon struct { imageName string imageTag string diff --git a/ecs-agent/manageddaemon/managed_daemon.go b/ecs-agent/manageddaemon/managed_daemon.go index 553d4608e0a..7074224c758 100644 --- a/ecs-agent/manageddaemon/managed_daemon.go +++ b/ecs-agent/manageddaemon/managed_daemon.go @@ -33,6 +33,12 @@ const ( defaultApplicationLogMount = "applicationLogMount" ) +var ( + ManagedDaemonImageNames = map[string]bool{ + "ebs-csi-driver": true, + } +) + type ManagedDaemon struct { imageName string imageTag string