From bd17b370dc9e1dbc6cdc39d4408c3c0c10fdc8d3 Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Thu, 22 Jun 2023 21:56:43 +0000 Subject: [PATCH 1/4] Add task-protection handlers to ecs-agent and integrate them with agent --- .../taskprotection/v1/handlers/handlers.go | 335 --------- .../v1/handlers/handlers_test.go | 640 ----------------- agent/handlers/task_server_setup.go | 35 +- agent/handlers/task_server_setup_test.go | 9 +- agent/handlers/v4/tmdsstate.go | 2 + .../ecs-agent/logger/field/constants.go | 1 + .../taskprotection/v1/handlers/handlers.go | 403 +++++++++++ .../tmds/handlers/v4/state/response.go | 1 + ecs-agent/logger/field/constants.go | 1 + .../taskprotection/v1/handlers/handlers.go | 403 +++++++++++ .../v1/handlers/handlers_test.go | 671 ++++++++++++++++++ ecs-agent/tmds/handlers/v4/state/response.go | 1 + 12 files changed, 1513 insertions(+), 989 deletions(-) delete mode 100644 agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go delete mode 100644 agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go create mode 100644 agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go create mode 100644 ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go create mode 100644 ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_test.go diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go deleted file mode 100644 index 80791949cef..00000000000 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers.go +++ /dev/null @@ -1,335 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package handlers - -import ( - "context" - "encoding/json" - "errors" - "fmt" - "net/http" - "time" - - apitask "github.com/aws/amazon-ecs-agent/agent/api/task" - "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" - v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" - "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" - "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" - "github.com/aws/amazon-ecs-agent/ecs-agent/logger" - loggerfield "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" - tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" - "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" - "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" -) - -const ( - ExpectedProtectionResponseLength = 1 - - // timeout for ECS SDK calls - // must be lower than server write timeout - ecsCallTimeout = 4 * time.Second - ecsCallTimedOutError = "Timed out calling ECS Task Protection API" - taskNotFoundErrorMsg = "Failed to find a task for the request" -) - -// TaskProtectionPath Returns endpoint path for UpdateTaskProtection API -func TaskProtectionPath() string { - return fmt.Sprintf( - "/api/%s/task-protection/v1/state", - utils.ConstructMuxVar(v3.V3EndpointIDMuxName, utils.AnythingButSlashRegEx)) -} - -// TaskProtectionRequest is the Task protection request received from customers pending validation -type TaskProtectionRequest struct { - ProtectionEnabled *bool - ExpiresInMinutes *int64 -} - -// UpdateTaskProtectionHandler returns an HTTP request handler function for -// UpdateTaskProtection API -func UpdateTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsManager credentials.Manager, - factory tpinterface.TaskProtectionClientFactoryInterface, cluster string) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - updateTaskProtectionRequestType := "api/UpdateTaskProtection/v1" - - var request TaskProtectionRequest - jsonDecoder := json.NewDecoder(r.Body) - jsonDecoder.DisallowUnknownFields() - if err := jsonDecoder.Decode(&request); err != nil { - logger.Error("UpdateTaskProtection: failed to decode request", logger.Fields{ - loggerfield.Error: err, - }) - writeJSONResponse(w, http.StatusBadRequest, - types.NewTaskProtectionResponseError(types.NewErrorResponsePtr("", ecs.ErrCodeInvalidParameterException, - "UpdateTaskProtection: failed to decode request"), nil), - updateTaskProtectionRequestType) - return - } - - task, statusCode, errorCode, err := getTaskFromRequest(state, r) - if err != nil { - writeJSONResponse(w, statusCode, - types.NewTaskProtectionResponseError(types.NewErrorResponsePtr("", errorCode, err.Error()), nil), - updateTaskProtectionRequestType) - return - } - - if request.ProtectionEnabled == nil { - writeJSONResponse(w, http.StatusBadRequest, - types.NewTaskProtectionResponseError(types.NewErrorResponsePtr(task.Arn, ecs.ErrCodeInvalidParameterException, - "Invalid request: does not contain 'ProtectionEnabled' field"), nil), - updateTaskProtectionRequestType) - return - } - - taskProtection := types.NewTaskProtection(*request.ProtectionEnabled, request.ExpiresInMinutes) - - logger.Info("UpdateTaskProtection endpoint was called", logger.Fields{ - loggerfield.Cluster: cluster, - loggerfield.TaskARN: task.Arn, - loggerfield.TaskProtection: taskProtection, - }) - - taskRoleCredential, ok := credentialsManager.GetTaskCredentials(task.GetCredentialsID()) - if !ok { - err = fmt.Errorf("Invalid Request: no task IAM role credentials available for task") - logger.Error(err.Error(), logger.Fields{ - loggerfield.TaskARN: task.Arn, - }) - writeJSONResponse(w, http.StatusForbidden, - types.NewTaskProtectionResponseError(types.NewErrorResponsePtr(task.Arn, ecs.ErrCodeAccessDeniedException, err.Error()), nil), - updateTaskProtectionRequestType) - return - } - ecsClient := factory.NewTaskProtectionClient(taskRoleCredential) - - ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) - defer cancel() - response, err := ecsClient.UpdateTaskProtectionWithContext(ctx, &ecs.UpdateTaskProtectionInput{ - Cluster: aws.String(cluster), - ExpiresInMinutes: taskProtection.GetExpiresInMinutes(), - ProtectionEnabled: aws.Bool(taskProtection.GetProtectionEnabled()), - Tasks: aws.StringSlice([]string{task.Arn}), - }) - - if err != nil { - errorCode, errorMsg, statusCode, reqId := getErrorCodeAndStatusCode(err) - var requestIdString = "" - if reqId != nil { - requestIdString = *reqId - } - logger.Error("Got an exception when calling UpdateTaskProtection.", logger.Fields{ - loggerfield.Error: err, - "ErrorCode": errorCode, - "ExceptionMessage": errorMsg, - "StatusCode": statusCode, - "RequestId": requestIdString, - }) - writeJSONResponse(w, statusCode, types.NewTaskProtectionResponseError(types.NewErrorResponsePtr(task.Arn, errorCode, errorMsg), reqId), - updateTaskProtectionRequestType) - return - } - - logger.Debug("updateTaskProtection response:", logger.Fields{ - loggerfield.TaskProtection: response.ProtectedTasks, - loggerfield.Reason: response.Failures, - }) - - // there are no exceptions but there are failures when setting protection in scheduler - if len(response.Failures) > 0 { - if len(response.Failures) > ExpectedProtectionResponseLength { - err := fmt.Errorf("expect at most %v failure in response, get %v", ExpectedProtectionResponseLength, len(response.Failures)) - logger.Error("Unexpected number of failures", logger.Fields{ - loggerfield.Error: err, - loggerfield.TaskARN: task.Arn, - }) - writeJSONResponse(w, http.StatusInternalServerError, types.NewTaskProtectionResponseError( - types.NewErrorResponsePtr(task.Arn, ecs.ErrCodeServerException, "Unexpected error occurred"), nil), - updateTaskProtectionRequestType) - return - } - writeJSONResponse(w, http.StatusOK, types.NewTaskProtectionResponseFailure(response.Failures[0]), updateTaskProtectionRequestType) - return - } - if len(response.ProtectedTasks) > ExpectedProtectionResponseLength { - err := fmt.Errorf("expect %v protectedTask in response when no failure, get %v", ExpectedProtectionResponseLength, len(response.ProtectedTasks)) - logger.Error("Unexpected number of protections", logger.Fields{ - loggerfield.Error: err, - loggerfield.TaskARN: task.Arn, - }) - writeJSONResponse(w, http.StatusInternalServerError, types.NewTaskProtectionResponseError( - types.NewErrorResponsePtr(task.Arn, ecs.ErrCodeServerException, "Unexpected error occurred"), nil), - updateTaskProtectionRequestType) - return - } - writeJSONResponse(w, http.StatusOK, types.NewTaskProtectionResponseProtection(response.ProtectedTasks[0]), updateTaskProtectionRequestType) - } -} - -// GetTaskProtectionHandler returns a handler function for GetTaskProtection API -func GetTaskProtectionHandler(state dockerstate.TaskEngineState, credentialsManager credentials.Manager, - factory tpinterface.TaskProtectionClientFactoryInterface, cluster string) func(http.ResponseWriter, *http.Request) { - return func(w http.ResponseWriter, r *http.Request) { - getTaskProtectionRequestType := "api/GetTaskProtection/v1" - - task, statusCode, errorCode, err := getTaskFromRequest(state, r) - if err != nil { - writeJSONResponse(w, statusCode, - types.NewTaskProtectionResponseError(types.NewErrorResponsePtr("", errorCode, err.Error()), nil), - getTaskProtectionRequestType) - return - } - - logger.Info("GetTaskProtection endpoint was called", logger.Fields{ - loggerfield.Cluster: cluster, - loggerfield.TaskARN: task.Arn, - }) - - taskRoleCredential, ok := credentialsManager.GetTaskCredentials(task.GetCredentialsID()) - if !ok { - err = fmt.Errorf("Invalid Request: no task IAM role credentials available for task") - logger.Error(err.Error(), logger.Fields{ - loggerfield.TaskARN: task.Arn, - }) - writeJSONResponse(w, http.StatusForbidden, - types.NewTaskProtectionResponseError(types.NewErrorResponsePtr(task.Arn, ecs.ErrCodeAccessDeniedException, err.Error()), nil), - getTaskProtectionRequestType) - return - } - - ecsClient := factory.NewTaskProtectionClient(taskRoleCredential) - - ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) - defer cancel() - response, err := ecsClient.GetTaskProtectionWithContext(ctx, &ecs.GetTaskProtectionInput{ - Cluster: aws.String(cluster), - Tasks: aws.StringSlice([]string{task.Arn}), - }) - - if err != nil { - errorCode, errorMsg, statusCode, reqId := getErrorCodeAndStatusCode(err) - var requestIdString = "" - if reqId != nil { - requestIdString = *reqId - } - logger.Error("Got an exception when calling GetTaskProtection.", logger.Fields{ - loggerfield.Error: err, - "ErrorCode": errorCode, - "ExceptionMessage": errorMsg, - "StatusCode": statusCode, - "RequestId": requestIdString, - }) - writeJSONResponse(w, statusCode, types.NewTaskProtectionResponseError(types.NewErrorResponsePtr(task.Arn, errorCode, errorMsg), reqId), - getTaskProtectionRequestType) - return - } - - logger.Debug("getTaskProtection response:", logger.Fields{ - loggerfield.TaskProtection: response.ProtectedTasks, - loggerfield.Reason: response.Failures, - }) - - // there are no exceptions but there are failures when getting protection in scheduler - if len(response.Failures) > 0 { - if len(response.Failures) > ExpectedProtectionResponseLength { - err := fmt.Errorf("expect at most %v failure in response, get %v", ExpectedProtectionResponseLength, len(response.Failures)) - logger.Error("Unexpected number of failures", logger.Fields{ - loggerfield.Error: err, - loggerfield.TaskARN: task.Arn, - }) - writeJSONResponse(w, http.StatusInternalServerError, types.NewTaskProtectionResponseError( - types.NewErrorResponsePtr(task.Arn, ecs.ErrCodeServerException, "Unexpected error occurred"), nil), - getTaskProtectionRequestType) - return - } - writeJSONResponse(w, http.StatusOK, types.NewTaskProtectionResponseFailure(response.Failures[0]), getTaskProtectionRequestType) - return - } - - if len(response.ProtectedTasks) > ExpectedProtectionResponseLength { - err := fmt.Errorf("expect %v protectedTask in response when no failure, get %v", ExpectedProtectionResponseLength, len(response.ProtectedTasks)) - logger.Error("Unexpected number of protections", logger.Fields{ - loggerfield.Error: err, - loggerfield.TaskARN: task.Arn, - }) - writeJSONResponse(w, http.StatusInternalServerError, types.NewTaskProtectionResponseError( - types.NewErrorResponsePtr(task.Arn, ecs.ErrCodeServerException, "Unexpected error occurred"), nil), - getTaskProtectionRequestType) - return - } - writeJSONResponse(w, http.StatusOK, types.NewTaskProtectionResponseProtection(response.ProtectedTasks[0]), getTaskProtectionRequestType) - } -} - -// Helper function to parse error to get ErrorCode, ExceptionMessage, HttpStatusCode, RequestID. -// RequestID will be empty if the request is not able to reach AWS -func getErrorCodeAndStatusCode(err error) (string, string, int, *string) { - msg := err.Error() - // The error is a Generic AWS Error with Code, Message, and original error (if any) - if awsErr, ok := err.(awserr.Error); ok { - // The error is an AWS service error occurred - msg = awsErr.Message() - if reqErr, ok := err.(awserr.RequestFailure); ok { - reqId := reqErr.RequestID() - return awsErr.Code(), msg, reqErr.StatusCode(), &reqId - } else if aerr, ok := err.(awserr.Error); ok && aerr.Code() == request.CanceledErrorCode { - return aerr.Code(), ecsCallTimedOutError, http.StatusGatewayTimeout, nil - } else { - logger.Error(fmt.Sprintf("got an exception that does not implement RequestFailure interface but is an aws error. This should not happen, return statusCode 500 for whatever errorCode. Original err: %v.", err)) - return awsErr.Code(), msg, http.StatusInternalServerError, nil - } - } else { - logger.Error(fmt.Sprintf("non aws error received: %v", err)) - return ecs.ErrCodeServerException, msg, http.StatusInternalServerError, nil - } -} - -// Helper function for finding task for the request -func getTaskFromRequest(state dockerstate.TaskEngineState, r *http.Request) (*apitask.Task, int, string, error) { - taskARN, err := v3.GetTaskARNByRequest(r, state) - if err != nil { - logger.Error("Failed to find task ARN for task protection request", logger.Fields{ - loggerfield.Error: err, - }) - return nil, http.StatusNotFound, ecs.ErrCodeResourceNotFoundException, errors.New(taskNotFoundErrorMsg) - } - - task, found := state.TaskByArn(taskARN) - if !found { - logger.Critical("No task was found for taskARN for task protection request", logger.Fields{ - loggerfield.TaskARN: taskARN, - }) - return nil, http.StatusInternalServerError, ecs.ErrCodeServerException, errors.New(taskNotFoundErrorMsg) - } - - return task, http.StatusOK, "", nil -} - -// Writes the provided response to the ResponseWriter and handles any errors -func writeJSONResponse(w http.ResponseWriter, statusCode int, response types.TaskProtectionResponse, requestType string) { - bytes, err := json.Marshal(response) - if err != nil { - logger.Error("Agent API Task Protection V1: failed to marshal response as JSON", logger.Fields{ - "response": response, - loggerfield.Error: err, - }) - utils.WriteJSONToResponse(w, http.StatusInternalServerError, []byte(`{}`), - requestType) - } else { - utils.WriteJSONToResponse(w, statusCode, bytes, requestType) - } -} diff --git a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go b/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go deleted file mode 100644 index 1f17482fbe6..00000000000 --- a/agent/handlers/agentapi/taskprotection/v1/handlers/handlers_test.go +++ /dev/null @@ -1,640 +0,0 @@ -// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"). You may -// not use this file except in compliance with the License. A copy of the -// License is located at -// -// http://aws.amazon.com/apache2.0/ -// -// or in the "license" file accompanying this file. This file is distributed -// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either -// express or implied. See the License for the specific language governing -// permissions and limitations under the License. - -package handlers - -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "net/http/httptest" - "testing" - "time" - - "github.com/aws/amazon-ecs-agent/agent/api/task" - "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" - mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" - tpfactory "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection" - v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" - "github.com/aws/amazon-ecs-agent/agent/utils" - mock_api "github.com/aws/amazon-ecs-agent/ecs-agent/api/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" - mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" - "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" - tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" - "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/golang/mock/gomock" - "github.com/gorilla/mux" - "github.com/stretchr/testify/assert" -) - -const ( - testAccessKey = "accessKey" - testSecretKey = "secretKey" - testSessionToken = "sessionToken" - testCluster = "cluster" - testRegion = "region" - testECSEndpoint = "endpoint" - testTaskCredentialsId = "taskCredentialsId" - testV3EndpointId = "endpointId" - testTaskArn = "taskArn" - testServiceName = "serviceName" - testAcceptInsecureCert = false - protectionEnabledFieldName = "ProtectionEnabled" - expiresInMinutesFieldName = "ExpiresInMinutes" - testExpiresInMinutes = 5 - testProtectionEnabled = true - testRequestID = "requestID" - testFailureReason = "failureReason" -) - -// Tests the path for UpdateTaskProtection API -func TestTaskProtectionPath(t *testing.T) { - assert.Equal(t, "/api/{v3EndpointIDMuxName:[^/]*}/task-protection/v1/state", TaskProtectionPath()) -} - -func getRequestWithUnknownFields(t *testing.T) map[string]interface{} { - request := TaskProtectionRequest{ProtectionEnabled: utils.BoolPtr(false)} - requestJSON, err := json.Marshal(request) - assert.NoError(t, err) - - var rawRequest map[string]interface{} - err = json.Unmarshal(requestJSON, &rawRequest) - assert.NoError(t, err) - rawRequest["UnknownField"] = 5 - return rawRequest -} - -// Helper function for running tests for UpdateTaskProtection handler -func testUpdateTaskProtectionHandler(t *testing.T, state dockerstate.TaskEngineState, - v3EndpointID string, credentialsManager credentials.Manager, - factory tpinterface.TaskProtectionClientFactoryInterface, - request interface{}, expectedResponse interface{}, expectedResponseCode int) { - // Prepare request - requestBytes, err := json.Marshal(request) - assert.NoError(t, err) - bodyReader := bytes.NewReader(requestBytes) - req, err := http.NewRequest("PUT", "", bodyReader) - assert.NoError(t, err) - req = mux.SetURLVars(req, map[string]string{v3.V3EndpointIDMuxName: v3EndpointID}) - - // Call handler - rr := httptest.NewRecorder() - handler := http.HandlerFunc(UpdateTaskProtectionHandler(state, credentialsManager, factory, testCluster)) - handler.ServeHTTP(rr, req) - - expectedResponseJSON, err := json.Marshal(expectedResponse) - assert.NoError(t, err, "Expected response must be JSON encodable") - - // Assert response - assert.Equal(t, "application/json", rr.Header().Get("Content-Type")) - assert.Equal(t, expectedResponseCode, rr.Code) - responseBody, err := io.ReadAll(rr.Body) - assert.NoError(t, err, "Failed to read response body") - assert.Equal(t, string(expectedResponseJSON), string(responseBody)) -} - -func generateRequestIdPtr() *string { - requestIdString := testRequestID - return &requestIdString -} - -// TestUpdateTaskProtectionHandler_InputValidationsDecodeError tests UpdateTaskProtection handler's -// behavior with different invalid inputs with decode error -func TestUpdateTaskProtectionHandler_InputValidationsDecodeError(t *testing.T) { - testCases := []struct { - name string - request interface{} - expectedError *types.ErrorResponse - }{ - { - name: "InvalidTypes", - request: &map[string]interface{}{ - protectionEnabledFieldName: true, - expiresInMinutesFieldName: "badType", - }, - expectedError: &types.ErrorResponse{Code: ecs.ErrCodeInvalidParameterException, Message: "UpdateTaskProtection: failed to decode request"}, - }, - { - name: "UnknownFieldsInRequest", - request: getRequestWithUnknownFields(t), - expectedError: &types.ErrorResponse{Code: ecs.ErrCodeInvalidParameterException, Message: "UpdateTaskProtection: failed to decode request"}, - }, - { - name: "InvalidJSONRequest", - request: "", - expectedError: &types.ErrorResponse{Code: ecs.ErrCodeInvalidParameterException, Message: "UpdateTaskProtection: failed to decode request"}, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - expectedResponse := types.TaskProtectionResponse{Error: tc.expectedError} - testUpdateTaskProtectionHandler(t, mock_dockerstate.NewMockTaskEngineState(ctrl), - testV3EndpointId, nil, nil, tc.request, expectedResponse, http.StatusBadRequest) - }) - } -} - -// TestUpdateTaskProtectionHandlerTaskARNNotFound tests UpdateTaskProtection handler's -// behavior when task ARN was not found for the request. -func TestUpdateTaskProtectionHandlerTaskARNNotFound(t *testing.T) { - request := TaskProtectionRequest{ProtectionEnabled: utils.BoolPtr(false)} - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return("", false) - - expectedResponse := types.TaskProtectionResponse{ - Error: &types.ErrorResponse{ - Code: ecs.ErrCodeResourceNotFoundException, - Message: "Failed to find a task for the request", - }, - } - testUpdateTaskProtectionHandler(t, mockState, testV3EndpointId, nil, nil, request, - expectedResponse, http.StatusNotFound) -} - -// TestUpdateTaskProtectionHandlerTaskNotFound tests UpdateTaskProtection handler's -// behavior when task ARN was not found for the request. -func TestUpdateTaskProtectionHandlerTaskNotFound(t *testing.T) { - request := TaskProtectionRequest{ProtectionEnabled: utils.BoolPtr(false)} - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) - mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(nil, false) - - expectedResponse := types.TaskProtectionResponse{ - Error: &types.ErrorResponse{ - Code: ecs.ErrCodeServerException, - Message: "Failed to find a task for the request", - }, - } - - testUpdateTaskProtectionHandler(t, mockState, testV3EndpointId, nil, nil, request, - expectedResponse, http.StatusInternalServerError) -} - -// TestUpdateTaskProtectionHandler_EmptyRequest tests UpdateTaskProtection handler's behavior with empty inputs -func TestUpdateTaskProtectionHandler_EmptyRequest(t *testing.T) { - expectedError := &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeInvalidParameterException, Message: "Invalid request: does not contain 'ProtectionEnabled' field"} - testTask := task.Task{ - Arn: testTaskArn, - ServiceName: testServiceName, - } - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) - mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true) - expectedResponse := types.TaskProtectionResponse{Error: expectedError} - testUpdateTaskProtectionHandler(t, mockState, testV3EndpointId, nil, nil, - nil, expectedResponse, http.StatusBadRequest) -} - -// TestUpdateTaskProtectionHandlerTaskRoleCredentialsNotFound tests UpdateTaskProtection handler's -// behavior when task IAM role credential is not found for the request. -func TestUpdateTaskProtectionHandlerTaskRoleCredentialsNotFound(t *testing.T) { - request := TaskProtectionRequest{ - ProtectionEnabled: utils.BoolPtr(true), - } - - testTask := task.Task{ - Arn: testTaskArn, - ServiceName: testServiceName, - } - testTask.SetCredentialsID(testTaskCredentialsId) - - factory := tpfactory.TaskProtectionClientFactory{ - Region: testRegion, Endpoint: testECSEndpoint, AcceptInsecureCert: testAcceptInsecureCert, - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockManager := mock_credentials.NewMockManager(ctrl) - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) - mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true) - mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, false) - - expectedResponse := types.TaskProtectionResponse{ - Error: &types.ErrorResponse{ - Arn: testTaskArn, - Code: ecs.ErrCodeAccessDeniedException, - Message: "Invalid Request: no task IAM role credentials available for task", - }, - } - - testUpdateTaskProtectionHandler(t, mockState, testV3EndpointId, mockManager, factory, request, - expectedResponse, http.StatusForbidden) -} - -// TestUpdateTaskProtectionHandler_PostCall tests UpdateTaskProtection handler's -// behavior when request successfully reached ECS and get response -func TestUpdateTaskProtectionHandler_PostCall(t *testing.T) { - testCases := []struct { - name string - ecsError error - ecsResponse *ecs.UpdateTaskProtectionOutput - expectedProtection *ecs.ProtectedTask - expectedFailure *ecs.Failure - expectedError *types.ErrorResponse - expectedRequestId *string - expectedStatusCode int - time time.Time - }{ - { - name: "RequestFailure_ServerException", - ecsError: awserr.NewRequestFailure(awserr.New(ecs.ErrCodeServerException, "error message", nil), http.StatusInternalServerError, testRequestID), - ecsResponse: &ecs.UpdateTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeServerException, Message: "error message"}, - expectedStatusCode: http.StatusInternalServerError, - expectedRequestId: generateRequestIdPtr(), - }, - { - name: "RequestFailure_OtherExceptions", - ecsError: awserr.NewRequestFailure(awserr.New(ecs.ErrCodeAccessDeniedException, "error message", nil), http.StatusBadRequest, testRequestID), - ecsResponse: &ecs.UpdateTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeAccessDeniedException, Message: "error message"}, - expectedStatusCode: http.StatusBadRequest, - expectedRequestId: generateRequestIdPtr(), - }, - { - name: "NonRequestFailureAwsError", - ecsError: awserr.New(ecs.ErrCodeInvalidParameterException, "error message", nil), - ecsResponse: &ecs.UpdateTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeInvalidParameterException, Message: "error message"}, - expectedStatusCode: http.StatusInternalServerError, - }, - { - name: "Agent timeout", - ecsError: awserr.New(request.CanceledErrorCode, "request cancelled", nil), - ecsResponse: &ecs.UpdateTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{ - Arn: testTaskArn, - Code: request.CanceledErrorCode, - Message: ecsCallTimedOutError, - }, - expectedStatusCode: http.StatusGatewayTimeout, - }, - { - name: "NonAwsError", - ecsError: fmt.Errorf("error message"), - ecsResponse: &ecs.UpdateTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeServerException, Message: "error message"}, - expectedStatusCode: http.StatusInternalServerError, - }, - { - name: "Failure", - ecsError: nil, - ecsResponse: &ecs.UpdateTaskProtectionOutput{ - Failures: []*ecs.Failure{{ - Arn: aws.String(testTaskArn), - Reason: aws.String(testFailureReason), - }}, - ProtectedTasks: []*ecs.ProtectedTask{}, - }, - expectedFailure: &ecs.Failure{ - Arn: aws.String(testTaskArn), - Reason: aws.String(testFailureReason), - }, - expectedStatusCode: http.StatusOK, - }, - { - name: "SuccessProtected", - ecsError: nil, - ecsResponse: &ecs.UpdateTaskProtectionOutput{ - Failures: []*ecs.Failure{}, - ProtectedTasks: []*ecs.ProtectedTask{{ - ProtectionEnabled: aws.Bool(true), - ExpirationDate: aws.Time(time.UnixMilli(0)), - TaskArn: aws.String(testTaskArn), - }}, - }, - expectedProtection: &ecs.ProtectedTask{ - ProtectionEnabled: aws.Bool(true), - ExpirationDate: aws.Time(time.UnixMilli(0)), - TaskArn: aws.String(testTaskArn), - }, - expectedStatusCode: http.StatusOK, - }, - { - name: "SuccessNotProtected", - ecsError: nil, - ecsResponse: &ecs.UpdateTaskProtectionOutput{ - Failures: []*ecs.Failure{}, - ProtectedTasks: []*ecs.ProtectedTask{{ - ProtectionEnabled: aws.Bool(false), - ExpirationDate: nil, - TaskArn: aws.String(testTaskArn), - }}, - }, - expectedProtection: &ecs.ProtectedTask{ - ProtectionEnabled: aws.Bool(false), - ExpirationDate: nil, - TaskArn: aws.String(testTaskArn), - }, - expectedStatusCode: http.StatusOK, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - request := TaskProtectionRequest{ - ProtectionEnabled: utils.BoolPtr(testProtectionEnabled), - ExpiresInMinutes: utils.Int64Ptr(testExpiresInMinutes), - } - - testTask := task.Task{ - Arn: testTaskArn, - ServiceName: testServiceName, - } - testTask.SetCredentialsID(testTaskCredentialsId) - - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockManager := mock_credentials.NewMockManager(ctrl) - mockFactory := tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl) - mockECSClient := mock_api.NewMockECSTaskProtectionSDK(ctrl) - - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) - mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true) - mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, true) - mockFactory.EXPECT().NewTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient) - mockECSClient.EXPECT(). - UpdateTaskProtectionWithContext(gomock.Any(), gomock.Any()). - Return(tc.ecsResponse, tc.ecsError) - - expectedResponse := types.TaskProtectionResponse{ - Protection: tc.expectedProtection, - Failure: tc.expectedFailure, - Error: tc.expectedError, - RequestID: tc.expectedRequestId, - } - - testUpdateTaskProtectionHandler(t, mockState, testV3EndpointId, mockManager, mockFactory, request, expectedResponse, tc.expectedStatusCode) - }) - } -} - -func testGetTaskProtectionHandler(t *testing.T, state dockerstate.TaskEngineState, - v3EndpointID string, credentialsManager credentials.Manager, - factory tpinterface.TaskProtectionClientFactoryInterface, - expectedResponse interface{}, expectedResponseCode int) { - // Prepare request - bodyReader := bytes.NewReader([]byte{}) - req, err := http.NewRequest("GET", "", bodyReader) - assert.NoError(t, err) - req = mux.SetURLVars(req, map[string]string{v3.V3EndpointIDMuxName: v3EndpointID}) - - // Call handler - rr := httptest.NewRecorder() - handler := http.HandlerFunc(GetTaskProtectionHandler(state, credentialsManager, factory, testCluster)) - handler.ServeHTTP(rr, req) - - expectedResponseJSON, err := json.Marshal(expectedResponse) - assert.NoError(t, err, "Expected response must be JSON encodable") - - // Assert response - assert.Equal(t, expectedResponseCode, rr.Code) - responseBody, err := io.ReadAll(rr.Body) - assert.NoError(t, err, "Failed to read response body") - assert.Equal(t, string(expectedResponseJSON), string(responseBody)) -} - -// TestGetTaskProtectionHandlerTaskARNNotFound tests GetTaskProtection handler's -// behavior when task ARN was not found for the request. -func TestGetTaskProtectionHandlerTaskARNNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return("", false) - - expectedResponse := types.TaskProtectionResponse{ - Error: &types.ErrorResponse{ - Code: ecs.ErrCodeResourceNotFoundException, - Message: "Failed to find a task for the request", - }, - } - testGetTaskProtectionHandler(t, mockState, testV3EndpointId, nil, nil, - expectedResponse, http.StatusNotFound) -} - -// TestGetTaskProtectionHandlerTaskNotFound tests GetTaskProtection handler's -// behavior when task ARN was not found for the request. -func TestGetTaskProtectionHandlerTaskNotFound(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) - mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(nil, false) - - expectedResponse := types.TaskProtectionResponse{ - Error: &types.ErrorResponse{ - Code: ecs.ErrCodeServerException, - Message: "Failed to find a task for the request", - }, - } - - testGetTaskProtectionHandler(t, mockState, testV3EndpointId, nil, nil, - expectedResponse, http.StatusInternalServerError) -} - -// TestGetTaskProtectionHandlerTaskRoleCredentialsNotFound tests GetTaskProtection handler's -// behavior when task IAM role credential is not found for the request. -func TestGetTaskProtectionHandlerTaskRoleCredentialsNotFound(t *testing.T) { - testTask := task.Task{ - Arn: testTaskArn, - ServiceName: testServiceName, - } - testTask.SetCredentialsID(testTaskCredentialsId) - - factory := tpfactory.TaskProtectionClientFactory{ - Region: testRegion, Endpoint: testECSEndpoint, AcceptInsecureCert: testAcceptInsecureCert, - } - - ctrl := gomock.NewController(t) - defer ctrl.Finish() - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockManager := mock_credentials.NewMockManager(ctrl) - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) - mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true) - mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, false) - - expectedResponse := types.TaskProtectionResponse{ - Error: &types.ErrorResponse{ - Arn: testTaskArn, - Code: ecs.ErrCodeAccessDeniedException, - Message: "Invalid Request: no task IAM role credentials available for task", - }, - } - - testGetTaskProtectionHandler(t, mockState, testV3EndpointId, mockManager, factory, - expectedResponse, http.StatusForbidden) -} - -// TestGetTaskProtectionHandler_PostCall tests GetTaskProtection handler's -// behavior when request successfully reached ECS and get response -func TestGetTaskProtectionHandler_PostCall(t *testing.T) { - testCases := []struct { - name string - ecsError error - ecsResponse *ecs.GetTaskProtectionOutput - expectedProtection *ecs.ProtectedTask - expectedFailure *ecs.Failure - expectedError *types.ErrorResponse - expectedRequestId *string - expectedStatusCode int - time time.Time - }{ - { - name: "RequestFailure_ServerException", - ecsError: awserr.NewRequestFailure(awserr.New(ecs.ErrCodeServerException, "error message", nil), http.StatusInternalServerError, testRequestID), - ecsResponse: &ecs.GetTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeServerException, Message: "error message"}, - expectedStatusCode: http.StatusInternalServerError, - expectedRequestId: generateRequestIdPtr(), - }, - { - name: "RequestFailure_OtherExceptions", - ecsError: awserr.NewRequestFailure(awserr.New(ecs.ErrCodeAccessDeniedException, "error message", nil), http.StatusBadRequest, testRequestID), - ecsResponse: &ecs.GetTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeAccessDeniedException, Message: "error message"}, - expectedStatusCode: http.StatusBadRequest, - expectedRequestId: generateRequestIdPtr(), - }, - { - name: "NonRequestFailureAwsError", - ecsError: awserr.New(ecs.ErrCodeInvalidParameterException, "error message", nil), - ecsResponse: &ecs.GetTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeInvalidParameterException, Message: "error message"}, - expectedStatusCode: http.StatusInternalServerError, - }, - { - name: "Agent timeout", - ecsError: awserr.New(request.CanceledErrorCode, "request cancelled", nil), - ecsResponse: &ecs.GetTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{ - Arn: testTaskArn, - Code: request.CanceledErrorCode, - Message: ecsCallTimedOutError, - }, - expectedStatusCode: http.StatusGatewayTimeout, - }, - { - name: "NonAwsError", - ecsError: fmt.Errorf("error message"), - ecsResponse: &ecs.GetTaskProtectionOutput{}, - expectedError: &types.ErrorResponse{Arn: testTaskArn, Code: ecs.ErrCodeServerException, Message: "error message"}, - expectedStatusCode: http.StatusInternalServerError, - }, - { - name: "Failure", - ecsError: nil, - ecsResponse: &ecs.GetTaskProtectionOutput{ - Failures: []*ecs.Failure{{ - Arn: aws.String(testTaskArn), - Reason: aws.String(testFailureReason), - }}, - ProtectedTasks: []*ecs.ProtectedTask{}, - }, - expectedFailure: &ecs.Failure{ - Arn: aws.String(testTaskArn), - Reason: aws.String(testFailureReason), - }, - expectedStatusCode: http.StatusOK, - }, - { - name: "SuccessProtected", - ecsError: nil, - ecsResponse: &ecs.GetTaskProtectionOutput{ - Failures: []*ecs.Failure{}, - ProtectedTasks: []*ecs.ProtectedTask{{ - ProtectionEnabled: aws.Bool(true), - ExpirationDate: aws.Time(time.UnixMilli(0)), - TaskArn: aws.String(testTaskArn), - }}, - }, - expectedProtection: &ecs.ProtectedTask{ - ProtectionEnabled: aws.Bool(true), - ExpirationDate: aws.Time(time.UnixMilli(0)), - TaskArn: aws.String(testTaskArn), - }, - expectedStatusCode: http.StatusOK, - }, - { - name: "SuccessNotProtected", - ecsError: nil, - ecsResponse: &ecs.GetTaskProtectionOutput{ - Failures: []*ecs.Failure{}, - ProtectedTasks: []*ecs.ProtectedTask{{ - ProtectionEnabled: aws.Bool(false), - ExpirationDate: nil, - TaskArn: aws.String(testTaskArn), - }}, - }, - expectedProtection: &ecs.ProtectedTask{ - ProtectionEnabled: aws.Bool(false), - ExpirationDate: nil, - TaskArn: aws.String(testTaskArn), - }, - expectedStatusCode: http.StatusOK, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - ctrl := gomock.NewController(t) - defer ctrl.Finish() - - testTask := task.Task{ - Arn: testTaskArn, - ServiceName: testServiceName, - } - testTask.SetCredentialsID(testTaskCredentialsId) - - mockState := mock_dockerstate.NewMockTaskEngineState(ctrl) - mockManager := mock_credentials.NewMockManager(ctrl) - mockFactory := tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl) - mockECSClient := mock_api.NewMockECSTaskProtectionSDK(ctrl) - - mockState.EXPECT().TaskARNByV3EndpointID(gomock.Eq(testV3EndpointId)).Return(testTaskArn, true) - mockState.EXPECT().TaskByArn(gomock.Eq(testTaskArn)).Return(&testTask, true) - mockManager.EXPECT().GetTaskCredentials(gomock.Eq(testTaskCredentialsId)).Return(credentials.TaskIAMRoleCredentials{}, true) - mockFactory.EXPECT().NewTaskProtectionClient(gomock.Eq(credentials.TaskIAMRoleCredentials{})).Return(mockECSClient) - mockECSClient.EXPECT(). - GetTaskProtectionWithContext(gomock.Any(), gomock.Any()). - Return(tc.ecsResponse, tc.ecsError) - - expectedResponse := types.TaskProtectionResponse{ - Protection: tc.expectedProtection, - Failure: tc.expectedFailure, - Error: tc.expectedError, - RequestID: tc.expectedRequestId, - } - - testGetTaskProtectionHandler(t, mockState, testV3EndpointId, mockManager, mockFactory, expectedResponse, tc.expectedStatusCode) - }) - } -} diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index b17173f4b1b..35fe453fb20 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -22,7 +22,6 @@ import ( "github.com/aws/amazon-ecs-agent/agent/config" "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate" tpfactory "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection" - tphandlers "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" v4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4" @@ -32,8 +31,7 @@ import ( auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" - tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" - + tphandlers "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" tmdsv1 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1" tmdsv2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" @@ -50,9 +48,13 @@ const ( // writeTimeout specifies the maximum duration before timing out write of the response. // The value is set to 5 seconds as per AWS SDK defaults. writeTimeout = 5 * time.Second + + // Timeout for ECS calls. Must be lower than server write timeout defined above. + ecsCallTimeout = 4 * time.Second ) -func taskServerSetup(credentialsManager credentials.Manager, +func taskServerSetup( + credentialsManager credentials.Manager, auditLogger auditinterface.AuditLogger, state dockerstate.TaskEngineState, ecsClient api.ECSClient, @@ -63,7 +65,7 @@ func taskServerSetup(credentialsManager credentials.Manager, availabilityZone string, vpcID string, containerInstanceArn string, - taskProtectionClientFactory tpinterface.TaskProtectionClientFactoryInterface, + taskProtectionClientFactory tphandlers.TaskProtectionClientFactoryInterface, ) (*http.Server, error) { muxRouter := mux.NewRouter() @@ -75,13 +77,18 @@ func taskServerSetup(credentialsManager credentials.Manager, muxRouter.HandleFunc(tmdsv1.CredentialsPath, tmdsv1.CredentialsHandler(credentialsManager, auditLogger)) + tmdsAgentState := v4.NewTMDSAgentState(state, ecsClient, cluster, availabilityZone, vpcID, containerInstanceArn) + metricsFactory := metrics.NewNopEntryFactory() + v2HandlersSetup(muxRouter, state, ecsClient, statsEngine, cluster, credentialsManager, auditLogger, availabilityZone, containerInstanceArn) v3HandlersSetup(muxRouter, state, ecsClient, statsEngine, cluster, availabilityZone, containerInstanceArn) - v4HandlersSetup(muxRouter, state, ecsClient, statsEngine, cluster, availabilityZone, vpcID, containerInstanceArn) + v4HandlersSetup(muxRouter, state, ecsClient, statsEngine, cluster, availabilityZone, vpcID, containerInstanceArn, + tmdsAgentState, metricsFactory) - agentAPIV1HandlersSetup(muxRouter, state, credentialsManager, cluster, taskProtectionClientFactory) + agentAPIV1HandlersSetup(muxRouter, state, credentialsManager, cluster, tmdsAgentState, + taskProtectionClientFactory, metricsFactory) return tmds.NewServer(auditLogger, tmds.WithHandler(muxRouter), @@ -140,9 +147,9 @@ func v4HandlersSetup(muxRouter *mux.Router, availabilityZone string, vpcID string, containerInstanceArn string, + tmdsAgentState *v4.TMDSAgentState, + metricsFactory metrics.EntryFactory, ) { - tmdsAgentState := v4.NewTMDSAgentState(state, ecsClient, cluster, availabilityZone, vpcID, containerInstanceArn) - metricsFactory := metrics.NewNopEntryFactory() muxRouter.HandleFunc(tmdsv4.ContainerMetadataPath(), tmdsv4.ContainerMetadataHandler(tmdsAgentState, metricsFactory)) muxRouter.HandleFunc(tmdsv4.TaskMetadataPath(), tmdsv4.TaskMetadataHandler(tmdsAgentState, metricsFactory)) muxRouter.HandleFunc(tmdsv4.TaskMetadataWithTagsPath(), tmdsv4.TaskMetadataWithTagsHandler(tmdsAgentState, metricsFactory)) @@ -159,17 +166,21 @@ func agentAPIV1HandlersSetup( state dockerstate.TaskEngineState, credentialsManager credentials.Manager, cluster string, - factory tpinterface.TaskProtectionClientFactoryInterface, + agentState *v4.TMDSAgentState, + factory tphandlers.TaskProtectionClientFactoryInterface, + metricsFactory metrics.EntryFactory, ) { muxRouter. HandleFunc( tphandlers.TaskProtectionPath(), - tphandlers.UpdateTaskProtectionHandler(state, credentialsManager, factory, cluster)). + tphandlers.UpdateTaskProtectionHandler(agentState, credentialsManager, + factory, cluster, metricsFactory, ecsCallTimeout)). Methods("PUT") muxRouter. HandleFunc( tphandlers.TaskProtectionPath(), - tphandlers.GetTaskProtectionHandler(state, credentialsManager, factory, cluster)). + tphandlers.GetTaskProtectionHandler(agentState, credentialsManager, + factory, cluster, metricsFactory, ecsCallTimeout)). Methods("GET") } diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index 7542c2161bc..a510affa58c 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -36,7 +36,6 @@ import ( apitaskstatus "github.com/aws/amazon-ecs-agent/agent/api/task/status" "github.com/aws/amazon-ecs-agent/agent/config" mock_dockerstate "github.com/aws/amazon-ecs-agent/agent/engine/dockerstate/mocks" - agentapihandlers "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers" v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3" v4stats "github.com/aws/amazon-ecs-agent/agent/handlers/v4" "github.com/aws/amazon-ecs-agent/agent/stats" @@ -3002,7 +3001,10 @@ func TestGetTaskProtection(t *testing.T) { happyStateExpectations := func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2), + state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), + state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) } happyCredentialsManagerExpectations := func(credsManager *mock_credentials.MockManager) { @@ -3261,7 +3263,7 @@ func TestUpdateTaskProtection(t *testing.T) { } ecsRequestID := "reqid" ecsErrMessage := "ecs error message" - happyReqBody := &agentapihandlers.TaskProtectionRequest{ + happyReqBody := &tpinterface.TaskProtectionRequest{ ProtectionEnabled: protectionEnabled, ExpiresInMinutes: expirationMinutes, } @@ -3270,7 +3272,10 @@ func TestUpdateTaskProtection(t *testing.T) { happyStateExpectations := func(state *mock_dockerstate.MockTaskEngineState) { gomock.InOrder( state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true), + state.EXPECT().TaskByArn(taskARN).Return(task, true).Times(2), + state.EXPECT().ContainerMapByArn(taskARN).Return(containerNameToDockerContainer, true), state.EXPECT().TaskByArn(taskARN).Return(task, true), + state.EXPECT().PulledContainerMapByArn(taskARN).Return(nil, true), ) } happyCredentialsManagerExpectations := func(credsManager *mock_credentials.MockManager) { diff --git a/agent/handlers/v4/tmdsstate.go b/agent/handlers/v4/tmdsstate.go index 635615be4fe..0dc72dac124 100644 --- a/agent/handlers/v4/tmdsstate.go +++ b/agent/handlers/v4/tmdsstate.go @@ -119,6 +119,8 @@ func (s *TMDSAgentState) getTaskMetadata(v3EndpointID string, includeTags bool) "Unable to generate metadata for v4 task: '%s'", taskARN)) } + taskResponse.CredentialsID = task.GetCredentialsID() + // for non-awsvpc task mode if !task.IsNetworkModeAWSVPC() { // fill in non-awsvpc network details for container responses here diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/logger/field/constants.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/logger/field/constants.go index 36cbe6f3053..d9f97940395 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/logger/field/constants.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/logger/field/constants.go @@ -50,4 +50,5 @@ const ( ContainerExitCode = "containerExitCode" TMDSEndpointContainerID = "tmdsEndpointContainerID" MessageID = "messageID" + RequestType = "requestType" ) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go new file mode 100644 index 00000000000..f19df4fc39b --- /dev/null +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go @@ -0,0 +1,403 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package handlers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" + v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/gorilla/mux" +) + +const ( + ExpectedProtectionResponseLength = 1 + ecsCallTimedOutError = "Timed out calling ECS Task Protection API" + taskMetadataFetchFailureMsg = "Failed to find a task for the request" +) + +// TaskProtectionPath Returns endpoint path for UpdateTaskProtection API +func TaskProtectionPath() string { + return fmt.Sprintf( + "/api/%s/task-protection/v1/state", + utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx)) +} + +// TaskProtectionRequest is the Task protection request received from customers pending validation +type TaskProtectionRequest struct { + ProtectionEnabled *bool + ExpiresInMinutes *int64 +} + +// GetTaskProtectionHandler returns a handler function for GetTaskProtection API +func GetTaskProtectionHandler( + agentState state.AgentState, + credentialsManager credentials.Manager, + factory TaskProtectionClientFactoryInterface, + cluster string, + metricsFactory metrics.EntryFactory, + ecsCallTimeout time.Duration, +) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + requestType := "api/GetTaskProtection/v1" + + // Initialize metrics + successMetric := metricsFactory.New(metrics.GetTaskProtectionMetricName) + + // Find task metadata + task, errResponseCode, errResponseBody := getTaskMetadata(r, agentState, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + if utils.Is5XXStatus(errResponseCode) { + successMetric.WithCount(0).Done(nil)() + } + return + } + logger.Info("GetTaskProtection endpoint was called", logger.Fields{ + field.Cluster: cluster, + field.TaskARN: task.TaskARN, + }) + + // Find task role creds + taskCreds, errResponseCode, errResponseBody := getTaskCredentials(credentialsManager, *task) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Call ECS TaskProtection API + ecsClient := factory.NewTaskProtectionClient(*taskCreds) + ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) + defer cancel() + responseBody, err := ecsClient.GetTaskProtectionWithContext(ctx, &ecs.GetTaskProtectionInput{ + Cluster: aws.String(cluster), + Tasks: aws.StringSlice([]string{task.TaskARN}), + }) + if err != nil { + errResponseCode, errResponseBody := logAndHandleECSError(err, *task, requestType) + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Validate ECS response + errResponseCode, errResponseBody = logAndValidateECSResponse( + responseBody.ProtectedTasks, responseBody.Failures, *task, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // ECS call was successful + utils.WriteJSONResponse(w, http.StatusOK, + types.NewTaskProtectionResponseProtection(responseBody.ProtectedTasks[0]), requestType) + successMetric.WithCount(1).Done(nil)() + } +} + +// UpdateTaskProtectionHandler returns an HTTP request handler function for UpdateTaskProtection API +func UpdateTaskProtectionHandler( + agentState state.AgentState, + credentialsManager credentials.Manager, + factory TaskProtectionClientFactoryInterface, + cluster string, + metricsFactory metrics.EntryFactory, + ecsCallTimeout time.Duration, +) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + requestType := "api/UpdateTaskProtection/v1" + + // Decode the request + var request TaskProtectionRequest + jsonDecoder := json.NewDecoder(r.Body) + jsonDecoder.DisallowUnknownFields() + if err := jsonDecoder.Decode(&request); err != nil { + logger.Error("UpdateTaskProtection: failed to decode request", logger.Fields{ + field.Error: err, + }) + utils.WriteJSONResponse(w, http.StatusBadRequest, + types.NewTaskProtectionResponseError(types.NewErrorResponsePtr( + "", + ecs.ErrCodeInvalidParameterException, + "UpdateTaskProtection: failed to decode request", + ), nil), + requestType) + return + } + + // Initialize metrics + successMetric := metricsFactory.New(metrics.UpdateTaskProtectionMetricName) + + // Find task metadata + task, errResponseCode, errResponseBody := getTaskMetadata(r, agentState, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + if utils.Is5XXStatus(errResponseCode) { + successMetric.WithCount(0).Done(nil)() + } + return + } + logger.Info("GetTaskProtection endpoint was called", logger.Fields{ + field.Cluster: cluster, + field.TaskARN: task.TaskARN, + }) + + // Validate the request + if request.ProtectionEnabled == nil { + responseErr := types.NewErrorResponsePtr(task.TaskARN, ecs.ErrCodeInvalidParameterException, + "Invalid request: does not contain 'ProtectionEnabled' field") + response := types.NewTaskProtectionResponseError(responseErr, nil) + utils.WriteJSONResponse(w, http.StatusBadRequest, response, requestType) + return + } + + // Prepare ECS request body + taskProtection := types.NewTaskProtection(*request.ProtectionEnabled, request.ExpiresInMinutes) + logger.Info("UpdateTaskProtection endpoint was called", logger.Fields{ + field.Cluster: cluster, + field.TaskARN: task.TaskARN, + field.TaskProtection: taskProtection, + field.RequestType: requestType, + }) + + // Find task role creds + taskCreds, errResponseCode, errResponseBody := getTaskCredentials(credentialsManager, *task) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Call ECS TaskProtection API + ecsClient := factory.NewTaskProtectionClient(*taskCreds) + ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) + defer cancel() + response, err := ecsClient.UpdateTaskProtectionWithContext(ctx, &ecs.UpdateTaskProtectionInput{ + Cluster: aws.String(cluster), + ExpiresInMinutes: taskProtection.GetExpiresInMinutes(), + ProtectionEnabled: aws.Bool(taskProtection.GetProtectionEnabled()), + Tasks: aws.StringSlice([]string{task.TaskARN}), + }) + if err != nil { + errResponseCode, errResponseBody := logAndHandleECSError(err, *task, requestType) + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Validate ECS response + errResponseCode, errResponseBody = logAndValidateECSResponse( + response.ProtectedTasks, response.Failures, *task, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // ECS call was successful + utils.WriteJSONResponse(w, http.StatusOK, + types.NewTaskProtectionResponseProtection(response.ProtectedTasks[0]), requestType) + successMetric.WithCount(1).Done(nil)() + } +} + +// Helper function for retrieving task metadata for the request +func getTaskMetadata( + r *http.Request, + agentState state.AgentState, + requestType string, +) (*state.TaskResponse, int, *types.TaskProtectionResponse) { + endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] + task, err := agentState.GetTaskMetadata(endpointContainerID) + if err != nil { + logger.Error("Failed to get v4 task metadata", logger.Fields{ + field.TMDSEndpointContainerID: endpointContainerID, + field.Error: err, + field.RequestType: requestType, + }) + + responseCode, responseBody := getTaskMetadataErrorResponse( + endpointContainerID, err, requestType) + return nil, responseCode, &responseBody + } + + return &task, 0, nil +} + +// Helper function for retrieving task role credentials +func getTaskCredentials( + credentialsManager credentials.Manager, + task state.TaskResponse, +) (*credentials.TaskIAMRoleCredentials, int, *types.TaskProtectionResponse) { + taskRoleCredential, ok := credentialsManager.GetTaskCredentials(task.CredentialsID) + if !ok { + errMsg := "Invalid Request: no task IAM role credentials available for task" + logger.Error(errMsg, logger.Fields{field.TaskARN: task.TaskARN}) + responseErr := types.NewErrorResponsePtr(task.TaskARN, ecs.ErrCodeAccessDeniedException, errMsg) + response := types.NewTaskProtectionResponseError(responseErr, nil) + return nil, http.StatusForbidden, &response + } + + return &taskRoleCredential, 0, nil +} + +// Helper function for logging and handling error that occurred when calling ECS TaskProtection API +func logAndHandleECSError( + err error, + task state.TaskResponse, + requestType string, +) (int, types.TaskProtectionResponse) { + errorCode, errorMsg, statusCode, reqId := getErrorCodeAndStatusCode(err) + var requestIdString = "" + if reqId != nil { + requestIdString = *reqId + } + + logger.Error("Got an exception when calling TaskProtection API", logger.Fields{ + field.Error: err, + "ErrorCode": errorCode, + "ExceptionMessage": errorMsg, + "StatusCode": statusCode, + "RequestId": requestIdString, + field.RequestType: requestType, + }) + + responseErr := types.NewErrorResponsePtr(task.TaskARN, errorCode, errorMsg) + response := types.NewTaskProtectionResponseError(responseErr, reqId) + + return statusCode, response +} + +// Helper function for logging and validating ECS TaskProtection API response +func logAndValidateECSResponse( + protectedTasks []*ecs.ProtectedTask, + failures []*ecs.Failure, + task state.TaskResponse, + requestType string, +) (int, *types.TaskProtectionResponse) { + logger.Debug("getTaskProtection response:", logger.Fields{ + field.TaskProtection: protectedTasks, + field.Reason: failures, + }) + + if len(failures) > 0 { + if len(failures) > ExpectedProtectionResponseLength { + err := fmt.Errorf( + "expect at most %v failure in response, get %v", + ExpectedProtectionResponseLength, len(failures)) + logger.Error("Unexpected number of failures", logger.Fields{ + field.Error: err, + field.TaskARN: task.TaskARN, + field.RequestType: requestType, + }) + responseErr := types.NewErrorResponsePtr( + task.TaskARN, ecs.ErrCodeServerException, "Unexpected error occurred") + response := types.NewTaskProtectionResponseError(responseErr, nil) + return http.StatusInternalServerError, &response + } + + response := types.NewTaskProtectionResponseFailure(failures[0]) + return http.StatusOK, &response + } + + if len(protectedTasks) > ExpectedProtectionResponseLength { + err := fmt.Errorf( + "expect %v protectedTask in response when no failure, get %v", + ExpectedProtectionResponseLength, len(protectedTasks)) + logger.Error("Unexpected number of protections", logger.Fields{ + field.Error: err, + field.TaskARN: task.TaskARN, + field.RequestType: requestType, + }) + + responseErr := types.NewErrorResponsePtr( + task.TaskARN, ecs.ErrCodeServerException, "Unexpected error occurred") + response := types.NewTaskProtectionResponseError(responseErr, nil) + return http.StatusInternalServerError, &response + } + + return 0, nil +} + +// Returns an appropriate HTTP response status code and body for the task metadata fetch error. +func getTaskMetadataErrorResponse( + endpointContainerID string, + err error, + requestType string, +) (int, types.TaskProtectionResponse) { + var errContainerLookupFailed *state.ErrorLookupFailure + if errors.As(err, &errContainerLookupFailed) { + responseErr := types.NewErrorResponsePtr( + "", ecs.ErrCodeResourceNotFoundException, taskMetadataFetchFailureMsg) + return http.StatusNotFound, types.NewTaskProtectionResponseError(responseErr, nil) + } + + var errFailedToGetContainerMetadata *state.ErrorMetadataFetchFailure + if errors.As(err, &errFailedToGetContainerMetadata) { + responseErr := types.NewErrorResponsePtr( + "", ecs.ErrCodeServerException, taskMetadataFetchFailureMsg) + return http.StatusInternalServerError, types.NewTaskProtectionResponseError(responseErr, nil) + } + + logger.Error("Unknown error encountered when handling task metadata fetch failure", logger.Fields{ + field.Error: err, + field.RequestType: requestType, + }) + + responseErr := types.NewErrorResponsePtr("", ecs.ErrCodeServerException, taskMetadataFetchFailureMsg) + return http.StatusInternalServerError, types.NewTaskProtectionResponseError(responseErr, nil) +} + +// Helper function to parse error to get ErrorCode, ExceptionMessage, HttpStatusCode, RequestID. +// RequestID will be empty if the request is not able to reach AWS +func getErrorCodeAndStatusCode(err error) (string, string, int, *string) { + msg := err.Error() + // The error is a Generic AWS Error with Code, Message, and original error (if any) + if awsErr, ok := err.(awserr.Error); ok { + // The error is an AWS service error occurred + msg = awsErr.Message() + if reqErr, ok := err.(awserr.RequestFailure); ok { + reqId := reqErr.RequestID() + return awsErr.Code(), msg, reqErr.StatusCode(), &reqId + } else if aerr, ok := err.(awserr.Error); ok && aerr.Code() == request.CanceledErrorCode { + return aerr.Code(), ecsCallTimedOutError, http.StatusGatewayTimeout, nil + } else { + logger.Error(fmt.Sprintf( + "got an exception that does not implement RequestFailure interface but is an aws error. This should not happen, return statusCode 500 for whatever errorCode. Original err: %v.", + err)) + return awsErr.Code(), msg, http.StatusInternalServerError, nil + } + } else { + logger.Error(fmt.Sprintf("non aws error received: %v", err)) + return ecs.ErrCodeServerException, msg, http.StatusInternalServerError, nil + } +} diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go index e0ceb62d1d1..2ff5a661271 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/response.go @@ -33,6 +33,7 @@ type TaskResponse struct { ServiceName string `json:"ServiceName,omitempty"` ClockDrift *ClockDrift `json:"ClockDrift,omitempty"` EphemeralStorageMetrics *EphemeralStorageMetrics `json:"EphemeralStorageMetrics,omitempty"` + CredentialsID string `json:"-"` } // Instance's clock drift status diff --git a/ecs-agent/logger/field/constants.go b/ecs-agent/logger/field/constants.go index 36cbe6f3053..d9f97940395 100644 --- a/ecs-agent/logger/field/constants.go +++ b/ecs-agent/logger/field/constants.go @@ -50,4 +50,5 @@ const ( ContainerExitCode = "containerExitCode" TMDSEndpointContainerID = "tmdsEndpointContainerID" MessageID = "messageID" + RequestType = "requestType" ) diff --git a/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go new file mode 100644 index 00000000000..f19df4fc39b --- /dev/null +++ b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go @@ -0,0 +1,403 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package handlers + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "time" + + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger" + "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" + v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/gorilla/mux" +) + +const ( + ExpectedProtectionResponseLength = 1 + ecsCallTimedOutError = "Timed out calling ECS Task Protection API" + taskMetadataFetchFailureMsg = "Failed to find a task for the request" +) + +// TaskProtectionPath Returns endpoint path for UpdateTaskProtection API +func TaskProtectionPath() string { + return fmt.Sprintf( + "/api/%s/task-protection/v1/state", + utils.ConstructMuxVar(v4.EndpointContainerIDMuxName, utils.AnythingButSlashRegEx)) +} + +// TaskProtectionRequest is the Task protection request received from customers pending validation +type TaskProtectionRequest struct { + ProtectionEnabled *bool + ExpiresInMinutes *int64 +} + +// GetTaskProtectionHandler returns a handler function for GetTaskProtection API +func GetTaskProtectionHandler( + agentState state.AgentState, + credentialsManager credentials.Manager, + factory TaskProtectionClientFactoryInterface, + cluster string, + metricsFactory metrics.EntryFactory, + ecsCallTimeout time.Duration, +) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + requestType := "api/GetTaskProtection/v1" + + // Initialize metrics + successMetric := metricsFactory.New(metrics.GetTaskProtectionMetricName) + + // Find task metadata + task, errResponseCode, errResponseBody := getTaskMetadata(r, agentState, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + if utils.Is5XXStatus(errResponseCode) { + successMetric.WithCount(0).Done(nil)() + } + return + } + logger.Info("GetTaskProtection endpoint was called", logger.Fields{ + field.Cluster: cluster, + field.TaskARN: task.TaskARN, + }) + + // Find task role creds + taskCreds, errResponseCode, errResponseBody := getTaskCredentials(credentialsManager, *task) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Call ECS TaskProtection API + ecsClient := factory.NewTaskProtectionClient(*taskCreds) + ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) + defer cancel() + responseBody, err := ecsClient.GetTaskProtectionWithContext(ctx, &ecs.GetTaskProtectionInput{ + Cluster: aws.String(cluster), + Tasks: aws.StringSlice([]string{task.TaskARN}), + }) + if err != nil { + errResponseCode, errResponseBody := logAndHandleECSError(err, *task, requestType) + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Validate ECS response + errResponseCode, errResponseBody = logAndValidateECSResponse( + responseBody.ProtectedTasks, responseBody.Failures, *task, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // ECS call was successful + utils.WriteJSONResponse(w, http.StatusOK, + types.NewTaskProtectionResponseProtection(responseBody.ProtectedTasks[0]), requestType) + successMetric.WithCount(1).Done(nil)() + } +} + +// UpdateTaskProtectionHandler returns an HTTP request handler function for UpdateTaskProtection API +func UpdateTaskProtectionHandler( + agentState state.AgentState, + credentialsManager credentials.Manager, + factory TaskProtectionClientFactoryInterface, + cluster string, + metricsFactory metrics.EntryFactory, + ecsCallTimeout time.Duration, +) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + requestType := "api/UpdateTaskProtection/v1" + + // Decode the request + var request TaskProtectionRequest + jsonDecoder := json.NewDecoder(r.Body) + jsonDecoder.DisallowUnknownFields() + if err := jsonDecoder.Decode(&request); err != nil { + logger.Error("UpdateTaskProtection: failed to decode request", logger.Fields{ + field.Error: err, + }) + utils.WriteJSONResponse(w, http.StatusBadRequest, + types.NewTaskProtectionResponseError(types.NewErrorResponsePtr( + "", + ecs.ErrCodeInvalidParameterException, + "UpdateTaskProtection: failed to decode request", + ), nil), + requestType) + return + } + + // Initialize metrics + successMetric := metricsFactory.New(metrics.UpdateTaskProtectionMetricName) + + // Find task metadata + task, errResponseCode, errResponseBody := getTaskMetadata(r, agentState, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + if utils.Is5XXStatus(errResponseCode) { + successMetric.WithCount(0).Done(nil)() + } + return + } + logger.Info("GetTaskProtection endpoint was called", logger.Fields{ + field.Cluster: cluster, + field.TaskARN: task.TaskARN, + }) + + // Validate the request + if request.ProtectionEnabled == nil { + responseErr := types.NewErrorResponsePtr(task.TaskARN, ecs.ErrCodeInvalidParameterException, + "Invalid request: does not contain 'ProtectionEnabled' field") + response := types.NewTaskProtectionResponseError(responseErr, nil) + utils.WriteJSONResponse(w, http.StatusBadRequest, response, requestType) + return + } + + // Prepare ECS request body + taskProtection := types.NewTaskProtection(*request.ProtectionEnabled, request.ExpiresInMinutes) + logger.Info("UpdateTaskProtection endpoint was called", logger.Fields{ + field.Cluster: cluster, + field.TaskARN: task.TaskARN, + field.TaskProtection: taskProtection, + field.RequestType: requestType, + }) + + // Find task role creds + taskCreds, errResponseCode, errResponseBody := getTaskCredentials(credentialsManager, *task) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Call ECS TaskProtection API + ecsClient := factory.NewTaskProtectionClient(*taskCreds) + ctx, cancel := context.WithTimeout(r.Context(), ecsCallTimeout) + defer cancel() + response, err := ecsClient.UpdateTaskProtectionWithContext(ctx, &ecs.UpdateTaskProtectionInput{ + Cluster: aws.String(cluster), + ExpiresInMinutes: taskProtection.GetExpiresInMinutes(), + ProtectionEnabled: aws.Bool(taskProtection.GetProtectionEnabled()), + Tasks: aws.StringSlice([]string{task.TaskARN}), + }) + if err != nil { + errResponseCode, errResponseBody := logAndHandleECSError(err, *task, requestType) + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // Validate ECS response + errResponseCode, errResponseBody = logAndValidateECSResponse( + response.ProtectedTasks, response.Failures, *task, requestType) + if errResponseBody != nil { + utils.WriteJSONResponse(w, errResponseCode, errResponseBody, requestType) + successMetric.WithCount(0).Done(nil)() + return + } + + // ECS call was successful + utils.WriteJSONResponse(w, http.StatusOK, + types.NewTaskProtectionResponseProtection(response.ProtectedTasks[0]), requestType) + successMetric.WithCount(1).Done(nil)() + } +} + +// Helper function for retrieving task metadata for the request +func getTaskMetadata( + r *http.Request, + agentState state.AgentState, + requestType string, +) (*state.TaskResponse, int, *types.TaskProtectionResponse) { + endpointContainerID := mux.Vars(r)[v4.EndpointContainerIDMuxName] + task, err := agentState.GetTaskMetadata(endpointContainerID) + if err != nil { + logger.Error("Failed to get v4 task metadata", logger.Fields{ + field.TMDSEndpointContainerID: endpointContainerID, + field.Error: err, + field.RequestType: requestType, + }) + + responseCode, responseBody := getTaskMetadataErrorResponse( + endpointContainerID, err, requestType) + return nil, responseCode, &responseBody + } + + return &task, 0, nil +} + +// Helper function for retrieving task role credentials +func getTaskCredentials( + credentialsManager credentials.Manager, + task state.TaskResponse, +) (*credentials.TaskIAMRoleCredentials, int, *types.TaskProtectionResponse) { + taskRoleCredential, ok := credentialsManager.GetTaskCredentials(task.CredentialsID) + if !ok { + errMsg := "Invalid Request: no task IAM role credentials available for task" + logger.Error(errMsg, logger.Fields{field.TaskARN: task.TaskARN}) + responseErr := types.NewErrorResponsePtr(task.TaskARN, ecs.ErrCodeAccessDeniedException, errMsg) + response := types.NewTaskProtectionResponseError(responseErr, nil) + return nil, http.StatusForbidden, &response + } + + return &taskRoleCredential, 0, nil +} + +// Helper function for logging and handling error that occurred when calling ECS TaskProtection API +func logAndHandleECSError( + err error, + task state.TaskResponse, + requestType string, +) (int, types.TaskProtectionResponse) { + errorCode, errorMsg, statusCode, reqId := getErrorCodeAndStatusCode(err) + var requestIdString = "" + if reqId != nil { + requestIdString = *reqId + } + + logger.Error("Got an exception when calling TaskProtection API", logger.Fields{ + field.Error: err, + "ErrorCode": errorCode, + "ExceptionMessage": errorMsg, + "StatusCode": statusCode, + "RequestId": requestIdString, + field.RequestType: requestType, + }) + + responseErr := types.NewErrorResponsePtr(task.TaskARN, errorCode, errorMsg) + response := types.NewTaskProtectionResponseError(responseErr, reqId) + + return statusCode, response +} + +// Helper function for logging and validating ECS TaskProtection API response +func logAndValidateECSResponse( + protectedTasks []*ecs.ProtectedTask, + failures []*ecs.Failure, + task state.TaskResponse, + requestType string, +) (int, *types.TaskProtectionResponse) { + logger.Debug("getTaskProtection response:", logger.Fields{ + field.TaskProtection: protectedTasks, + field.Reason: failures, + }) + + if len(failures) > 0 { + if len(failures) > ExpectedProtectionResponseLength { + err := fmt.Errorf( + "expect at most %v failure in response, get %v", + ExpectedProtectionResponseLength, len(failures)) + logger.Error("Unexpected number of failures", logger.Fields{ + field.Error: err, + field.TaskARN: task.TaskARN, + field.RequestType: requestType, + }) + responseErr := types.NewErrorResponsePtr( + task.TaskARN, ecs.ErrCodeServerException, "Unexpected error occurred") + response := types.NewTaskProtectionResponseError(responseErr, nil) + return http.StatusInternalServerError, &response + } + + response := types.NewTaskProtectionResponseFailure(failures[0]) + return http.StatusOK, &response + } + + if len(protectedTasks) > ExpectedProtectionResponseLength { + err := fmt.Errorf( + "expect %v protectedTask in response when no failure, get %v", + ExpectedProtectionResponseLength, len(protectedTasks)) + logger.Error("Unexpected number of protections", logger.Fields{ + field.Error: err, + field.TaskARN: task.TaskARN, + field.RequestType: requestType, + }) + + responseErr := types.NewErrorResponsePtr( + task.TaskARN, ecs.ErrCodeServerException, "Unexpected error occurred") + response := types.NewTaskProtectionResponseError(responseErr, nil) + return http.StatusInternalServerError, &response + } + + return 0, nil +} + +// Returns an appropriate HTTP response status code and body for the task metadata fetch error. +func getTaskMetadataErrorResponse( + endpointContainerID string, + err error, + requestType string, +) (int, types.TaskProtectionResponse) { + var errContainerLookupFailed *state.ErrorLookupFailure + if errors.As(err, &errContainerLookupFailed) { + responseErr := types.NewErrorResponsePtr( + "", ecs.ErrCodeResourceNotFoundException, taskMetadataFetchFailureMsg) + return http.StatusNotFound, types.NewTaskProtectionResponseError(responseErr, nil) + } + + var errFailedToGetContainerMetadata *state.ErrorMetadataFetchFailure + if errors.As(err, &errFailedToGetContainerMetadata) { + responseErr := types.NewErrorResponsePtr( + "", ecs.ErrCodeServerException, taskMetadataFetchFailureMsg) + return http.StatusInternalServerError, types.NewTaskProtectionResponseError(responseErr, nil) + } + + logger.Error("Unknown error encountered when handling task metadata fetch failure", logger.Fields{ + field.Error: err, + field.RequestType: requestType, + }) + + responseErr := types.NewErrorResponsePtr("", ecs.ErrCodeServerException, taskMetadataFetchFailureMsg) + return http.StatusInternalServerError, types.NewTaskProtectionResponseError(responseErr, nil) +} + +// Helper function to parse error to get ErrorCode, ExceptionMessage, HttpStatusCode, RequestID. +// RequestID will be empty if the request is not able to reach AWS +func getErrorCodeAndStatusCode(err error) (string, string, int, *string) { + msg := err.Error() + // The error is a Generic AWS Error with Code, Message, and original error (if any) + if awsErr, ok := err.(awserr.Error); ok { + // The error is an AWS service error occurred + msg = awsErr.Message() + if reqErr, ok := err.(awserr.RequestFailure); ok { + reqId := reqErr.RequestID() + return awsErr.Code(), msg, reqErr.StatusCode(), &reqId + } else if aerr, ok := err.(awserr.Error); ok && aerr.Code() == request.CanceledErrorCode { + return aerr.Code(), ecsCallTimedOutError, http.StatusGatewayTimeout, nil + } else { + logger.Error(fmt.Sprintf( + "got an exception that does not implement RequestFailure interface but is an aws error. This should not happen, return statusCode 500 for whatever errorCode. Original err: %v.", + err)) + return awsErr.Code(), msg, http.StatusInternalServerError, nil + } + } else { + logger.Error(fmt.Sprintf("non aws error received: %v", err)) + return ecs.ErrCodeServerException, msg, http.StatusInternalServerError, nil + } +} diff --git a/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_test.go new file mode 100644 index 00000000000..ad1d3f899cb --- /dev/null +++ b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers_test.go @@ -0,0 +1,671 @@ +// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"). You may +// not use this file except in compliance with the License. A copy of the +// License is located at +// +// http://aws.amazon.com/apache2.0/ +// +// or in the "license" file accompanying this file. This file is distributed +// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either +// express or implied. See the License for the specific language governing +// permissions and limitations under the License. + +package handlers + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/http/httptest" + "testing" + "time" + + mock_api "github.com/aws/amazon-ecs-agent/ecs-agent/api/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/credentials" + mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + mock_metrics "github.com/aws/amazon-ecs-agent/ecs-agent/metrics/mocks" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" + v2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state" + mock_state "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4/state/mocks" + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/aws/request" + + "github.com/golang/mock/gomock" + "github.com/gorilla/mux" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const ( + cluster = "cluster" + endpointId = "endpointId" + ecsCallTimeout = 5 * time.Second + taskARN = "taskARN" + taskRoleCredsID = "taskRoleCredsID" +) + +// Tests the path for UpdateTaskProtection API +func TestTaskProtectionPath(t *testing.T) { + assert.Equal(t, "/api/{endpointContainerIDMuxName:[^/]*}/task-protection/v1/state", TaskProtectionPath()) +} + +type TestCase struct { + requestBody interface{} // Required for UpdateTaskProtection + setAgentStateExpectations func(agentState *mock_state.MockAgentState) + setCredsManagerExpectations func(credsManager *mock_credentials.MockManager) + setFactoryExpectations func(ctrl *gomock.Controller, factory *MockTaskProtectionClientFactoryInterface) + setMetricsExpectations func(ctrl *gomock.Controller, metricsFactory *mock_metrics.MockEntryFactory) + expectedStatusCode int + expectedResponseBody types.TaskProtectionResponse + postAssertions func(t *testing.T) // Any extra assertions for the test case +} + +func testTaskProtectionRequest(t *testing.T, tc TestCase) { + // Mocks + ctrl := gomock.NewController(t) + defer ctrl.Finish() + + agentState := mock_state.NewMockAgentState(ctrl) + credsManager := mock_credentials.NewMockManager(ctrl) + factory := NewMockTaskProtectionClientFactoryInterface(ctrl) + metricsFactory := mock_metrics.NewMockEntryFactory(ctrl) + + if tc.setAgentStateExpectations != nil { + tc.setAgentStateExpectations(agentState) + } + if tc.setCredsManagerExpectations != nil { + tc.setCredsManagerExpectations(credsManager) + } + if tc.setFactoryExpectations != nil { + tc.setFactoryExpectations(ctrl, factory) + } + if tc.setMetricsExpectations != nil { + tc.setMetricsExpectations(ctrl, metricsFactory) + } + + // Setup the handlers + router := mux.NewRouter() + router.HandleFunc( + TaskProtectionPath(), + GetTaskProtectionHandler(agentState, credsManager, factory, cluster, metricsFactory, ecsCallTimeout), + ).Methods("GET") + router.HandleFunc( + TaskProtectionPath(), + UpdateTaskProtectionHandler(agentState, credsManager, factory, cluster, metricsFactory, ecsCallTimeout), + ).Methods("PUT") + + // Create the request + method := "GET" + var requestBody io.Reader + if tc.requestBody != nil { + method = "PUT" + reqBodyBytes, err := json.Marshal(tc.requestBody) + require.NoError(t, err) + requestBody = bytes.NewReader(reqBodyBytes) + } + req, err := http.NewRequest(method, fmt.Sprintf("/api/%s/task-protection/v1/state", endpointId), + requestBody) + require.NoError(t, err) + + // Send the request and record the response + recorder := httptest.NewRecorder() + router.ServeHTTP(recorder, req) + + // Parse the response body + var actualResponseBody types.TaskProtectionResponse + err = json.Unmarshal(recorder.Body.Bytes(), &actualResponseBody) + require.NoError(t, err) + + // Assert status code and body + assert.Equal(t, tc.expectedStatusCode, recorder.Code) + assert.Equal(t, tc.expectedResponseBody, actualResponseBody) + + // Run any post assertions + if tc.postAssertions != nil { + tc.postAssertions(t) + } +} + +func TestGetTaskProtection(t *testing.T) { + // Initialize some data common to the test cases + happyECSInput := ecs.GetTaskProtectionInput{ + Cluster: aws.String(cluster), + Tasks: aws.StringSlice([]string{taskARN}), + } + metricName := metrics.GetTaskProtectionMetricName + + // A helper function for setting expectations on mock ECS Client Factory + factoryExpectations := func( + input ecs.GetTaskProtectionInput, + output *ecs.GetTaskProtectionOutput, + err error, + ) func(*gomock.Controller, *MockTaskProtectionClientFactoryInterface) { + return func(ctrl *gomock.Controller, factory *MockTaskProtectionClientFactoryInterface) { + client := mock_api.NewMockECSTaskProtectionSDK(ctrl) + client.EXPECT().GetTaskProtectionWithContext(gomock.Any(), &input).Return(output, err) + factory.EXPECT().NewTaskProtectionClient(taskRoleCreds()).Return(client) + } + } + + // Test cases start here + t.Run("task lookup failure", func(t *testing.T) { + testTaskProtectionRequest(t, taskMetadataLookupFailureCase(metricName, nil)) + }) + t.Run("task metadata fetch failure", func(t *testing.T) { + testTaskProtectionRequest(t, taskMetadataFetchErrorCase( + state.NewErrorMetadataFetchFailure(""), metricName, nil)) + }) + t.Run("task metadata uknown error", func(t *testing.T) { + testTaskProtectionRequest(t, taskMetadataFetchErrorCase( + errors.New("unknown"), metricName, nil)) + }) + t.Run("task role creds not found", func(t *testing.T) { + testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, nil)) + }) + t.Run("request failure", func(t *testing.T) { + ecsRequestID := "reqID" + ecsErrMessage := "ecs error" + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, + awserr.NewRequestFailure( + awserr.New(ecs.ErrCodeAccessDeniedException, ecsErrMessage, nil), + http.StatusBadRequest, + ecsRequestID, + )), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: types.TaskProtectionResponse{ + RequestID: &ecsRequestID, + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeAccessDeniedException, + Message: ecsErrMessage, + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("agent timeout", func(t *testing.T) { + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, + awserr.New(request.CanceledErrorCode, "request cancelled", nil)), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusGatewayTimeout, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: request.CanceledErrorCode, + Message: "Timed out calling ECS Task Protection API", + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("non-request-failure aws error", func(t *testing.T) { + ecsErrMessage := "ecs error" + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, + awserr.New(ecs.ErrCodeInvalidParameterException, ecsErrMessage, nil)), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeInvalidParameterException, + Message: ecsErrMessage, + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("non-aws error", func(t *testing.T) { + err := errors.New("some error") + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, err), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, Code: ecs.ErrCodeServerException, Message: err.Error(), + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("ecs failure", func(t *testing.T) { + metricsPublishCount := 0 // tracks the number of times metrics were published + ecsFailure := makeECSFailure("ecs failure") + testTaskProtectionRequest(t, TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.GetTaskProtectionOutput{ + Failures: []*ecs.Failure{ecsFailure}, + }, nil), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusOK, + expectedResponseBody: types.TaskProtectionResponse{ + Failure: ecsFailure, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("more than one ecs failure", func(t *testing.T) { + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.GetTaskProtectionOutput{ + Failures: []*ecs.Failure{makeECSFailure("1"), makeECSFailure("2")}, + }, nil), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeServerException, + Message: "Unexpected error occurred", + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("happy case", func(t *testing.T) { + metricsPublishCount := 0 // tracks the number of times metrics were published + protectedTask := ecsProtectedTask() + testTaskProtectionRequest(t, TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.GetTaskProtectionOutput{ + ProtectedTasks: []*ecs.ProtectedTask{&protectedTask}, + }, nil), + setMetricsExpectations: metricsExpectations(metricName, 1, &metricsPublishCount), + expectedStatusCode: http.StatusOK, + expectedResponseBody: types.TaskProtectionResponse{Protection: &protectedTask}, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) +} + +func TestUpdateTaskProtection(t *testing.T) { + // Initialize some data common to the test cases + metricName := metrics.UpdateTaskProtectionMetricName + expiresInMinutes := aws.Int64(5) + protectionEnabled := aws.Bool(true) + happyRequestBody := &TaskProtectionRequest{ + ProtectionEnabled: protectionEnabled, ExpiresInMinutes: expiresInMinutes, + } + happyECSInput := ecs.UpdateTaskProtectionInput{ + Cluster: aws.String(cluster), + Tasks: aws.StringSlice([]string{taskARN}), + ExpiresInMinutes: expiresInMinutes, + ProtectionEnabled: protectionEnabled, + } + + // A helper function for setting expectations on mock ECS Client Factory + factoryExpectations := func( + input ecs.UpdateTaskProtectionInput, + output *ecs.UpdateTaskProtectionOutput, + err error, + ) func(*gomock.Controller, *MockTaskProtectionClientFactoryInterface) { + return func(ctrl *gomock.Controller, factory *MockTaskProtectionClientFactoryInterface) { + client := mock_api.NewMockECSTaskProtectionSDK(ctrl) + client.EXPECT().UpdateTaskProtectionWithContext(gomock.Any(), &input).Return(output, err) + factory.EXPECT().NewTaskProtectionClient(taskRoleCreds()).Return(client) + } + } + + // Test cases start here + t.Run("task lookup failure", func(t *testing.T) { + testTaskProtectionRequest(t, taskMetadataLookupFailureCase(metricName, happyRequestBody)) + }) + t.Run("task metadata fetch failure", func(t *testing.T) { + testTaskProtectionRequest(t, taskMetadataFetchErrorCase( + state.NewErrorMetadataFetchFailure(""), metricName, happyRequestBody)) + }) + t.Run("task metadata unknown error", func(t *testing.T) { + testTaskProtectionRequest(t, taskMetadataFetchErrorCase( + errors.New("unknown"), metricName, happyRequestBody)) + }) + t.Run("unknown field in request", func(t *testing.T) { + testTaskProtectionRequest(t, TestCase{ + requestBody: map[string]interface{}{ + "ProtectionEnabled": true, + "ExpiresInMinutes": 5, + "Unknown": 2, + }, + setMetricsExpectations: nil, // no metrics interaction expected + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Code: ecs.ErrCodeInvalidParameterException, + Message: "UpdateTaskProtection: failed to decode request", + }, + }, + }) + }) + t.Run("invalid type in the request", func(t *testing.T) { + testTaskProtectionRequest(t, TestCase{ + requestBody: map[string]interface{}{"ProtectionEnabled": "bad"}, + setMetricsExpectations: nil, // no metrics interaction expected + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Code: ecs.ErrCodeInvalidParameterException, + Message: "UpdateTaskProtection: failed to decode request", + }, + }, + }) + }) + t.Run("ProtectionEnabled field not found on the request", func(t *testing.T) { + testTaskProtectionRequest(t, TestCase{ + requestBody: &TaskProtectionRequest{ExpiresInMinutes: expiresInMinutes}, + setAgentStateExpectations: happyStateExpectations, + setMetricsExpectations: func(ctrl *gomock.Controller, metricsFactory *mock_metrics.MockEntryFactory) { + // expecting entry creation but no publish + entry := mock_metrics.NewMockEntry(ctrl) + metricsFactory.EXPECT().New(metricName).Return(entry) + }, + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeInvalidParameterException, + Message: "Invalid request: does not contain 'ProtectionEnabled' field", + }, + }, + }) + }) + t.Run("task role creds not found", func(t *testing.T) { + testTaskProtectionRequest(t, taskRoleCredsNotFoundCase(metricName, happyRequestBody)) + }) + t.Run("request failure", func(t *testing.T) { + ecsRequestID := "reqID" + ecsErrMessage := "ecs error" + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + requestBody: happyRequestBody, + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, + awserr.NewRequestFailure( + awserr.New(ecs.ErrCodeAccessDeniedException, ecsErrMessage, nil), + http.StatusBadRequest, + ecsRequestID, + )), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusBadRequest, + expectedResponseBody: types.TaskProtectionResponse{ + RequestID: &ecsRequestID, + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeAccessDeniedException, + Message: ecsErrMessage, + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("agent timeout", func(t *testing.T) { + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + requestBody: happyRequestBody, + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, + awserr.New(request.CanceledErrorCode, "request cancelled", nil)), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusGatewayTimeout, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: request.CanceledErrorCode, + Message: "Timed out calling ECS Task Protection API", + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("non-request-failure aws error", func(t *testing.T) { + ecsErrMessage := "ecs error" + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + requestBody: happyRequestBody, + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, + awserr.New(ecs.ErrCodeInvalidParameterException, ecsErrMessage, nil)), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeInvalidParameterException, + Message: ecsErrMessage, + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("non-aws error", func(t *testing.T) { + err := errors.New("some error") + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + requestBody: happyRequestBody, + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, nil, err), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, Code: ecs.ErrCodeServerException, Message: err.Error(), + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("ecs failure", func(t *testing.T) { + ecsFailure := makeECSFailure("ecs failure") + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + requestBody: happyRequestBody, + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.UpdateTaskProtectionOutput{ + Failures: []*ecs.Failure{ecsFailure}, + }, nil), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusOK, + expectedResponseBody: types.TaskProtectionResponse{ + Failure: ecsFailure, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("more than one ecs failure", func(t *testing.T) { + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + requestBody: happyRequestBody, + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.UpdateTaskProtectionOutput{ + Failures: []*ecs.Failure{makeECSFailure("1"), makeECSFailure("2")}, + }, nil), + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeServerException, + Message: "Unexpected error occurred", + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) + t.Run("happy case", func(t *testing.T) { + protectedTask := ecsProtectedTask() + metricsPublishCount := 0 // tracks the number of times metrics were published + testTaskProtectionRequest(t, TestCase{ + requestBody: happyRequestBody, + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: happyCredsManagerExpectations, + setFactoryExpectations: factoryExpectations(happyECSInput, &ecs.UpdateTaskProtectionOutput{ + ProtectedTasks: []*ecs.ProtectedTask{&protectedTask}, + }, nil), + setMetricsExpectations: metricsExpectations(metricName, 1, &metricsPublishCount), + expectedStatusCode: http.StatusOK, + expectedResponseBody: types.TaskProtectionResponse{Protection: &protectedTask}, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + }) + }) +} + +// Returns an ECS Failure with the given reason. Uses standard Task ARN. +func makeECSFailure(reason string) *ecs.Failure { + return &ecs.Failure{ + Arn: aws.String(taskARN), + Reason: aws.String("ecs failure 1"), + } +} + +// Returns a standard ECS Protected Task for testing. +func ecsProtectedTask() ecs.ProtectedTask { + return ecs.ProtectedTask{ + ProtectionEnabled: aws.Bool(true), + TaskArn: aws.String(taskARN), + } +} + +// Returns a function that sets expectations on mock metrics factory. +// The expectation is for one entry to be created with the provided name and count values. +func metricsExpectations( + name string, + count int, + doneCallCountPtr *int, // incremented when 'Done()' function is called for publishing metrics +) func(*gomock.Controller, *mock_metrics.MockEntryFactory) { + return func(ctrl *gomock.Controller, metricsFactory *mock_metrics.MockEntryFactory) { + entry := mock_metrics.NewMockEntry(ctrl) + gomock.InOrder( + metricsFactory.EXPECT().New(name).Return(entry), + entry.EXPECT().WithCount(count).Return(entry), + entry.EXPECT().Done(nil).Return(func() { + *doneCallCountPtr++ // increment done call count + }), + ) + } +} + +// Function for setting happy case expectations on credentials manager. +// The expectation is for GetTaskCredentials method to be called with standard +// task role credentials ID returning standard task role credentials. +func happyCredsManagerExpectations(credsManager *mock_credentials.MockManager) { + credsManager.EXPECT().GetTaskCredentials(taskRoleCredsID).Return(taskRoleCreds(), true) +} + +// Returns a test case for Task Metadata fetch failure case. +func taskMetadataFetchErrorCase(err error, metricName string, reqBody interface{}) TestCase { + metricsPublishCount := 0 // tracks the number of times metrics were published + return TestCase{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, err) + }, + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + requestBody: reqBody, + expectedStatusCode: http.StatusInternalServerError, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Code: ecs.ErrCodeServerException, + Message: "Failed to find a task for the request", + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + } +} + +// Returns a test case for Task Metadata Lookup failure case. +func taskMetadataLookupFailureCase(metricName string, reqBody interface{}) TestCase { + err := state.NewErrorLookupFailure("external reason") + return TestCase{ + setAgentStateExpectations: func(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{}, err) + }, + setMetricsExpectations: func(ctrl *gomock.Controller, metricsFactory *mock_metrics.MockEntryFactory) { + entry := mock_metrics.NewMockEntry(ctrl) + metricsFactory.EXPECT().New(metricName).Return(entry) + }, + requestBody: reqBody, + expectedStatusCode: http.StatusNotFound, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Code: ecs.ErrCodeResourceNotFoundException, + Message: "Failed to find a task for the request", + }, + }, + } +} + +// Creates a test case for Task Role credentials not found case. +func taskRoleCredsNotFoundCase(metricName string, reqBody interface{}) TestCase { + metricsPublishCount := 0 // tracks the number of times metrics were published + return TestCase{ + setAgentStateExpectations: happyStateExpectations, + setCredsManagerExpectations: func(credsManager *mock_credentials.MockManager) { + credsManager.EXPECT().GetTaskCredentials(taskRoleCredsID). + Return(credentials.TaskIAMRoleCredentials{}, false) + }, + setMetricsExpectations: metricsExpectations(metricName, 0, &metricsPublishCount), + requestBody: reqBody, + expectedStatusCode: http.StatusForbidden, + expectedResponseBody: types.TaskProtectionResponse{ + Error: &types.ErrorResponse{ + Arn: taskARN, + Code: ecs.ErrCodeAccessDeniedException, + Message: "Invalid Request: no task IAM role credentials available for task", + }, + }, + postAssertions: func(t *testing.T) { assert.Equal(t, 1, metricsPublishCount) }, + } +} + +// Function for setting expectations on mock AgentState. +// The expectation is for GetTaskMetadata to be called with the test endpointID +// returning a standard test Task Metadata. +func happyStateExpectations(agentState *mock_state.MockAgentState) { + agentState.EXPECT().GetTaskMetadata(endpointId).Return(state.TaskResponse{ + TaskResponse: &v2.TaskResponse{TaskARN: taskARN}, + CredentialsID: taskRoleCredsID, + }, nil) +} + +// Returns standard Task Role credentials for testing. +func taskRoleCreds() credentials.TaskIAMRoleCredentials { + return credentials.TaskIAMRoleCredentials{ + ARN: "taskRoleCredsARN", + IAMRoleCredentials: credentials.IAMRoleCredentials{ + RoleArn: "roleARN", + AccessKeyID: "accessKeyID", + SecretAccessKey: "secretAccessKey", + }, + } +} diff --git a/ecs-agent/tmds/handlers/v4/state/response.go b/ecs-agent/tmds/handlers/v4/state/response.go index e0ceb62d1d1..2ff5a661271 100644 --- a/ecs-agent/tmds/handlers/v4/state/response.go +++ b/ecs-agent/tmds/handlers/v4/state/response.go @@ -33,6 +33,7 @@ type TaskResponse struct { ServiceName string `json:"ServiceName,omitempty"` ClockDrift *ClockDrift `json:"ClockDrift,omitempty"` EphemeralStorageMetrics *EphemeralStorageMetrics `json:"EphemeralStorageMetrics,omitempty"` + CredentialsID string `json:"-"` } // Instance's clock drift status From 6b58360e9e7f43b8613b416d73b95bc5a874b81c Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Thu, 29 Jun 2023 17:40:11 -0700 Subject: [PATCH 2/4] Some renaming --- agent/handlers/task_server_setup.go | 14 ++++----- agent/handlers/task_server_setup_test.go | 40 ++++++++++++------------ 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/agent/handlers/task_server_setup.go b/agent/handlers/task_server_setup.go index 35fe453fb20..ba8f7816a68 100644 --- a/agent/handlers/task_server_setup.go +++ b/agent/handlers/task_server_setup.go @@ -31,7 +31,7 @@ import ( auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" - tphandlers "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" + tp "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" tmdsv1 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1" tmdsv2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2" tmdsv4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" @@ -65,7 +65,7 @@ func taskServerSetup( availabilityZone string, vpcID string, containerInstanceArn string, - taskProtectionClientFactory tphandlers.TaskProtectionClientFactoryInterface, + taskProtectionClientFactory tp.TaskProtectionClientFactoryInterface, ) (*http.Server, error) { muxRouter := mux.NewRouter() @@ -167,19 +167,19 @@ func agentAPIV1HandlersSetup( credentialsManager credentials.Manager, cluster string, agentState *v4.TMDSAgentState, - factory tphandlers.TaskProtectionClientFactoryInterface, + factory tp.TaskProtectionClientFactoryInterface, metricsFactory metrics.EntryFactory, ) { muxRouter. HandleFunc( - tphandlers.TaskProtectionPath(), - tphandlers.UpdateTaskProtectionHandler(agentState, credentialsManager, + tp.TaskProtectionPath(), + tp.UpdateTaskProtectionHandler(agentState, credentialsManager, factory, cluster, metricsFactory, ecsCallTimeout)). Methods("PUT") muxRouter. HandleFunc( - tphandlers.TaskProtectionPath(), - tphandlers.GetTaskProtectionHandler(agentState, credentialsManager, + tp.TaskProtectionPath(), + tp.GetTaskProtectionHandler(agentState, credentialsManager, factory, cluster, metricsFactory, ecsCallTimeout)). Methods("GET") } diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index a510affa58c..457a9e856be 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -47,7 +47,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/ecs_client/model/ecs" mock_audit "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/mocks" tmdsresponse "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/response" - tpinterface "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" + tp "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers" tptypes "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/types" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" tmdsv1 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1" @@ -785,7 +785,7 @@ func testErrorResponsesFromServer(t *testing.T, path string, expectedErrorMessag ecsClient := mock_api.NewMockECSClient(ctrl) server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() @@ -822,7 +822,7 @@ func getResponseForCredentialsRequest(t *testing.T, expectedStatus int, ecsClient := mock_api.NewMockECSClient(ctrl) server, err := taskServerSetup(credentialsManager, auditLog, nil, ecsClient, "", nil, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() @@ -881,7 +881,7 @@ func TestV3ContainerAssociations(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType, nil) @@ -913,7 +913,7 @@ func TestV3ContainerAssociation(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v3BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) @@ -944,7 +944,7 @@ func TestV4ContainerAssociations(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType, nil) @@ -976,7 +976,7 @@ func TestV4ContainerAssociation(t *testing.T) { ) server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) recorder := httptest.NewRecorder() req, _ := http.NewRequest("GET", v4BasePath+v3EndpointID+"/associations/"+associationType+"/"+associationName, nil) @@ -1003,7 +1003,7 @@ func TestTaskHTTPEndpoint301Redirect(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for testPath, expectedPath := range testPathsMap { @@ -1046,7 +1046,7 @@ func TestTaskHTTPEndpointErrorCode404(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1086,7 +1086,7 @@ func TestTaskHTTPEndpointErrorCode400(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1125,7 +1125,7 @@ func TestTaskHTTPEndpointErrorCode500(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) for _, testPath := range testPaths { @@ -1195,7 +1195,7 @@ func TestV4TaskNotFoundError404(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) state.EXPECT().TaskARNByV3EndpointID(gomock.Any()).Return("", tc.taskFound).AnyTimes() @@ -1251,7 +1251,7 @@ func TestV4Unexpected500Error(t *testing.T) { server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine, config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID, - containerInstanceArn, tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl)) + containerInstanceArn, tp.NewMockTaskProtectionClientFactoryInterface(ctrl)) require.NoError(t, err) // Initial lookups succeed @@ -1308,7 +1308,7 @@ type TMDSTestCase[R TMDSResponse] struct { setECSClientExpectations func(ecsClient *mock_api.MockECSClient) // Function to set expectations on mock Task Protection Client Factory setTaskProtectionClientFactoryExpectations func( - ctrl *gomock.Controller, factory *tpinterface.MockTaskProtectionClientFactoryInterface) + ctrl *gomock.Controller, factory *tp.MockTaskProtectionClientFactoryInterface) // Function to set expectations on mock Credentials Manager setCredentialsManagerExpectations func(credsManager *mock_credentials.MockManager) // Expected HTTP status code of the response @@ -1335,7 +1335,7 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) { statsEngine := mock_stats.NewMockEngine(ctrl) ecsClient := mock_api.NewMockECSClient(ctrl) credsManager := mock_credentials.NewMockManager(ctrl) - taskProtectionClientFactory := tpinterface.NewMockTaskProtectionClientFactoryInterface(ctrl) + taskProtectionClientFactory := tp.NewMockTaskProtectionClientFactoryInterface(ctrl) // Set expectations on mocks auditLog.EXPECT().Log(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes() @@ -3013,11 +3013,11 @@ func TestGetTaskProtection(t *testing.T) { Return(taskRoleCredentials(), true) } taskProtectionClientFactoryExpectations := func(output *ecs.GetTaskProtectionOutput, err error) func( - *gomock.Controller, *tpinterface.MockTaskProtectionClientFactoryInterface, + *gomock.Controller, *tp.MockTaskProtectionClientFactoryInterface, ) { return func( ctrl *gomock.Controller, - factory *tpinterface.MockTaskProtectionClientFactoryInterface, + factory *tp.MockTaskProtectionClientFactoryInterface, ) { client := mock_taskprotection.NewMockECSTaskProtectionSDK(ctrl) client.EXPECT().GetTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err) @@ -3263,7 +3263,7 @@ func TestUpdateTaskProtection(t *testing.T) { } ecsRequestID := "reqid" ecsErrMessage := "ecs error message" - happyReqBody := &tpinterface.TaskProtectionRequest{ + happyReqBody := &tp.TaskProtectionRequest{ ProtectionEnabled: protectionEnabled, ExpiresInMinutes: expirationMinutes, } @@ -3284,11 +3284,11 @@ func TestUpdateTaskProtection(t *testing.T) { Return(taskRoleCredentials(), true) } taskProtectionClientFactoryExpectations := func(output *ecs.UpdateTaskProtectionOutput, err error) func( - *gomock.Controller, *tpinterface.MockTaskProtectionClientFactoryInterface, + *gomock.Controller, *tp.MockTaskProtectionClientFactoryInterface, ) { return func( ctrl *gomock.Controller, - factory *tpinterface.MockTaskProtectionClientFactoryInterface, + factory *tp.MockTaskProtectionClientFactoryInterface, ) { client := mock_taskprotection.NewMockECSTaskProtectionSDK(ctrl) client.EXPECT().UpdateTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err) From 9af7a91deafa1983f2038bfa0345109df9a9e760 Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Thu, 29 Jun 2023 17:46:21 -0700 Subject: [PATCH 3/4] Remove unnecessary export --- .../handlers/taskprotection/v1/handlers/handlers.go | 10 +++++----- .../handlers/taskprotection/v1/handlers/handlers.go | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go index f19df4fc39b..527d714687b 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go @@ -38,7 +38,7 @@ import ( ) const ( - ExpectedProtectionResponseLength = 1 + expectedProtectionResponseLength = 1 ecsCallTimedOutError = "Timed out calling ECS Task Protection API" taskMetadataFetchFailureMsg = "Failed to find a task for the request" ) @@ -310,10 +310,10 @@ func logAndValidateECSResponse( }) if len(failures) > 0 { - if len(failures) > ExpectedProtectionResponseLength { + if len(failures) > expectedProtectionResponseLength { err := fmt.Errorf( "expect at most %v failure in response, get %v", - ExpectedProtectionResponseLength, len(failures)) + expectedProtectionResponseLength, len(failures)) logger.Error("Unexpected number of failures", logger.Fields{ field.Error: err, field.TaskARN: task.TaskARN, @@ -329,10 +329,10 @@ func logAndValidateECSResponse( return http.StatusOK, &response } - if len(protectedTasks) > ExpectedProtectionResponseLength { + if len(protectedTasks) > expectedProtectionResponseLength { err := fmt.Errorf( "expect %v protectedTask in response when no failure, get %v", - ExpectedProtectionResponseLength, len(protectedTasks)) + expectedProtectionResponseLength, len(protectedTasks)) logger.Error("Unexpected number of protections", logger.Fields{ field.Error: err, field.TaskARN: task.TaskARN, diff --git a/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go index f19df4fc39b..527d714687b 100644 --- a/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/taskprotection/v1/handlers/handlers.go @@ -38,7 +38,7 @@ import ( ) const ( - ExpectedProtectionResponseLength = 1 + expectedProtectionResponseLength = 1 ecsCallTimedOutError = "Timed out calling ECS Task Protection API" taskMetadataFetchFailureMsg = "Failed to find a task for the request" ) @@ -310,10 +310,10 @@ func logAndValidateECSResponse( }) if len(failures) > 0 { - if len(failures) > ExpectedProtectionResponseLength { + if len(failures) > expectedProtectionResponseLength { err := fmt.Errorf( "expect at most %v failure in response, get %v", - ExpectedProtectionResponseLength, len(failures)) + expectedProtectionResponseLength, len(failures)) logger.Error("Unexpected number of failures", logger.Fields{ field.Error: err, field.TaskARN: task.TaskARN, @@ -329,10 +329,10 @@ func logAndValidateECSResponse( return http.StatusOK, &response } - if len(protectedTasks) > ExpectedProtectionResponseLength { + if len(protectedTasks) > expectedProtectionResponseLength { err := fmt.Errorf( "expect %v protectedTask in response when no failure, get %v", - ExpectedProtectionResponseLength, len(protectedTasks)) + expectedProtectionResponseLength, len(protectedTasks)) logger.Error("Unexpected number of protections", logger.Fields{ field.Error: err, field.TaskARN: task.TaskARN, From 29427d82894de796daee3738da554dd914b53e16 Mon Sep 17 00:00:00 2001 From: Amogh Rathore Date: Fri, 30 Jun 2023 17:00:01 +0000 Subject: [PATCH 4/4] Update task metadata test to ensure credentials ID are not in the response --- ecs-agent/tmds/handlers/v4/handlers_test.go | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/ecs-agent/tmds/handlers/v4/handlers_test.go b/ecs-agent/tmds/handlers/v4/handlers_test.go index 19c498e121c..1d514359813 100644 --- a/ecs-agent/tmds/handlers/v4/handlers_test.go +++ b/ecs-agent/tmds/handlers/v4/handlers_test.go @@ -107,8 +107,13 @@ var ( }}, }, } - now = time.Now() - taskResponse = state.TaskResponse{ + now = time.Now() + credentialsID = "credentialsID" +) + +// Returns a standard agent task response +func taskResponse() *state.TaskResponse { + return &state.TaskResponse{ TaskResponse: &v2.TaskResponse{ Cluster: clusterName, TaskARN: taskARN, @@ -136,8 +141,9 @@ var ( UtilizedMiBs: 500, ReservedMiBs: 600, }, + CredentialsID: credentialsID, } -) +} func TestContainerMetadata(t *testing.T) { var setup = func(t *testing.T) (*mux.Router, *gomock.Controller, *mock_state.MockAgentState, @@ -238,14 +244,17 @@ func TestTaskMetadata(t *testing.T) { } t.Run("happy case", func(t *testing.T) { + metadata := taskResponse() + expectedTaskResponse := taskResponse() + expectedTaskResponse.CredentialsID = "" // credentials ID not expected handler, _, agentState, _ := setup(t) agentState.EXPECT(). GetTaskMetadata(endpointContainerID). - Return(taskResponse, nil) + Return(*metadata, nil) testTMDSRequest(t, handler, TMDSTestCase[state.TaskResponse]{ path: path, expectedStatusCode: http.StatusOK, - expectedResponseBody: taskResponse, + expectedResponseBody: *expectedTaskResponse, }) }) t.Run("task lookup failure", func(t *testing.T) {