Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 5 additions & 9 deletions lib/teleterm/vnet/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,14 @@ func (s *Service) GetServiceInfo(ctx context.Context, _ *api.GetServiceInfoReque
return nil, trace.Wrap(err)
}

sshDiag, err := diag.NewSSHDiag(&diag.SSHConfig{
ProfilePath: s.cfg.profilePath,
})
sshConfigChecker, err := diag.NewSSHConfigChecker(s.cfg.profilePath)
if err != nil {
return nil, trace.Wrap(err, "building SSH diagnostic")
return nil, trace.Wrap(err, "building SSH config checker")
}
sshReport, err := sshDiag.Run(ctx)
if err != nil {
return nil, trace.Wrap(err, "running SSH diagnostic")
_, sshConfigured, err := sshConfigChecker.OpenSSHConfigIncludesVNetSSHConfig()
if err != nil && !trace.IsNotFound(err) {
return nil, trace.Wrap(err, "checking SSH configuration")
}
sshConfigured := sshReport.Status == diagv1.CheckReportStatus_CHECK_REPORT_STATUS_OK &&
sshReport.GetSshConfigurationReport().UserOpensshConfigIncludesVnetSshConfig

return &api.GetServiceInfoResponse{
AppDnsZones: unifiedClusterConfig.AppDNSZones(),
Expand Down
152 changes: 89 additions & 63 deletions lib/vnet/diag/ssh.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/dustin/go-humanize"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/profile"
"github.com/gravitational/teleport/api/utils/keypaths"
diagv1 "github.com/gravitational/teleport/gen/proto/go/teleport/lib/vnet/diag/v1"
Expand All @@ -50,27 +51,19 @@ type SSHConfig struct {
// SSHDiag is a diagnostic check that inspects whether the default user OpenSSH
// config file includes VNet's generated SSH config file.
type SSHDiag struct {
cfg *SSHConfig
userHome string
userOpenSSHConfigPath string
vnetSSHConfigPath string
isWindows bool
cfg *SSHConfig
sshConfigChecker *SSHConfigChecker
}

// NewSSHDiag returns a new [SSHDiag].
func NewSSHDiag(cfg *SSHConfig) (*SSHDiag, error) {
userHome, ok := profile.UserHomeDir()
if !ok {
return nil, trace.Errorf("unable to find user's home directory")
sshConfigChecker, err := NewSSHConfigChecker(cfg.ProfilePath)
if err != nil {
return nil, trace.Wrap(err)
}
userOpenSSHConfigPath := filepath.Join(userHome, ".ssh", "config")
vnetSSHConfigPath := filepath.Join(cfg.ProfilePath, keypaths.VNetSSHConfig)
return &SSHDiag{
cfg: cfg,
userHome: userHome,
userOpenSSHConfigPath: userOpenSSHConfigPath,
vnetSSHConfigPath: vnetSSHConfigPath,
isWindows: runtime.GOOS == "windows",
cfg: cfg,
sshConfigChecker: sshConfigChecker,
}, nil
}

Expand Down Expand Up @@ -103,50 +96,84 @@ func (d *SSHDiag) Run(ctx context.Context) (*diagv1.CheckReport, error) {
}

func (d *SSHDiag) run(ctx context.Context) (*diagv1.SSHConfigurationReport, error) {
_, err := os.Stat(d.userOpenSSHConfigPath)
userOpenSSHConfigExists := err == nil
if !userOpenSSHConfigExists {
return &diagv1.SSHConfigurationReport{
UserOpensshConfigPath: d.userOpenSSHConfigPath,
VnetSshConfigPath: d.vnetSSHConfigPath,
}, nil
userOpenSSHConfigContents, included, err := d.sshConfigChecker.OpenSSHConfigIncludesVNetSSHConfig()
if err != nil {
if trace.IsNotFound(err) {
return &diagv1.SSHConfigurationReport{
UserOpensshConfigPath: d.sshConfigChecker.UserOpenSSHConfigPath,
VnetSshConfigPath: d.sshConfigChecker.VNetSSHConfigPath,
}, nil
}
return nil, trace.Wrap(err)
}
if !utf8.Valid(userOpenSSHConfigContents) {
return nil, trace.Errorf("%s is not valid UTF-8", d.sshConfigChecker.UserOpenSSHConfigPath)
}
return &diagv1.SSHConfigurationReport{
UserOpensshConfigPath: d.sshConfigChecker.UserOpenSSHConfigPath,
VnetSshConfigPath: d.sshConfigChecker.VNetSSHConfigPath,
UserOpensshConfigIncludesVnetSshConfig: included,
UserOpensshConfigExists: true,
UserOpensshConfigContents: string(userOpenSSHConfigContents),
}, nil
}

userOpenSSHConfigFile, err := os.Open(d.userOpenSSHConfigPath)
// SSHConfigChecker checks the state of the user's SSH configuration.
type SSHConfigChecker struct {
userHome string
UserOpenSSHConfigPath string
VNetSSHConfigPath string
isWindows bool
}

// NewSSHConfigChecker returns a new SSHConfigChecker.
func NewSSHConfigChecker(profilePath string) (*SSHConfigChecker, error) {
userHome, ok := profile.UserHomeDir()
if !ok {
return nil, trace.Errorf("unable to find user's home directory")
}
userOpenSSHConfigPath := filepath.Join(userHome, ".ssh", "config")
vnetSSHConfigPath := keypaths.VNetSSHConfigPath(profilePath)
return &SSHConfigChecker{
userHome: userHome,
UserOpenSSHConfigPath: userOpenSSHConfigPath,
VNetSSHConfigPath: vnetSSHConfigPath,
isWindows: runtime.GOOS == constants.WindowsOS,
}, nil
}

// OpenSSHConfigIncludesVNetSSHConfig returns the current user OpenSSH
// configuration file contents (~/.ssh/config) and a boolean indicating whether
// it already includes VNet's generated OpenSSH-compatible configuration file.
//
// If ~/.ssh/config does not exist it returns a [trace.NotFoundError]
func (c *SSHConfigChecker) OpenSSHConfigIncludesVNetSSHConfig() ([]byte, bool, error) {
userOpenSSHConfigFile, err := os.Open(c.UserOpenSSHConfigPath)
if err != nil {
return nil, trace.Wrap(trace.ConvertSystemError(err), "opening %s for reading", d.userOpenSSHConfigPath)
return nil, false, trace.Wrap(trace.ConvertSystemError(err), "opening %s for reading", c.UserOpenSSHConfigPath)
}
defer userOpenSSHConfigFile.Close()

userOpenSSHConfigContents, err := io.ReadAll(io.LimitReader(userOpenSSHConfigFile, maxOpenSSHConfigFileSize))
if err != nil {
return nil, trace.Wrap(trace.ConvertSystemError(err), "reading %s", d.userOpenSSHConfigPath)
return nil, false, trace.Wrap(trace.ConvertSystemError(err), "reading %s", c.UserOpenSSHConfigPath)
}
if len(userOpenSSHConfigContents) == maxOpenSSHConfigFileSize {
return nil, trace.Errorf("%s is too large to (max size %s)",
d.userOpenSSHConfigPath, humanize.Bytes(maxOpenSSHConfigFileSize))
}
if !utf8.Valid(userOpenSSHConfigContents) {
return nil, trace.Errorf("%s is not valid UTF-8", d.userOpenSSHConfigPath)
return nil, false, trace.Errorf("%s is too large to read (max size %s)",
c.UserOpenSSHConfigPath, humanize.Bytes(maxOpenSSHConfigFileSize))
}

included, err := d.openSSHConfigIncludesVNetSSHConfig(bytes.NewReader(userOpenSSHConfigContents))
included, err := c.openSSHConfigIncludesVNetSSHConfig(bytes.NewReader(userOpenSSHConfigContents))
if err != nil {
return nil, trace.Wrap(err, "checking if the default user OpenSSH config includes VNet's SSH configuration")
return nil, false, trace.Wrap(err, "checking if the default user OpenSSH config includes VNet's SSH configuration")
}
return &diagv1.SSHConfigurationReport{
UserOpensshConfigPath: d.userOpenSSHConfigPath,
VnetSshConfigPath: d.vnetSSHConfigPath,
UserOpensshConfigIncludesVnetSshConfig: included,
UserOpensshConfigExists: true,
UserOpensshConfigContents: string(userOpenSSHConfigContents),
}, nil
return userOpenSSHConfigContents, included, nil
}

func (d *SSHDiag) openSSHConfigIncludesVNetSSHConfig(r io.Reader) (bool, error) {
func (c *SSHConfigChecker) openSSHConfigIncludesVNetSSHConfig(r io.Reader) (bool, error) {
scanner := bufio.NewScanner(r)
for scanner.Scan() {
if d.openSSHConfigLineIncludesPath(scanner.Text(), d.vnetSSHConfigPath) {
if c.openSSHConfigLineIncludesPath(scanner.Text(), c.VNetSSHConfigPath) {
return true, nil
}
}
Expand All @@ -155,8 +182,8 @@ func (d *SSHDiag) openSSHConfigIncludesVNetSSHConfig(r io.Reader) (bool, error)

// openSSHConfigLineIncludesPath returns true if the given line of an OpenSSH
// configuration file is an include statement for the given path.
func (d *SSHDiag) openSSHConfigLineIncludesPath(line, wantPath string) bool {
wantPath = d.normalizePath(wantPath)
func (c *SSHConfigChecker) openSSHConfigLineIncludesPath(line, wantPath string) bool {
Comment thread
nklaassen marked this conversation as resolved.
wantPath = c.normalizePath(wantPath)
line = strings.TrimSpace(line)

// Only consider lines that begin with "include" (case-insensitive).
Expand All @@ -178,54 +205,53 @@ func (d *SSHDiag) openSSHConfigLineIncludesPath(line, wantPath string) bool {
// returns true. It does support ~ as an alias for the user's home
// directory.
var (
// b is a running buffer holding the current argument as parsed up to
// pathBuf is a running buffer holding the current argument as parsed up to
// the current point.
b strings.Builder
pathBuf strings.Builder
// quote holds the opening quote character if one has been found.
quote = byte(0)
)
loop:
for i := 0; i < len(line); i++ {
c := line[i]
b := line[i]
switch {
case c == '\\' && i < len(line)-1 && canBeEscaped(line[i+1]):
case b == '\\' && i < len(line)-1 && canBeEscaped(line[i+1]):
// Skip the escape char and write the next char literally.
i++
b.WriteByte(line[i])
case quote == 0 && (c == '"' || c == '\''):
pathBuf.WriteByte(line[i])
case quote == 0 && (b == '"' || b == '\''):
// Start of quote
quote = c
case quote != 0 && c == quote:
quote = b
case quote != 0 && b == quote:
// End of quote
quote = 0
case b.Len() == 0 && c == '~':
case pathBuf.Len() == 0 && b == '~':
// Support ~ as an alias for the user's home directory.
b.WriteString(d.userHome)
case quote == 0 && c == '#':
pathBuf.WriteString(c.userHome)
case quote == 0 && b == '#':
// Found an unquoted comment in the middle of the line, ignore the rest.
break loop
case quote == 0 && isSpace(rune(c)):
case quote == 0 && isSpace(rune(b)):
// Reached the end of this argument, check if it matches wantPath.
if d.normalizePath(b.String()) == wantPath {
if c.normalizePath(pathBuf.String()) == wantPath {
return true
}
b.Reset()
pathBuf.Reset()
default:
// By default just append the current character to the current
// argument.
b.WriteByte(c)
// By default just append the current byte to the path.
pathBuf.WriteByte(b)
}
}
if quote != 0 {
// Unmatched quote.
return false
}
// Handle an argument that ends at the end of the line.
return d.normalizePath(b.String()) == wantPath
return c.normalizePath(pathBuf.String()) == wantPath
}

func (d *SSHDiag) normalizePath(path string) string {
if d.isWindows {
func (c *SSHConfigChecker) normalizePath(path string) string {
if c.isWindows {
// Normalize all paths to use unix-style separators since OpenSSH
// supports / or \\ on Windows.
path = strings.ReplaceAll(path, `\`, `/`)
Expand Down
Loading
Loading