diff --git a/api/utils/sshutils/ssh.go b/api/utils/sshutils/ssh.go index 69f3e13040dd6..f062e3924c4f0 100644 --- a/api/utils/sshutils/ssh.go +++ b/api/utils/sshutils/ssh.go @@ -271,7 +271,7 @@ func WithDialer(dialer contextDialer) RunSSHOption { } // RunSSH runs a command on an SSH server and returns the output. -func RunSSH(ctx context.Context, addr, command string, cfg *ssh.ClientConfig, opts ...RunSSHOption) ([]byte, error) { +func RunSSH(ctx context.Context, addr, command string, cfg *ssh.ClientConfig, opts ...RunSSHOption) ([]byte, []byte, error) { var options runSSHOpts for _, opt := range opts { opt(&options) @@ -279,26 +279,28 @@ func RunSSH(ctx context.Context, addr, command string, cfg *ssh.ClientConfig, op conn, err := options.dialContext(ctx, "tcp", addr) if err != nil { - return nil, trace.Wrap(err) + return nil, nil, trace.Wrap(err) } clientConn, newCh, requestsCh, err := ssh.NewClientConn(conn, addr, cfg) if err != nil { - return nil, trace.Wrap(err) + return nil, nil, trace.Wrap(err) } sshClient := ssh.NewClient(clientConn, newCh, requestsCh) defer sshClient.Close() session, err := sshClient.NewSession() if err != nil { - return nil, trace.Wrap(err) + return nil, nil, trace.Wrap(err) } defer session.Close() // Execute the command. - var b bytes.Buffer - session.Stdout = &b + var stdout bytes.Buffer + session.Stdout = &stdout + var stderr bytes.Buffer + session.Stderr = &stderr err = session.Run(command) - return b.Bytes(), trace.Wrap(err) + return stdout.Bytes(), stderr.Bytes(), trace.Wrap(err) } // ChannelReadWriter represents the data streams of an ssh.Channel-like object. diff --git a/lib/cloud/gcp/vm.go b/lib/cloud/gcp/vm.go index fc468feb01224..3fa118fa893b2 100644 --- a/lib/cloud/gcp/vm.go +++ b/lib/cloud/gcp/vm.go @@ -47,6 +47,9 @@ import ( // sshUser is the user to log in as on GCP VMs. const sshUser = "teleport" +// sshDefaultTimeout is the default timeout for dialing an instance. +const sshDefaultTimeout = 10 * time.Second + // convertAPIError converts an error from the GCP API into a trace error. func convertAPIError(err error) error { var googleError *googleapi.Error @@ -449,10 +452,6 @@ type RunCommandRequest struct { Script string // SSHPort is the ssh server port to connect to. Defaults to 22. SSHPort string - // UseExternalIP, if true, connects to the instance with an external IP - // address instead of the internal one. This is necessary if the instance - // isn't in the same VPC as this client. - UseExternalIP bool dialContext func(ctx context.Context, network, addr string) (net.Conn, error) } @@ -468,7 +467,9 @@ func (req *RunCommandRequest) CheckAndSetDefaults() error { req.SSHPort = "22" } if req.dialContext == nil { - dialer := net.Dialer{} + dialer := net.Dialer{ + Timeout: sshDefaultTimeout, + } req.dialContext = dialer.DialContext } return nil @@ -509,11 +510,14 @@ func RunCommand(ctx context.Context, req *RunCommandRequest) error { return trace.NotFound(`Instance %v is missing host keys. Did you enable guest attributes on the instance? https://cloud.google.com/solutions/connecting-securely#storing_host_keys_by_enabling_guest_attributes`, req.Name) } - ipAddress := instance.internalIPAddress - if req.UseExternalIP { - ipAddress = instance.externalIPAddress + var ipAddrs []string + if instance.externalIPAddress != "" { + ipAddrs = append(ipAddrs, instance.externalIPAddress) + } + if instance.internalIPAddress != "" { + ipAddrs = append(ipAddrs, instance.internalIPAddress) } - if ipAddress == "" { + if len(ipAddrs) == 0 { return trace.NotFound("Instance %v is missing an IP address.", req.Name) } keyReq := &SSHKeyRequest{ @@ -550,15 +554,25 @@ https://cloud.google.com/solutions/connecting-securely#storing_host_keys_by_enab }, HostKeyCallback: callback, } - addr := net.JoinHostPort(instance.externalIPAddress, req.SSHPort) - out, err := sshutils.RunSSH(ctx, addr, req.Script, config, sshutils.WithDialer(req.dialContext)) - if err != nil { - logrus.WithError(err).Debugf("Command exited with error.") + var errs []error + for _, ip := range ipAddrs { + addr := net.JoinHostPort(ip, req.SSHPort) + stdout, stderr, err := sshutils.RunSSH(ctx, addr, req.Script, config, sshutils.WithDialer(req.dialContext)) + logrus.Debug(string(stdout)) + logrus.Debug(string(stderr)) + if err == nil { + return nil + } + + // An exit error means the connection was successful, so don't try another address. if errors.Is(err, &ssh.ExitError{}) { - logrus.Debugf(string(out)) + return trace.Wrap(err) } - return trace.Wrap(err) + errs = append(errs, err) } - return nil + + err = trace.NewAggregate(errs...) + logrus.WithError(err).Debug("Command exited with error.") + return err }