diff --git a/agent/app/agent_capability_unix.go b/agent/app/agent_capability_unix.go index aceef91c60b..a2da74c6a0f 100644 --- a/agent/app/agent_capability_unix.go +++ b/agent/app/agent_capability_unix.go @@ -18,6 +18,7 @@ package app import ( "context" + "fmt" "os/exec" "path/filepath" "strings" @@ -30,6 +31,7 @@ import ( "github.com/aws/amazon-ecs-agent/agent/taskresource/volume" "github.com/aws/amazon-ecs-agent/agent/utils" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" "github.com/aws/aws-sdk-go/aws" "github.com/cihub/seelog" @@ -45,6 +47,7 @@ const ( modInfoCmd = "modinfo" faultInjectionKernelModules = "sch_netem" ctxTimeoutDuration = 60 * time.Second + tcShowCmdString = "tc qdisc show dev %s root" ) var ( @@ -250,6 +253,7 @@ var isFaultInjectionToolingAvailable = checkFaultInjectionTooling // wrapper around exec.LookPath var lookPathFunc = exec.LookPath var osExecWrapper = execwrapper.NewExec() +var networkConfigClient = netconfig.NewNetworkConfigClient() // checkFaultInjectionTooling checks for the required network packages like iptables, tc // to be available on the host before ecs.capability.fault-injection can be advertised @@ -263,7 +267,7 @@ func checkFaultInjectionTooling() bool { return false } } - return checkFaultInjectionModules() + return checkFaultInjectionModules() && checkTCShowTooling() } // checkFaultInjectionModules checks for the required kernel modules such as sch_netem to be installed @@ -278,3 +282,20 @@ func checkFaultInjectionModules() bool { } return true } + +func checkTCShowTooling() bool { + ctxWithTimeout, cancel := context.WithTimeout(context.Background(), ctxTimeoutDuration) + defer cancel() + hostDeviceName, netErr := netconfig.DefaultNetInterfaceName(networkConfigClient.NetlinkClient) + if netErr != nil { + seelog.Warnf("Failed to obtain the network interface device name on the host: %v", netErr) + return false + } + tcShowCmd := fmt.Sprintf(tcShowCmdString, hostDeviceName) + _, err := osExecWrapper.CommandContext(ctxWithTimeout, tcShowCmd).CombinedOutput() + if err != nil { + seelog.Warnf("Failed to call %s which is needed for fault-injection feature: %v", err) + return false + } + return true +} diff --git a/agent/app/agent_capability_unix_test.go b/agent/app/agent_capability_unix_test.go index 526c898e7a7..ac983563d08 100644 --- a/agent/app/agent_capability_unix_test.go +++ b/agent/app/agent_capability_unix_test.go @@ -19,6 +19,8 @@ package app import ( "context" "errors" + "fmt" + "net" "os" "os/exec" "path/filepath" @@ -40,12 +42,36 @@ import ( mock_mobypkgwrapper "github.com/aws/amazon-ecs-agent/agent/utils/mobypkgwrapper/mocks" "github.com/aws/amazon-ecs-agent/ecs-agent/api/ecs/model/ecs" md "github.com/aws/amazon-ecs-agent/ecs-agent/manageddaemon" + "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/utils/netconfig" "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper" mock_execwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/execwrapper/mocks" + mock_netlinkwrapper "github.com/aws/amazon-ecs-agent/ecs-agent/utils/netlinkwrapper/mocks" "github.com/aws/aws-sdk-go/aws" aws_credentials "github.com/aws/aws-sdk-go/aws/credentials" "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" + "github.com/vishvananda/netlink" +) + +const ( + deviceName = "eth0" + internalError = "internal error" +) + +var ( + routes = []netlink.Route{ + netlink.Route{ + Gw: net.ParseIP("10.194.20.1"), + Dst: nil, + LinkIndex: 0, + }, + } + link = &netlink.Device{ + LinkAttrs: netlink.LinkAttrs{ + Index: 0, + Name: deviceName, + }, + } ) func init() { @@ -982,8 +1008,10 @@ func TestCheckFaultInjectionTooling(t *testing.T) { lookPathFunc = originalLookPath }() originalOSExecWrapper := execwrapper.NewExec() + originalNetConfig := netconfig.NewNetworkConfigClient() defer func() { osExecWrapper = originalOSExecWrapper + networkConfigClient = originalNetConfig }() t.Run("all tools and kernel modules available", func(t *testing.T) { @@ -994,9 +1022,19 @@ func TestCheckFaultInjectionTooling(t *testing.T) { defer ctrl.Finish() mockExec := mock_execwrapper.NewMockExec(ctrl) cmdExec := mock_execwrapper.NewMockCmd(ctrl) + mock_netlinkwrapper := mock_netlinkwrapper.NewMockNetLink(ctrl) + + gomock.InOrder( + mock_netlinkwrapper.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, nil).AnyTimes(), + mock_netlinkwrapper.EXPECT().LinkByIndex(link.Attrs().Index).Return(link, nil).AnyTimes(), + ) + networkConfigClient.NetlinkClient = mock_netlinkwrapper gomock.InOrder( mockExec.EXPECT().CommandContext(gomock.Any(), modInfoCmd, faultInjectionKernelModules).Times(1).Return(cmdExec), cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + + mockExec.EXPECT().CommandContext(gomock.Any(), fmt.Sprintf(tcShowCmdString, deviceName)).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), ) osExecWrapper = mockExec assert.True(t, @@ -1022,6 +1060,58 @@ func TestCheckFaultInjectionTooling(t *testing.T) { "Expected checkFaultInjectionTooling to return false when kernel modules are not available") }) + t.Run("failed to obtain default host device name", func(t *testing.T) { + lookPathFunc = func(file string) (string, error) { + return "/usr/bin" + file, nil + } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockExec := mock_execwrapper.NewMockExec(ctrl) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + mock_netlinkwrapper := mock_netlinkwrapper.NewMockNetLink(ctrl) + + gomock.InOrder( + mock_netlinkwrapper.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, errors.New(internalError)).AnyTimes(), + ) + networkConfigClient.NetlinkClient = mock_netlinkwrapper + gomock.InOrder( + mockExec.EXPECT().CommandContext(gomock.Any(), modInfoCmd, faultInjectionKernelModules).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + ) + osExecWrapper = mockExec + assert.False(t, + checkFaultInjectionTooling(), + "Expected checkFaultInjectionTooling to return false when unable to find default host interface name") + }) + + t.Run("failed tc show command", func(t *testing.T) { + lookPathFunc = func(file string) (string, error) { + return "/usr/bin" + file, nil + } + ctrl := gomock.NewController(t) + defer ctrl.Finish() + mockExec := mock_execwrapper.NewMockExec(ctrl) + cmdExec := mock_execwrapper.NewMockCmd(ctrl) + mock_netlinkwrapper := mock_netlinkwrapper.NewMockNetLink(ctrl) + + gomock.InOrder( + mock_netlinkwrapper.EXPECT().RouteList(nil, netlink.FAMILY_ALL).Return(routes, nil).AnyTimes(), + mock_netlinkwrapper.EXPECT().LinkByIndex(link.Attrs().Index).Return(link, nil).AnyTimes(), + ) + networkConfigClient.NetlinkClient = mock_netlinkwrapper + gomock.InOrder( + mockExec.EXPECT().CommandContext(gomock.Any(), modInfoCmd, faultInjectionKernelModules).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, nil), + + mockExec.EXPECT().CommandContext(gomock.Any(), fmt.Sprintf(tcShowCmdString, deviceName)).Times(1).Return(cmdExec), + cmdExec.EXPECT().CombinedOutput().Times(1).Return([]byte{}, errors.New("What is \"root\"? Try \"tc qdisc help\".")), + ) + osExecWrapper = mockExec + assert.False(t, + checkFaultInjectionTooling(), + "Expected checkFaultInjectionTooling to return false when required tc show command failed") + }) + tools := []string{"iptables", "tc", "nsenter"} for _, tool := range tools { t.Run(tool+" missing", func(t *testing.T) {