Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 22 additions & 9 deletions tool/tsh/common/app_aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ func onAWS(cf *CLIConf) error {
}

func shouldUseAWSEndpointURLMode(cf *CLIConf) bool {
inputAWSCommand := strings.Join(removeAWSCommandFlags(cf.AWSCommandArgs), " ")
switch inputAWSCommand {
// `aws ssm start-session` first calls ssm.<region>.amazonaws.com to get an
// stream URL and an token. Then it makes a wss connection with the
// provided token to the provided stream URL. The wss request currently
Expand All @@ -92,27 +94,38 @@ func shouldUseAWSEndpointURLMode(cf *CLIConf) bool {
//
// Reference:
// https://github.com/aws/session-manager-plugin/
return isAWSCommand(cf, "ssm start-session")
}

func isAWSCommand(cf *CLIConf, wantCommand string) bool {
return strings.Join(removeAWSCommandFlags(cf.AWSCommandArgs), " ") == wantCommand
//
// "aws ecs execute-command" also start SSM sessions.
case "ssm start-session", "ecs execute-command":
return true
default:
return false
}
}

func removeAWSCommandFlags(args []string) (ret []string) {
for i := 0; i < len(args); i++ {
arg := args[i]
switch {
case strings.HasPrefix(arg, "--"):
i++
case isAWSFlag(args, i):
// Skip next arg, if next arg is not a flag but a flag value.
if !isAWSFlag(args, i+1) {
i++
}
continue
default:
ret = append(ret, arg)
ret = append(ret, args[i])
}
}
return
}

func isAWSFlag(args []string, i int) bool {
if i >= len(args) {
return false
}
return strings.HasPrefix(args[i], "--")
}

// awsApp is an AWS app that can start local proxies to serve AWS APIs.
type awsApp struct {
cf *CLIConf
Expand Down
16 changes: 16 additions & 0 deletions tool/tsh/common/app_aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,22 @@ func TestAWS(t *testing.T) {
)
require.NoError(t, err)
})
t.Run("aws ecs execute-command", func(t *testing.T) {
// Validate --endpoint-url 127.0.0.1:<port> is added to the command.
validateCmd := func(cmd *exec.Cmd) error {
require.Len(t, cmd.Args, 13)
require.Equal(t, []string{"aws", "ecs", "execute-command", "--debug", "--cluster", "cluster-name", "--task", "task-name", "--command", "/bin/bash", "--interactive", "--endpoint-url"}, cmd.Args[:12])
require.Contains(t, cmd.Args[12], "127.0.0.1:")
return nil
}
err = Run(
context.Background(),
[]string{"aws", "ecs", "execute-command", "--debug", "--cluster", "cluster-name", "--task", "task-name", "--command", "/bin/bash", "--interactive"},
setHomePath(tmpHomePath),
setCmdRunner(validateCmd),
)
require.NoError(t, err)
})
}

func makeUserWithAWSRole(t *testing.T) (types.User, types.Role) {
Expand Down