diff --git a/tool/tsh/common/app_aws.go b/tool/tsh/common/app_aws.go index 2cbc2307e1072..d92496d56c7d0 100644 --- a/tool/tsh/common/app_aws.go +++ b/tool/tsh/common/app_aws.go @@ -78,6 +78,8 @@ func onAWS(cf *CLIConf) error { } func shouldUseAWSEndpointURLMode(cf *CLIConf) bool { + inputAWSCommand := strings.Join(removeAWSCommandFlags(cf.AWSCommandArgs), " ") + switch inputAWSCommand { // `aws ssm start-session` first calls ssm..amazonaws.com to get an // stream URL and an token. Then it makes a wss connection with the // provided token to the provided stream URL. The wss request currently @@ -92,27 +94,38 @@ func shouldUseAWSEndpointURLMode(cf *CLIConf) bool { // // Reference: // https://github.com/aws/session-manager-plugin/ - return isAWSCommand(cf, "ssm start-session") -} - -func isAWSCommand(cf *CLIConf, wantCommand string) bool { - return strings.Join(removeAWSCommandFlags(cf.AWSCommandArgs), " ") == wantCommand + // + // "aws ecs execute-command" also start SSM sessions. + case "ssm start-session", "ecs execute-command": + return true + default: + return false + } } func removeAWSCommandFlags(args []string) (ret []string) { for i := 0; i < len(args); i++ { - arg := args[i] switch { - case strings.HasPrefix(arg, "--"): - i++ + case isAWSFlag(args, i): + // Skip next arg, if next arg is not a flag but a flag value. + if !isAWSFlag(args, i+1) { + i++ + } continue default: - ret = append(ret, arg) + ret = append(ret, args[i]) } } return } +func isAWSFlag(args []string, i int) bool { + if i >= len(args) { + return false + } + return strings.HasPrefix(args[i], "--") +} + // awsApp is an AWS app that can start local proxies to serve AWS APIs. type awsApp struct { cf *CLIConf diff --git a/tool/tsh/common/app_aws_test.go b/tool/tsh/common/app_aws_test.go index f71a099a8d977..79b02bc6beb60 100644 --- a/tool/tsh/common/app_aws_test.go +++ b/tool/tsh/common/app_aws_test.go @@ -159,6 +159,22 @@ func TestAWS(t *testing.T) { ) require.NoError(t, err) }) + t.Run("aws ecs execute-command", func(t *testing.T) { + // Validate --endpoint-url 127.0.0.1: is added to the command. + validateCmd := func(cmd *exec.Cmd) error { + require.Len(t, cmd.Args, 13) + require.Equal(t, []string{"aws", "ecs", "execute-command", "--debug", "--cluster", "cluster-name", "--task", "task-name", "--command", "/bin/bash", "--interactive", "--endpoint-url"}, cmd.Args[:12]) + require.Contains(t, cmd.Args[12], "127.0.0.1:") + return nil + } + err = Run( + context.Background(), + []string{"aws", "ecs", "execute-command", "--debug", "--cluster", "cluster-name", "--task", "task-name", "--command", "/bin/bash", "--interactive"}, + setHomePath(tmpHomePath), + setCmdRunner(validateCmd), + ) + require.NoError(t, err) + }) } func makeUserWithAWSRole(t *testing.T) (types.User, types.Role) {