From d8f18437179b2c7a55ac47f3b365b62d33638e2f Mon Sep 17 00:00:00 2001 From: Andrew LeFevre Date: Tue, 11 Apr 2023 11:42:20 -0400 Subject: [PATCH] let SFTP server figure out remote users home directories Previously a Teleport client using SFTP would resolve remote host user home directories by making a subsystem request to a Teleport server which would return the home directory. The problem was the subsystem request counted as an open session, which could make the SFTP file transfer fail. This was frustrating and didn't make much sense, but after reading the SFTP specification again I realized that SFTP servers are to handle relative paths by assuming they start at the user's home directory. So let the server figure out the correct path and remove any tilde prefixes from remote paths. --- lib/srv/regular/sshserver.go | 1 + lib/sshutils/sftp/sftp.go | 110 ++++++++++++--------------------- lib/sshutils/sftp/sftp_test.go | 28 ++++++--- 3 files changed, 60 insertions(+), 79 deletions(-) diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index 0ba18c7f6eca8..5dece9a24dbff 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -2040,6 +2040,7 @@ func (s *Server) parseSubsystemRequest(req *ssh.Request, ch ssh.Channel, ctx *sr return parseProxySubsys(r.Name, s, ctx) case s.proxyMode && strings.HasPrefix(r.Name, "proxysites"): return parseProxySitesSubsys(r.Name, s) + // DELETE IN 15.0.0 (deprecated, tsh will not be using this anymore) case r.Name == teleport.GetHomeDirSubsystem: return newHomeDirSubsys(), nil case r.Name == sftpSubsystem: diff --git a/lib/sshutils/sftp/sftp.go b/lib/sshutils/sftp/sftp.go index efadd14c83b9f..0b089d5c7c316 100644 --- a/lib/sshutils/sftp/sftp.go +++ b/lib/sshutils/sftp/sftp.go @@ -18,7 +18,6 @@ limitations under the License. package sftp import ( - "bytes" "context" "errors" "fmt" @@ -26,7 +25,6 @@ import ( "io/fs" "net/http" "os" - "os/user" "path" // SFTP requires UNIX-style path separators "runtime" "strconv" @@ -39,7 +37,6 @@ import ( log "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh" - "github.com/gravitational/teleport" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/sshutils/scp" ) @@ -53,8 +50,6 @@ type Options struct { PreserveAttrs bool } -type homeDirRetriever func() (string, error) - // Config describes the settings of a file transfer type Config struct { srcPaths []string @@ -63,10 +58,6 @@ type Config struct { dstFS FileSystem opts Options - // getHomeDir returns the home directory of the remote user of the - // SSH session - getHomeDir homeDirRetriever - // ProgressStream is a callback to return a read/writer for printing the progress // (used only on the client) ProgressStream func(fileInfo os.FileInfo) io.ReadWriter @@ -322,95 +313,70 @@ func (c *Config) initFS(sshClient *ssh.Client, client *sftp.Client) error { return nil } - if c.getHomeDir == nil { - c.getHomeDir = func() (string, error) { - return getRemoteHomeDir(sshClient) - } - } - return trace.Wrap(c.expandPaths(srcOK, dstOK)) } func (c *Config) expandPaths(srcIsRemote, dstIsRemote bool) (err error) { - srcHomeRetriever := getLocalHomeDir if srcIsRemote { - srcHomeRetriever = c.getHomeDir - } - for i, srcPath := range c.srcPaths { - c.srcPaths[i], err = expandPath(srcPath, srcHomeRetriever) - if err != nil { - return trace.Wrap(err) + for i, srcPath := range c.srcPaths { + c.srcPaths[i], err = expandPath(srcPath) + if err != nil { + return trace.Wrap(err, "error expanding %q", srcPath) + } } } - dstHomeRetriever := getLocalHomeDir if dstIsRemote { - dstHomeRetriever = c.getHomeDir + c.dstPath, err = expandPath(c.dstPath) + if err != nil { + return trace.Wrap(err, "error expanding %q", c.dstPath) + } } - c.dstPath, err = expandPath(c.dstPath, dstHomeRetriever) - return trace.Wrap(err) -} - -func getLocalHomeDir() (string, error) { - u, err := user.Current() - if err != nil { - return "", trace.Wrap(err) - } - return u.HomeDir, nil + return nil } -func expandPath(pathStr string, getHomeDir homeDirRetriever) (string, error) { - if !needsExpansion(pathStr) { +func expandPath(pathStr string) (string, error) { + pfxLen, ok := homeDirPrefixLen(pathStr) + if !ok { return pathStr, nil } - homeDir, err := getHomeDir() - if err != nil { - return "", trace.Wrap(err) + // Removing the home dir prefix would mean returning an empty string, + // which is supported by SFTP but won't be as clear in logs or audit + // events. Since the SFTP server will be rooted at the user's home + // directory, "." and "" are equivalent in this context. + if pathStr == "~" { + return ".", nil + } + if pfxLen == 1 && len(pathStr) > 1 { + return "", trace.BadParameter("expanding remote ~user paths is not supported, specify an absolute path instead") } - // this is safe because we verified that all paths are non-empty - // in CreateUploadConfig/CreateDownloadConfig - return path.Join(homeDir, pathStr[1:]), nil + // if an SFTP path is not absolute, it is assumed to start at the user's + // home directory so just strip the prefix and let the SFTP server + // figure out the correct remote path + return pathStr[pfxLen:], nil } -// needsExpansion returns true if path is '~', '~/', or '~\' on Windows -func needsExpansion(path string) bool { - if len(path) == 1 { - return path == "~" +// homeDirPrefixLen returns the length of a set of characters that +// indicates the user wants the path to begin with a user's home +// directory and a bool that indicates whether such a prefix exists. +func homeDirPrefixLen(path string) (int, bool) { + if strings.HasPrefix(path, "~/") { + return 2, true } - // allow '~\' or '~/' on Windows since '\' is the canonical path // separator but some users may use '/' instead if runtime.GOOS == "windows" && strings.HasPrefix(path, `~\`) { - return true + return 2, true } - return strings.HasPrefix(path, "~/") -} -// getRemoteHomeDir returns the home directory of the remote user of -// the SSH connection -func getRemoteHomeDir(sshClient *ssh.Client) (string, error) { - s, err := sshClient.NewSession() - if err != nil { - return "", trace.Wrap(err) - } - defer s.Close() - if err := s.RequestSubsystem(teleport.GetHomeDirSubsystem); err != nil { - return "", trace.Wrap(err) - } - r, err := s.StdoutPipe() - if err != nil { - return "", trace.Wrap(err) + if len(path) >= 1 && path[0] == '~' { + return 1, true } - var homeDirBuf bytes.Buffer - if _, err := io.Copy(&homeDirBuf, r); err != nil { - return "", trace.Wrap(err) - } - - return homeDirBuf.String(), nil + return -1, false } // transfer performs file transfers @@ -513,6 +479,8 @@ func (c *Config) transfer(ctx context.Context) error { // transferDir transfers a directory func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo) error { + c.Log.Debugf("copying %s dir %q to %s dir %q", c.srcFS.Type(), srcPath, c.dstFS.Type(), dstPath) + err := c.dstFS.Mkdir(ctx, dstPath) if err != nil && !errors.Is(err, os.ErrExist) { return trace.Errorf("error creating %s directory %q: %w", c.dstFS.Type(), dstPath, err) @@ -555,6 +523,8 @@ func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFi // transferFile transfers a file func (c *Config) transferFile(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo) error { + c.Log.Debugf("copying %s file %q to %s file %q", c.srcFS.Type(), srcPath, c.dstFS.Type(), dstPath) + srcFile, err := c.srcFS.Open(ctx, srcPath) if err != nil { return trace.Errorf("error opening %s file %q: %w", c.srcFS.Type(), srcPath, err) diff --git a/lib/sshutils/sftp/sftp_test.go b/lib/sshutils/sftp/sftp_test.go index d9971bb6bb750..1a30ce7e3880a 100644 --- a/lib/sshutils/sftp/sftp_test.go +++ b/lib/sshutils/sftp/sftp_test.go @@ -33,6 +33,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/utils" @@ -560,6 +561,7 @@ func TestHomeDirExpansion(t *testing.T) { name string path string expandedPath string + errCheck require.ErrorAssertionFunc }{ { name: "absolute path", @@ -567,26 +569,34 @@ func TestHomeDirExpansion(t *testing.T) { expandedPath: "/foo/bar", }, { - name: "path with tilde", + name: "path with tilde-slash", path: "~/foo/bar", - expandedPath: "/home/user/foo/bar", + expandedPath: "foo/bar", }, { name: "just tilde", path: "~", - expandedPath: "/home/user", + expandedPath: ".", }, - } - getHomeDirFunc := func() (string, error) { - return "/home/user", nil + { + name: "~user path", + path: "~user/foo", + errCheck: func(t require.TestingT, err error, i ...interface{}) { + require.True(t, trace.IsBadParameter(err)) + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - expanded, err := expandPath(tt.path, getHomeDirFunc) - require.NoError(t, err) - require.Equal(t, tt.expandedPath, expanded) + expanded, err := expandPath(tt.path) + if tt.errCheck == nil { + require.NoError(t, err) + require.Equal(t, tt.expandedPath, expanded) + } else { + tt.errCheck(t, err) + } }) } }