diff --git a/api/types/constants.go b/api/types/constants.go index 04239b1653bb4..4670e089b9cb0 100644 --- a/api/types/constants.go +++ b/api/types/constants.go @@ -464,6 +464,9 @@ const ( // found via automatic discovery, to avoid re-running installation // commands on the node. AWSInstanceIDLabel = TeleportNamespace + "/instance-id" + // AWSInstanceRegion is used to identify the region an EC2 + // instance is running in + AWSInstanceRegion = TeleportNamespace + "/aws-region" // SubscriptionIDLabel is used to identify virtual machines by Azure // subscription ID found via automatic discovery, to avoid re-running // installation commands on the node. diff --git a/lib/cloud/aws/imds.go b/lib/cloud/aws/imds.go index 9b2db36cf1c16..873bedec70db9 100644 --- a/lib/cloud/aws/imds.go +++ b/lib/cloud/aws/imds.go @@ -159,3 +159,31 @@ func (client *InstanceMetadataClient) GetID(ctx context.Context) (string, error) return id, nil } + +// GetLocalIPV4 gets the EC2 instance's local ipv4 address. +func (client *InstanceMetadataClient) GetLocalIPV4(ctx context.Context) (string, error) { + ip, err := client.getMetadata(ctx, "local-ipv4") + if err != nil { + return "", trace.Wrap(err) + } + + return ip, nil +} + +// GetPublicIPV4 gets the EC2 instance's local ipv4 address. +func (client *InstanceMetadataClient) GetPublicIPV4(ctx context.Context) (string, error) { + ip, err := client.getMetadata(ctx, "public-ipv4") + if err != nil { + return "", trace.Wrap(err) + } + + return ip, nil +} + +func (client *InstanceMetadataClient) GetAccountID(ctx context.Context) (string, error) { + idOut, err := client.c.GetInstanceIdentityDocument(ctx, &imds.GetInstanceIdentityDocumentInput{}) + if err != nil { + return "", trace.Wrap(err) + } + return idOut.AccountID, nil +} diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 9e38f73bae6e8..2fe0db3e3d0b1 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -176,16 +176,22 @@ type CommandLineFlags struct { // if the value cannot be obtained from the database. DatabaseMySQLServerVersion string - // ProxyServer is the url of the proxy server to connect to + // ProxyServer is the url of the proxy server to connect to. ProxyServer string - // OpenSSHConfigPath is the path of the file to write agentless configuration to + // OpenSSHConfigPath is the path of the file to write agentless configuration to. OpenSSHConfigPath string - // OpenSSHKeysPath is the path to write teleport keys and certs into - OpenSSHKeysPath string - // AdditionalPrincipals are a list of extra principals to include when generating host keys. - AdditionalPrincipals string // RestartOpenSSH indicates whether openssh should be restarted or not. RestartOpenSSH bool + // RestartCommand is the command to use when restarting sshd + RestartCommand string + // CheckCommand is the command to use when checking sshd config validity + CheckCommand string + // Address is the ip address of the OpenSSH node. + Address string + // AdditionalPrincipals is a list of additional principals to include in the SSH cert. + AdditionalPrincipals string + // Directory to store + DataDir string } // ReadConfigFile reads /etc/teleport.yaml (or whatever is passed via --config flag) @@ -2144,6 +2150,63 @@ func Configure(clf *CommandLineFlags, cfg *servicecfg.Config, legacyAppFlags boo return nil } +// ConfigureOpenSSH initializes a config from the commandline flags passed +func ConfigureOpenSSH(clf *CommandLineFlags, cfg *servicecfg.Config) error { + // pass the value of --insecure flag to the runtime + lib.SetInsecureDevMode(clf.InsecureMode) + + // Apply command line --debug flag to override logger severity. + if clf.Debug { + log.SetLevel(log.DebugLevel) + cfg.Log.SetLevel(log.DebugLevel) + cfg.Debug = clf.Debug + } + + if clf.AuthToken != "" { + // store the value of the --token flag: + cfg.SetToken(clf.AuthToken) + } + + log.Debugf("Disabling all services, only the Teleport OpenSSH service can run during the `teleport join openssh` command") + servicecfg.DisableLongRunningServices(cfg) + + cfg.DataDir = clf.DataDir + cfg.Version = defaults.TeleportConfigVersionV3 + cfg.OpenSSH.SSHDConfigPath = clf.OpenSSHConfigPath + cfg.OpenSSH.RestartSSHD = clf.RestartOpenSSH + cfg.OpenSSH.RestartCommand = clf.RestartCommand + cfg.OpenSSH.CheckCommand = clf.CheckCommand + cfg.JoinMethod = types.JoinMethod(clf.JoinMethod) + + hostname, err := os.Hostname() + if err != nil { + return trace.Wrap(err) + } + + cfg.Hostname = hostname + cfg.OpenSSH.InstanceAddr = clf.Address + cfg.OpenSSH.AdditionalPrincipals = []string{hostname, clf.Address} + for _, principal := range strings.Split(clf.AdditionalPrincipals, ",") { + if principal == "" { + continue + } + cfg.OpenSSH.AdditionalPrincipals = append(cfg.OpenSSH.AdditionalPrincipals, principal) + } + cfg.OpenSSH.Labels, err = client.ParseLabelSpec(clf.Labels) + if err != nil { + return trace.Wrap(err) + } + + proxyServer, err := utils.ParseAddr(clf.ProxyServer) + if err != nil { + return trace.Wrap(err) + } + cfg.SetAuthServerAddresses(nil) + cfg.ProxyServer = *proxyServer + + return nil +} + // parseLabels parses the labels command line flag and returns static and // dynamic labels. func parseLabels(spec string) (map[string]string, services.CommandLabels, error) { diff --git a/lib/openssh/sshd.go b/lib/openssh/sshd.go new file mode 100644 index 0000000000000..6277baab2dc2f --- /dev/null +++ b/lib/openssh/sshd.go @@ -0,0 +1,263 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openssh + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "path/filepath" + "strings" + "text/template" + + "github.com/gravitational/trace" + + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/auth" +) + +var ( + // SSHDConfigPath is the path to write teleport specific SSHD config options + sshdConfigFile = "sshd.conf" + // SSHDKeysDir is the path to the Teleport openssh keys + SSHDKeysDir = "openssh" +) + +func sshdConfigInclude(dataDir string) string { + return fmt.Sprintf("Include %s", filepath.Join(dataDir, sshdConfigFile)) +} + +const DefaultRestartCommand = "systemctl restart sshd" + +const ( + // TeleportKey is the name the OpenSSH private key + TeleportKey = "ssh_host_teleport_key" + // TeleportKey is the name the OpenSSH cert + TeleportCert = TeleportKey + "-cert.pub" + // TeleportOpenSSHCA is the path to the Teleport OpenSSHCA + TeleportOpenSSHCA = "teleport_openssh_ca.pub" +) + +type sshdBackendOperations interface { + restart() error + checkConfig(path string) error +} + +// SSHD is used to update the OpenSSH config +type SSHD struct { + sshd sshdBackendOperations +} + +// NewSSHD initializes SSHD +func NewSSHD(restartCmd string, checkCmd string, sshdConfigPath string) SSHD { + return SSHD{ + sshd: &sshdBackend{ + restartCmd: restartCmd, + checkCmd: checkCmd, + sshdConfigPath: sshdConfigPath, + }, + } +} + +// SSHDConfigUpdate is the list of options to be set in the Teleport OpenSSH config +type SSHDConfigUpdate struct { + // SSHDConfigPath is the path to the OpenSSH sshd_config file + SSHDConfigPath string + // DataDir is the path to the global Teleport datadir + DataDir string +} + +var sshdConfigTmpl = template.Must(template.New("sshd_config_include").Parse(`# Created by 'teleport join openssh', do not edit +TrustedUserCAKeys {{ .OpenSSHCAPath }} +HostKey {{ .HostKeyPath }} +HostCertificate {{ .HostCertPath }} +`)) + +func fmtSSHDConfigUpdate(u SSHDConfigUpdate) (string, error) { + type SSHDConfigUpdateBackend struct { + SSHDConfigUpdate + // OpenSSHCAPath is the path to which Teleport OpenSSHCA will be written + OpenSSHCAPath string + // HostKeyPath is the path to the Teleport Host Key + HostKeyPath string + // HostCertPath is the path to the Teleport OpenSSH cert + HostCertPath string + } + keysDir := filepath.Join(u.DataDir, SSHDKeysDir) + update := SSHDConfigUpdateBackend{ + SSHDConfigUpdate: u, + OpenSSHCAPath: filepath.Join(keysDir, TeleportOpenSSHCA), + HostKeyPath: filepath.Join(keysDir, TeleportKey), + HostCertPath: filepath.Join(keysDir, TeleportCert), + } + + buf := &bytes.Buffer{} + if err := sshdConfigTmpl.Execute(buf, update); err != nil { + return "", trace.Wrap(err) + } + return buf.String(), nil +} + +// UpdateConfig updates the sshd_config file if needed and writes the +// teleport specific configuration +func (s *SSHD) UpdateConfig(u SSHDConfigUpdate, restart bool) error { + configUpdate, err := fmtSSHDConfigUpdate(u) + if err != nil { + return trace.Wrap(err) + } + + if err := writeTempAndRename(filepath.Join(u.DataDir, sshdConfigFile), s.sshd.checkConfig, []byte(configUpdate)); err != nil { + return trace.Wrap(err) + } + + sshdConfigInclude := sshdConfigInclude(u.DataDir) + + needsUpdate, err := checkSSHDConfigAlreadyUpdated(u.SSHDConfigPath, sshdConfigInclude) + if err != nil { + return trace.Wrap(err) + } + if needsUpdate { + if err := prependToSSHDConfig(u.SSHDConfigPath, sshdConfigInclude); err != nil { + return trace.Wrap(err) + } + } + + if restart { + if err := s.sshd.restart(); err != nil { + return trace.Wrap(err) + } + } + + return nil +} + +// WriteKeys writes the OpenSSH keys and CA from the Identity and the +// OpenSSH CA to disk for the OpenSSH daemon to use +func WriteKeys(keysdir string, id *auth.Identity, cas []types.CertAuthority) error { + if err := os.MkdirAll(keysdir, 0o755); err != nil { + return trace.ConvertSystemError(err) + } + + if err := writeTempAndRename(filepath.Join(keysdir, TeleportKey), nil, id.KeyBytes); err != nil { + return trace.ConvertSystemError(err) + } + + if err := writeTempAndRename(filepath.Join(keysdir, TeleportCert), nil, id.CertBytes); err != nil { + return trace.ConvertSystemError(err) + } + + var caKeyBytes []byte + for _, ca := range cas { + for _, key := range ca.GetTrustedSSHKeyPairs() { + pubKey := append(bytes.TrimSpace(key.PublicKey), byte('\n')) + caKeyBytes = append(caKeyBytes, pubKey...) + } + } + + if err := writeTempAndRename(filepath.Join(keysdir, TeleportOpenSSHCA), nil, caKeyBytes); err != nil { + return trace.ConvertSystemError(err) + } + return nil +} + +type sshdBackend struct { + restartCmd string + checkCmd string + sshdConfigPath string +} + +var _ sshdBackendOperations = &sshdBackend{} + +func (b *sshdBackend) checkConfig(path string) error { + cmd := exec.Command("/bin/sh", "-c", fmt.Sprintf("%s %q", b.checkCmd, path)) + if err := cmd.Run(); err != nil { + output, outErr := cmd.CombinedOutput() + if err != nil { + return trace.Wrap(trace.NewAggregate(err, outErr), "invalid sshd config file, failed to get `%s %q` output", b.checkCmd, path) + } + return trace.Wrap(err, "invalid sshd config file %q, not writing", string(output)) + } + return nil +} + +func (b *sshdBackend) restart() error { + if err := b.checkConfig(b.sshdConfigPath); err != nil { + return trace.Wrap(err) + } + + cmd := exec.Command("/bin/sh", "-c", b.restartCmd) + if err := cmd.Run(); err != nil { + return trace.Wrap(err, "failed to restart the sshd service") + } + return nil +} + +// writeTempAndRename creates a temporary file with 0o600 permissions, +// and writes contents to the it, if checkfunc passes without error, +// it'll then rename the file to the path specified with configPath +func writeTempAndRename(configPath string, checkFunc func(string) error, contents []byte) error { + configTmp, err := os.CreateTemp(filepath.Dir(configPath), "") + if err != nil { + return trace.ConvertSystemError(err) + } + tmpName := configTmp.Name() + defer configTmp.Close() + defer os.Remove(tmpName) + + _, err = configTmp.Write(contents) + if err != nil { + return trace.ConvertSystemError(err) + } + if err := configTmp.Close(); err != nil { + return trace.Wrap(err) + } + + if checkFunc != nil { + if err := checkFunc(tmpName); err != nil { + return trace.Wrap(err) + } + } + + if err := os.Rename(tmpName, configPath); err != nil { + return trace.Wrap(err) + } + return nil +} + +func prependToSSHDConfig(sshdConfigPath, config string) error { + contents, err := os.ReadFile(sshdConfigPath) + if err != nil { + return trace.ConvertSystemError(err) + } + line := append([]byte(config), byte('\n')) + contents = append(line, contents...) + + if err := writeTempAndRename(sshdConfigPath, nil, contents); err != nil { + return trace.Wrap(err) + } + + return nil +} + +func checkSSHDConfigAlreadyUpdated(sshdConfigPath, fileContains string) (bool, error) { + contents, err := os.ReadFile(sshdConfigPath) + if err != nil { + return false, trace.ConvertSystemError(err) + } + return !strings.Contains(string(contents), fileContains), nil +} diff --git a/lib/openssh/sshd_test.go b/lib/openssh/sshd_test.go new file mode 100644 index 0000000000000..aa08aa65588c7 --- /dev/null +++ b/lib/openssh/sshd_test.go @@ -0,0 +1,115 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package openssh + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +type testSSHDBackend struct { + didRestart bool +} + +func (b *testSSHDBackend) restart() error { + b.didRestart = true + return nil +} + +func (b *testSSHDBackend) checkConfig(path string) error { + return nil +} + +func TestSSHD(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + + initialSSHDConfig string + expectedSSHDConfigPrefix string + expectedTeleportSSHDConfig string + restart bool + }{ + { + name: "sshd config update with restart", + initialSSHDConfig: "SomeSSHConfig Hello", + expectedSSHDConfigPrefix: "Include %s/sshd.conf", + expectedTeleportSSHDConfig: `# Created by 'teleport join openssh', do not edit +TrustedUserCAKeys %s/teleport_openssh_ca.pub +HostKey %s/ssh_host_teleport_key +HostCertificate %s/ssh_host_teleport_key-cert.pub +`, + restart: true, + }, + { + name: "sshd config update without restart", + initialSSHDConfig: "SomeSSHConfig Hello", + expectedSSHDConfigPrefix: "Include %s/sshd.conf", + expectedTeleportSSHDConfig: `# Created by 'teleport join openssh', do not edit +TrustedUserCAKeys %s/teleport_openssh_ca.pub +HostKey %s/ssh_host_teleport_key +HostCertificate %s/ssh_host_teleport_key-cert.pub +`, + restart: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + testDir := t.TempDir() + backend := &testSSHDBackend{} + sshd := SSHD{ + sshd: backend, + } + + openSSHConfigFile := filepath.Join(testDir, "sshd_config") + if tc.initialSSHDConfig != "" { + require.NoError(t, os.WriteFile(openSSHConfigFile, []byte(tc.initialSSHDConfig), 0o700)) + } + + dataDir := filepath.Join(testDir, "teleport") + require.NoError(t, os.MkdirAll(dataDir, 0o700)) + + err := sshd.UpdateConfig(SSHDConfigUpdate{ + SSHDConfigPath: openSSHConfigFile, + DataDir: dataDir, + }, tc.restart) + require.NoError(t, err) + + teleportSSHDPath := filepath.Join(dataDir, "sshd.conf") + + actualSSHDConfig, err := os.ReadFile(openSSHConfigFile) + require.NoError(t, err) + expectedPrefix := fmt.Sprintf(tc.expectedSSHDConfigPrefix+"\n", dataDir) + require.Equal(t, expectedPrefix+tc.initialSSHDConfig, string(actualSSHDConfig)) + + actualTeleportSSHDConfig, err := os.ReadFile(teleportSSHDPath) + require.NoError(t, err) + openSSHKeyDir := filepath.Join(dataDir, "openssh") + expectedTeleportSSHDConfig := fmt.Sprintf(tc.expectedTeleportSSHDConfig, openSSHKeyDir, openSSHKeyDir, openSSHKeyDir) + + require.Equal(t, expectedTeleportSSHDConfig, string(actualTeleportSSHDConfig)) + + require.Equal(t, tc.restart, backend.didRestart) + + }) + } + +} diff --git a/lib/service/connect.go b/lib/service/connect.go index d10354bdd5ffd..9e6f4a4c8f754 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -17,9 +17,11 @@ limitations under the License. package service import ( + "context" "crypto/tls" "path/filepath" "strings" + "time" "github.com/coreos/go-semver/semver" "github.com/google/uuid" @@ -43,7 +45,9 @@ import ( "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/observability/metrics" + "github.com/gravitational/teleport/lib/openssh" "github.com/gravitational/teleport/lib/reversetunnel" + "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/interval" @@ -684,6 +688,119 @@ func (process *TeleportProcess) firstTimeConnect(role types.SystemRole) (*Connec return connector, nil } +func (process *TeleportProcess) initOpenSSH() { + process.RegisterWithAuthServer(types.RoleNode, SSHIdentityEvent) + process.SSHD = openssh.NewSSHD( + process.Config.OpenSSH.RestartCommand, + process.Config.OpenSSH.CheckCommand, + process.Config.OpenSSH.SSHDConfigPath, + ) + process.RegisterCriticalFunc("openssh.rotate", process.syncOpenSSHRotationState) +} + +func (process *TeleportProcess) syncOpenSSHRotationState() error { + if _, err := process.WaitForEvent(process.GracefulExitContext(), TeleportReadyEvent); err != nil { + return trace.Wrap(err) + } + conn, err := process.WaitForConnector(SSHIdentityEvent, nil) + if conn == nil { + return trace.Wrap(err) + } + defer conn.Close() + + _, err = process.syncRotationState(conn) + if err != nil { + return trace.Wrap(err) + } + + id, err := process.storage.ReadIdentity(auth.IdentityCurrent, types.RoleNode) + if err != nil { + return trace.Wrap(err) + } + + ctx := process.GracefulExitContext() + cas, err := conn.Client.GetCertAuthorities(ctx, types.OpenSSHCA, false) + if err != nil { + return trace.Wrap(err) + } + + keysDir := filepath.Join(process.Config.DataDir, openssh.SSHDKeysDir) + if err := openssh.WriteKeys(keysDir, id, cas); err != nil { + return trace.Wrap(err) + } + + err = process.SSHD.UpdateConfig(openssh.SSHDConfigUpdate{ + SSHDConfigPath: process.Config.OpenSSH.SSHDConfigPath, + DataDir: process.Config.DataDir, + }, process.Config.OpenSSH.RestartSSHD) + if err != nil { + return trace.Wrap(err) + } + + state, err := process.storage.GetState(types.RoleNode) + if err != nil { + return trace.Wrap(err) + } + + mostRecentRotation := state.Spec.Rotation.LastRotated + if state.Spec.Rotation.State == types.RotationStateInProgress && state.Spec.Rotation.Started.After(mostRecentRotation) { + mostRecentRotation = state.Spec.Rotation.Started + } + for _, ca := range cas { + caRot := ca.GetRotation() + if caRot.State == types.RotationStateInProgress && caRot.Started.After(mostRecentRotation) { + mostRecentRotation = caRot.Started + } + + if caRot.LastRotated.After(mostRecentRotation) { + mostRecentRotation = caRot.LastRotated + } + } + + if err := registerServer(process.Config, ctx, conn.Client, mostRecentRotation); err != nil { + return trace.Wrap(err) + } + + // if any of the above exits with non nil error, the process is + // shut down as it is run via RegisterCriticalFunction, so we + // manually shut down here as we dont want teleport to remain + // running after + go func() { + // run in a go routine as process.Shutdown waits until + // all registered services/functions have finished and + // this cant finish if its waiting on this function to + // return + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + process.Shutdown(ctx) + }() + + return nil +} + +func registerServer(a *servicecfg.Config, ctx context.Context, client auth.ClientI, lastRotation time.Time) error { + server, err := types.NewServer(a.HostUUID, types.KindNode, types.ServerSpecV2{ + Addr: a.OpenSSH.InstanceAddr, + Hostname: a.Hostname, + Rotation: types.Rotation{ + LastRotated: lastRotation, + }, + }) + if err != nil { + return trace.Wrap(err) + } + server.SetSubKind(types.SubKindOpenSSHNode) + server.SetStaticLabels(a.OpenSSH.Labels) + if err := server.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + if _, err := client.UpsertNode(ctx, server); err != nil { + return trace.Wrap(err) + } + return nil +} + // periodicSyncRotationState checks rotation state periodically and // takes action if necessary func (process *TeleportProcess) periodicSyncRotationState() error { diff --git a/lib/service/service.go b/lib/service/service.go index 76dbc4fb6248c..30613cd48b859 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -101,6 +101,7 @@ import ( "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/multiplexer" "github.com/gravitational/teleport/lib/observability/tracing" + "github.com/gravitational/teleport/lib/openssh" "github.com/gravitational/teleport/lib/plugin" "github.com/gravitational/teleport/lib/proxy" "github.com/gravitational/teleport/lib/proxy/clusterdial" @@ -387,6 +388,9 @@ type TeleportProcess struct { // TracingProvider is the provider to be used for exporting traces. In the event // that tracing is disabled this will be a no-op provider that drops all spans. TracingProvider *tracing.Provider + + // SSHD is used to execute commands to update or validate OpenSSH config. + SSHD openssh.SSHD } type keyPairKey struct { @@ -1059,6 +1063,7 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { if cfg.Discovery.Enabled { eventMapping.In = append(eventMapping.In, DiscoveryReady) } + process.RegisterEventMapping(eventMapping) if cfg.Auth.Enabled { @@ -1137,7 +1142,12 @@ func NewTeleport(cfg *servicecfg.Config) (*TeleportProcess, error) { serviceStarted = true } - process.RegisterFunc("common.rotate", process.periodicSyncRotationState) + if cfg.OpenSSH.Enabled { + process.initOpenSSH() + serviceStarted = true + } else { + process.RegisterFunc("common.rotate", process.periodicSyncRotationState) + } // run one upload completer per-process // even in sync recording modes, since the recording mode can be changed @@ -3024,6 +3034,17 @@ func (process *TeleportProcess) getAdditionalPrincipals(role types.SystemRole) ( ) addrs = append(addrs, process.Config.WindowsDesktop.PublicAddrs...) } + + if process.Config.OpenSSH.Enabled { + for _, a := range process.Config.OpenSSH.AdditionalPrincipals { + addr, err := utils.ParseAddr(a) + if err != nil { + return nil, nil, trace.Wrap(err) + } + addrs = append(addrs, *addr) + } + } + for _, addr := range addrs { if addr.IsEmpty() { continue @@ -4701,7 +4722,7 @@ func (process *TeleportProcess) registerExpectedServices(cfg *servicecfg.Config) process.SetExpectedInstanceRole(types.RoleAuth, AuthIdentityEvent) } - if cfg.SSH.Enabled { + if cfg.SSH.Enabled || cfg.OpenSSH.Enabled { process.SetExpectedInstanceRole(types.RoleNode, SSHIdentityEvent) } diff --git a/lib/service/servicecfg/config.go b/lib/service/servicecfg/config.go index 8ede8bfbe8b7e..f95e06ee0d190 100644 --- a/lib/service/servicecfg/config.go +++ b/lib/service/servicecfg/config.go @@ -103,6 +103,9 @@ type Config struct { // Discovery defines the discovery service configuration. Discovery DiscoveryConfig + // OpenSSH defines the configuration for an openssh node + OpenSSH OpenSSHConfig + // Okta defines the okta service configuration. Okta OktaConfig @@ -279,6 +282,18 @@ type RoleAndIdentityEvent struct { IdentityEvent string } +// DisableLongRunningServices disables all services but OpenSSH +func DisableLongRunningServices(cfg *Config) { + cfg.Auth.Enabled = false + cfg.Proxy.Enabled = false + cfg.SSH.Enabled = false + cfg.Kube.Enabled = false + cfg.Apps.Enabled = false + cfg.WindowsDesktop.Enabled = false + cfg.Databases.Enabled = false + cfg.Okta.Enabled = false +} + // JoinParams is a set of extra parameters for joining the auth server. type JoinParams struct { Azure AzureJoinParams @@ -666,6 +681,7 @@ func verifyEnabledService(cfg *Config) error { cfg.WindowsDesktop.Enabled, cfg.Discovery.Enabled, cfg.Okta.Enabled, + cfg.OpenSSH.Enabled, } for _, item := range enabled { diff --git a/lib/service/servicecfg/openssh.go b/lib/service/servicecfg/openssh.go new file mode 100644 index 0000000000000..ce62d230c9b04 --- /dev/null +++ b/lib/service/servicecfg/openssh.go @@ -0,0 +1,39 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package servicecfg + +import ( + "github.com/gravitational/teleport/lib/utils" +) + +type OpenSSHConfig struct { + Enabled bool + // SSHDConfigPath is the path to the OpenSSH config file. + SSHDConfigPath string + // RestartSSHD is true if sshd should be restarted after config updates. + RestartSSHD bool + // RestartCommand is the command to use when restarting sshd. + RestartCommand string + // CheckCommand is the command to use when validating sshd config. + CheckCommand string + // AdditionalPrincipals is a list of additional principals to be included. + AdditionalPrincipals []string + // InstanceAddr is the connectable address of the OpenSSh instance. + InstanceAddr string + // ProxyServer is the address of the teleport proxy. + ProxyServer *utils.NetAddr + // Labels are labels to set on the instance. + Labels map[string]string +} diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index aa8905a808927..951a8a78f851a 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -18,12 +18,9 @@ package common import ( "context" - "crypto/tls" "fmt" - "io" "net/url" "os" - "os/exec" "os/user" "path" "path/filepath" @@ -31,28 +28,20 @@ import ( "strings" "github.com/alecthomas/kingpin/v2" - "github.com/google/uuid" "github.com/gravitational/trace" log "github.com/sirupsen/logrus" - "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - "github.com/gravitational/teleport/lib/auth" - "github.com/gravitational/teleport/lib/auth/authclient" - "github.com/gravitational/teleport/lib/auth/native" - "github.com/gravitational/teleport/lib/client" - "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/config" awsconfigurators "github.com/gravitational/teleport/lib/configurators/aws" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/modules" + "github.com/gravitational/teleport/lib/openssh" "github.com/gravitational/teleport/lib/service" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/sshutils/scp" - "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" ) @@ -65,8 +54,6 @@ type Options struct { InitOnly bool } -const agentlessKeysDir = "/etc/teleport/agentless" - // Run inits/starts the process according to the provided options func Run(options Options) (app *kingpin.Application, executedCommand string, conf *servicecfg.Config) { var err error @@ -435,10 +422,15 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con joinOpenSSH.Flag("token", "Invitation token to register with an auth server.").StringVar(&ccf.AuthToken) joinOpenSSH.Flag("join-method", "Method to use to join the cluster (token, iam, ec2).").EnumVar(&ccf.JoinMethod, "token", "iam", "ec2") joinOpenSSH.Flag("openssh-config", fmt.Sprintf("Path to the OpenSSH config file [%v].", "/etc/ssh/sshd_config")).Default("/etc/ssh/sshd_config").StringVar(&ccf.OpenSSHConfigPath) - joinOpenSSH.Flag("openssh-keys-path", fmt.Sprintf("Path to directory to store Teleport keys and certs [%v].", agentlessKeysDir)).Default(agentlessKeysDir).StringVar(&ccf.OpenSSHKeysPath) + joinOpenSSH.Flag("data-dir", fmt.Sprintf("Path to directory to store teleport data [%v].", defaults.DataDir)).Default(defaults.DataDir).StringVar(&ccf.DataDir) joinOpenSSH.Flag("restart-sshd", "Restart OpenSSH.").Default("true").BoolVar(&ccf.RestartOpenSSH) + joinOpenSSH.Flag("sshd-check-command", "Command to use when checking OpenSSH config for validity. (sshd -t -f )").Default("sshd -t -f").StringVar(&ccf.CheckCommand) + joinOpenSSH.Flag("sshd-restart-command", "Command to use when restarting openssh.").Default(openssh.DefaultRestartCommand).StringVar(&ccf.RestartCommand) + joinOpenSSH.Flag("labels", "Comma-separated list of labels for this OpenSSH node, for example env=dev,app=web.").StringVar(&ccf.Labels) + joinOpenSSH.Flag("address", "IP Address of this OpenSSH node.").StringVar(&ccf.Address) + joinOpenSSH.Flag("additional-principals", "Additional principal to include, can be specified multiple times.").StringVar(&ccf.AdditionalPrincipals) joinOpenSSH.Flag("insecure", "Insecure mode disables certificate validation.").BoolVar(&ccf.InsecureMode) - joinOpenSSH.Flag("additional-principals", "Comma separated list of host names the node can be accessed by.").StringVar(&ccf.AdditionalPrincipals) + joinOpenSSH.Flag("debug", "Enable verbose logging to stderr.").Short('d').BoolVar(&ccf.Debug) // parse CLI commands+flags: @@ -525,7 +517,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con configureDiscoveryBootstrapFlags.config.DiscoveryService = true err = onConfigureDiscoveryBootstrap(configureDiscoveryBootstrapFlags) case joinOpenSSH.FullCommand(): - err = onJoinOpenSSH(ccf) + err = onJoinOpenSSH(ccf, conf) } if err != nil { utils.FatalError(err) @@ -550,333 +542,6 @@ func OnStart(clf config.CommandLineFlags, config *servicecfg.Config) error { return service.Run(context.TODO(), *config, nil) } -// GenerateKeys generates TLS and SSH keypairs. -func GenerateKeys() (private, sshpub, tlspub []byte, err error) { - privateKey, publicKey, err := native.GenerateKeyPair() - if err != nil { - return nil, nil, nil, trace.Wrap(err) - } - - sshPrivateKey, err := ssh.ParseRawPrivateKey(privateKey) - if err != nil { - return nil, nil, nil, trace.Wrap(err) - } - - tlsPublicKey, err := tlsca.MarshalPublicKeyFromPrivateKeyPEM(sshPrivateKey) - if err != nil { - return nil, nil, nil, trace.Wrap(err) - } - - return privateKey, publicKey, tlsPublicKey, nil -} - -func authenticatedUserClientFromIdentity(ctx context.Context, fips bool, proxy utils.NetAddr, id *auth.Identity) (auth.ClientI, error) { - var tlsConfig *tls.Config - var err error - var cipherSuites []uint16 - if fips { - cipherSuites = defaults.FIPSCipherSuites - } - tlsConfig, err = id.TLSConfig(cipherSuites) - if err != nil { - return nil, trace.Wrap(err) - } - - sshConfig, err := id.SSHClientConfig(fips) - if err != nil { - return nil, trace.Wrap(err) - } - - authClientConfig := &authclient.Config{ - TLS: tlsConfig, - SSH: sshConfig, - AuthServers: []utils.NetAddr{proxy}, - Log: log.StandardLogger(), - } - - c, err := authclient.Connect(ctx, authClientConfig) - return c, trace.Wrap(err) -} - -func getAWSInstanceHostname(ctx context.Context) (string, error) { - imds, err := aws.NewInstanceMetadataClient(ctx) - if err != nil { - return "", trace.Wrap(err) - } - hostname, err := imds.GetHostname(ctx) - if err != nil { - return "", trace.Wrap(err) - } - hostname = strings.ReplaceAll(hostname, " ", "_") - if utils.IsValidHostname(hostname) { - return hostname, nil - } - return "", trace.NotFound("failed to get a valid hostname from IMDS") -} - -func tryCreateDefaultAgentlesKeysDir(agentlessKeysPath string) error { - baseTeleportDir := filepath.Dir(agentlessKeysPath) - _, err := os.Stat(baseTeleportDir) - if err != nil { - if os.IsNotExist(err) { - log.Debugf("%s did not exist, creating %s", baseTeleportDir, agentlessKeysPath) - return trace.Wrap(os.MkdirAll(agentlessKeysPath, 0700)) - } - return trace.Wrap(err) - } - - var alreadyExistedAndDeleted bool - _, err = os.Stat(agentlessKeysPath) - if err == nil { - log.Debugf("%s already existed, removing old files", agentlessKeysPath) - err = os.RemoveAll(agentlessKeysPath) - if err != nil { - return trace.Wrap(err) - } - alreadyExistedAndDeleted = true - } - - if os.IsNotExist(err) || alreadyExistedAndDeleted { - log.Debugf("%s did not exist, creating", agentlessKeysPath) - return trace.Wrap(os.Mkdir(agentlessKeysPath, 0700)) - } - - return trace.Wrap(err) -} - -func onJoinOpenSSH(clf config.CommandLineFlags) error { - if err := checkSSHDConfigAlreadyUpdated(clf.OpenSSHConfigPath); err != nil { - return trace.Wrap(err) - } - - if clf.Debug { - log.SetLevel(log.DebugLevel) - } - - // Proxy Server and Token are required configuration so confirming they exist before continuing - missingProxyServerFlag := clf.ProxyServer == "" - missingAuthTokenFlag := clf.AuthToken == "" - - if missingProxyServerFlag && missingAuthTokenFlag { - return trace.BadParameter("No proxy address and token specified, check --proxy-server and --token flags were set") - } - if missingProxyServerFlag { - return trace.BadParameter("No proxy server specified, check --proxy-server flag was set") - } - if missingAuthTokenFlag { - return trace.BadParameter("No token specified, check --token flag was set") - } - - addr, err := utils.ParseAddr(clf.ProxyServer) - if err != nil { - return trace.Wrap(err) - } - privateKey, sshPublicKey, tlsPublicKey, err := GenerateKeys() - if err != nil { - return trace.Wrap(err, "unable to generate new keypairs") - } - - ctx := context.Background() - hostname, err := getAWSInstanceHostname(ctx) - if err != nil { - var hostErr error - hostname, hostErr = os.Hostname() - if hostErr != nil { - return trace.NewAggregate(err, hostErr) - } - } - - // TODO(amk) get uuid from a cli argument once agentless inventory management is implemented to allow tsh ssh access via uuid - uuid := uuid.NewString() - - principals := []string{uuid} - for _, principal := range strings.Split(clf.AdditionalPrincipals, ",") { - if principal == "" { - continue - } - principals = append(principals, principal) - } - - registerParams := auth.RegisterParams{ - Token: clf.AuthToken, - AdditionalPrincipals: principals, - JoinMethod: types.JoinMethod(clf.JoinMethod), - ID: auth.IdentityID{ - Role: types.RoleNode, - NodeName: hostname, - HostUUID: uuid, - }, - ProxyServer: *addr, - PublicTLSKey: tlsPublicKey, - PublicSSHKey: sshPublicKey, - GetHostCredentials: client.HostCredentials, - FIPS: clf.FIPS, - } - - if clf.FIPS { - registerParams.CipherSuites = defaults.FIPSCipherSuites - } - - certs, err := auth.Register(registerParams) - if err != nil { - return trace.Wrap(err) - } - - identity, err := auth.ReadIdentityFromKeyPair(privateKey, certs) - if err != nil { - return trace.Wrap(err) - } - - client, err := authenticatedUserClientFromIdentity(ctx, clf.FIPS, *addr, identity) - if err != nil { - return trace.Wrap(err) - } - - cas, err := client.GetCertAuthorities(ctx, types.OpenSSHCA, false) - if err != nil { - return trace.Wrap(err) - } - - var openSSHCA []byte - for _, ca := range cas { - for _, key := range ca.GetActiveKeys().SSH { - openSSHCA = append(openSSHCA, key.PublicKey...) - openSSHCA = append(openSSHCA, byte('\n')) - } - } - - defaultKeysPath := clf.OpenSSHKeysPath == agentlessKeysDir - if defaultKeysPath { - if err := tryCreateDefaultAgentlesKeysDir(agentlessKeysDir); err != nil { - return trace.Wrap(err) - } - } - - fmt.Printf("Writing Teleport keys to %s\n", clf.OpenSSHKeysPath) - if err := writeKeys(clf.OpenSSHKeysPath, privateKey, certs, openSSHCA); err != nil { - if defaultKeysPath { - rmdirErr := os.RemoveAll(agentlessKeysDir) - if rmdirErr != nil { - return trace.NewAggregate(err, rmdirErr) - } - } - return trace.Wrap(err) - } - - fmt.Println("Updating OpenSSH config") - if err := updateSSHDConfig(clf.OpenSSHKeysPath, clf.OpenSSHConfigPath); err != nil { - return trace.Wrap(err) - } - - fmt.Println("Restarting the OpenSSH daemon") - if err := restartSSHD(); err != nil { - return trace.Wrap(err) - } - - return nil -} - -const ( - teleportKey = "teleport" - teleportCert = "teleport-cert.pub" - teleportOpenSSHCA = "teleport_user_ca.pub" -) - -func writeKeys(sshdConfigDir string, private []byte, certs *proto.Certs, openSSHCA []byte) error { - if err := os.WriteFile(filepath.Join(sshdConfigDir, teleportKey), private, 0600); err != nil { - return trace.Wrap(err) - } - - if err := os.WriteFile(filepath.Join(sshdConfigDir, teleportCert), certs.SSH, 0600); err != nil { - return trace.Wrap(err) - } - - if err := os.WriteFile(filepath.Join(sshdConfigDir, teleportOpenSSHCA), openSSHCA, 0600); err != nil { - return trace.Wrap(err) - } - - return nil -} - -const sshdConfigSectionModificationHeader = "### Section created by 'teleport join openssh'" - -func checkSSHDConfigAlreadyUpdated(sshdConfigPath string) error { - contents, err := os.ReadFile(sshdConfigPath) - if err != nil { - return trace.Wrap(err) - } - - if strings.Contains(string(contents), sshdConfigSectionModificationHeader) { - return trace.AlreadyExists("not updating %s as it has already been modified by teleport", sshdConfigPath) - } - return nil -} - -const sshdBinary = "sshd" - -func updateSSHDConfig(keyDir, sshdConfigPath string) error { - // has to write to the beginning of the sshd_config file as - // openssh takes the first occurrence of a setting - sshdConfig, err := os.OpenFile(sshdConfigPath, os.O_RDONLY|os.O_CREATE, 0644) - if err != nil { - return trace.Wrap(err) - } - defer sshdConfig.Close() - - configUpdate := fmt.Sprintf(` -%s -TrustedUserCaKeys %s -HostKey %s -HostCertificate %s -### Section end -`, - sshdConfigSectionModificationHeader, - filepath.Join(keyDir, "teleport_user_ca.pub"), - filepath.Join(keyDir, "teleport"), - filepath.Join(keyDir, "teleport-cert.pub"), - ) - sshdConfigTmp, err := os.CreateTemp(keyDir, "") - if err != nil { - return trace.Wrap(err) - } - defer sshdConfigTmp.Close() - if _, err := sshdConfigTmp.Write([]byte(configUpdate)); err != nil { - return trace.Wrap(err) - } - - if _, err := io.Copy(sshdConfigTmp, sshdConfig); err != nil { - return trace.Wrap(err) - } - - if err := sshdConfigTmp.Sync(); err != nil { - return trace.Wrap(err) - } - - cmd := exec.Command(sshdBinary, "-t", "-f", sshdConfigTmp.Name()) - if err := cmd.Run(); err != nil { - return trace.Wrap(err, "teleport generated an invalid ssh config file, not writing") - } - - if err := os.Rename(sshdConfigTmp.Name(), sshdConfigPath); err != nil { - return trace.Wrap(err) - } - - return nil -} - -func restartSSHD() error { - cmd := exec.Command("sshd", "-t") - if err := cmd.Run(); err != nil { - return trace.Wrap(err, "teleport generated an invalid ssh config file") - } - - cmd = exec.Command("systemctl", "restart", "sshd") - if err := cmd.Run(); err != nil { - return trace.Wrap(err, "teleport failed to restart the sshd service") - } - return nil -} - // onStatus is the handler for "status" CLI command func onStatus() error { sshClient := os.Getenv("SSH_CLIENT") @@ -1165,3 +830,15 @@ func (rw *StdReadWriter) Read(b []byte) (int, error) { func (rw *StdReadWriter) Write(b []byte) (int, error) { return os.Stdout.Write(b) } + +func onJoinOpenSSH(clf config.CommandLineFlags, conf *servicecfg.Config) error { + // configuration merge: defaults -> file-based conf -> CLI conf + conf.OpenSSH.Enabled = true + if err := config.ConfigureOpenSSH(&clf, conf); err != nil { + return trace.Wrap(err) + } + if err := OnStart(clf, conf); err != nil { + return trace.Wrap(err) + } + return nil +}