Skip to content

Commit

Permalink
Prevent network-blackhole-port from affecting TMDS access (#4403)
Browse files Browse the repository at this point in the history
* Protect TMDS IP from being affected by network-blackhole-port fault

* Fix test

---------

Co-authored-by: xingzhen <[email protected]>
  • Loading branch information
amogh09 and xxx0624 authored Oct 22, 2024
1 parent f7dfa32 commit 79f17a5
Show file tree
Hide file tree
Showing 6 changed files with 117 additions and 35 deletions.
2 changes: 2 additions & 0 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3806,6 +3806,8 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 37 additions & 15 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
"github.com/aws/amazon-ecs-agent/ecs-agent/metrics"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils"
v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4"
Expand All @@ -56,7 +57,7 @@ const (
requestTimeoutSeconds = 5
// Commands that will be used to start/stop/check fault.
iptablesNewChainCmd = "iptables -w %d -N %s"
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s --dport %s -j DROP"
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s -d %s --dport %s -j %s"
iptablesInsertChainCmd = "iptables -w %d -I %s -j %s"
iptablesChainExistCmd = "iptables -w %d -C %s -p %s --dport %s -j DROP"
iptablesClearChainCmd = "iptables -w %d -F %s"
Expand All @@ -71,6 +72,9 @@ const (
tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1"
tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:"
tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio"
allIPv4CIDR = "0.0.0.0/0"
dropTarget = "DROP"
acceptTarget = "ACCEPT"
)

type FaultHandler struct {
Expand Down Expand Up @@ -220,24 +224,42 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol,
"taskArn": taskArn,
})

// Appending a new rule based on the protocol and port number from the request body
appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, requestTimeoutSeconds, chain, protocol, port)
cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " "))
if err != nil {
logger.Error("Unable to append rule to chain", logger.Fields{
"netns": netNs,
"command": appendRuleCmdString,
// Helper function to run iptables rule change commands
var execRuleChangeCommand = func(cmdString string) (string, error) {
// Appending a new rule based on the protocol and port number from the request body
cmdOutput, err = h.runExecCommand(ctx, strings.Split(cmdString, " "))
if err != nil {
logger.Error("Unable to add rule to chain", logger.Fields{
"netns": netNs,
"command": cmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
"error": err,
})
return string(cmdOutput), err
}
logger.Info("Successfully added new rule to iptable chain", logger.Fields{
"command": cmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
"error": err,
})
return string(cmdOutput), err
return "", nil
}

// Add a rule to accept all traffic to TMDS
protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks,
acceptTarget)
if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil {
return out, err
}

// Add a rule to drop all traffic to the port that the fault targets
faultRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, allIPv4CIDR, port, dropTarget)
if out, err := execRuleChangeCommand(faultRuleCmdString); err != nil {
return out, err
}
logger.Info("Successfully appended new rule to iptable chain", logger.Fields{
"command": appendRuleCmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
})

// Inserting the chain into the built-in INPUT/OUTPUT table
insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, requestTimeoutSeconds, insertTable, chain)
Expand Down
34 changes: 33 additions & 1 deletion ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand Down Expand Up @@ -554,6 +556,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand All @@ -578,7 +582,7 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
},
},
{
name: fmt.Sprintf("%s fail append rule to chain", startNetworkBlackHolePortTestPrefix),
name: fmt.Sprintf("%s fail append ACCEPT rule to chain", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
requestBody: happyBlackHolePortReqBody,
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
Expand All @@ -603,6 +607,34 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
)
},
},
{
name: fmt.Sprintf("%s fail append DROP rule to chain", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
requestBody: happyBlackHolePortReqBody,
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")),
)
},
},
{
name: fmt.Sprintf("%s fail insert chain to table", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
Expand Down
6 changes: 4 additions & 2 deletions ecs-agent/tmds/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ import (

const (
// TMDS IP and port
IPv4 = "127.0.0.1"
Port = 51679
IPv4 = "127.0.0.1"
Port = 51679
IPForTasks = "169.254.170.2"
PortForTasks = "80"
)

// IPv4 address for TMDS
Expand Down

0 comments on commit 79f17a5

Please sign in to comment.