diff --git a/agent/handlers/task_server_setup_test.go b/agent/handlers/task_server_setup_test.go index c328cd077a9..221382023a6 100644 --- a/agent/handlers/task_server_setup_test.go +++ b/agent/handlers/task_server_setup_test.go @@ -3806,6 +3806,8 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) { cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) } tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 42fa7956106..7841937fd38 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -30,6 +30,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" @@ -56,7 +57,7 @@ const ( requestTimeoutSeconds = 5 // Commands that will be used to start/stop/check fault. iptablesNewChainCmd = "iptables -w %d -N %s" - iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s --dport %s -j DROP" + iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s -d %s --dport %s -j %s" iptablesInsertChainCmd = "iptables -w %d -I %s -j %s" iptablesChainExistCmd = "iptables -w %d -C %s -p %s --dport %s -j DROP" iptablesClearChainCmd = "iptables -w %d -F %s" @@ -71,6 +72,9 @@ const ( tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1" tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:" tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio" + allIPv4CIDR = "0.0.0.0/0" + dropTarget = "DROP" + acceptTarget = "ACCEPT" ) type FaultHandler struct { @@ -220,24 +224,42 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, "taskArn": taskArn, }) - // Appending a new rule based on the protocol and port number from the request body - appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, requestTimeoutSeconds, chain, protocol, port) - cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " ")) - if err != nil { - logger.Error("Unable to append rule to chain", logger.Fields{ - "netns": netNs, - "command": appendRuleCmdString, + // Helper function to run iptables rule change commands + var execRuleChangeCommand = func(cmdString string) (string, error) { + // Appending a new rule based on the protocol and port number from the request body + cmdOutput, err = h.runExecCommand(ctx, strings.Split(cmdString, " ")) + if err != nil { + logger.Error("Unable to add rule to chain", logger.Fields{ + "netns": netNs, + "command": cmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + logger.Info("Successfully added new rule to iptable chain", logger.Fields{ + "command": cmdString, "output": string(cmdOutput), "taskArn": taskArn, - "error": err, }) - return string(cmdOutput), err + return "", nil + } + + // Add a rule to accept all traffic to TMDS + protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, + requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks, + acceptTarget) + if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil { + return out, err + } + + // Add a rule to drop all traffic to the port that the fault targets + faultRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, + requestTimeoutSeconds, chain, protocol, allIPv4CIDR, port, dropTarget) + if out, err := execRuleChangeCommand(faultRuleCmdString); err != nil { + return out, err } - logger.Info("Successfully appended new rule to iptable chain", logger.Fields{ - "command": appendRuleCmdString, - "output": string(cmdOutput), - "taskArn": taskArn, - }) // Inserting the chain into the built-in INPUT/OUTPUT table insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, requestTimeoutSeconds, insertTable, chain) diff --git a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go index 0d72f659de0..7248bd7ff25 100644 --- a/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go +++ b/agent/vendor/github.com/aws/amazon-ecs-agent/ecs-agent/tmds/server.go @@ -29,8 +29,10 @@ import ( const ( // TMDS IP and port - IPv4 = "127.0.0.1" - Port = 51679 + IPv4 = "127.0.0.1" + Port = 51679 + IPForTasks = "169.254.170.2" + PortForTasks = "80" ) // IPv4 address for TMDS diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go index 42fa7956106..7841937fd38 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go @@ -30,6 +30,7 @@ import ( "github.com/aws/amazon-ecs-agent/ecs-agent/logger" "github.com/aws/amazon-ecs-agent/ecs-agent/logger/field" "github.com/aws/amazon-ecs-agent/ecs-agent/metrics" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types" "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils" v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4" @@ -56,7 +57,7 @@ const ( requestTimeoutSeconds = 5 // Commands that will be used to start/stop/check fault. iptablesNewChainCmd = "iptables -w %d -N %s" - iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s --dport %s -j DROP" + iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s -d %s --dport %s -j %s" iptablesInsertChainCmd = "iptables -w %d -I %s -j %s" iptablesChainExistCmd = "iptables -w %d -C %s -p %s --dport %s -j DROP" iptablesClearChainCmd = "iptables -w %d -F %s" @@ -71,6 +72,9 @@ const ( tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1" tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:" tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio" + allIPv4CIDR = "0.0.0.0/0" + dropTarget = "DROP" + acceptTarget = "ACCEPT" ) type FaultHandler struct { @@ -220,24 +224,42 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol, "taskArn": taskArn, }) - // Appending a new rule based on the protocol and port number from the request body - appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, requestTimeoutSeconds, chain, protocol, port) - cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " ")) - if err != nil { - logger.Error("Unable to append rule to chain", logger.Fields{ - "netns": netNs, - "command": appendRuleCmdString, + // Helper function to run iptables rule change commands + var execRuleChangeCommand = func(cmdString string) (string, error) { + // Appending a new rule based on the protocol and port number from the request body + cmdOutput, err = h.runExecCommand(ctx, strings.Split(cmdString, " ")) + if err != nil { + logger.Error("Unable to add rule to chain", logger.Fields{ + "netns": netNs, + "command": cmdString, + "output": string(cmdOutput), + "taskArn": taskArn, + "error": err, + }) + return string(cmdOutput), err + } + logger.Info("Successfully added new rule to iptable chain", logger.Fields{ + "command": cmdString, "output": string(cmdOutput), "taskArn": taskArn, - "error": err, }) - return string(cmdOutput), err + return "", nil + } + + // Add a rule to accept all traffic to TMDS + protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, + requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks, + acceptTarget) + if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil { + return out, err + } + + // Add a rule to drop all traffic to the port that the fault targets + faultRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, + requestTimeoutSeconds, chain, protocol, allIPv4CIDR, port, dropTarget) + if out, err := execRuleChangeCommand(faultRuleCmdString); err != nil { + return out, err } - logger.Info("Successfully appended new rule to iptable chain", logger.Fields{ - "command": appendRuleCmdString, - "output": string(cmdOutput), - "taskArn": taskArn, - }) // Inserting the chain into the built-in INPUT/OUTPUT table insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, requestTimeoutSeconds, insertTable, chain) diff --git a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go index d32206d6215..6ba1242d33d 100644 --- a/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go +++ b/ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go @@ -521,6 +521,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) }, }, @@ -554,6 +556,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) }, }, @@ -578,7 +582,7 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase }, }, { - name: fmt.Sprintf("%s fail append rule to chain", startNetworkBlackHolePortTestPrefix), + name: fmt.Sprintf("%s fail append ACCEPT rule to chain", startNetworkBlackHolePortTestPrefix), expectedStatusCode: 500, requestBody: happyBlackHolePortReqBody, expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), @@ -603,6 +607,34 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase ) }, }, + { + name: fmt.Sprintf("%s fail append DROP rule to chain", startNetworkBlackHolePortTestPrefix), + expectedStatusCode: 500, + requestBody: happyBlackHolePortReqBody, + expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError), + setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) { + agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient). + Return(happyTaskResponse, nil). + Times(1) + }, + setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) { + ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + gomock.InOrder( + exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")), + exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true), + exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")), + ) + }, + }, { name: fmt.Sprintf("%s fail insert chain to table", startNetworkBlackHolePortTestPrefix), expectedStatusCode: 500, diff --git a/ecs-agent/tmds/server.go b/ecs-agent/tmds/server.go index 0d72f659de0..7248bd7ff25 100644 --- a/ecs-agent/tmds/server.go +++ b/ecs-agent/tmds/server.go @@ -29,8 +29,10 @@ import ( const ( // TMDS IP and port - IPv4 = "127.0.0.1" - Port = 51679 + IPv4 = "127.0.0.1" + Port = 51679 + IPForTasks = "169.254.170.2" + PortForTasks = "80" ) // IPv4 address for TMDS