diff --git a/lib/teleterm/vnet/service.go b/lib/teleterm/vnet/service.go index c3b17dd6f9a53..e92f981187ff6 100644 --- a/lib/teleterm/vnet/service.go +++ b/lib/teleterm/vnet/service.go @@ -238,18 +238,14 @@ func (s *Service) GetServiceInfo(ctx context.Context, _ *api.GetServiceInfoReque return nil, trace.Wrap(err) } - sshDiag, err := diag.NewSSHDiag(&diag.SSHConfig{ - ProfilePath: s.cfg.profilePath, - }) + sshConfigChecker, err := diag.NewSSHConfigChecker(s.cfg.profilePath) if err != nil { - return nil, trace.Wrap(err, "building SSH diagnostic") + return nil, trace.Wrap(err, "building SSH config checker") } - sshReport, err := sshDiag.Run(ctx) - if err != nil { - return nil, trace.Wrap(err, "running SSH diagnostic") + _, sshConfigured, err := sshConfigChecker.OpenSSHConfigIncludesVNetSSHConfig() + if err != nil && !trace.IsNotFound(err) { + return nil, trace.Wrap(err, "checking SSH configuration") } - sshConfigured := sshReport.Status == diagv1.CheckReportStatus_CHECK_REPORT_STATUS_OK && - sshReport.GetSshConfigurationReport().UserOpensshConfigIncludesVnetSshConfig return &api.GetServiceInfoResponse{ AppDnsZones: unifiedClusterConfig.AppDNSZones(), diff --git a/lib/vnet/diag/ssh.go b/lib/vnet/diag/ssh.go index 27be9b1afd741..c0708aa650b2e 100644 --- a/lib/vnet/diag/ssh.go +++ b/lib/vnet/diag/ssh.go @@ -31,6 +31,7 @@ import ( "github.com/dustin/go-humanize" "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/api/utils/keypaths" diagv1 "github.com/gravitational/teleport/gen/proto/go/teleport/lib/vnet/diag/v1" @@ -50,27 +51,19 @@ type SSHConfig struct { // SSHDiag is a diagnostic check that inspects whether the default user OpenSSH // config file includes VNet's generated SSH config file. type SSHDiag struct { - cfg *SSHConfig - userHome string - userOpenSSHConfigPath string - vnetSSHConfigPath string - isWindows bool + cfg *SSHConfig + sshConfigChecker *SSHConfigChecker } // NewSSHDiag returns a new [SSHDiag]. func NewSSHDiag(cfg *SSHConfig) (*SSHDiag, error) { - userHome, ok := profile.UserHomeDir() - if !ok { - return nil, trace.Errorf("unable to find user's home directory") + sshConfigChecker, err := NewSSHConfigChecker(cfg.ProfilePath) + if err != nil { + return nil, trace.Wrap(err) } - userOpenSSHConfigPath := filepath.Join(userHome, ".ssh", "config") - vnetSSHConfigPath := filepath.Join(cfg.ProfilePath, keypaths.VNetSSHConfig) return &SSHDiag{ - cfg: cfg, - userHome: userHome, - userOpenSSHConfigPath: userOpenSSHConfigPath, - vnetSSHConfigPath: vnetSSHConfigPath, - isWindows: runtime.GOOS == "windows", + cfg: cfg, + sshConfigChecker: sshConfigChecker, }, nil } @@ -103,50 +96,84 @@ func (d *SSHDiag) Run(ctx context.Context) (*diagv1.CheckReport, error) { } func (d *SSHDiag) run(ctx context.Context) (*diagv1.SSHConfigurationReport, error) { - _, err := os.Stat(d.userOpenSSHConfigPath) - userOpenSSHConfigExists := err == nil - if !userOpenSSHConfigExists { - return &diagv1.SSHConfigurationReport{ - UserOpensshConfigPath: d.userOpenSSHConfigPath, - VnetSshConfigPath: d.vnetSSHConfigPath, - }, nil + userOpenSSHConfigContents, included, err := d.sshConfigChecker.OpenSSHConfigIncludesVNetSSHConfig() + if err != nil { + if trace.IsNotFound(err) { + return &diagv1.SSHConfigurationReport{ + UserOpensshConfigPath: d.sshConfigChecker.UserOpenSSHConfigPath, + VnetSshConfigPath: d.sshConfigChecker.VNetSSHConfigPath, + }, nil + } + return nil, trace.Wrap(err) + } + if !utf8.Valid(userOpenSSHConfigContents) { + return nil, trace.Errorf("%s is not valid UTF-8", d.sshConfigChecker.UserOpenSSHConfigPath) } + return &diagv1.SSHConfigurationReport{ + UserOpensshConfigPath: d.sshConfigChecker.UserOpenSSHConfigPath, + VnetSshConfigPath: d.sshConfigChecker.VNetSSHConfigPath, + UserOpensshConfigIncludesVnetSshConfig: included, + UserOpensshConfigExists: true, + UserOpensshConfigContents: string(userOpenSSHConfigContents), + }, nil +} - userOpenSSHConfigFile, err := os.Open(d.userOpenSSHConfigPath) +// SSHConfigChecker checks the state of the user's SSH configuration. +type SSHConfigChecker struct { + userHome string + UserOpenSSHConfigPath string + VNetSSHConfigPath string + isWindows bool +} + +// NewSSHConfigChecker returns a new SSHConfigChecker. +func NewSSHConfigChecker(profilePath string) (*SSHConfigChecker, error) { + userHome, ok := profile.UserHomeDir() + if !ok { + return nil, trace.Errorf("unable to find user's home directory") + } + userOpenSSHConfigPath := filepath.Join(userHome, ".ssh", "config") + vnetSSHConfigPath := keypaths.VNetSSHConfigPath(profilePath) + return &SSHConfigChecker{ + userHome: userHome, + UserOpenSSHConfigPath: userOpenSSHConfigPath, + VNetSSHConfigPath: vnetSSHConfigPath, + isWindows: runtime.GOOS == constants.WindowsOS, + }, nil +} + +// OpenSSHConfigIncludesVNetSSHConfig returns the current user OpenSSH +// configuration file contents (~/.ssh/config) and a boolean indicating whether +// it already includes VNet's generated OpenSSH-compatible configuration file. +// +// If ~/.ssh/config does not exist it returns a [trace.NotFoundError] +func (c *SSHConfigChecker) OpenSSHConfigIncludesVNetSSHConfig() ([]byte, bool, error) { + userOpenSSHConfigFile, err := os.Open(c.UserOpenSSHConfigPath) if err != nil { - return nil, trace.Wrap(trace.ConvertSystemError(err), "opening %s for reading", d.userOpenSSHConfigPath) + return nil, false, trace.Wrap(trace.ConvertSystemError(err), "opening %s for reading", c.UserOpenSSHConfigPath) } defer userOpenSSHConfigFile.Close() userOpenSSHConfigContents, err := io.ReadAll(io.LimitReader(userOpenSSHConfigFile, maxOpenSSHConfigFileSize)) if err != nil { - return nil, trace.Wrap(trace.ConvertSystemError(err), "reading %s", d.userOpenSSHConfigPath) + return nil, false, trace.Wrap(trace.ConvertSystemError(err), "reading %s", c.UserOpenSSHConfigPath) } if len(userOpenSSHConfigContents) == maxOpenSSHConfigFileSize { - return nil, trace.Errorf("%s is too large to (max size %s)", - d.userOpenSSHConfigPath, humanize.Bytes(maxOpenSSHConfigFileSize)) - } - if !utf8.Valid(userOpenSSHConfigContents) { - return nil, trace.Errorf("%s is not valid UTF-8", d.userOpenSSHConfigPath) + return nil, false, trace.Errorf("%s is too large to read (max size %s)", + c.UserOpenSSHConfigPath, humanize.Bytes(maxOpenSSHConfigFileSize)) } - included, err := d.openSSHConfigIncludesVNetSSHConfig(bytes.NewReader(userOpenSSHConfigContents)) + included, err := c.openSSHConfigIncludesVNetSSHConfig(bytes.NewReader(userOpenSSHConfigContents)) if err != nil { - return nil, trace.Wrap(err, "checking if the default user OpenSSH config includes VNet's SSH configuration") + return nil, false, trace.Wrap(err, "checking if the default user OpenSSH config includes VNet's SSH configuration") } - return &diagv1.SSHConfigurationReport{ - UserOpensshConfigPath: d.userOpenSSHConfigPath, - VnetSshConfigPath: d.vnetSSHConfigPath, - UserOpensshConfigIncludesVnetSshConfig: included, - UserOpensshConfigExists: true, - UserOpensshConfigContents: string(userOpenSSHConfigContents), - }, nil + return userOpenSSHConfigContents, included, nil } -func (d *SSHDiag) openSSHConfigIncludesVNetSSHConfig(r io.Reader) (bool, error) { +func (c *SSHConfigChecker) openSSHConfigIncludesVNetSSHConfig(r io.Reader) (bool, error) { scanner := bufio.NewScanner(r) for scanner.Scan() { - if d.openSSHConfigLineIncludesPath(scanner.Text(), d.vnetSSHConfigPath) { + if c.openSSHConfigLineIncludesPath(scanner.Text(), c.VNetSSHConfigPath) { return true, nil } } @@ -155,8 +182,8 @@ func (d *SSHDiag) openSSHConfigIncludesVNetSSHConfig(r io.Reader) (bool, error) // openSSHConfigLineIncludesPath returns true if the given line of an OpenSSH // configuration file is an include statement for the given path. -func (d *SSHDiag) openSSHConfigLineIncludesPath(line, wantPath string) bool { - wantPath = d.normalizePath(wantPath) +func (c *SSHConfigChecker) openSSHConfigLineIncludesPath(line, wantPath string) bool { + wantPath = c.normalizePath(wantPath) line = strings.TrimSpace(line) // Only consider lines that begin with "include" (case-insensitive). @@ -178,42 +205,41 @@ func (d *SSHDiag) openSSHConfigLineIncludesPath(line, wantPath string) bool { // returns true. It does support ~ as an alias for the user's home // directory. var ( - // b is a running buffer holding the current argument as parsed up to + // pathBuf is a running buffer holding the current argument as parsed up to // the current point. - b strings.Builder + pathBuf strings.Builder // quote holds the opening quote character if one has been found. quote = byte(0) ) loop: for i := 0; i < len(line); i++ { - c := line[i] + b := line[i] switch { - case c == '\\' && i < len(line)-1 && canBeEscaped(line[i+1]): + case b == '\\' && i < len(line)-1 && canBeEscaped(line[i+1]): // Skip the escape char and write the next char literally. i++ - b.WriteByte(line[i]) - case quote == 0 && (c == '"' || c == '\''): + pathBuf.WriteByte(line[i]) + case quote == 0 && (b == '"' || b == '\''): // Start of quote - quote = c - case quote != 0 && c == quote: + quote = b + case quote != 0 && b == quote: // End of quote quote = 0 - case b.Len() == 0 && c == '~': + case pathBuf.Len() == 0 && b == '~': // Support ~ as an alias for the user's home directory. - b.WriteString(d.userHome) - case quote == 0 && c == '#': + pathBuf.WriteString(c.userHome) + case quote == 0 && b == '#': // Found an unquoted comment in the middle of the line, ignore the rest. break loop - case quote == 0 && isSpace(rune(c)): + case quote == 0 && isSpace(rune(b)): // Reached the end of this argument, check if it matches wantPath. - if d.normalizePath(b.String()) == wantPath { + if c.normalizePath(pathBuf.String()) == wantPath { return true } - b.Reset() + pathBuf.Reset() default: - // By default just append the current character to the current - // argument. - b.WriteByte(c) + // By default just append the current byte to the path. + pathBuf.WriteByte(b) } } if quote != 0 { @@ -221,11 +247,11 @@ loop: return false } // Handle an argument that ends at the end of the line. - return d.normalizePath(b.String()) == wantPath + return c.normalizePath(pathBuf.String()) == wantPath } -func (d *SSHDiag) normalizePath(path string) string { - if d.isWindows { +func (c *SSHConfigChecker) normalizePath(path string) string { + if c.isWindows { // Normalize all paths to use unix-style separators since OpenSSH // supports / or \\ on Windows. path = strings.ReplaceAll(path, `\`, `/`) diff --git a/lib/vnet/diag/ssh_test.go b/lib/vnet/diag/ssh_test.go index 2a1549a49c172..a152f38c1ac90 100644 --- a/lib/vnet/diag/ssh_test.go +++ b/lib/vnet/diag/ssh_test.go @@ -19,6 +19,7 @@ package diag import ( "os" "path/filepath" + "strings" "testing" "github.com/stretchr/testify/require" @@ -27,171 +28,173 @@ import ( diagv1 "github.com/gravitational/teleport/gen/proto/go/teleport/lib/vnet/diag/v1" ) -// TestSSHDiag tests the SSH configuration diagnostic, specifically its ability -// to check whether an OpenSSH config file includes the VNet SSH config file. -func TestSSHDiag(t *testing.T) { - t.Parallel() - for _, tc := range []struct { - desc string - profilePath string - userHome string - isWindows bool - input string - expect bool - }{ - { - desc: "empty", - profilePath: `/Users/user/.tsh`, - userHome: `/Users/user`, - }, - { - desc: "macos tsh", - profilePath: `/Users/user/.tsh`, - userHome: `/Users/user`, - input: `Include /Users/user/.tsh/vnet_ssh_config`, - expect: true, - }, - { - desc: "macos tsh ~", - profilePath: `/Users/user/.tsh`, - userHome: `/Users/user`, - input: `Include ~/.tsh/vnet_ssh_config`, - expect: true, - }, - { - desc: "macos connect", - profilePath: `/Users/user/Application Support/Teleport Connect/tsh`, - userHome: `/Users/user`, - input: `Include "/Users/user/Application Support/Teleport Connect/tsh/vnet_ssh_config"`, - expect: true, - }, - { - desc: "macos connect ~", - profilePath: `/Users/user/Application Support/Teleport Connect/tsh`, - userHome: `/Users/user`, - input: `Include "~/Application Support/Teleport Connect/tsh/vnet_ssh_config"`, - expect: true, - }, - { - desc: "macos tsh not match connect", - profilePath: `/Users/user/.tsh`, - userHome: `/Users/user`, - input: `Include "/Users/user/Application Support/Teleport Connect/tsh/vnet_ssh_config"`, - }, - { - desc: "macos connect not match tsh", - profilePath: `/Users/user/Application Support/Teleport Connect/tsh`, - userHome: `/Users/user`, - input: `Include /Users/user/.tsh/vnet_ssh_config`, - }, - { - desc: "windows tsh", - profilePath: `C:\Users\User\.tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:\\Users\\User\\.tsh\\vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows tsh unescaped", - profilePath: `C:\Users\User\.tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:\Users\User\.tsh\vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows tsh unix path", - profilePath: `C:\Users\User\.tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:/Users/User/.tsh/vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows tsh ~", - profilePath: `C:\Users\User\.tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "~\\.tsh\\vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows connect", - profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:\\Users\\User\\AppData\\Roaming\\Teleport\ Connect\\tsh\\vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows connect unescaped", - profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:\Users\User\AppData\Roaming\Teleport Connect\tsh\vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows connect unix path", - profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:/Users/User/AppData/Roaming/Teleport\ Connect/tsh/vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows connect ~", - profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "~\\AppData\\Roaming\\Teleport\ Connect\\tsh\\vnet_ssh_config"`, - expect: true, - }, - { - desc: "windows tsh not match connect", - profilePath: `C:\Users\User\.tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:\\Users\\User\\AppData\\Roaming\\Teleport\ Connect\\tsh\\vnet_ssh_config"`, - }, - { - desc: "windows connect not match tsh", - profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, - userHome: `C:\Users\User`, - isWindows: true, - input: `Include "C:\\Users\\User\\.tsh\\vnet_ssh_config"`, - }, - { - desc: "some other file", - profilePath: `/Users/user/.tsh`, - input: `Include /Users/user/.tsh/ssh_config`, - }, - { - desc: "multiple includes", - profilePath: `/Users/user/.tsh`, - userHome: `/Users/user`, - input: ` +var sshDiagTestCases = []struct { + desc string + profilePath string + userHome string + isWindows bool + input string + expect bool +}{ + { + desc: "empty", + profilePath: `/Users/user/.tsh`, + userHome: `/Users/user`, + }, + { + desc: "macos tsh", + profilePath: `/Users/user/.tsh`, + userHome: `/Users/user`, + input: `Include /Users/user/.tsh/vnet_ssh_config`, + expect: true, + }, + { + desc: "macos tsh ~", + profilePath: `/Users/user/.tsh`, + userHome: `/Users/user`, + input: `Include ~/.tsh/vnet_ssh_config`, + expect: true, + }, + { + desc: "macos connect", + profilePath: `/Users/user/Application Support/Teleport Connect/tsh`, + userHome: `/Users/user`, + input: `Include "/Users/user/Application Support/Teleport Connect/tsh/vnet_ssh_config"`, + expect: true, + }, + { + desc: "macos connect ~", + profilePath: `/Users/user/Application Support/Teleport Connect/tsh`, + userHome: `/Users/user`, + input: `Include "~/Application Support/Teleport Connect/tsh/vnet_ssh_config"`, + expect: true, + }, + { + desc: "macos tsh not match connect", + profilePath: `/Users/user/.tsh`, + userHome: `/Users/user`, + input: `Include "/Users/user/Application Support/Teleport Connect/tsh/vnet_ssh_config"`, + }, + { + desc: "macos connect not match tsh", + profilePath: `/Users/user/Application Support/Teleport Connect/tsh`, + userHome: `/Users/user`, + input: `Include /Users/user/.tsh/vnet_ssh_config`, + }, + { + desc: "windows tsh", + profilePath: `C:\Users\User\.tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:\\Users\\User\\.tsh\\vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows tsh unescaped", + profilePath: `C:\Users\User\.tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:\Users\User\.tsh\vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows tsh unix path", + profilePath: `C:\Users\User\.tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:/Users/User/.tsh/vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows tsh ~", + profilePath: `C:\Users\User\.tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "~\\.tsh\\vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows connect", + profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:\\Users\\User\\AppData\\Roaming\\Teleport\ Connect\\tsh\\vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows connect unescaped", + profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:\Users\User\AppData\Roaming\Teleport Connect\tsh\vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows connect unix path", + profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:/Users/User/AppData/Roaming/Teleport\ Connect/tsh/vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows connect ~", + profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "~\\AppData\\Roaming\\Teleport\ Connect\\tsh\\vnet_ssh_config"`, + expect: true, + }, + { + desc: "windows tsh not match connect", + profilePath: `C:\Users\User\.tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:\\Users\\User\\AppData\\Roaming\\Teleport\ Connect\\tsh\\vnet_ssh_config"`, + }, + { + desc: "windows connect not match tsh", + profilePath: `C:\Users\User\AppData\Roaming\Teleport Connect\tsh`, + userHome: `C:\Users\User`, + isWindows: true, + input: `Include "C:\\Users\\User\\.tsh\\vnet_ssh_config"`, + }, + { + desc: "some other file", + profilePath: `/Users/user/.tsh`, + input: `Include /Users/user/.tsh/ssh_config`, + }, + { + desc: "multiple includes", + profilePath: `/Users/user/.tsh`, + userHome: `/Users/user`, + input: ` Include ~/.ssh/include/* Include /Users/user/ssh_config Include /Users/user/.tsh/vnet_ssh_config `, - expect: true, - }, - { - desc: "commented", - profilePath: `/Users/user/.tsh`, - userHome: `/Users/user`, - input: `Include #/Users/user/.tsh/vnet_ssh_config`, - }, - { - desc: "single quotes", - profilePath: `/Users/user/.tsh`, - userHome: `/Users/user`, - input: `Include '/Users/user/.tsh/vnet_ssh_config'`, - expect: true, - }, - } { + expect: true, + }, + { + desc: "commented", + profilePath: `/Users/user/.tsh`, + userHome: `/Users/user`, + input: `Include #/Users/user/.tsh/vnet_ssh_config`, + }, + { + desc: "single quotes", + profilePath: `/Users/user/.tsh`, + userHome: `/Users/user`, + input: `Include '/Users/user/.tsh/vnet_ssh_config'`, + expect: true, + }, +} + +// TestSSHDiag tests the SSH configuration diagnostic, specifically its ability +// to check whether an OpenSSH config file includes the VNet SSH config file. +func TestSSHDiag(t *testing.T) { + t.Parallel() + for _, tc := range sshDiagTestCases { t.Run(tc.desc, func(t *testing.T) { diag, err := NewSSHDiag(&SSHConfig{ ProfilePath: tc.profilePath, @@ -200,9 +203,9 @@ Include /Users/user/.tsh/vnet_ssh_config userOpenSSHConfigPath := filepath.Join(t.TempDir(), "test_ssh_config") // Override isWindows and paths for the purpose of the test. - diag.isWindows = tc.isWindows - diag.userHome = tc.userHome - diag.userOpenSSHConfigPath = userOpenSSHConfigPath + diag.sshConfigChecker.isWindows = tc.isWindows + diag.sshConfigChecker.userHome = tc.userHome + diag.sshConfigChecker.UserOpenSSHConfigPath = userOpenSSHConfigPath if len(tc.input) > 0 { require.NoError(t, os.WriteFile(userOpenSSHConfigPath, []byte(tc.input), 0o600)) @@ -227,3 +230,21 @@ Include /Users/user/.tsh/vnet_ssh_config }) } } + +// FuzzOpenSSHConfigIncludesPath fuzzes [SSHConfigChecker.openSSHConfigIncludesVNetSSHConfig] +// to make sure it won't panic on arbitrary input. +func FuzzOpenSSHConfigIncludesPath(f *testing.F) { + // Add all test cases as the base test corpus. + for _, tc := range sshDiagTestCases { + f.Add(tc.isWindows, tc.profilePath, tc.input) + } + f.Fuzz(func(t *testing.T, isWindows bool, profilePath, input string) { + vnetSSHConfigPath := keypaths.VNetSSHConfigPath(profilePath) + sshConfigChecker := &SSHConfigChecker{ + VNetSSHConfigPath: vnetSSHConfigPath, + isWindows: isWindows, + } + // Can't deterministically check the result for fuzzed inputs but it shouldn't panic. + sshConfigChecker.openSSHConfigIncludesVNetSSHConfig(strings.NewReader(input)) + }) +} diff --git a/lib/vnet/opensshconfig.go b/lib/vnet/opensshconfig.go index 65f2587f63796..e1ff49cb17222 100644 --- a/lib/vnet/opensshconfig.go +++ b/lib/vnet/opensshconfig.go @@ -21,6 +21,7 @@ import ( "cmp" "context" "encoding/pem" + "fmt" "io" "os" "path/filepath" @@ -40,6 +41,8 @@ import ( "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/keypaths" "github.com/gravitational/teleport/lib/cryptosuites" + libutils "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/teleport/lib/vnet/diag" ) const ( @@ -261,3 +264,79 @@ type configFileTemplateInput struct { PrivateKeyPath string KnownHostsPath string } + +type autoConfigureOpenSSHOptions struct { + overrideUserSSHConfigPath string +} +type autoConfigureOpenSSHOption func(*autoConfigureOpenSSHOptions) + +func withUserSSHConfigPathOverride(path string) autoConfigureOpenSSHOption { + return func(opts *autoConfigureOpenSSHOptions) { + opts.overrideUserSSHConfigPath = path + } +} + +// AutoConfigureOpenSSH adds an Include directive to the default user OpenSSH +// config file (~/.ssh/config) to include the vnet_ssh_config file found under +// profilePath. +func AutoConfigureOpenSSH(ctx context.Context, profilePath string, opts ...autoConfigureOpenSSHOption) (err error) { + var options autoConfigureOpenSSHOptions + for _, opt := range opts { + opt(&options) + } + + sshConfigChecker, err := diag.NewSSHConfigChecker(profilePath) + if err != nil { + return trace.Wrap(err) + } + + if options.overrideUserSSHConfigPath != "" { + sshConfigChecker.UserOpenSSHConfigPath = options.overrideUserSSHConfigPath + } + + // Create ~/.ssh if it does not exist yet. + err = trace.ConvertSystemError(os.Mkdir( + filepath.Dir(sshConfigChecker.UserOpenSSHConfigPath), os.FileMode(0o700))) + switch { + case trace.IsAlreadyExists(err): + // This is fine/expected. + case err != nil: + return trace.Wrap(err, "creating directory for %s", sshConfigChecker.UserOpenSSHConfigPath) + } + + // There should not be much lock contention on this file and it's okay if + // this fails so just try once to grab the lock. + unlock, err := libutils.FSTryWriteLock(sshConfigChecker.UserOpenSSHConfigPath) + if err != nil { + return trace.Wrap(err, "getting write lock for %s", sshConfigChecker.UserOpenSSHConfigPath) + } + defer func() { + unlockErr := unlock() + err = trace.NewAggregate(err, trace.Wrap(unlockErr, "unlocking %s", sshConfigChecker.UserOpenSSHConfigPath)) + }() + + currentContents, alreadyIncluded, err := sshConfigChecker.OpenSSHConfigIncludesVNetSSHConfig() + switch { + case trace.IsNotFound(err): + // This is fine, the file will be created with a single include. + case err != nil: + return trace.Wrap(err) + case alreadyIncluded: + return trace.AlreadyExists("%s is already included in %s", + sshConfigChecker.VNetSSHConfigPath, sshConfigChecker.UserOpenSSHConfigPath) + } + + // Add the include at the top of the file for 2 reasons: + // - options set first take precedence over options set later in the file + // - if the include line is added after an existing Host block it will only + // be included if the host block matches + var newContents bytes.Buffer + fmt.Fprintf(&newContents, `# Include Teleport VNet generated configuration +Include "%s" + +`, sshConfigChecker.VNetSSHConfigPath) + newContents.Write(currentContents) + + err = renameio.WriteFile(sshConfigChecker.UserOpenSSHConfigPath, newContents.Bytes(), filePerms) + return trace.Wrap(trace.ConvertSystemError(err), "writing to %s", sshConfigChecker.UserOpenSSHConfigPath) +} diff --git a/lib/vnet/opensshconfig_test.go b/lib/vnet/opensshconfig_test.go index 4e933fc3bc915..9bed1ed5b0410 100644 --- a/lib/vnet/opensshconfig_test.go +++ b/lib/vnet/opensshconfig_test.go @@ -20,10 +20,13 @@ import ( "context" "fmt" "os" + "path/filepath" "testing" "time" + "github.com/gravitational/trace" "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/utils/keypaths" @@ -116,3 +119,85 @@ func TestSSHConfigurator(t *testing.T) { _, err = os.Stat(keypaths.VNetSSHConfigPath(homePath)) require.ErrorIs(t, err, os.ErrNotExist) } + +func TestAutoConfigureOpenSSH(t *testing.T) { + d := t.TempDir() + profilePath := filepath.Join(d, ".tsh") + vnetSSHConfigPath := keypaths.VNetSSHConfigPath(profilePath) + userOpenSSHConfigPath := filepath.Join(d, ".ssh", "config") + expectedInclude := fmt.Sprintf(`# Include Teleport VNet generated configuration +Include "%s" + +`, vnetSSHConfigPath) + for _, tc := range []struct { + desc string + userOpenSSHConfigExists bool + userOpenSSHConfigContents string + expectAlreadyIncludedError bool + expectUserOpenSSHConfigContents string + }{ + { + // When the user OpenSSH config file doesn't exist, it should be + // created with the include. + desc: "no file", + expectUserOpenSSHConfigContents: expectedInclude, + }, + { + // When the user OpenSSH config file already exists but it's empty, + // the include should be added. + desc: "empty file", + userOpenSSHConfigExists: true, + expectUserOpenSSHConfigContents: expectedInclude, + }, + { + // When the user OpenSSH config file already exists with some + // content, the include should be added at the top. + desc: "not empty", + userOpenSSHConfigExists: true, + userOpenSSHConfigContents: "something\nsomethingelse\n", + expectUserOpenSSHConfigContents: expectedInclude + "something\nsomethingelse\n", + }, + { + // When the user OpenSSH config file already includes VNet's config + // file, it should return an AlreadyExists error and the file + // should not be modified. + desc: "already included", + userOpenSSHConfigExists: true, + userOpenSSHConfigContents: expectedInclude, + expectAlreadyIncludedError: true, + expectUserOpenSSHConfigContents: expectedInclude, + }, + { + // When the user OpenSSH config file already includes VNet's config + // file along with existing content, it should return an + // AlreadyExists error and the file should not be modified. + desc: "already included with extra content", + userOpenSSHConfigExists: true, + userOpenSSHConfigContents: "something\n" + expectedInclude + "somethingelse", + expectAlreadyIncludedError: true, + expectUserOpenSSHConfigContents: "something\n" + expectedInclude + "somethingelse", + }, + } { + t.Run(tc.desc, func(t *testing.T) { + if tc.userOpenSSHConfigExists { + // Write the existing user OpenSSH config file if it's supposed + // to exist for this test case. + require.NoError(t, os.WriteFile(userOpenSSHConfigPath, + []byte(tc.userOpenSSHConfigContents), filePerms)) + } + + err := AutoConfigureOpenSSH(t.Context(), profilePath, withUserSSHConfigPathOverride(userOpenSSHConfigPath)) + + if tc.expectAlreadyIncludedError { + assert.ErrorIs(t, err, trace.AlreadyExists("%s is already included in %s", + vnetSSHConfigPath, userOpenSSHConfigPath)) + } else { + assert.NoError(t, err) + } + + contents, err := os.ReadFile(userOpenSSHConfigPath) + require.NoError(t, err) + assert.Equal(t, tc.expectUserOpenSSHConfigContents, string(contents)) + }) + } +} diff --git a/tool/tsh/common/tsh.go b/tool/tsh/common/tsh.go index d76fed163de01..448a39cc90f5b 100644 --- a/tool/tsh/common/tsh.go +++ b/tool/tsh/common/tsh.go @@ -1361,6 +1361,7 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { workloadIdentityCmd := newWorkloadIdentityCommands(app) vnetCommand := newVnetCommand(app) + vnetSSHAutoConfigCommand := newVnetSSHAutoConfigCommand(app) vnetAdminSetupCommand := newVnetAdminSetupCommand(app) vnetDaemonCommand := newVnetDaemonCommand(app) vnetServiceCommand := newVnetServiceCommand(app) @@ -1781,6 +1782,8 @@ func Run(ctx context.Context, args []string, opts ...CliOption) error { err = workloadIdentityCmd.issueX509.run(&cf) case vnetCommand.FullCommand(): err = vnetCommand.run(&cf) + case vnetSSHAutoConfigCommand.FullCommand(): + err = vnetSSHAutoConfigCommand.run(&cf) case vnetAdminSetupCommand.FullCommand(): err = vnetAdminSetupCommand.run(&cf) case vnetDaemonCommand.FullCommand(): diff --git a/tool/tsh/common/vnet.go b/tool/tsh/common/vnet.go index ead9d27eac297..addebb9dc4d56 100644 --- a/tool/tsh/common/vnet.go +++ b/tool/tsh/common/vnet.go @@ -23,6 +23,7 @@ import ( "github.com/alecthomas/kingpin/v2" "github.com/gravitational/trace" + "github.com/gravitational/teleport/api/profile" "github.com/gravitational/teleport/lib/vnet" ) @@ -77,6 +78,22 @@ func (c *vnetCommand) run(cf *CLIConf) error { return trace.Wrap(vnetProcess.Wait()) } +type vnetSSHAutoConfigCommand struct { + *kingpin.CmdClause +} + +func newVnetSSHAutoConfigCommand(app *kingpin.Application) *vnetSSHAutoConfigCommand { + cmd := &vnetSSHAutoConfigCommand{ + CmdClause: app.Command("vnet-ssh-autoconfig", "Automatically include VNet's generated OpenSSH-compatible config file in ~/.ssh/config."), + } + return cmd +} + +func (c *vnetSSHAutoConfigCommand) run(cf *CLIConf) error { + err := vnet.AutoConfigureOpenSSH(cf.Context, profile.FullProfilePath(cf.HomePath)) + return trace.Wrap(err) +} + func newVnetAdminSetupCommand(app *kingpin.Application) vnetCLICommand { return newPlatformVnetAdminSetupCommand(app) } @@ -104,6 +121,7 @@ type vnetCommandNotSupported struct{} func (vnetCommandNotSupported) FullCommand() string { return "" } + func (vnetCommandNotSupported) run(*CLIConf) error { panic("vnetCommandNotSupported.run should never be called, this is a bug") }