Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions api/utils/sshutils/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -271,34 +271,36 @@ func WithDialer(dialer contextDialer) RunSSHOption {
}

// RunSSH runs a command on an SSH server and returns the output.
func RunSSH(ctx context.Context, addr, command string, cfg *ssh.ClientConfig, opts ...RunSSHOption) ([]byte, error) {
func RunSSH(ctx context.Context, addr, command string, cfg *ssh.ClientConfig, opts ...RunSSHOption) ([]byte, []byte, error) {
var options runSSHOpts
for _, opt := range opts {
opt(&options)
}

conn, err := options.dialContext(ctx, "tcp", addr)
if err != nil {
return nil, trace.Wrap(err)
return nil, nil, trace.Wrap(err)
}

clientConn, newCh, requestsCh, err := ssh.NewClientConn(conn, addr, cfg)
if err != nil {
return nil, trace.Wrap(err)
return nil, nil, trace.Wrap(err)
}
sshClient := ssh.NewClient(clientConn, newCh, requestsCh)
defer sshClient.Close()
session, err := sshClient.NewSession()
if err != nil {
return nil, trace.Wrap(err)
return nil, nil, trace.Wrap(err)
}
defer session.Close()

// Execute the command.
var b bytes.Buffer
session.Stdout = &b
var stdout bytes.Buffer
session.Stdout = &stdout
var stderr bytes.Buffer
session.Stderr = &stderr
err = session.Run(command)
return b.Bytes(), trace.Wrap(err)
return stdout.Bytes(), stderr.Bytes(), trace.Wrap(err)
}

// ChannelReadWriter represents the data streams of an ssh.Channel-like object.
Expand Down
46 changes: 30 additions & 16 deletions lib/cloud/gcp/vm.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ import (
// sshUser is the user to log in as on GCP VMs.
const sshUser = "teleport"

// sshDefaultTimeout is the default timeout for dialing an instance.
const sshDefaultTimeout = 10 * time.Second

// convertAPIError converts an error from the GCP API into a trace error.
func convertAPIError(err error) error {
var googleError *googleapi.Error
Expand Down Expand Up @@ -447,10 +450,6 @@ type RunCommandRequest struct {
Script string
// SSHPort is the ssh server port to connect to. Defaults to 22.
SSHPort string
// UseExternalIP, if true, connects to the instance with an external IP
// address instead of the internal one. This is necessary if the instance
// isn't in the same VPC as this client.
UseExternalIP bool

dialContext func(ctx context.Context, network, addr string) (net.Conn, error)
}
Expand All @@ -466,7 +465,9 @@ func (req *RunCommandRequest) CheckAndSetDefaults() error {
req.SSHPort = "22"
}
if req.dialContext == nil {
dialer := net.Dialer{}
dialer := net.Dialer{
Timeout: sshDefaultTimeout,
}
req.dialContext = dialer.DialContext
}
return nil
Expand Down Expand Up @@ -507,11 +508,14 @@ func RunCommand(ctx context.Context, req *RunCommandRequest) error {
return trace.NotFound(`Instance %v is missing host keys. Did you enable guest attributes on the instance?
https://cloud.google.com/solutions/connecting-securely#storing_host_keys_by_enabling_guest_attributes`, req.Name)
}
ipAddress := instance.internalIPAddress
if req.UseExternalIP {
ipAddress = instance.externalIPAddress
var ipAddrs []string
if instance.externalIPAddress != "" {
ipAddrs = append(ipAddrs, instance.externalIPAddress)
}
if instance.internalIPAddress != "" {
ipAddrs = append(ipAddrs, instance.internalIPAddress)
}
if ipAddress == "" {
if len(ipAddrs) == 0 {
return trace.NotFound("Instance %v is missing an IP address.", req.Name)
}
keyReq := &SSHKeyRequest{
Expand Down Expand Up @@ -548,15 +552,25 @@ https://cloud.google.com/solutions/connecting-securely#storing_host_keys_by_enab
},
HostKeyCallback: callback,
}
addr := net.JoinHostPort(instance.externalIPAddress, req.SSHPort)

out, err := sshutils.RunSSH(ctx, addr, req.Script, config, sshutils.WithDialer(req.dialContext))
if err != nil {
logrus.WithError(err).Debugf("Command exited with error.")
var errs []error
for _, ip := range ipAddrs {
addr := net.JoinHostPort(ip, req.SSHPort)
stdout, stderr, err := sshutils.RunSSH(ctx, addr, req.Script, config, sshutils.WithDialer(req.dialContext))
logrus.Debug(string(stdout))
logrus.Debug(string(stderr))
if err == nil {
return nil
}

// An exit error means the connection was successful, so don't try another address.
if errors.Is(err, &ssh.ExitError{}) {
logrus.Debugf(string(out))
return trace.Wrap(err)
}
return trace.Wrap(err)
errs = append(errs, err)
}
return nil

err = trace.NewAggregate(errs...)
logrus.WithError(err).Debug("Command exited with error.")
return err
}