From 6d611f82fc4e514baaadb1bebdcf21ad3033d569 Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Wed, 24 May 2023 13:09:35 -0700 Subject: [PATCH] Add more tests for v2, v3, and v4 container metadata handlers (#3708) --- agent/handlers/task_server_setup_test.go | 555 +++++++++++++++-------- 1 file changed, 366 insertions(+), 189 deletions(-) diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 4471d3bdc0c..5b7d714cc5d 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -330,6 +330,11 @@ var ( DockerName: containerName, Container: container1, } + bridgeContainerNoNetwork = &apicontainer.DockerContainer{ + DockerID: containerID, + DockerName: containerName, + Container: container, + } containerNameToBridgeContainer = map[string]*apicontainer.DockerContainer{ taskARN: bridgeContainer, } @@ -401,12 +406,6 @@ var ( HostPort: containerPort, }, }, - Networks: []tmdsresponse.Network{ - { - NetworkMode: utils.NetworkModeAWSVPC, - IPv4Addresses: []string{eniIPv4Address}, - }, - }, }, Networks: []v4.Network{{ Network: tmdsresponse.Network{ @@ -481,22 +480,19 @@ var ( Containers: []v4.ContainerResponse{expectedV4ContainerResponse, expectedV4PulledContainerResponse}, VPCID: vpcID, } - expectedV4BridgeContainerResponse = v4.ContainerResponse{ - ContainerResponse: &expectedBridgeContainerResponse, - Networks: []v4.Network{{ - Network: tmdsresponse.Network{ - NetworkMode: bridgeMode, - IPv4Addresses: []string{bridgeIPAddr}, - }, - NetworkInterfaceProperties: v4.NetworkInterfaceProperties{ - AttachmentIndex: nil, - IPV4SubnetCIDRBlock: "", - MACAddress: "", - PrivateDNSName: "", - SubnetGatewayIPV4Address: "", - }}, + expectedV4BridgeContainerResponse = v4ContainerResponseFromV2(expectedBridgeContainerResponse, []v4.Network{{ + Network: tmdsresponse.Network{ + NetworkMode: bridgeMode, + IPv4Addresses: []string{bridgeIPAddr}, }, - } + NetworkInterfaceProperties: v4.NetworkInterfaceProperties{ + AttachmentIndex: nil, + IPV4SubnetCIDRBlock: "", + MACAddress: "", + PrivateDNSName: "", + SubnetGatewayIPV4Address: "", + }}, + }) expectedV4BridgeTaskResponse = v4.TaskResponse{ TaskResponse: &v2.TaskResponse{ Cluster: clusterName, @@ -521,6 +517,16 @@ var ( } ) +// Creates a v4 ContainerResponse given a v2 ContainerResponse and v4 networks +func v4ContainerResponseFromV2( + v2ContainerResponse v2.ContainerResponse, networks []v4.Network) v4.ContainerResponse { + v2ContainerResponse.Networks = nil + return v4.ContainerResponse{ + ContainerResponse: &v2ContainerResponse, + Networks: networks, + } +} + func init() { container.SetLabels(labels) container1.SetLabels(labels) @@ -878,38 +884,6 @@ func TestV2TaskWithTagsMetadata(t *testing.T) { } } -func TestV2ContainerMetadata(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - gomock.InOrder( - state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true), - state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), - state.EXPECT().TaskByID(containerID).Return(task, true), - ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) - require.NoError(t, err) - - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v2BaseMetadataPath+"/"+containerID, nil) - req.RemoteAddr = remoteIP + ":" + remotePort - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - var containerResponse v2.ContainerResponse - err = json.Unmarshal(res, &containerResponse) - assert.NoError(t, err) - assert.Equal(t, expectedContainerResponse, containerResponse) -} - func TestV2ContainerStats(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1060,37 +1034,6 @@ func TestV3BridgeTaskMetadata(t *testing.T) { assert.Equal(t, expectedBridgeTaskResponse, taskResponse) } -func TestV3BridgeContainerMetadata(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - gomock.InOrder( - state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), - state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), - ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) - require.NoError(t, err) - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID, nil) - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - var containerResponse v2.ContainerResponse - err = json.Unmarshal(res, &containerResponse) - assert.NoError(t, err) - assert.Equal(t, expectedBridgeContainerResponse, containerResponse) -} - // Test API calls for propagating Tags to Task Metadata func TestV3TaskMetadataWithTags(t *testing.T) { ctrl := gomock.NewController(t) @@ -1164,36 +1107,6 @@ func TestV3TaskMetadataWithTags(t *testing.T) { assert.Equal(t, expectedTaskResponseWithTags, taskResponse) } -func TestV3ContainerMetadata(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - gomock.InOrder( - state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), - state.EXPECT().TaskByID(containerID).Return(task, true), - ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) - require.NoError(t, err) - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID, nil) - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - var containerResponse v2.ContainerResponse - err = json.Unmarshal(res, &containerResponse) - assert.NoError(t, err) - assert.Equal(t, expectedContainerResponse, containerResponse) -} - func TestV3TaskStats(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -1358,7 +1271,6 @@ func TestV4TaskMetadata(t *testing.T) { assert.NoError(t, err) expectedV4TaskResponse.TaskResponse.Containers = nil - expectedV4ContainerResponse.ContainerResponse.Networks = nil assert.Equal(t, expectedV4TaskResponse, taskResponse) } @@ -1392,47 +1304,9 @@ func TestV4TaskMetadataWithPulledContainers(t *testing.T) { err = json.Unmarshal(res, &taskResponse) assert.NoError(t, err) expectedV4PulledTaskResponse.TaskResponse.Containers = nil - expectedV4ContainerResponse.ContainerResponse.Networks = nil assert.Equal(t, expectedV4PulledTaskResponse, taskResponse) } -func TestV4ContainerMetadata(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - gomock.InOrder( - state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), - state.EXPECT().TaskByID(containerID).Return(task, true).Times(2), - ) - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, availabilityzone, vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) - require.NoError(t, err) - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID, nil) - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - - var containerResponse v4.ContainerResponse - err = json.Unmarshal(res, &containerResponse) - assert.NoError(t, err) - - // v4.ContainerMetadata overrides Networks properties defined in v2.ContainerResponse - // during json.Unmarshal(), values for the Networks property will be written to v4.ContainerMetadata.Networks - // instead of v4.ContainerMetadata.(v2.ContainerMetadata).Networks - // v2.ContainerMetadata.Networks should be nil - expectedV4ContainerResponse.ContainerResponse.Networks = nil - assert.Equal(t, expectedV4ContainerResponse, containerResponse) -} - // Test API calls for propagating Tags to v4 Task Metadata func TestV4TaskMetadataWithTags(t *testing.T) { ctrl := gomock.NewController(t) @@ -1506,7 +1380,6 @@ func TestV4TaskMetadataWithTags(t *testing.T) { assert.NoError(t, err) expectedv4TaskResponseWithTags.TaskResponse.Containers = nil - expectedV4ContainerResponse.ContainerResponse.Networks = nil assert.Equal(t, expectedv4TaskResponseWithTags, taskResponse) } @@ -1543,7 +1416,6 @@ func TestV4BridgeTaskMetadata(t *testing.T) { assert.NoError(t, err) expectedV4BridgeTaskResponse.TaskResponse.Containers = nil - expectedV4BridgeContainerResponse.ContainerResponse.Networks = nil assert.Equal(t, expectedV4BridgeTaskResponse, taskResponse) } @@ -1580,40 +1452,6 @@ func TestV4BridgeTaskMetadataAllowMissingContainerNetwork(t *testing.T) { assert.NoError(t, err) } -func TestV4BridgeContainerMetadata(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - state := mock_dockerstate.NewMockTaskEngineState(ctrl) - auditLog := mock_audit.NewMockAuditLogger(ctrl) - statsEngine := mock_stats.NewMockEngine(ctrl) - ecsClient := mock_api.NewMockECSClient(ctrl) - - gomock.InOrder( - state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), - state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), - state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), - ) - - server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, region, statsEngine, - config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, endpoint, acceptInsecureCert) - require.NoError(t, err) - recorder := httptest.NewRecorder() - req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID, nil) - server.Handler.ServeHTTP(recorder, req) - res, err := ioutil.ReadAll(recorder.Body) - assert.NoError(t, err) - assert.Equal(t, http.StatusOK, recorder.Code) - var containerResponse v4.ContainerResponse - err = json.Unmarshal(res, &containerResponse) - assert.NoError(t, err) - - expectedV4BridgeContainerResponse.ContainerResponse.Networks = nil - assert.Equal(t, expectedV4BridgeContainerResponse, containerResponse) -} - func TestV4TaskStats(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() @@ -2037,6 +1875,345 @@ func TestV4Unexpected500Error(t *testing.T) { } } +// Types of TMDS responses, add more types as needed +type TMDSResponse interface { + v2.ContainerResponse | v4.ContainerResponse | string +} + +// Represents a test case for TMDS. Supports generic TMDS response body types using type parametesrs. +type TMDSTestCase[R TMDSResponse] struct { + // Request path + path string + // Function to set expectations on mock task engine state + setStateExpectations func(state *mock_dockerstate.MockTaskEngineState) + // Expected HTTP status code of the response + expectedStatusCode int + // Expected response body, all JSON compatible types are accepted + expectedResponseBody R +} + +// Tests a TMDS request as per the provided test case. +// This function can be used to test all metadata and stats endpoints. +// It - +// 1. Initializes a TMDS server +// 2. Creates a request as per the test case and sends it to the server +// 3. Unmarshals the JSON response body +// 4. Asserts that the response status code is as expected +// 5. Asserts that the unmarshaled resopnse body is as expected +func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { + // Define mocks + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + state := mock_dockerstate.NewMockTaskEngineState(ctrl) + auditLog := mock_audit.NewMockAuditLogger(ctrl) + statsEngine := mock_stats.NewMockEngine(ctrl) + ecsClient := mock_api.NewMockECSClient(ctrl) + + // Set expectations on mocks + auditLog.EXPECT().Log(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() + tc.setStateExpectations(state) + + // Initialize server + server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, + clusterName, region, statsEngine, + config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, + containerInstanceArn, endpoint, acceptInsecureCert) + require.NoError(t, err) + + // Create the request + req, err := http.NewRequest("GET", tc.path, nil) + require.NoError(t, err) + req.RemoteAddr = remoteIP + ":" + remotePort + + // Send the request and record the response + recorder := httptest.NewRecorder() + server.Handler.ServeHTTP(recorder, req) + + // Parse the response body + var actualResponseBody R + err = json.Unmarshal(recorder.Body.Bytes(), &actualResponseBody) + require.NoError(t, err) + + // Assert status code and body + assert.Equal(t, tc.expectedStatusCode, recorder.Code) + assert.Equal(t, tc.expectedResponseBody, actualResponseBody) +} + +// Tests for v2 container metadata endpoint +func TestV2ContainerMetadata(t *testing.T) { + t.Run("task not found by IP", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v2BaseMetadataPath + "/" + containerID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().GetTaskByIPAddress(remoteIP).Return("", false), + ) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf( + "Unable to get task arn from request: unable to associate '%s' with task", + remoteIP), + }) + }) + t.Run("invalid container ID", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v2BaseMetadataPath + "/" + containerID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true), + state.EXPECT().ContainerByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf( + "Unable to generate metadata for container '%s'", containerID), + }) + }) + t.Run("task not found but container ID is valid", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v2BaseMetadataPath + "/" + containerID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), + state.EXPECT().TaskByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: fmt.Sprintf( + "Unable to generate metadata for container '%s'", containerID), + }) + }) + t.Run("happy case", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[v2.ContainerResponse]{ + path: v2BaseMetadataPath + "/" + containerID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().GetTaskByIPAddress(remoteIP).Return(taskARN, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), + state.EXPECT().TaskByID(containerID).Return(task, true), + ) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: expectedContainerResponse, + }) + }) +} + +func TestV3ContainerMetadata(t *testing.T) { + t.Run("v3EndpointID invalid", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v3BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return("", false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf( + "V3 container metadata handler: unable to get container ID from request: unable to get docker ID from v3 endpoint ID: %s", + v3EndpointID), + }) + }) + t.Run("container not found but ID is valid", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v3BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf( + "Unable to generate metadata for container '%s'", containerID), + }) + }) + t.Run("happy case", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[v2.ContainerResponse]{ + path: v3BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), + state.EXPECT().TaskByID(containerID).Return(task, true), + ) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: expectedContainerResponse, + }) + }) + t.Run("bridge mode container not found when looking up network settings", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v3BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), + state.EXPECT().ContainerByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf("Unable to find container '%s'", containerID), + }) + }) + t.Run("bridge mode container no network settings", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v3BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf( + "Unable to generate network response for container '%s'", containerID), + }) + }) + t.Run("happy case bridge mode", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[v2.ContainerResponse]{ + path: v3BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + ) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: expectedBridgeContainerResponse, + }) + }) +} + +func TestV4ContainerMetadata(t *testing.T) { + t.Run("v3EndpointID is invalid", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return("", false), + ) + }, + expectedStatusCode: http.StatusNotFound, + expectedResponseBody: fmt.Sprintf( + "V4 container metadata handler: unable to get container ID from request: unable to get docker ID from v3 endpoint ID: %s", + v3EndpointID), + }) + }) + t.Run("container not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf( + "unable to generate metadata for container '%s'", containerID), + }) + }) + t.Run("task not found", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), + state.EXPECT().TaskByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf( + "unable to generate metadata for container '%s'", containerID), + }) + }) + t.Run("awsvpc task not found on second lookup", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), + state.EXPECT().TaskByID(containerID).Return(task, true), + state.EXPECT().TaskByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf( + "unable to generate metadata for container '%s'", containerID), + }) + }) + t.Run("happy case awsvpc task", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[v4.ContainerResponse]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(dockerContainer, true), + state.EXPECT().TaskByID(containerID).Return(task, true).Times(2), + ) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: expectedV4ContainerResponse, + }) + }) + t.Run("bridge mode container not found during network population", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), + state.EXPECT().ContainerByID(containerID).Return(nil, false), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf("unable to find container '%s'", containerID), + }) + }) + t.Run("bridge mode no network settings", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[string]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainerNoNetwork, true), + ) + }, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: fmt.Sprintf( + "unable to generate network response for container '%s'", containerID), + }) + }) + t.Run("happy case bridge mode", func(t *testing.T) { + testTMDSRequest(t, TMDSTestCase[v4.ContainerResponse]{ + path: v4BasePath + v3EndpointID, + setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { + gomock.InOrder( + state.EXPECT().DockerIDByV3EndpointID(v3EndpointID).Return(containerID, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + state.EXPECT().TaskByID(containerID).Return(bridgeTask, true), + state.EXPECT().ContainerByID(containerID).Return(bridgeContainer, true), + ) + }, + expectedStatusCode: http.StatusOK, + expectedResponseBody: expectedV4BridgeContainerResponse, + }) + }) +} + // Helper function for testing Agent API Task Protection v1 handlers func testAgentAPITaskProtectionV1Handler(t *testing.T, requestBody interface{}, method string) { // Prepare dependency mocks