Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent network-blackhole-port from affecting TMDS access #4403

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3806,6 +3806,8 @@ func TestRegisterStartBlackholePortFaultHandler(t *testing.T) {
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
}
tcs := generateCommonNetworkFaultInjectionTestCases("start blackhole port", "running", setExecExpectations, happyBlackHolePortReqBody)
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 37 additions & 15 deletions ecs-agent/tmds/handlers/fault/v1/handlers/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/logger"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/field"
"github.com/aws/amazon-ecs-agent/ecs-agent/metrics"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/fault/v1/types"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils"
v4 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v4"
Expand All @@ -56,7 +57,7 @@ const (
requestTimeoutSeconds = 5
// Commands that will be used to start/stop/check fault.
iptablesNewChainCmd = "iptables -w %d -N %s"
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s --dport %s -j DROP"
iptablesAppendChainRuleCmd = "iptables -w %d -A %s -p %s -d %s --dport %s -j %s"
iptablesInsertChainCmd = "iptables -w %d -I %s -j %s"
iptablesChainExistCmd = "iptables -w %d -C %s -p %s --dport %s -j DROP"
iptablesClearChainCmd = "iptables -w %d -F %s"
Expand All @@ -71,6 +72,9 @@ const (
tcAddFilterForIPCommandString = "tc filter add dev %s protocol ip parent 1:0 prio 2 u32 match ip dst %s flowid 1:1"
tcDeleteQdiscParentCommandString = "tc qdisc del dev %s parent 1:1 handle 10:"
tcDeleteQdiscRootCommandString = "tc qdisc del dev %s root handle 1: prio"
allIPv4CIDR = "0.0.0.0/0"
dropTarget = "DROP"
acceptTarget = "ACCEPT"
)

type FaultHandler struct {
Expand Down Expand Up @@ -220,24 +224,42 @@ func (h *FaultHandler) startNetworkBlackholePort(ctx context.Context, protocol,
"taskArn": taskArn,
})

// Appending a new rule based on the protocol and port number from the request body
appendRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd, requestTimeoutSeconds, chain, protocol, port)
cmdOutput, err = h.runExecCommand(ctx, strings.Split(appendRuleCmdString, " "))
if err != nil {
logger.Error("Unable to append rule to chain", logger.Fields{
"netns": netNs,
"command": appendRuleCmdString,
// Helper function to run iptables rule change commands
var execRuleChangeCommand = func(cmdString string) (string, error) {
// Appending a new rule based on the protocol and port number from the request body
cmdOutput, err = h.runExecCommand(ctx, strings.Split(cmdString, " "))
if err != nil {
logger.Error("Unable to add rule to chain", logger.Fields{
"netns": netNs,
"command": cmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
"error": err,
})
return string(cmdOutput), err
}
logger.Info("Successfully added new rule to iptable chain", logger.Fields{
"command": cmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
"error": err,
})
return string(cmdOutput), err
return "", nil
}

// Add a rule to accept all traffic to TMDS
protectTMDSRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
xxx0624 marked this conversation as resolved.
Show resolved Hide resolved
requestTimeoutSeconds, chain, protocol, tmds.IPForTasks, tmds.PortForTasks,
acceptTarget)
if out, err := execRuleChangeCommand(protectTMDSRuleCmdString); err != nil {
return out, err
}

// Add a rule to drop all traffic to the port that the fault targets
faultRuleCmdString := nsenterPrefix + fmt.Sprintf(iptablesAppendChainRuleCmd,
requestTimeoutSeconds, chain, protocol, allIPv4CIDR, port, dropTarget)
if out, err := execRuleChangeCommand(faultRuleCmdString); err != nil {
return out, err
}
logger.Info("Successfully appended new rule to iptable chain", logger.Fields{
"command": appendRuleCmdString,
"output": string(cmdOutput),
"taskArn": taskArn,
})

// Inserting the chain into the built-in INPUT/OUTPUT table
insertChainCmdString := nsenterPrefix + fmt.Sprintf(iptablesInsertChainCmd, requestTimeoutSeconds, insertTable, chain)
Expand Down
34 changes: 33 additions & 1 deletion ecs-agent/tmds/handlers/fault/v1/handlers/handlers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand Down Expand Up @@ -554,6 +556,8 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
)
},
},
Expand All @@ -578,7 +582,7 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
},
},
{
name: fmt.Sprintf("%s fail append rule to chain", startNetworkBlackHolePortTestPrefix),
name: fmt.Sprintf("%s fail append ACCEPT rule to chain", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
requestBody: happyBlackHolePortReqBody,
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
Expand All @@ -603,6 +607,34 @@ func generateStartBlackHolePortFaultTestCases() []networkFaultInjectionTestCase
)
},
},
{
name: fmt.Sprintf("%s fail append DROP rule to chain", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
requestBody: happyBlackHolePortReqBody,
expectedResponseBody: types.NewNetworkFaultInjectionErrorResponse(internalError),
setAgentStateExpectations: func(agentState *mock_state.MockAgentState, netConfigClient *netconfig.NetworkConfigClient) {
agentState.EXPECT().GetTaskMetadataWithTaskNetworkConfig(endpointId, netConfigClient).
Return(happyTaskResponse, nil).
Times(1)
},
setExecExpectations: func(exec *mock_execwrapper.MockExec, ctrl *gomock.Controller) {
ctx, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration)
cmdExec := mock_execwrapper.NewMockCmd(ctrl)
gomock.InOrder(
exec.EXPECT().NewExecContextWithTimeout(gomock.Any(), gomock.Any()).Times(1).Return(ctx, cancel),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(iptablesChainNotFoundError), errors.New("exit status 1")),
exec.EXPECT().ConvertToExitError(gomock.Any()).Times(1).Return(nil, true),
exec.EXPECT().GetExitCode(gomock.Any()).Times(1).Return(1),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil),
exec.EXPECT().CommandContext(gomock.Any(), gomock.Any(), gomock.Any()).Times(1).Return(cmdExec),
cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte(internalError), errors.New("exit status 1")),
)
},
},
{
name: fmt.Sprintf("%s fail insert chain to table", startNetworkBlackHolePortTestPrefix),
expectedStatusCode: 500,
Expand Down
6 changes: 4 additions & 2 deletions ecs-agent/tmds/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ import (

const (
// TMDS IP and port
IPv4 = "127.0.0.1"
Port = 51679
IPv4 = "127.0.0.1"
Port = 51679
IPForTasks = "169.254.170.2"
PortForTasks = "80"
)

// IPv4 address for TMDS
Expand Down
Loading