Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for UpdateTaskProtection API to high-level TMDS tests #3740

Merged
merged 2 commits into from
Jun 12, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
343 changes: 292 additions & 51 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ import (
v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3"
"github.com/aws/amazon-ecs-agent/agent/stats"
mock_stats "github.com/aws/amazon-ecs-agent/agent/stats/mock"
agentutils "github.com/aws/amazon-ecs-agent/agent/utils"
apieni "github.com/aws/amazon-ecs-agent/ecs-agent/api/eni"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks"
Expand Down Expand Up @@ -1527,6 +1526,10 @@ type TMDSResponse interface {
type TMDSTestCase[R TMDSResponse] struct {
// Request path
path string
// Method to use for the request, defaults to GET
method string
// Optional request body
requestBody interface{}
// Function to set expectations on mock task engine state
setStateExpectations func(state *mock_dockerstate.MockTaskEngineState)
// Function to set expectations on mock ECS Client
Expand Down Expand Up @@ -1564,7 +1567,9 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) {

// Set expectations on mocks
auditLog.EXPECT().Log(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes()
tc.setStateExpectations(state)
if tc.setStateExpectations != nil {
tc.setStateExpectations(state)
}
if tc.setECSClientExpectations != nil {
tc.setECSClientExpectations(ecsClient)
}
Expand All @@ -1583,7 +1588,16 @@ func testTMDSRequest[R TMDSResponse](t *testing.T, tc TMDSTestCase[R]) {
require.NoError(t, err)

// Create the request
req, err := http.NewRequest("GET", tc.path, nil)
var reqBody io.Reader
if tc.requestBody != nil {
reqBodyBytes, err := json.Marshal(tc.requestBody)
require.NoError(t, err)
reqBody = bytes.NewReader(reqBodyBytes)
}
if tc.method == "" {
tc.method = "GET"
}
req, err := http.NewRequest(tc.method, tc.path, reqBody)
require.NoError(t, err)
req.RemoteAddr = remoteIP + ":" + remotePort

Expand Down Expand Up @@ -2739,49 +2753,6 @@ func TestV4TaskMetadataWithTags(t *testing.T) {
})
}

// Helper function for testing Agent API Task Protection v1 handlers
func testAgentAPITaskProtectionV1Handler(t *testing.T, requestBody interface{}, method string) {
// Prepare dependency mocks
ctrl := gomock.NewController(t)
defer ctrl.Finish()

task := standardTask()

state := mock_dockerstate.NewMockTaskEngineState(ctrl)
auditLog := mock_audit.NewMockAuditLogger(ctrl)
statsEngine := mock_stats.NewMockEngine(ctrl)
ecsClient := mock_api.NewMockECSClient(ctrl)
ecsClientFactory := agentapihandlers.NewMockTaskProtectionClientFactoryInterface(ctrl)

gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
)

// Set up the server
server, err := taskServerSetup(credentials.NewManager(), auditLog, state, ecsClient, clusterName, statsEngine,
config.DefaultTaskMetadataSteadyStateRate, config.DefaultTaskMetadataBurstRate, "", vpcID,
containerInstanceArn, ecsClientFactory)
require.NoError(t, err)

// Prepare the request
var requestReader io.Reader = nil
if requestBody != nil {
requestBodyJSON, err := json.Marshal(requestBody)
assert.NoError(t, err)
requestReader = bytes.NewReader(requestBodyJSON)
}

// Send request and record response
recorder := httptest.NewRecorder()
req, _ := http.NewRequest(method, fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID),
requestReader)
server.Handler.ServeHTTP(recorder, req)

// assert that there is response
assert.NotNil(t, recorder.Body)
}

func TestGetTaskProtection(t *testing.T) {
path := fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID)

Expand Down Expand Up @@ -3016,10 +2987,280 @@ func TestGetTaskProtection(t *testing.T) {
})
}

// Tests that Agent API v1 UpdateTaskProtection handler is registered correctly
func TestAgentAPIV1UpdateTaskProtectionHandler(t *testing.T) {
requestBody := task_protection_v1.TaskProtectionRequest{
ProtectionEnabled: agentutils.BoolPtr(false),
func TestUpdateTaskProtection(t *testing.T) {
// Set up some fake data
task := standardTask()
protectionEnabled := aws.Bool(true)
expirationMinutes := aws.Int64(5)
ecsInput := ecs.UpdateTaskProtectionInput{
Cluster: aws.String(clusterName),
ProtectionEnabled: protectionEnabled,
ExpiresInMinutes: expirationMinutes,
Tasks: aws.StringSlice([]string{taskARN}),
}
protectedTask := ecs.ProtectedTask{
ProtectionEnabled: aws.Bool(true),
TaskArn: aws.String(taskARN),
}
ecsOutput := ecs.UpdateTaskProtectionOutput{
ProtectedTasks: []*ecs.ProtectedTask{&protectedTask},
}
ecsRequestID := "reqid"
ecsErrMessage := "ecs error message"
happyReqBody := &agentapihandlers.TaskProtectionRequest{
ProtectionEnabled: protectionEnabled,
ExpiresInMinutes: expirationMinutes,
}

// Helper functions to set expectation on mocks
happyStateExpectations := func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(task, true),
)
}
happyCredentialsManagerExpectations := func(credsManager *mock_credentials.MockManager) {
credsManager.EXPECT().
GetTaskCredentials(task.GetCredentialsID()).
Return(taskRoleCredentials(), true)
}
taskProtectionClientFactoryExpectations := func(output *ecs.UpdateTaskProtectionOutput, err error) func(
*gomock.Controller, *task_protection_v1.MockTaskProtectionClientFactoryInterface,
) {
return func(
ctrl *gomock.Controller,
factory *task_protection_v1.MockTaskProtectionClientFactoryInterface,
) {
client := mock_api.NewMockECSTaskProtectionSDK(ctrl)
client.EXPECT().UpdateTaskProtectionWithContext(gomock.Any(), &ecsInput).Return(output, err)
factory.EXPECT().NewTaskProtectionClient(taskRoleCredentials()).Return(client)
}
}

// Helper function for creating a function that runs a test case
runTest := func(t *testing.T, tc TMDSTestCase[agentapi.TaskProtectionResponse]) func(*testing.T) {
return func(t *testing.T) {
tc.path = fmt.Sprintf("/api/%s/task-protection/v1/state", v3EndpointID)
tc.method = "PUT"
testTMDSRequest(t, tc)
}
}
testAgentAPITaskProtectionV1Handler(t, requestBody, "PUT")

// Test cases start here
t.Run("task ARN not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return("", false),
)
},
expectedStatusCode: http.StatusNotFound,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeResourceNotFoundException,
Message: "Invalid request: no task was found",
},
},
}))
t.Run("task not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: func(state *mock_dockerstate.MockTaskEngineState) {
gomock.InOrder(
state.EXPECT().TaskARNByV3EndpointID(v3EndpointID).Return(taskARN, true),
state.EXPECT().TaskByArn(taskARN).Return(nil, false),
)
},
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeServerException,
Message: "Failed to find a task for the request",
},
},
}))
t.Run("task credentials not found", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: func(credsManager *mock_credentials.MockManager) {
credsManager.
EXPECT().GetTaskCredentials(taskCredentialsID).
Return(credentials.TaskIAMRoleCredentials{}, false)
},
expectedStatusCode: http.StatusForbidden,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeAccessDeniedException,
Message: "Invalid Request: no task IAM role credentials available for task",
},
},
}))
t.Run("ecs call server exception", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil,
awserr.NewRequestFailure(
awserr.New(ecs.ErrCodeServerException, ecsErrMessage, nil),
http.StatusInternalServerError,
ecsRequestID,
),
),
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
RequestID: &ecsRequestID,
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeServerException,
Message: ecsErrMessage,
},
},
}))
t.Run("ecs call access denied exception", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil,
awserr.NewRequestFailure(
awserr.New(ecs.ErrCodeAccessDeniedException, ecsErrMessage, nil),
http.StatusBadRequest,
ecsRequestID,
),
),
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
RequestID: &ecsRequestID,
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeAccessDeniedException,
Message: ecsErrMessage,
},
},
}))
t.Run("ecs call non-request-failure aws error", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil,
awserr.New(ecs.ErrCodeInvalidParameterException, ecsErrMessage, nil)),
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeInvalidParameterException,
Message: ecsErrMessage,
},
},
}))
t.Run("agent timeout", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil, awserr.New(request.CanceledErrorCode, "request cancelled", nil)),
expectedStatusCode: http.StatusGatewayTimeout,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: request.CanceledErrorCode,
Message: "Timed out calling ECS Task Protection API",
},
},
}))
t.Run("non-aws error", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
nil, errors.New("some error")),
expectedStatusCode: http.StatusInternalServerError,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeServerException,
Message: "some error",
},
},
}))
t.Run("ecs failure", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(
&ecs.UpdateTaskProtectionOutput{
Failures: []*ecs.Failure{{
Arn: aws.String(taskARN),
Reason: aws.String("ecs failure"),
}},
}, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: agentapi.TaskProtectionResponse{
Failure: &ecs.Failure{
Arn: aws.String(taskARN),
Reason: aws.String("ecs failure"),
},
},
}))
t.Run("empty request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: map[string]string{},
setStateExpectations: happyStateExpectations,
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Arn: taskARN,
Code: ecs.ErrCodeInvalidParameterException,
Message: "Invalid request: does not contain 'ProtectionEnabled' field",
},
},
}))
t.Run("invalid type in request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: map[string]interface{}{
"ProtectionEnabled": true,
"ExpiresInMinutes": "bad",
},
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
},
},
}))
t.Run("unknown fields in the request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: map[string]interface{}{
"ProtectionEnabled": true,
"ExpiresInMinutes": 5,
"Unknown": "unknown",
},
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
},
},
}))
t.Run("non-JSON object request", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: "bad",
expectedStatusCode: http.StatusBadRequest,
expectedResponseBody: agentapi.TaskProtectionResponse{
Error: &agentapi.ErrorResponse{
Code: ecs.ErrCodeInvalidParameterException,
Message: "UpdateTaskProtection: failed to decode request",
},
},
}))
t.Run("happy case", runTest(t, TMDSTestCase[agentapi.TaskProtectionResponse]{
requestBody: happyReqBody,
setStateExpectations: happyStateExpectations,
setCredentialsManagerExpectations: happyCredentialsManagerExpectations,
setTaskProtectionClientFactoryExpectations: taskProtectionClientFactoryExpectations(&ecsOutput, nil),
expectedStatusCode: http.StatusOK,
expectedResponseBody: agentapi.TaskProtectionResponse{
Protection: &protectedTask,
},
}))
}