Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -2907,6 +2907,11 @@ func ConfigureOpenSSH(clf *CommandLineFlags, cfg *servicecfg.Config) error {
cfg.SetTokenSecret(clf.TokenSecret)
}

// apply --skip-version-check flag.
if clf.SkipVersionCheck {
cfg.SkipVersionCheck = clf.SkipVersionCheck
}

slog.DebugContext(context.Background(), "Disabling all services, only the Teleport OpenSSH service can run during the `teleport join openssh` command")
servicecfg.DisableLongRunningServices(cfg)

Expand Down
45 changes: 40 additions & 5 deletions lib/openssh/sshd.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ package openssh

import (
"bytes"
"context"
"fmt"
"log/slog"
"os"
"os/exec"
"path/filepath"
Expand All @@ -44,8 +46,6 @@ func sshdConfigInclude(dataDir string) string {
return fmt.Sprintf("Include %s", filepath.Join(dataDir, sshdConfigFile))
}

const DefaultRestartCommand = "systemctl restart sshd"

const (
// TeleportKey is the name the OpenSSH private key
TeleportKey = "ssh_host_teleport_key"
Expand Down Expand Up @@ -202,10 +202,17 @@ func (b *sshdBackend) restart() error {
return trace.Wrap(err)
}

cmd := exec.Command("/bin/sh", "-c", b.restartCmd)
if err := cmd.Run(); err != nil {
return trace.Wrap(err, "failed to restart the sshd service")
if b.restartCmd == "" {
if err := defaultRestart(); err != nil {
return trace.Wrap(err, "failed to restart OpenSSH")
}
} else {
cmd := exec.Command("/bin/sh", "-c", b.restartCmd)
if err := cmd.Run(); err != nil {
return trace.Wrap(err, "failed to restart OpenSSH")
}
}

return nil
}

Expand Down Expand Up @@ -263,3 +270,31 @@ func checkSSHDConfigAlreadyUpdated(sshdConfigPath, fileContains string) (bool, e
}
return !strings.Contains(string(contents), fileContains), nil
}

// defaultRestart invokes systemctl to restart the OpenSSH service. Due to varying names
// of OpenSSH services on different distributions both ssh and the sshd service are attempted
// to be restarted if they are active. An error is returned if any of the enabled services fail to be restarted.
func defaultRestart() error {
var restartErrors []error
for _, service := range []string{"ssh", "sshd"} {
sshShowCommand := exec.Command("/bin/sh", "-c", "systemctl show --property=ActiveState "+service+".service")
out, err := sshShowCommand.CombinedOutput()
if err != nil {
return trace.Wrap(err, "listing OpenSSH services")
}

const activeService = "ActiveState=active"
if !bytes.Equal([]byte(activeService), out) {
slog.DebugContext(context.Background(), "skipping inactive OpenSSH service", "service", service)
continue
}

slog.DebugContext(context.Background(), "restarting active OpenSSH service", "service", service)
restartCommand := exec.Command("/bin/sh", "-c", "systemctl restart "+service)
if err := restartCommand.Run(); err != nil {
restartErrors = append(restartErrors, err)
}
}

return trace.NewAggregate(restartErrors...)
}
8 changes: 5 additions & 3 deletions tool/teleport/common/teleport.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@ import (
"github.com/gravitational/teleport/lib/defaults"
dtconfig "github.com/gravitational/teleport/lib/devicetrust/config"
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/openssh"
"github.com/gravitational/teleport/lib/selinux"
"github.com/gravitational/teleport/lib/service"
"github.com/gravitational/teleport/lib/service/servicecfg"
Expand Down Expand Up @@ -492,12 +491,15 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con
joinOpenSSH.Flag("data-dir", fmt.Sprintf("Path to directory to store teleport data [%v].", defaults.DataDir)).Default(defaults.DataDir).StringVar(&ccf.DataDir)
joinOpenSSH.Flag("restart-sshd", "Restart OpenSSH.").Default("true").BoolVar(&ccf.RestartOpenSSH)
joinOpenSSH.Flag("sshd-check-command", "Command to use when checking OpenSSH config for validity. (sshd -t -f <sshd_config>)").Default("sshd -t -f").StringVar(&ccf.CheckCommand)
joinOpenSSH.Flag("sshd-restart-command", "Command to use when restarting openssh.").Default(openssh.DefaultRestartCommand).StringVar(&ccf.RestartCommand)
joinOpenSSH.Flag("sshd-restart-command", "Command to use when restarting openssh.").StringVar(&ccf.RestartCommand)
joinOpenSSH.Flag("labels", "Comma-separated list of labels for this OpenSSH node, for example env=dev,app=web.").StringVar(&ccf.Labels)
joinOpenSSH.Flag("address", "Hostname or IP address of this OpenSSH node.").StringVar(&ccf.Address)
joinOpenSSH.Flag("additional-principals", "Additional principal to include, can be specified multiple times.").StringVar(&ccf.AdditionalPrincipals)
joinOpenSSH.Flag("insecure", "Insecure mode disables certificate validation.").BoolVar(&ccf.InsecureMode)

joinOpenSSH.Flag("skip-version-check",
"Skip version checking between server and client.").
Default("false").
BoolVar(&ccf.SkipVersionCheck)
joinOpenSSH.Flag("debug", "Enable verbose logging to stderr.").Short('d').BoolVar(&ccf.Debug)

integrationCmd := app.Command("integration", "Integration commands")
Expand Down
Loading