Skip to content

Commit

Permalink
Add tests for UpdateTaskProtection API to high-level TMDS tests (#3740)
Browse files Browse the repository at this point in the history
  • Loading branch information
amogh09 authored Jun 12, 2023
1 parent ba48cc1 commit 827d32f
Showing 1 changed file with 292 additions and 51 deletions.
343 changes: 292 additions & 51 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3"
"github.com/aws/amazon-ecs-agent/agent/stats"
mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock"
agentutils "github.com/aws/amazon-ecs-agent/agent/utils"
apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks"
Expand Down Expand Up @@ -1527,6 +1526,10 @@ type TMDSResponse interface {
type TMDSTestCase[R TMDSResponse] struct {
// Request path
path string
// Method to use for the request, defaults to GET
method string
// Optional request body
requestBody interface{}
// Function to set expectations on mock task engine state
setStateExpectations func(state *mock_dockerstate.MockTaskEngineState)
// Function to set expectations on mock ECS Client
Expand Down Expand Up @@ -1564,7 +1567,9 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) {

// Set expectations on mocks
auditLog.EXPECT().Log(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tc.setStateExpectations(state)
if tc.setStateExpectations != nil {
tc.setStateExpectations(state)
}
if tc.setECSClientExpectations != nil {
tc.setECSClientExpectations(ecsClient)
}
Expand All @@ -1583,7 +1588,16 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) {
require.NoError(t, err)

// Create the request
req, err := http.NewRequest("GET", tc.path, nil)
var reqBody io.Reader
if tc.requestBody != nil {
reqBodyBytes, err := json.Marshal(tc.requestBody)
require.NoError(t, err)
reqBody = bytes.NewReader(reqBodyBytes)
}
if tc.method == "" {
tc.method = "GET"
}
req, err := http.NewRequest(tc.method, tc.path, reqBody)
require.NoError(t, err)
req.RemoteAddr = remoteIP + ":" + remotePort

Expand Down Expand Up @@ -2739,49 +2753,6 @@ func TestV4TaskMetadataWithTags(t *testing.T) {
})
}

// Helper function for testing Agent API Task Protection v1 handlers
func testAgentAPITaskProtectionV1Handler(t *testing.T, requestBody interface{}, method string) {
// Prepare dependency mocks
ctrl := gomock.NewController(t)
defer ctrl.Finish()

task := standardTask()

state := mock_dockerstate.NewMockTaskEngineState(ctrl)
auditLog := mock_audit.NewMockAuditLogger(ctrl)
statsEngine := mock_stats.NewMockEngine(ctrl)
ecsClient := mock_api.NewMockECSClient(ctrl)
ecsClientFactory := agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)

gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
)

// Set up the server
server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine,
config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID,
containerInstanceArn, ecsClientFactory)
require.NoError(t, err)

// Prepare the request
var requestReader io.Reader = nil
if requestBody != nil {
requestBodyJSON, err := json.Marshal(requestBody)
assert.NoError(t, err)
requestReader = bytes.NewReader(requestBodyJSON)
}

// Send request and record response
recorder := httptest.NewRecorder()
req, _ := http.NewRequest(method, fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID),
requestReader)
server.Handler.ServeHTTP(recorder, req)

// assert that there is response
assert.NotNil(t, recorder.Body)
}

func TestGetTaskProtection(t *testing.T) {
path := fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID)

Expand Down Expand Up @@ -3016,10 +2987,280 @@ func TestGetTaskProtection(t *testing.T) {
})
}

// Tests that Agent API v1 UpdateTaskProtection handler is registered correctly
func TestAgentAPIV1UpdateTaskProtectionHandler(t *testing.T) {
requestBody := task_protection_v1.TaskProtectionRequest{
ProtectionEnabled: agentutils.BoolPtr(false),
func TestUpdateTaskProtection(t *testing.T) {
// Set up some fake data
task := standardTask()
protectionEnabled := aws.Bool(true)
expirationMinutes := aws.Int64(5)
ecsInput := ecs.UpdateTaskProtectionInput{
Cluster: aws.String(clusterName),
ProtectionEnabled: protectionEnabled,
ExpiresInMinutes: expirationMinutes,
Tasks: aws.StringSlice([]string{taskARN}),
}
protectedTask := ecs.ProtectedTask{
ProtectionEnabled: aws.Bool(true),
TaskArn: aws.String(taskARN),
}
ecsOutput := ecs.UpdateTaskProtectionOutput{
ProtectedTasks: []*ecs.ProtectedTask{&protectedTask},
}
ecsRequestID := "reqid"
ecsErrMessage := "ecs error message"
happyReqBody := &agentapihandlers.TaskProtectionRequest{
ProtectionEnabled: protectionEnabled,
ExpiresInMinutes: expirationMinutes,
}

// Helper functions to set expectation on mocks
happyStateExpectations := func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
)
}
happyCredentialsManagerExpectations := func(credsManager *mock_credentials.MockManager) {
credsManager.EXPECT().
GetTaskCredentials(task.GetCredentialsID()).
Return(taskRoleCredentials(), true)
}
taskProtectionClientFactoryExpectations := func(output *ecs.UpdateTaskProtectionOutput, err error) func(
*gomock.Controller, *task_protection_v1.MockTaskProtectionClientFactoryInterface,
) {
return func(
ctrl *gomock.Controller,
factory *task_protection_v1.MockTaskProtectionClientFactoryInterface,
) {
client := mock_api.NewMockECSTaskProtectionSDK(ctrl)
client.EXPECT().UpdateTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err)
factory.EXPECT().NewTaskProtectionClient(taskRoleCredentials()).Return(client)
}
}

// Helper function for creating a function that runs a test case
runTest := func(t *testing.T, tc TMDSTestCase[agentapi.TaskProtectionResponse]) func(*testing.T) {
return func(t *testing.T) {
tc.path = fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID)
tc.method = "PUT"
testTMDSRequest(t, tc)
}
}
testAgentAPITaskProtectionV1Handler(t, requestBody, "PUT")

// Test cases start here
t.Run("task ARN not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return("", false),
)
},
expectedStatusCode: http.StatusNotFound,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeResourceNotFoundException,
Message: "Invalid request: no task was found",
},
},
}))
t.Run("task not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(nil, false),
)
},
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeServerException,
Message: "Failed to find a task for the request",
},
},
}))
t.Run("task credentials not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: func(credsManager *mock_credentials.MockManager) {
credsManager.
EXPECT().GetTaskCredentials(taskCredentialsID).
Return(credentials.TaskIAMRoleCredentials{}, false)
},
expectedStatusCode: http.StatusForbidden,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeAccessDeniedException,
Message: "Invalid Request: no task IAM role credentials available for task",
},
},
}))
t.Run("ecs call server exception", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil,
awserr.NewRequestFailure(
awserr.New(ecs.ErrCodeServerException, ecsErrMessage, nil),
http.StatusInternalServerError,
ecsRequestID,
),
),
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
RequestID: &ecsRequestID,
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeServerException,
Message: ecsErrMessage,
},
},
}))
t.Run("ecs call access denied exception", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil,
awserr.NewRequestFailure(
awserr.New(ecs.ErrCodeAccessDeniedException, ecsErrMessage, nil),
http.StatusBadRequest,
ecsRequestID,
),
),
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
RequestID: &ecsRequestID,
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeAccessDeniedException,
Message: ecsErrMessage,
},
},
}))
t.Run("ecs call non-request-failure aws error", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil,
awserr.New(ecs.ErrCodeInvalidParameterException, ecsErrMessage, nil)),
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeInvalidParameterException,
Message: ecsErrMessage,
},
},
}))
t.Run("agent timeout", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil, awserr.New(request.CanceledErrorCode, "request cancelled", nil)),
expectedStatusCode: http.StatusGatewayTimeout,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: request.CanceledErrorCode,
Message: "Timed out calling ECS Task Protection API",
},
},
}))
t.Run("non-aws error", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil, errors.New("some error")),
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeServerException,
Message: "some error",
},
},
}))
t.Run("ecs failure", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
&ecs.UpdateTaskProtectionOutput{
Failures: []*ecs.Failure{{
Arn: aws.String(taskARN),
Reason: aws.String("ecs failure"),
}},
}, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: agentapi.TaskProtectionResponse{
Failure: &ecs.Failure{
Arn: aws.String(taskARN),
Reason: aws.String("ecs failure"),
},
},
}))
t.Run("empty request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: map[string]string{},
setStateExpectations: happyStateExpectations,
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeInvalidParameterException,
Message: "Invalid request: does not contain 'ProtectionEnabled' field",
},
},
}))
t.Run("invalid type in request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: map[string]interface{}{
"ProtectionEnabled": true,
"ExpiresInMinutes": "bad",
},
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
},
},
}))
t.Run("unknown fields in the request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: map[string]interface{}{
"ProtectionEnabled": true,
"ExpiresInMinutes": 5,
"Unknown": "unknown",
},
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
},
},
}))
t.Run("non-JSON object request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: "bad",
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
},
},
}))
t.Run("happy case", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: agentapi.TaskProtectionResponse{
Protection: &protectedTask,
},
}))
}

0 comments on commit 827d32f

Please sign in to comment.