Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/srv/regular/sshserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -2040,6 +2040,7 @@ func (s *Server) parseSubsystemRequest(req *ssh.Request, ch ssh.Channel, ctx *sr
return parseProxySubsys(r.Name, s, ctx)
case s.proxyMode && strings.HasPrefix(r.Name, "proxysites"):
return parseProxySitesSubsys(r.Name, s)
// DELETE IN 15.0.0 (deprecated, tsh will not be using this anymore)
case r.Name == teleport.GetHomeDirSubsystem:
return newHomeDirSubsys(), nil
case r.Name == sftpSubsystem:
Expand Down
110 changes: 40 additions & 70 deletions lib/sshutils/sftp/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,13 @@ limitations under the License.
package sftp

import (
"bytes"
"context"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"os/user"
"path" // SFTP requires UNIX-style path separators
"runtime"
"strconv"
Expand All @@ -39,7 +37,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/sshutils/scp"
)
Expand All @@ -53,8 +50,6 @@ type Options struct {
PreserveAttrs bool
}

type homeDirRetriever func() (string, error)

// Config describes the settings of a file transfer
type Config struct {
srcPaths []string
Expand All @@ -63,10 +58,6 @@ type Config struct {
dstFS FileSystem
opts Options

// getHomeDir returns the home directory of the remote user of the
// SSH session
getHomeDir homeDirRetriever

// ProgressStream is a callback to return a read/writer for printing the progress
// (used only on the client)
ProgressStream func(fileInfo os.FileInfo) io.ReadWriter
Expand Down Expand Up @@ -322,95 +313,70 @@ func (c *Config) initFS(sshClient *ssh.Client, client *sftp.Client) error {
return nil
}

if c.getHomeDir == nil {
c.getHomeDir = func() (string, error) {
return getRemoteHomeDir(sshClient)
}
}

return trace.Wrap(c.expandPaths(srcOK, dstOK))
}

func (c *Config) expandPaths(srcIsRemote, dstIsRemote bool) (err error) {
srcHomeRetriever := getLocalHomeDir
if srcIsRemote {
srcHomeRetriever = c.getHomeDir
}
for i, srcPath := range c.srcPaths {
c.srcPaths[i], err = expandPath(srcPath, srcHomeRetriever)
if err != nil {
return trace.Wrap(err)
for i, srcPath := range c.srcPaths {
c.srcPaths[i], err = expandPath(srcPath)
if err != nil {
return trace.Wrap(err, "error expanding %q", srcPath)
}
}
}

dstHomeRetriever := getLocalHomeDir
if dstIsRemote {
dstHomeRetriever = c.getHomeDir
c.dstPath, err = expandPath(c.dstPath)
if err != nil {
return trace.Wrap(err, "error expanding %q", c.dstPath)
}
}
c.dstPath, err = expandPath(c.dstPath, dstHomeRetriever)

return trace.Wrap(err)
}

func getLocalHomeDir() (string, error) {
u, err := user.Current()
if err != nil {
return "", trace.Wrap(err)
}
return u.HomeDir, nil
return nil
}

func expandPath(pathStr string, getHomeDir homeDirRetriever) (string, error) {
if !needsExpansion(pathStr) {
func expandPath(pathStr string) (string, error) {
pfxLen, ok := homeDirPrefixLen(pathStr)
if !ok {
return pathStr, nil
}

homeDir, err := getHomeDir()
if err != nil {
return "", trace.Wrap(err)
// Removing the home dir prefix would mean returning an empty string,
// which is supported by SFTP but won't be as clear in logs or audit
// events. Since the SFTP server will be rooted at the user's home
// directory, "." and "" are equivalent in this context.
Comment thread
capnspacehook marked this conversation as resolved.
if pathStr == "~" {
return ".", nil
}
if pfxLen == 1 && len(pathStr) > 1 {
return "", trace.BadParameter("expanding remote ~user paths is not supported, specify an absolute path instead")
}

// this is safe because we verified that all paths are non-empty
// in CreateUploadConfig/CreateDownloadConfig
return path.Join(homeDir, pathStr[1:]), nil
// if an SFTP path is not absolute, it is assumed to start at the user's
// home directory so just strip the prefix and let the SFTP server
// figure out the correct remote path
return pathStr[pfxLen:], nil
}

// needsExpansion returns true if path is '~', '~/', or '~\' on Windows
func needsExpansion(path string) bool {
if len(path) == 1 {
return path == "~"
// homeDirPrefixLen returns the length of a set of characters that
// indicates the user wants the path to begin with a user's home
// directory and a bool that indicates whether such a prefix exists.
func homeDirPrefixLen(path string) (int, bool) {
if strings.HasPrefix(path, "~/") {
return 2, true
}

// allow '~\' or '~/' on Windows since '\' is the canonical path
// separator but some users may use '/' instead
if runtime.GOOS == "windows" && strings.HasPrefix(path, `~\`) {
return true
return 2, true
}
return strings.HasPrefix(path, "~/")
}

// getRemoteHomeDir returns the home directory of the remote user of
// the SSH connection
func getRemoteHomeDir(sshClient *ssh.Client) (string, error) {
s, err := sshClient.NewSession()
if err != nil {
return "", trace.Wrap(err)
}
defer s.Close()
if err := s.RequestSubsystem(teleport.GetHomeDirSubsystem); err != nil {
return "", trace.Wrap(err)
}
r, err := s.StdoutPipe()
if err != nil {
return "", trace.Wrap(err)
if len(path) >= 1 && path[0] == '~' {
return 1, true
}

var homeDirBuf bytes.Buffer
if _, err := io.Copy(&homeDirBuf, r); err != nil {
return "", trace.Wrap(err)
}

return homeDirBuf.String(), nil
return -1, false
}

// transfer performs file transfers
Expand Down Expand Up @@ -513,6 +479,8 @@ func (c *Config) transfer(ctx context.Context) error {

// transferDir transfers a directory
func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo) error {
c.Log.Debugf("copying %s dir %q to %s dir %q", c.srcFS.Type(), srcPath, c.dstFS.Type(), dstPath)

err := c.dstFS.Mkdir(ctx, dstPath)
if err != nil && !errors.Is(err, os.ErrExist) {
return trace.Errorf("error creating %s directory %q: %w", c.dstFS.Type(), dstPath, err)
Expand Down Expand Up @@ -555,6 +523,8 @@ func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFi

// transferFile transfers a file
func (c *Config) transferFile(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo) error {
c.Log.Debugf("copying %s file %q to %s file %q", c.srcFS.Type(), srcPath, c.dstFS.Type(), dstPath)

srcFile, err := c.srcFS.Open(ctx, srcPath)
if err != nil {
return trace.Errorf("error opening %s file %q: %w", c.srcFS.Type(), srcPath, err)
Expand Down
28 changes: 19 additions & 9 deletions lib/sshutils/sftp/sftp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/gravitational/trace"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -560,33 +561,42 @@ func TestHomeDirExpansion(t *testing.T) {
name string
path string
expandedPath string
errCheck require.ErrorAssertionFunc
}{
{
name: "absolute path",
path: "/foo/bar",
expandedPath: "/foo/bar",
},
{
name: "path with tilde",
name: "path with tilde-slash",
path: "~/foo/bar",
expandedPath: "/home/user/foo/bar",
expandedPath: "foo/bar",
},
{
name: "just tilde",
path: "~",
expandedPath: "/home/user",
expandedPath: ".",
},
}

getHomeDirFunc := func() (string, error) {
return "/home/user", nil
{
name: "~user path",
path: "~user/foo",
errCheck: func(t require.TestingT, err error, i ...interface{}) {
require.True(t, trace.IsBadParameter(err))
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expanded, err := expandPath(tt.path, getHomeDirFunc)
require.NoError(t, err)
require.Equal(t, tt.expandedPath, expanded)
expanded, err := expandPath(tt.path)
if tt.errCheck == nil {
require.NoError(t, err)
require.Equal(t, tt.expandedPath, expanded)
} else {
tt.errCheck(t, err)
}
})
}
}
Expand Down