Skip to content

Commit

Permalink
Fix loading CSI driver container from state if it exists
Browse files Browse the repository at this point in the history
  • Loading branch information
mye956 committed Oct 18, 2023
1 parent 047e722 commit 7a2b7b1
Show file tree
Hide file tree
Showing 9 changed files with 200 additions and 1 deletion.
11 changes: 11 additions & 0 deletions agent/api/container/container.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
md "github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon"

"github.com/aws/aws-sdk-go/aws"
"github.com/cihub/seelog"
Expand Down Expand Up @@ -1508,3 +1509,13 @@ func (c *Container) GetContainerPortRangeMap() map[string]string {
defer c.lock.RUnlock()
return c.ContainerPortRangeMap
}

func (c *Container) IsManagedDaemonContainer() (string, bool) {
c.lock.RLock()
defer c.lock.RUnlock()
containerImage := strings.Split(c.Image, ":")[0]
if md.ManagedDaemonImageNames[containerImage] && c.Type == ContainerManagedDaemon {
return containerImage, true
}
return "", false
}
25 changes: 24 additions & 1 deletion agent/api/container/container_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,13 +130,36 @@ func TestIsInternal(t *testing.T) {
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("IsInternal shoukd return %t for %s", tc.internal, tc.container.String()),
t.Run(fmt.Sprintf("IsInternal should return %t for %s", tc.internal, tc.container.String()),
func(t *testing.T) {
assert.Equal(t, tc.internal, tc.container.IsInternal())
})
}
}

func TestIsManagedDaemonContainer(t *testing.T) {
testCases := []struct {
container *Container
internal bool
isManagedDaemon bool
}{
{&Container{}, false, false},
{&Container{Type: ContainerNormal, Image: "someImage"}, false, false},
{&Container{Type: ContainerManagedDaemon, Image: "ebs-csi-driver:latest"}, true, true},
{&Container{Type: ContainerNormal, Image: "ebs-csi-driver:latest"}, false, false},
{&Container{Type: ContainerManagedDaemon, Image: "someImage"}, true, false},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("IsManagedDaemonContainer should return %t for %s", tc.isManagedDaemon, tc.container.String()),
func(t *testing.T) {
assert.Equal(t, tc.internal, tc.container.IsInternal())
_, ok := tc.container.IsManagedDaemonContainer()
assert.Equal(t, tc.isManagedDaemon, ok)
})
}
}

// TestSetupExecutionRoleFlag tests whether or not the container appropriately
// sets the flag for using execution roles
func TestSetupExecutionRoleFlag(t *testing.T) {
Expand Down
1 change: 1 addition & 0 deletions agent/api/container/containertype.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ var stringToContainerType = map[string]ContainerType{
"EMPTY_HOST_VOLUME": ContainerEmptyHostVolume,
"CNI_PAUSE": ContainerCNIPause,
"NAMESPACE_PAUSE": ContainerNamespacePause,
"MANAGED_DAEMON": ContainerManagedDaemon,
}

// String converts the container type enum to a string
Expand Down
20 changes: 20 additions & 0 deletions agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -3723,3 +3723,23 @@ func (task *Task) HasActiveContainers() bool {
}
return false
}

func (task *Task) IsManagedDaemonTask() (string, bool) {
task.lock.RLock()
defer task.lock.RUnlock()

if !task.IsInternal {
return "", false
}

for _, c := range task.Containers {
containerStatus := c.GetKnownStatus()
if containerStatus.IsRunning() && c.IsInternal() {
md, ok := c.IsManagedDaemonContainer()
if ok {
return md, ok
}
}
}
return "", false
}
76 changes: 76 additions & 0 deletions agent/api/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5234,3 +5234,79 @@ func TestToHostResources(t *testing.T) {
assert.Equal(t, len(tc.expectedResources["PORTS_UDP"].StringSetValue), len(calcResources["PORTS_UDP"].StringSetValue), "Error converting task UDP port resources")
}
}

func TestIsManagedDaemonTask(t *testing.T) {

testTask1 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerManagedDaemon,
Image: "ebs-csi-driver:latest",
KnownStatusUnsafe: apicontainerstatus.ContainerRunning,
},
},
IsInternal: true,
}

testTask2 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerNormal,
Image: "someImage",
KnownStatusUnsafe: apicontainerstatus.ContainerRunning,
},
{
Type: apicontainer.ContainerNormal,
Image: "ebs-csi-driver:latest",
KnownStatusUnsafe: apicontainerstatus.ContainerRunning,
},
{
Type: apicontainer.ContainerNormal,
Image: "someImage:latest",
KnownStatusUnsafe: apicontainerstatus.ContainerRunning,
},
},
IsInternal: false,
}

testTask3 := &Task{
Containers: []*apicontainer.Container{
{
Type: apicontainer.ContainerManagedDaemon,
Image: "someImage:latest",
KnownStatusUnsafe: apicontainerstatus.ContainerRunning,
},
},
IsInternal: true,
}

testCases := []struct {
task *Task
internal bool
isManagedDaemon bool
}{
{
task: testTask1,
internal: true,
isManagedDaemon: true,
},
{
task: testTask2,
internal: false,
isManagedDaemon: false,
},
{
task: testTask3,
internal: true,
isManagedDaemon: false,
},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("IsManagedDaemonTask should return %t for %s", tc.isManagedDaemon, tc.task.String()),
func(t *testing.T) {
_, ok := tc.task.IsManagedDaemonTask()
assert.Equal(t, tc.isManagedDaemon, ok)
})
}
}
5 changes: 5 additions & 0 deletions agent/engine/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ func (engine *DockerTaskEngine) loadTasks() error {
for _, task := range tasks {
engine.state.AddTask(task)

md, ok := task.IsManagedDaemonTask()
if ok {
engine.SetDaemonTask(md, task)
}

// Populate ip <-> task mapping if task has a local ip. This mapping is needed for serving v2 task metadata.
if ip := task.GetLocalIPAddress(); ip != "" {
engine.state.AddTaskIPAddress(ip, task.Arn)
Expand Down
51 changes: 51 additions & 0 deletions agent/engine/data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,13 @@ var (
TaskARNUnsafe: testTaskARN,
KnownStatusUnsafe: apicontainerstatus.ContainerPulled,
}
testManagedDaemonContainer = &apicontainer.Container{
Name: "ecs-managed-" + testContainerName,
Image: "ebs-csi-driver",
TaskARNUnsafe: testTaskARN,
Type: apicontainer.ContainerManagedDaemon,
KnownStatusUnsafe: apicontainerstatus.ContainerRunning,
}
testDockerContainer = &apicontainer.DockerContainer{
DockerID: testDockerID,
Container: testContainer,
Expand All @@ -59,6 +66,10 @@ var (
DockerID: testDockerID,
Container: testPulledContainer,
}
testManagedDaemonDockerContainer = &apicontainer.DockerContainer{
DockerID: testDockerID,
Container: testManagedDaemonContainer,
}
testTask = &apitask.Task{
Arn: testTaskARN,
Containers: []*apicontainer.Container{testContainer},
Expand All @@ -69,6 +80,12 @@ var (
Containers: []*apicontainer.Container{testContainer, testPulledContainer},
LocalIPAddressUnsafe: testTaskIP,
}
testTaskWithManagedDaemonContainer = &apitask.Task{
Arn: testTaskARN,
Containers: []*apicontainer.Container{testManagedDaemonContainer},
LocalIPAddressUnsafe: testTaskIP,
IsInternal: true,
}
testImageState = &image.ImageState{
Image: testImage,
PullSucceeded: false,
Expand Down Expand Up @@ -135,6 +152,40 @@ func TestLoadState(t *testing.T) {
assert.Equal(t, testTaskARN, arn)
}

func TestLoadStateWithManagedDaemon(t *testing.T) {
dataClient := newTestDataClient(t)

engine := &DockerTaskEngine{
state: dockerstate.NewTaskEngineState(),
dataClient: dataClient,
daemonTasks: make(map[string]*apitask.Task),
}

require.NoError(t, dataClient.SaveTask(testTaskWithManagedDaemonContainer))
require.NoError(t, dataClient.SaveDockerContainer(testManagedDaemonDockerContainer))
require.NoError(t, dataClient.SaveENIAttachment(testENIAttachment))
require.NoError(t, dataClient.SaveImageState(testImageState))

require.NoError(t, engine.LoadState())
task, ok := engine.state.TaskByArn(testTaskARN)
assert.True(t, ok)
assert.Equal(t, apicontainerstatus.ContainerRunning, task.Containers[0].GetKnownStatus())
_, ok = engine.state.ContainerByID(testDockerID)
assert.True(t, ok)
assert.Len(t, engine.state.AllImageStates(), 1)
assert.Len(t, engine.state.AllENIAttachments(), 1)

// Check ip <-> task arn mapping is loaded in state.
ip, ok := engine.state.GetIPAddressByTaskARN(testTaskARN)
require.True(t, ok)
assert.Equal(t, testTaskIP, ip)
arn, ok := engine.state.GetTaskByIPAddress(testTaskIP)
require.True(t, ok)
assert.Equal(t, testTaskARN, arn)

assert.NotNil(t, engine.GetDaemonTask("ebs-csi-driver"))
}

func TestSaveState(t *testing.T) {
dataClient := newTestDataClient(t)

Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions ecs-agent/manageddaemon/managed_daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ const (
defaultApplicationLogMount = "applicationLogMount"
)

var (
ManagedDaemonImageNames = map[string]bool{
"ebs-csi-driver": true,
}
)

type ManagedDaemon struct {
imageName string
imageTag string
Expand Down

0 comments on commit 7a2b7b1

Please sign in to comment.