Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions api/types/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,9 @@ const (
// found via automatic discovery, to avoid re-running installation
// commands on the node.
AWSInstanceIDLabel = TeleportNamespace + "/instance-id"
// AWSInstanceRegion is used to identify the region an EC2
// instance is running in
AWSInstanceRegion = TeleportNamespace + "/aws-region"
// SubscriptionIDLabel is used to identify virtual machines by Azure
// subscription ID found via automatic discovery, to avoid re-running
// installation commands on the node.
Expand Down
28 changes: 28 additions & 0 deletions lib/cloud/aws/imds.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,31 @@ func (client *InstanceMetadataClient) GetID(ctx context.Context) (string, error)

return id, nil
}

// GetLocalIPV4 gets the EC2 instance's local ipv4 address.
func (client *InstanceMetadataClient) GetLocalIPV4(ctx context.Context) (string, error) {
ip, err := client.getMetadata(ctx, "local-ipv4")
if err != nil {
return "", trace.Wrap(err)
}

return ip, nil
}

// GetPublicIPV4 gets the EC2 instance's local ipv4 address.
func (client *InstanceMetadataClient) GetPublicIPV4(ctx context.Context) (string, error) {
ip, err := client.getMetadata(ctx, "public-ipv4")
if err != nil {
return "", trace.Wrap(err)
}

return ip, nil
}

func (client *InstanceMetadataClient) GetAccountID(ctx context.Context) (string, error) {
idOut, err := client.c.GetInstanceIdentityDocument(ctx, &imds.GetInstanceIdentityDocumentInput{})
if err != nil {
return "", trace.Wrap(err)
}
return idOut.AccountID, nil
}
75 changes: 69 additions & 6 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,16 +176,22 @@ type CommandLineFlags struct {
// if the value cannot be obtained from the database.
DatabaseMySQLServerVersion string

// ProxyServer is the url of the proxy server to connect to
// ProxyServer is the url of the proxy server to connect to.
ProxyServer string
// OpenSSHConfigPath is the path of the file to write agentless configuration to
// OpenSSHConfigPath is the path of the file to write agentless configuration to.
OpenSSHConfigPath string
// OpenSSHKeysPath is the path to write teleport keys and certs into
OpenSSHKeysPath string
// AdditionalPrincipals are a list of extra principals to include when generating host keys.
AdditionalPrincipals string
// RestartOpenSSH indicates whether openssh should be restarted or not.
RestartOpenSSH bool
// RestartCommand is the command to use when restarting sshd
RestartCommand string
// CheckCommand is the command to use when checking sshd config validity
CheckCommand string
// Address is the ip address of the OpenSSH node.
Address string
// AdditionalPrincipals is a list of additional principals to include in the SSH cert.
AdditionalPrincipals string
// Directory to store
DataDir string
}

// ReadConfigFile reads /etc/teleport.yaml (or whatever is passed via --config flag)
Expand Down Expand Up @@ -2144,6 +2150,63 @@ func Configure(clf *CommandLineFlags, cfg *servicecfg.Config, legacyAppFlags boo
return nil
}

// ConfigureOpenSSH initializes a config from the commandline flags passed
func ConfigureOpenSSH(clf *CommandLineFlags, cfg *servicecfg.Config) error {
// pass the value of --insecure flag to the runtime
lib.SetInsecureDevMode(clf.InsecureMode)

// Apply command line --debug flag to override logger severity.
if clf.Debug {
log.SetLevel(log.DebugLevel)
cfg.Log.SetLevel(log.DebugLevel)
cfg.Debug = clf.Debug
}

if clf.AuthToken != "" {
// store the value of the --token flag:
cfg.SetToken(clf.AuthToken)
}

log.Debugf("Disabling all services, only the Teleport OpenSSH service can run during the `teleport join openssh` command")
servicecfg.DisableLongRunningServices(cfg)

cfg.DataDir = clf.DataDir
cfg.Version = defaults.TeleportConfigVersionV3
cfg.OpenSSH.SSHDConfigPath = clf.OpenSSHConfigPath
cfg.OpenSSH.RestartSSHD = clf.RestartOpenSSH
cfg.OpenSSH.RestartCommand = clf.RestartCommand
cfg.OpenSSH.CheckCommand = clf.CheckCommand
cfg.JoinMethod = types.JoinMethod(clf.JoinMethod)

hostname, err := os.Hostname()
if err != nil {
return trace.Wrap(err)
}

cfg.Hostname = hostname
Comment thread
lxea marked this conversation as resolved.
cfg.OpenSSH.InstanceAddr = clf.Address
cfg.OpenSSH.AdditionalPrincipals = []string{hostname, clf.Address}
for _, principal := range strings.Split(clf.AdditionalPrincipals, ",") {
if principal == "" {
continue
}
cfg.OpenSSH.AdditionalPrincipals = append(cfg.OpenSSH.AdditionalPrincipals, principal)
}
cfg.OpenSSH.Labels, err = client.ParseLabelSpec(clf.Labels)
if err != nil {
return trace.Wrap(err)
}

proxyServer, err := utils.ParseAddr(clf.ProxyServer)
if err != nil {
return trace.Wrap(err)
}
cfg.SetAuthServerAddresses(nil)
cfg.ProxyServer = *proxyServer
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we also handle clf.AuthServerAddr here, in case someone wants to point an openssh node directly at an auth?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, the existing discovery stuff all assumes its using a proxy

should there be separate exclusive proxy/auth flags in that case?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be fine then; it's a bit weird that teleport join only supports --proxy-address while the normal start uses --auth-address tho, especially since openssh mode will never actually open a reverse tunnel connection.


return nil
}

// parseLabels parses the labels command line flag and returns static and
// dynamic labels.
func parseLabels(spec string) (map[string]string, services.CommandLabels, error) {
Expand Down
263 changes: 263 additions & 0 deletions lib/openssh/sshd.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,263 @@
/*
Copyright 2023 Gravitational, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package openssh

import (
"bytes"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"text/template"

"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/auth"
)

var (
// SSHDConfigPath is the path to write teleport specific SSHD config options
sshdConfigFile = "sshd.conf"
// SSHDKeysDir is the path to the Teleport openssh keys
SSHDKeysDir = "openssh"
)

func sshdConfigInclude(dataDir string) string {
return fmt.Sprintf("Include %s", filepath.Join(dataDir, sshdConfigFile))
}

const DefaultRestartCommand = "systemctl restart sshd"

const (
// TeleportKey is the name the OpenSSH private key
TeleportKey = "ssh_host_teleport_key"
// TeleportKey is the name the OpenSSH cert
TeleportCert = TeleportKey + "-cert.pub"
// TeleportOpenSSHCA is the path to the Teleport OpenSSHCA
TeleportOpenSSHCA = "teleport_openssh_ca.pub"
)

type sshdBackendOperations interface {
restart() error
checkConfig(path string) error
}
Comment thread
lxea marked this conversation as resolved.
Outdated

// SSHD is used to update the OpenSSH config
type SSHD struct {
sshd sshdBackendOperations
}

// NewSSHD initializes SSHD
func NewSSHD(restartCmd string, checkCmd string, sshdConfigPath string) SSHD {
return SSHD{
sshd: &sshdBackend{
restartCmd: restartCmd,
checkCmd: checkCmd,
sshdConfigPath: sshdConfigPath,
},
}
}

// SSHDConfigUpdate is the list of options to be set in the Teleport OpenSSH config
type SSHDConfigUpdate struct {
// SSHDConfigPath is the path to the OpenSSH sshd_config file
SSHDConfigPath string
// DataDir is the path to the global Teleport datadir
DataDir string
}

var sshdConfigTmpl = template.Must(template.New("sshd_config_include").Parse(`# Created by 'teleport join openssh', do not edit
TrustedUserCAKeys {{ .OpenSSHCAPath }}
HostKey {{ .HostKeyPath }}
HostCertificate {{ .HostCertPath }}
`))

func fmtSSHDConfigUpdate(u SSHDConfigUpdate) (string, error) {
type SSHDConfigUpdateBackend struct {
SSHDConfigUpdate
// OpenSSHCAPath is the path to which Teleport OpenSSHCA will be written
OpenSSHCAPath string
// HostKeyPath is the path to the Teleport Host Key
HostKeyPath string
// HostCertPath is the path to the Teleport OpenSSH cert
HostCertPath string
}
keysDir := filepath.Join(u.DataDir, SSHDKeysDir)
update := SSHDConfigUpdateBackend{
SSHDConfigUpdate: u,
OpenSSHCAPath: filepath.Join(keysDir, TeleportOpenSSHCA),
HostKeyPath: filepath.Join(keysDir, TeleportKey),
HostCertPath: filepath.Join(keysDir, TeleportCert),
}

buf := &bytes.Buffer{}
if err := sshdConfigTmpl.Execute(buf, update); err != nil {
return "", trace.Wrap(err)
}
return buf.String(), nil
}

// UpdateConfig updates the sshd_config file if needed and writes the
// teleport specific configuration
func (s *SSHD) UpdateConfig(u SSHDConfigUpdate, restart bool) error {
configUpdate, err := fmtSSHDConfigUpdate(u)
if err != nil {
return trace.Wrap(err)
}

if err := writeTempAndRename(filepath.Join(u.DataDir, sshdConfigFile), s.sshd.checkConfig, []byte(configUpdate)); err != nil {
return trace.Wrap(err)
}

sshdConfigInclude := sshdConfigInclude(u.DataDir)

needsUpdate, err := checkSSHDConfigAlreadyUpdated(u.SSHDConfigPath, sshdConfigInclude)
if err != nil {
return trace.Wrap(err)
}
if needsUpdate {
if err := prependToSSHDConfig(u.SSHDConfigPath, sshdConfigInclude); err != nil {
return trace.Wrap(err)
}
}

if restart {
if err := s.sshd.restart(); err != nil {
return trace.Wrap(err)
}
}

return nil
}

// WriteKeys writes the OpenSSH keys and CA from the Identity and the
// OpenSSH CA to disk for the OpenSSH daemon to use
func WriteKeys(keysdir string, id *auth.Identity, cas []types.CertAuthority) error {
if err := os.MkdirAll(keysdir, 0o755); err != nil {
return trace.ConvertSystemError(err)
}

if err := writeTempAndRename(filepath.Join(keysdir, TeleportKey), nil, id.KeyBytes); err != nil {
return trace.ConvertSystemError(err)
}

if err := writeTempAndRename(filepath.Join(keysdir, TeleportCert), nil, id.CertBytes); err != nil {
return trace.ConvertSystemError(err)
}

var caKeyBytes []byte
for _, ca := range cas {
for _, key := range ca.GetTrustedSSHKeyPairs() {
pubKey := append(bytes.TrimSpace(key.PublicKey), byte('\n'))
caKeyBytes = append(caKeyBytes, pubKey...)
}
}

if err := writeTempAndRename(filepath.Join(keysdir, TeleportOpenSSHCA), nil, caKeyBytes); err != nil {
return trace.ConvertSystemError(err)
}
return nil
}

type sshdBackend struct {
restartCmd string
checkCmd string
sshdConfigPath string
}

var _ sshdBackendOperations = &sshdBackend{}

func (b *sshdBackend) checkConfig(path string) error {
cmd := exec.Command("/bin/sh", "-c", fmt.Sprintf("%s %q", b.checkCmd, path))
if err := cmd.Run(); err != nil {
output, outErr := cmd.CombinedOutput()
if err != nil {
return trace.Wrap(trace.NewAggregate(err, outErr), "invalid sshd config file, failed to get `%s %q` output", b.checkCmd, path)
}
return trace.Wrap(err, "invalid sshd config file %q, not writing", string(output))
}
return nil
}

func (b *sshdBackend) restart() error {
if err := b.checkConfig(b.sshdConfigPath); err != nil {
return trace.Wrap(err)
}

cmd := exec.Command("/bin/sh", "-c", b.restartCmd)
if err := cmd.Run(); err != nil {
return trace.Wrap(err, "failed to restart the sshd service")
}
return nil
}

// writeTempAndRename creates a temporary file with 0o600 permissions,
// and writes contents to the it, if checkfunc passes without error,
// it'll then rename the file to the path specified with configPath
func writeTempAndRename(configPath string, checkFunc func(string) error, contents []byte) error {
configTmp, err := os.CreateTemp(filepath.Dir(configPath), "")
if err != nil {
return trace.ConvertSystemError(err)
}
tmpName := configTmp.Name()
defer configTmp.Close()
defer os.Remove(tmpName)

_, err = configTmp.Write(contents)
if err != nil {
return trace.ConvertSystemError(err)
}
if err := configTmp.Close(); err != nil {
return trace.Wrap(err)
}

if checkFunc != nil {
if err := checkFunc(tmpName); err != nil {
return trace.Wrap(err)
}
}

if err := os.Rename(tmpName, configPath); err != nil {
return trace.Wrap(err)
}
return nil
}

func prependToSSHDConfig(sshdConfigPath, config string) error {
contents, err := os.ReadFile(sshdConfigPath)
if err != nil {
return trace.ConvertSystemError(err)
}
line := append([]byte(config), byte('\n'))
contents = append(line, contents...)

if err := writeTempAndRename(sshdConfigPath, nil, contents); err != nil {
return trace.Wrap(err)
}

return nil
}

func checkSSHDConfigAlreadyUpdated(sshdConfigPath, fileContains string) (bool, error) {
contents, err := os.ReadFile(sshdConfigPath)
if err != nil {
return false, trace.ConvertSystemError(err)
}
return !strings.Contains(string(contents), fileContains), nil
}
Loading