Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions lib/srv/server/ssm_install.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,16 @@ import (
libevents "github.com/gravitational/teleport/lib/events"
)

// waiterTimedOutErrorMessage is the error message returned by the AWS SDK command
// executed waiter when it times out.
const waiterTimedOutErrorMessage = "exceeded max wait time for CommandExecuted waiter"
const (
// waiterTimedOutErrorMessage is the error message returned by the AWS SDK command
// executed waiter when it times out.
waiterTimedOutErrorMessage = "exceeded max wait time for CommandExecuted waiter"

// waiterTransitionedToFailureErrorMessage is the error message returned by the AWS SDK command
//nolint:misspell // ignore Cancelled and Cancelling
// executed waiter when the command state transitions to one of Cancelled, TimedOut, Failed or Cancelling.
waiterTransitionedToFailureErrorMessage = "waiter state transitioned to Failure"
)

// SSMClient is the subset of the AWS SSM API required for EC2 discovery.
type SSMClient interface {
Expand Down Expand Up @@ -373,25 +380,29 @@ func (si *SSMInstaller) describeSSMAgentState(ctx context.Context, req SSMRunReq
return ret, nil
}

// skipAWSWaitErr is used to ignore the error returned from
// Wait if it times out, as this can represent one of several different errors which
// are handled by checking the command invocation after calling this
// to get more information about the error.
func skipAWSWaitErr(err error) error {
if err != nil && err.Error() == waiterTimedOutErrorMessage {
return nil
}
return trace.Wrap(err)
}

func (si *SSMInstaller) checkCommand(ctx context.Context, req SSMRunRequest, commandID, instanceID *string, instanceName string) error {
err := si.getWaiter(req.SSM).Wait(ctx, &ssm.GetCommandInvocationInput{
CommandId: commandID,
InstanceId: instanceID,
// 100 seconds to match v1 sdk waiter default.
}, 100*time.Second)
switch {
case err == nil:
// Command executed successfully.

case err.Error() == waiterTransitionedToFailureErrorMessage:
//nolint:misspell // ignore Cancelled and Cancelling
// When the command invocation state is one of Cancelled, TimedOut, Failed or Cancelling,
// the waiter returns the error message above.
// Ignoring this error allows us to get the actual command status to report it.

if err := skipAWSWaitErr(err); err != nil {
case err.Error() == waiterTimedOutErrorMessage:
// When the Waiter times out, it returns the error message above.
// The command might still be Pending or InProgress.
// Ignoring this error allows us to report that status back to the user.

default:
// For every other unknown error, return the error.
return trace.Wrap(err)
}

Expand Down
66 changes: 65 additions & 1 deletion lib/srv/server/ssm_install_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package server
import (
"context"
"fmt"
"slices"
"testing"
"time"

Expand All @@ -37,6 +38,7 @@ import (
type mockSSMClient struct {
SSMClient
commandOutput *ssm.SendCommandOutput
waiterTimeout bool
commandInvokeOutput map[string]*ssm.GetCommandInvocationOutput
describeOutput *ssm.DescribeInstanceInformationOutput
listCommandInvocations *ssm.ListCommandInvocationsOutput
Expand Down Expand Up @@ -80,9 +82,20 @@ func (sm *mockSSMClient) ListCommandInvocations(_ context.Context, input *ssm.Li
}

func (sm *mockSSMClient) Wait(ctx context.Context, params *ssm.GetCommandInvocationInput, maxWaitDur time.Duration, optFns ...func(*ssm.CommandExecutedWaiterOptions)) error {
if sm.commandOutput.Command.Status == ssmtypes.CommandStatusFailed {
if sm.waiterTimeout {
return trace.Errorf(waiterTimedOutErrorMessage)
}

var failureStates = []ssmtypes.CommandStatus{
ssmtypes.CommandStatusCancelled,
ssmtypes.CommandStatusFailed,
ssmtypes.CommandStatusTimedOut,
ssmtypes.CommandStatusCancelling,
}

if slices.Contains(failureStates, sm.commandOutput.Command.Status) {
return trace.Errorf(waiterTransitionedToFailureErrorMessage)
}
return nil
}

Expand Down Expand Up @@ -250,6 +263,7 @@ func TestSSMInstaller(t *testing.T) {
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
CommandId: aws.String("command-id-1"),
Status: ssmtypes.CommandStatusFailed,
},
},
commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{
Expand Down Expand Up @@ -282,6 +296,55 @@ func TestSSMInstaller(t *testing.T) {
SSMDocumentName: "ssmdocument",
}},
},
{
name: "ssm run takes too long, and waiter times out",
req: SSMRunRequest{
DocumentName: document,
Instances: []EC2Instance{
{InstanceID: "instance-id-1"},
},
IntegrationName: "aws-1",
Params: map[string]string{"token": "abcdefg"},
Region: "eu-central-1",
AccountID: "account-id",
},
client: &mockSSMClient{
waiterTimeout: true,
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
CommandId: aws.String("command-id-1"),
Status: ssmtypes.CommandStatusInProgress,
},
},
commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{
"downloadContent": {
Status: ssmtypes.CommandInvocationStatusInProgress,
StandardErrorContent: aws.String("downloading..."),
StandardOutputContent: aws.String(""),
},
},
},
expectedInstallations: []*SSMInstallationResult{{
IntegrationName: "aws-1",
SSMRunEvent: &events.SSMRun{
Metadata: events.Metadata{
Type: libevent.SSMRunEvent,
Code: libevent.SSMRunFailCode,
},
CommandID: "command-id-1",
InstanceID: "instance-id-1",
AccountID: "account-id",
Region: "eu-central-1",
ExitCode: -1,
Status: string(ssmtypes.CommandInvocationStatusInProgress),
StandardOutput: "",
StandardError: "downloading...",
InvocationURL: "https://eu-central-1.console.aws.amazon.com/systems-manager/run-command/command-id-1/instance-id-1",
},
IssueType: "ec2-ssm-script-failure",
SSMDocumentName: "ssmdocument",
}},
},
{
name: "ssm run failed in run shell script",
req: SSMRunRequest{
Expand All @@ -297,6 +360,7 @@ func TestSSMInstaller(t *testing.T) {
commandOutput: &ssm.SendCommandOutput{
Command: &ssmtypes.Command{
CommandId: aws.String("command-id-1"),
Status: ssmtypes.CommandStatusFailed,
},
},
commandInvokeOutput: map[string]*ssm.GetCommandInvocationOutput{
Expand Down
Loading