diff --git a/agent/handlers/agentapi/taskprotection/factory.go b/agent/handlers/agentapi/taskprotection/factory.go new file mode 100644 index 00000000000..89b2bdfb2fc --- /dev/null +++ b/agent/handlers/agentapi/taskprotection/factory.go @@ -0,0 +1,50 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package taskprotection + +import ( + "github.com/aws/amazon-ecs-agent/agent/api/ecsclient" + "github.com/aws/amazon-ecs-agent/agent/httpclient" + + "github.com/aws/amazon-ecs-agent/ecs-agent/api" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" + + "github.com/aws/aws-sdk-go/aws" + awscreds "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go/aws/session" +) + +// TaskProtectionClientFactory implements TaskProtectionClientFactoryInterface +type TaskProtectionClientFactory struct { + Region string + Endpoint string + AcceptInsecureCert bool +} + +// Helper function for retrieving credential from credentials manager and create ecs client +func (factory TaskProtectionClientFactory) NewTaskProtectionClient( + taskRoleCredential credentials.TaskIAMRoleCredentials, +) api.ECSTaskProtectionSDK { + taskCredential := taskRoleCredential.GetIAMRoleCredentials() + cfg := aws.NewConfig(). + WithCredentials(awscreds.NewStaticCredentials(taskCredential.AccessKeyID, + taskCredential.SecretAccessKey, + taskCredential.SessionToken)). + WithRegion(factory.Region). + WithHTTPClient(httpclient.New(ecsclient.RoundtripTimeout, factory.AcceptInsecureCert)). + WithEndpoint(factory.Endpoint) + + ecsClient := ecs.New(session.Must(session.NewSession()), cfg) + return ecsClient +} diff --git a/agent/handlers/agentapi/taskprotection/factory_test.go b/agent/handlers/agentapi/taskprotection/factory_test.go new file mode 100644 index 00000000000..379d385b1b0 --- /dev/null +++ b/agent/handlers/agentapi/taskprotection/factory_test.go @@ -0,0 +1,56 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +package taskprotection + +import ( + "testing" + + "github.com/aws/amazon-ecs-agent/ecs-agent/api" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" +) + +const ( + testAccessKey = "accessKey" + testSecretKey = "secretKey" + testSessionToken = "sessionToken" + testRegion = "region" + testECSEndpoint = "endpoint" + testAcceptInsecureCert = false +) + +// TestGetECSClientHappyCase tests newTaskProtectionClient uses credential in credentials manager and +// returns an ECS client with correct status code and error +func TestGetECSClientHappyCase(t *testing.T) { + testIAMRoleCredentials := credentials.TaskIAMRoleCredentials{ + IAMRoleCredentials: credentials.IAMRoleCredentials{ + AccessKeyID: testAccessKey, + SecretAccessKey: testSecretKey, + SessionToken: testSessionToken, + }, + } + + factory := TaskProtectionClientFactory{ + Region: testRegion, Endpoint: testECSEndpoint, AcceptInsecureCert: testAcceptInsecureCert, + } + + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + ret := factory.NewTaskProtectionClient(testIAMRoleCredentials) + _, ok := ret.(api.ECSTaskProtectionSDK) + + // Assert response + assert.True(t, ok) +} diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go index bd8e7093321..80791949cef 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go +++ b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go @@ -21,23 +21,19 @@ import ( "net/http" "time" - "github.com/aws/amazon-ecs-agent/agent/api/ecsclient" apitask "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" - "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/types" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" - "github.com/aws/amazon-ecs-agent/agent/httpclient" - "github.com/aws/amazon-ecs-agent/ecs-agent/api" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" "github.com/aws/amazon-ecs-agent/ecs-agent/logger" loggerfield "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" - awscreds "github.com/aws/aws-sdk-go/aws/credentials" "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" ) const ( @@ -63,17 +59,10 @@ type TaskProtectionRequest struct { ExpiresInMinutes *int64 } -// TaskProtectionClientFactory implements TaskProtectionClientFactoryInterface -type TaskProtectionClientFactory struct { - Region string - Endpoint string - AcceptInsecureCert bool -} - // UpdateTaskProtectionHandler returns an HTTP request handler function for // UpdateTaskProtection API func UpdateTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsManager credentials.Manager, - factory TaskProtectionClientFactoryInterface, cluster string) func(http.ResponseWriter, *http.Request) { + factory tpinterface.TaskProtectionClientFactoryInterface, cluster string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { updateTaskProtectionRequestType := "api/UpdateTaskProtection/v1" @@ -193,7 +182,7 @@ func UpdateTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsM // GetTaskProtectionHandler returns a handler function for GetTaskProtection API func GetTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsManager credentials.Manager, - factory TaskProtectionClientFactoryInterface, cluster string) func(http.ResponseWriter, *http.Request) { + factory tpinterface.TaskProtectionClientFactoryInterface, cluster string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { getTaskProtectionRequestType := "api/GetTaskProtection/v1" @@ -286,21 +275,6 @@ func GetTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsMana } } -// Helper function for retrieving credential from credentials manager and create ecs client -func (factory TaskProtectionClientFactory) NewTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK { - taskCredential := taskRoleCredential.GetIAMRoleCredentials() - cfg := aws.NewConfig(). - WithCredentials(awscreds.NewStaticCredentials(taskCredential.AccessKeyID, - taskCredential.SecretAccessKey, - taskCredential.SessionToken)). - WithRegion(factory.Region). - WithHTTPClient(httpclient.New(ecsclient.RoundtripTimeout, factory.AcceptInsecureCert)). - WithEndpoint(factory.Endpoint) - - ecsClient := ecs.New(session.Must(session.NewSession()), cfg) - return ecsClient -} - // Helper function to parse error to get ErrorCode, ExceptionMessage, HttpStatusCode, RequestID. // RequestID will be empty if the request is not able to reach AWS func getErrorCodeAndStatusCode(err error) (string, string, int, *string) { diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go index 3fac184a45a..1f17482fbe6 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go +++ b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go @@ -26,14 +26,15 @@ import ( "github.com/aws/amazon-ecs-agent/agent/api/task" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" - "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/types" + tpfactory "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" "github.com/aws/amazon-ecs-agent/agent/utils" - "github.com/aws/amazon-ecs-agent/ecs-agent/api" mock_api "github.com/aws/amazon-ecs-agent/ecs-agent/api/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" + tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/request" @@ -67,32 +68,6 @@ func TestTaskProtectionPath(t *testing.T) { assert.Equal(t, "/api/{v3EndpointIDMuxName:[^/]*}/task-protection/v1/state", TaskProtectionPath()) } -// TestGetECSClientHappyCase tests newTaskProtectionClient uses credential in credentials manager and -// returns an ECS client with correct status code and error -func TestGetECSClientHappyCase(t *testing.T) { - - testIAMRoleCredentials := credentials.TaskIAMRoleCredentials{ - IAMRoleCredentials: credentials.IAMRoleCredentials{ - AccessKeyID: testAccessKey, - SecretAccessKey: testSecretKey, - SessionToken: testSessionToken, - }, - } - - factory := TaskProtectionClientFactory{ - Region: testRegion, Endpoint: testECSEndpoint, AcceptInsecureCert: testAcceptInsecureCert, - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - ret := factory.NewTaskProtectionClient(testIAMRoleCredentials) - _, ok := ret.(api.ECSTaskProtectionSDK) - - // Assert response - assert.True(t, ok) -} - func getRequestWithUnknownFields(t *testing.T) map[string]interface{} { request := TaskProtectionRequest{ProtectionEnabled: utils.BoolPtr(false)} requestJSON, err := json.Marshal(request) @@ -107,7 +82,8 @@ func getRequestWithUnknownFields(t *testing.T) map[string]interface{} { // Helper function for running tests for UpdateTaskProtection handler func testUpdateTaskProtectionHandler(t *testing.T, state dockerstate.TaskEngineState, - v3EndpointID string, credentialsManager credentials.Manager, factory TaskProtectionClientFactoryInterface, + v3EndpointID string, credentialsManager credentials.Manager, + factory tpinterface.TaskProtectionClientFactoryInterface, request interface{}, expectedResponse interface{}, expectedResponseCode int) { // Prepare request requestBytes, err := json.Marshal(request) @@ -249,7 +225,7 @@ func TestUpdateTaskProtectionHandlerTaskRoleCredentialsNotFound(t *testing.T) { } testTask.SetCredentialsID(testTaskCredentialsId) - factory := TaskProtectionClientFactory{ + factory := tpfactory.TaskProtectionClientFactory{ Region: testRegion, Endpoint: testECSEndpoint, AcceptInsecureCert: testAcceptInsecureCert, } @@ -399,7 +375,7 @@ func TestUpdateTaskProtectionHandler_PostCall(t *testing.T) { mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) mockManager := mock_credentials.NewMockManager(ctrl) - mockFactory := NewMockTaskProtectionClientFactoryInterface(ctrl) + mockFactory := tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl) mockECSClient := mock_api.NewMockECSTaskProtectionSDK(ctrl) mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) @@ -423,7 +399,9 @@ func TestUpdateTaskProtectionHandler_PostCall(t *testing.T) { } func testGetTaskProtectionHandler(t *testing.T, state dockerstate.TaskEngineState, - v3EndpointID string, credentialsManager credentials.Manager, factory TaskProtectionClientFactoryInterface, expectedResponse interface{}, expectedResponseCode int) { + v3EndpointID string, credentialsManager credentials.Manager, + factory tpinterface.TaskProtectionClientFactoryInterface, + expectedResponse interface{}, expectedResponseCode int) { // Prepare request bodyReader := bytes.NewReader([]byte{}) req, err := http.NewRequest("GET", "", bodyReader) @@ -492,7 +470,7 @@ func TestGetTaskProtectionHandlerTaskRoleCredentialsNotFound(t *testing.T) { } testTask.SetCredentialsID(testTaskCredentialsId) - factory := TaskProtectionClientFactory{ + factory := tpfactory.TaskProtectionClientFactory{ Region: testRegion, Endpoint: testECSEndpoint, AcceptInsecureCert: testAcceptInsecureCert, } @@ -638,7 +616,7 @@ func TestGetTaskProtectionHandler_PostCall(t *testing.T) { mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) mockManager := mock_credentials.NewMockManager(ctrl) - mockFactory := NewMockTaskProtectionClientFactoryInterface(ctrl) + mockFactory := tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl) mockECSClient := mock_api.NewMockECSTaskProtectionSDK(ctrl) mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index 8cb90400f6c..b17173f4b1b 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -21,7 +21,8 @@ import ( "github.com/aws/amazon-ecs-agent/agent/api" "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" - agentAPITaskProtectionV1 "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" + tpfactory "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection" + tphandlers "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" v4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4" @@ -31,6 +32,8 @@ import ( auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" + tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" + tmdsv1 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1" tmdsv2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" @@ -60,7 +63,7 @@ func taskServerSetup(credentialsManager credentials.Manager, availabilityZone string, vpcID string, containerInstanceArn string, - taskProtectionClientFactory agentAPITaskProtectionV1.TaskProtectionClientFactoryInterface, + taskProtectionClientFactory tpinterface.TaskProtectionClientFactoryInterface, ) (*http.Server, error) { muxRouter := mux.NewRouter() @@ -156,17 +159,17 @@ func agentAPIV1HandlersSetup( state dockerstate.TaskEngineState, credentialsManager credentials.Manager, cluster string, - factory agentAPITaskProtectionV1.TaskProtectionClientFactoryInterface, + factory tpinterface.TaskProtectionClientFactoryInterface, ) { muxRouter. HandleFunc( - agentAPITaskProtectionV1.TaskProtectionPath(), - agentAPITaskProtectionV1.UpdateTaskProtectionHandler(state, credentialsManager, factory, cluster)). + tphandlers.TaskProtectionPath(), + tphandlers.UpdateTaskProtectionHandler(state, credentialsManager, factory, cluster)). Methods("PUT") muxRouter. HandleFunc( - agentAPITaskProtectionV1.TaskProtectionPath(), - agentAPITaskProtectionV1.GetTaskProtectionHandler(state, credentialsManager, factory, cluster)). + tphandlers.TaskProtectionPath(), + tphandlers.GetTaskProtectionHandler(state, credentialsManager, factory, cluster)). Methods("GET") } @@ -192,7 +195,7 @@ func ServeTaskHTTPEndpoint( auditLogger := audit.NewAuditLog(containerInstanceArn, cfg, logger) - taskProtectionClientFactory := agentAPITaskProtectionV1.TaskProtectionClientFactory{ + taskProtectionClientFactory := tpfactory.TaskProtectionClientFactory{ Region: cfg.AWSRegion, Endpoint: cfg.APIEndpoint, AcceptInsecureCert: cfg.AcceptInsecureCert, } server, err := taskServerSetup(credentialsManager, auditLogger, state, ecsClient, cfg.Cluster, diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index dd9cf8e88b2..7542c2161bc 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -37,8 +37,6 @@ import ( "github.com/aws/amazon-ecs-agent/agent/config" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" agentapihandlers "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" - task_protection_v1 "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" - agentapi "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/types" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" v4stats "github.com/aws/amazon-ecs-agent/agent/handlers/v4" "github.com/aws/amazon-ecs-agent/agent/stats" @@ -50,6 +48,8 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" mock_audit "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/mocks" tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response" + tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" + tptypes "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" tmdsv1 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1" v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" @@ -786,7 +786,7 @@ func testErrorResponsesFromServer(t *testing.T, path string, expectedErrorMessag ecsClient := mock_api.NewMockECSClient(ctrl) server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() @@ -823,7 +823,7 @@ func getResponseForCredentialsRequest(t *testing.T, expectedStatus int, ecsClient := mock_api.NewMockECSClient(ctrl) server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() @@ -882,7 +882,7 @@ func TestV3ContainerAssociations(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType, nil) @@ -914,7 +914,7 @@ func TestV3ContainerAssociation(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) @@ -945,7 +945,7 @@ func TestV4ContainerAssociations(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType, nil) @@ -977,7 +977,7 @@ func TestV4ContainerAssociation(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) @@ -1004,7 +1004,7 @@ func TestTaskHTTPEndpoint301Redirect(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for testPath, expectedPath := range testPathsMap { @@ -1047,7 +1047,7 @@ func TestTaskHTTPEndpointErrorCode404(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1087,7 +1087,7 @@ func TestTaskHTTPEndpointErrorCode400(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1126,7 +1126,7 @@ func TestTaskHTTPEndpointErrorCode500(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1196,7 +1196,7 @@ func TestV4TaskNotFoundError404(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) state.EXPECT().TaskARNByV3EndpointID(gomock.Any()).Return("", tc.taskFound).AnyTimes() @@ -1252,7 +1252,7 @@ func TestV4Unexpected500Error(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) // Initial lookups succeed @@ -1285,7 +1285,7 @@ type TMDSResponse interface { v2.TaskResponse | v4.ContainerResponse | v4.TaskResponse | - agentapi.TaskProtectionResponse | + tptypes.TaskProtectionResponse | types.StatsJSON | v4stats.StatsResponse | map[string]*types.StatsJSON | @@ -1309,7 +1309,7 @@ type TMDSTestCase[R TMDSResponse] struct { setECSClientExpectations func(ecsClient *mock_api.MockECSClient) // Function to set expectations on mock Task Protection Client Factory setTaskProtectionClientFactoryExpectations func( - ctrl *gomock.Controller, factory *agentapihandlers.MockTaskProtectionClientFactoryInterface) + ctrl *gomock.Controller, factory *tpinterface.MockTaskProtectionClientFactoryInterface) // Function to set expectations on mock Credentials Manager setCredentialsManagerExpectations func(credsManager *mock_credentials.MockManager) // Expected HTTP status code of the response @@ -1336,7 +1336,7 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) credsManager := mock_credentials.NewMockManager(ctrl) - taskProtectionClientFactory := agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl) + taskProtectionClientFactory := tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl) // Set expectations on mocks auditLog.EXPECT().Log(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() @@ -3011,11 +3011,11 @@ func TestGetTaskProtection(t *testing.T) { Return(taskRoleCredentials(), true) } taskProtectionClientFactoryExpectations := func(output *ecs.GetTaskProtectionOutput, err error) func( - *gomock.Controller, *task_protection_v1.MockTaskProtectionClientFactoryInterface, + *gomock.Controller, *tpinterface.MockTaskProtectionClientFactoryInterface, ) { return func( ctrl *gomock.Controller, - factory *task_protection_v1.MockTaskProtectionClientFactoryInterface, + factory *tpinterface.MockTaskProtectionClientFactoryInterface, ) { client := mock_taskprotection.NewMockECSTaskProtectionSDK(ctrl) client.EXPECT().GetTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err) @@ -3025,7 +3025,7 @@ func TestGetTaskProtection(t *testing.T) { // Test cases start here t.Run("task ARN not found", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( @@ -3033,8 +3033,8 @@ func TestGetTaskProtection(t *testing.T) { ) }, expectedStatusCode: http.StatusNotFound, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Code: ecs.ErrCodeResourceNotFoundException, Message: "Failed to find a task for the request", }, @@ -3042,7 +3042,7 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("task not found", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( @@ -3051,8 +3051,8 @@ func TestGetTaskProtection(t *testing.T) { ) }, expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Code: ecs.ErrCodeServerException, Message: "Failed to find a task for the request", }, @@ -3060,7 +3060,7 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("task credentials not found", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: func(credsManager *mock_credentials.MockManager) { @@ -3069,8 +3069,8 @@ func TestGetTaskProtection(t *testing.T) { Return(credentials.TaskIAMRoleCredentials{}, false) }, expectedStatusCode: http.StatusForbidden, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeAccessDeniedException, Message: "Invalid Request: no task IAM role credentials available for task", @@ -3079,7 +3079,7 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("ecs call server exception", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3092,9 +3092,9 @@ func TestGetTaskProtection(t *testing.T) { ), ), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ RequestID: &ecsRequestID, - Error: &agentapi.ErrorResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeServerException, Message: ecsErrMessage, @@ -3103,7 +3103,7 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("ecs call access denied exception", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3116,9 +3116,9 @@ func TestGetTaskProtection(t *testing.T) { ), ), expectedStatusCode: http.StatusBadRequest, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ RequestID: &ecsRequestID, - Error: &agentapi.ErrorResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeAccessDeniedException, Message: ecsErrMessage, @@ -3127,7 +3127,7 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("ecs call non-request-failure aws error", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3135,8 +3135,8 @@ func TestGetTaskProtection(t *testing.T) { nil, awserr.New(ecs.ErrCodeInvalidParameterException, ecsErrMessage, nil)), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeInvalidParameterException, Message: ecsErrMessage, @@ -3145,15 +3145,15 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("agent timeout", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, awserr.New(request.CanceledErrorCode, "request cancelled", nil)), expectedStatusCode: http.StatusGatewayTimeout, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: request.CanceledErrorCode, Message: "Timed out calling ECS Task Protection API", @@ -3162,15 +3162,15 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("non-aws error", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, errors.New("some error")), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeServerException, Message: "some error", @@ -3179,7 +3179,7 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("ecs failure", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3191,7 +3191,7 @@ func TestGetTaskProtection(t *testing.T) { }}, }, nil), expectedStatusCode: http.StatusOK, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ Failure: &ecs.Failure{ Arn: aws.String(taskARN), Reason: aws.String("ecs failure"), @@ -3200,7 +3200,7 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("more than one ecs failure", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3218,8 +3218,8 @@ func TestGetTaskProtection(t *testing.T) { }, }, nil), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeServerException, Message: "Unexpected error occurred", @@ -3228,13 +3228,13 @@ func TestGetTaskProtection(t *testing.T) { }) }) t.Run("happy case", func(t *testing.T) { - testTMDSRequest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + testTMDSRequest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ path: path, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil), expectedStatusCode: http.StatusOK, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ Protection: &protectedTask, }, }) @@ -3279,11 +3279,11 @@ func TestUpdateTaskProtection(t *testing.T) { Return(taskRoleCredentials(), true) } taskProtectionClientFactoryExpectations := func(output *ecs.UpdateTaskProtectionOutput, err error) func( - *gomock.Controller, *task_protection_v1.MockTaskProtectionClientFactoryInterface, + *gomock.Controller, *tpinterface.MockTaskProtectionClientFactoryInterface, ) { return func( ctrl *gomock.Controller, - factory *task_protection_v1.MockTaskProtectionClientFactoryInterface, + factory *tpinterface.MockTaskProtectionClientFactoryInterface, ) { client := mock_taskprotection.NewMockECSTaskProtectionSDK(ctrl) client.EXPECT().UpdateTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err) @@ -3292,7 +3292,7 @@ func TestUpdateTaskProtection(t *testing.T) { } // Helper function for creating a function that runs a test case - runTest := func(t *testing.T, tc TMDSTestCase[agentapi.TaskProtectionResponse]) func(*testing.T) { + runTest := func(t *testing.T, tc TMDSTestCase[tptypes.TaskProtectionResponse]) func(*testing.T) { return func(t *testing.T) { tc.path = fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID) tc.method = "PUT" @@ -3301,7 +3301,7 @@ func TestUpdateTaskProtection(t *testing.T) { } // Test cases start here - t.Run("task ARN not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("task ARN not found", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( @@ -3309,14 +3309,14 @@ func TestUpdateTaskProtection(t *testing.T) { ) }, expectedStatusCode: http.StatusNotFound, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Code: ecs.ErrCodeResourceNotFoundException, Message: "Failed to find a task for the request", }, }, })) - t.Run("task not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("task not found", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( @@ -3325,14 +3325,14 @@ func TestUpdateTaskProtection(t *testing.T) { ) }, expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Code: ecs.ErrCodeServerException, Message: "Failed to find a task for the request", }, }, })) - t.Run("task credentials not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("task credentials not found", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: func(credsManager *mock_credentials.MockManager) { @@ -3341,15 +3341,15 @@ func TestUpdateTaskProtection(t *testing.T) { Return(credentials.TaskIAMRoleCredentials{}, false) }, expectedStatusCode: http.StatusForbidden, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeAccessDeniedException, Message: "Invalid Request: no task IAM role credentials available for task", }, }, })) - t.Run("ecs call server exception", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("ecs call server exception", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3362,16 +3362,16 @@ func TestUpdateTaskProtection(t *testing.T) { ), ), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ RequestID: &ecsRequestID, - Error: &agentapi.ErrorResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeServerException, Message: ecsErrMessage, }, }, })) - t.Run("ecs call access denied exception", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("ecs call access denied exception", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3384,16 +3384,16 @@ func TestUpdateTaskProtection(t *testing.T) { ), ), expectedStatusCode: http.StatusBadRequest, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ RequestID: &ecsRequestID, - Error: &agentapi.ErrorResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeAccessDeniedException, Message: ecsErrMessage, }, }, })) - t.Run("ecs call non-request-failure aws error", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("ecs call non-request-failure aws error", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3401,45 +3401,45 @@ func TestUpdateTaskProtection(t *testing.T) { nil, awserr.New(ecs.ErrCodeInvalidParameterException, ecsErrMessage, nil)), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeInvalidParameterException, Message: ecsErrMessage, }, }, })) - t.Run("agent timeout", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("agent timeout", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, awserr.New(request.CanceledErrorCode, "request cancelled", nil)), expectedStatusCode: http.StatusGatewayTimeout, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: request.CanceledErrorCode, Message: "Timed out calling ECS Task Protection API", }, }, })) - t.Run("non-aws error", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("non-aws error", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations( nil, errors.New("some error")), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeServerException, Message: "some error", }, }, })) - t.Run("ecs failure", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("ecs failure", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3451,14 +3451,14 @@ func TestUpdateTaskProtection(t *testing.T) { }}, }, nil), expectedStatusCode: http.StatusOK, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ Failure: &ecs.Failure{ Arn: aws.String(taskARN), Reason: aws.String("ecs failure"), }, }, })) - t.Run("more than on ecs failure", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("more than on ecs failure", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, @@ -3476,70 +3476,70 @@ func TestUpdateTaskProtection(t *testing.T) { }, }, nil), expectedStatusCode: http.StatusInternalServerError, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeServerException, Message: "Unexpected error occurred", }, }, })) - t.Run("empty request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("empty request", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: map[string]string{}, setStateExpectations: happyStateExpectations, expectedStatusCode: http.StatusBadRequest, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Arn: taskARN, Code: ecs.ErrCodeInvalidParameterException, Message: "Invalid request: does not contain 'ProtectionEnabled' field", }, }, })) - t.Run("invalid type in request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("invalid type in request", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: map[string]interface{}{ "ProtectionEnabled": true, "ExpiresInMinutes": "bad", }, expectedStatusCode: http.StatusBadRequest, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Code: ecs.ErrCodeInvalidParameterException, Message: "UpdateTaskProtection: failed to decode request", }, }, })) - t.Run("unknown fields in the request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("unknown fields in the request", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: map[string]interface{}{ "ProtectionEnabled": true, "ExpiresInMinutes": 5, "Unknown": "unknown", }, expectedStatusCode: http.StatusBadRequest, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Code: ecs.ErrCodeInvalidParameterException, Message: "UpdateTaskProtection: failed to decode request", }, }, })) - t.Run("non-JSON object request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("non-JSON object request", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: "bad", expectedStatusCode: http.StatusBadRequest, - expectedResponseBody: agentapi.TaskProtectionResponse{ - Error: &agentapi.ErrorResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ + Error: &tptypes.ErrorResponse{ Code: ecs.ErrCodeInvalidParameterException, Message: "UpdateTaskProtection: failed to decode request", }, }, })) - t.Run("happy case", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{ + t.Run("happy case", runTest(t, TMDSTestCase[tptypes.TaskProtectionResponse]{ requestBody: happyReqBody, setStateExpectations: happyStateExpectations, setCredentialsManagerExpectations: happyCredentialsManagerExpectations, setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil), expectedStatusCode: http.StatusOK, - expectedResponseBody: agentapi.TaskProtectionResponse{ + expectedResponseBody: tptypes.TaskProtectionResponse{ Protection: &protectedTask, }, })) diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/generate_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/generate_mocks.go similarity index 75% rename from agent/handlers/agentapi/taskprotection/v1/handlers/generate_mocks.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/generate_mocks.go index cf06d6fafc2..84c877b2226 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/generate_mocks.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/generate_mocks.go @@ -1,3 +1,3 @@ package handlers -//go:generate mockgen -destination=handlers_mocks.go -package=handlers -copyright_file=../../../../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers TaskProtectionClientFactoryInterface +//go:generate mockgen -destination=handlers_mocks.go -package=handlers -copyright_file=../../../../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers TaskProtectionClientFactoryInterface diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_mocks.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_mocks.go similarity index 95% rename from agent/handlers/agentapi/taskprotection/v1/handlers/handlers_mocks.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_mocks.go index d59f8c4d6b3..22dc4c1fc2d 100644 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_mocks.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_mocks.go @@ -13,7 +13,7 @@ // // Code generated by MockGen. DO NOT EDIT. -// Source: github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers (interfaces: TaskProtectionClientFactoryInterface) +// Source: github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers (interfaces: TaskProtectionClientFactoryInterface) // Package handlers is a generated GoMock package. package handlers diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/interface.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/interface.go similarity index 100% rename from agent/handlers/agentapi/taskprotection/v1/handlers/interface.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/interface.go diff --git a/agent/handlers/agentapi/taskprotection/v1/types/types.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types/types.go similarity index 100% rename from agent/handlers/agentapi/taskprotection/v1/types/types.go rename to agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types/types.go diff --git a/agent/vendor/modules.txt b/agent/vendor/modules.txt index 5a1941b795d..3a91d6ca16d 100644 --- a/agent/vendor/modules.txt +++ b/agent/vendor/modules.txt @@ -34,6 +34,8 @@ github.com/aws/amazon-ecs-agent/ecs-agent/tcs/handler github.com/aws/amazon-ecs-agent/ecs-agent/tcs/model/ecstcs github.com/aws/amazon-ecs-agent/ecs-agent/tmds github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response +github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers +github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1 github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2 diff --git a/ecs-agent/tmds/handlers/taskprotection/v1/handlers/generate_mocks.go b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/generate_mocks.go new file mode 100644 index 00000000000..84c877b2226 --- /dev/null +++ b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/generate_mocks.go @@ -0,0 +1,3 @@ +package handlers + +//go:generate mockgen -destination=handlers_mocks.go -package=handlers -copyright_file=../../../../../../scripts/copyright_file github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers TaskProtectionClientFactoryInterface diff --git a/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_mocks.go b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_mocks.go new file mode 100644 index 00000000000..22dc4c1fc2d --- /dev/null +++ b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_mocks.go @@ -0,0 +1,64 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. +// + +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers (interfaces: TaskProtectionClientFactoryInterface) + +// Package handlers is a generated GoMock package. +package handlers + +import ( + reflect "reflect" + + api "github.com/aws/amazon-ecs-agent/ecs-agent/api" + credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + gomock "github.com/golang/mock/gomock" +) + +// MockTaskProtectionClientFactoryInterface is a mock of TaskProtectionClientFactoryInterface interface. +type MockTaskProtectionClientFactoryInterface struct { + ctrl *gomock.Controller + recorder *MockTaskProtectionClientFactoryInterfaceMockRecorder +} + +// MockTaskProtectionClientFactoryInterfaceMockRecorder is the mock recorder for MockTaskProtectionClientFactoryInterface. +type MockTaskProtectionClientFactoryInterfaceMockRecorder struct { + mock *MockTaskProtectionClientFactoryInterface +} + +// NewMockTaskProtectionClientFactoryInterface creates a new mock instance. +func NewMockTaskProtectionClientFactoryInterface(ctrl *gomock.Controller) *MockTaskProtectionClientFactoryInterface { + mock := &MockTaskProtectionClientFactoryInterface{ctrl: ctrl} + mock.recorder = &MockTaskProtectionClientFactoryInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTaskProtectionClientFactoryInterface) EXPECT() *MockTaskProtectionClientFactoryInterfaceMockRecorder { + return m.recorder +} + +// NewTaskProtectionClient mocks base method. +func (m *MockTaskProtectionClientFactoryInterface) NewTaskProtectionClient(arg0 credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "NewTaskProtectionClient", arg0) + ret0, _ := ret[0].(api.ECSTaskProtectionSDK) + return ret0 +} + +// NewTaskProtectionClient indicates an expected call of NewTaskProtectionClient. +func (mr *MockTaskProtectionClientFactoryInterfaceMockRecorder) NewTaskProtectionClient(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NewTaskProtectionClient", reflect.TypeOf((*MockTaskProtectionClientFactoryInterface)(nil).NewTaskProtectionClient), arg0) +} diff --git a/ecs-agent/tmds/handlers/taskprotection/v1/handlers/interface.go b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/interface.go new file mode 100644 index 00000000000..02c8c4c4dd8 --- /dev/null +++ b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/interface.go @@ -0,0 +1,10 @@ +package handlers + +import ( + "github.com/aws/amazon-ecs-agent/ecs-agent/api" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" +) + +type TaskProtectionClientFactoryInterface interface { + NewTaskProtectionClient(taskRoleCredential credentials.TaskIAMRoleCredentials) api.ECSTaskProtectionSDK +} diff --git a/ecs-agent/tmds/handlers/taskprotection/v1/types/types.go b/ecs-agent/tmds/handlers/taskprotection/v1/types/types.go new file mode 100644 index 00000000000..8bc763437e4 --- /dev/null +++ b/ecs-agent/tmds/handlers/taskprotection/v1/types/types.go @@ -0,0 +1,107 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package types + +import ( + "encoding/json" + "fmt" + + "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" +) + +// taskProtection is type of Protection for a Task +type taskProtection struct { + protectionEnabled bool + expiresInMinutes *int64 +} + +// MarshalJSON is custom JSON marshal function to marshal unexported fields for logging purposes +func (taskProtection *taskProtection) MarshalJSON() ([]byte, error) { + jsonBytes, err := json.Marshal(struct { + ProtectionEnabled bool + ExpiresInMinutes *int64 + }{ + ProtectionEnabled: taskProtection.protectionEnabled, + ExpiresInMinutes: taskProtection.expiresInMinutes, + }) + + if err != nil { + return nil, err + } + + return jsonBytes, nil +} + +// NewTaskProtection creates a taskProtection +func NewTaskProtection(protectionEnabled bool, expiresInMinutes *int64) *taskProtection { + return &taskProtection{ + protectionEnabled: protectionEnabled, + expiresInMinutes: expiresInMinutes, + } +} + +func (taskProtection *taskProtection) GetProtectionEnabled() bool { + return taskProtection.protectionEnabled +} + +func (taskProtection *taskProtection) GetExpiresInMinutes() *int64 { + return taskProtection.expiresInMinutes +} + +func (taskProtection *taskProtection) String() string { + jsonBytes, err := taskProtection.MarshalJSON() + if err != nil { + return fmt.Sprintf("failed to get string representation of taskProtection type: %v", err) + } + return string(jsonBytes) +} + +// TaskProtectionResponse is response type for all Update/GetTaskProtection requests +type TaskProtectionResponse struct { + RequestID *string `json:"requestID,omitempty"` + Protection *ecs.ProtectedTask `json:"protection,omitempty"` + Failure *ecs.Failure `json:"failure,omitempty"` + Error *ErrorResponse `json:"error,omitempty"` +} + +// NewTaskProtectionResponseProtection creates a TaskProtectionResponse when it is a successful response (has protection) +func NewTaskProtectionResponseProtection(protection *ecs.ProtectedTask) TaskProtectionResponse { + return TaskProtectionResponse{Protection: protection} +} + +// NewTaskProtectionResponseFailure creates a TaskProtectionResponse when there is a failed response with failure +func NewTaskProtectionResponseFailure(failure *ecs.Failure) TaskProtectionResponse { + return TaskProtectionResponse{Failure: failure} +} + +// NewTaskProtectionResponseError creates a TaskProtectionResponse when there is an error response with optional requestID +func NewTaskProtectionResponseError(error *ErrorResponse, requestID *string) TaskProtectionResponse { + return TaskProtectionResponse{RequestID: requestID, Error: error} +} + +// ErrorResponse is the type for all Update/GetTaskProtection request errors +type ErrorResponse struct { + Arn string `json:"Arn,omitempty"` + Code string + Message string +} + +// NewErrorResponsePtr creates a *ErrorResponse for Agent input validations failures and exceptions +func NewErrorResponsePtr(arn string, code string, message string) *ErrorResponse { + return &ErrorResponse{ + Arn: arn, + Code: code, + Message: message, + } +}