Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/srv/forward/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ func (h *proxyHandlers) Filelist(req *sftp.Request) (_ sftp.ListerAt, err error)
return lister, nil
}

// RealPath canonicalizes a path name, including resolving ".." and
// following symlinks. Required to implement [sftp.RealPathFileLister].
func (h *proxyHandlers) RealPath(path string) (string, error) {
return h.remoteFS.RealPath(path)
}

func (h *proxyHandlers) sendSFTPEvent(req *sftp.Request, reqErr error) {
wd, err := h.remoteFS.Getwd()
if err != nil {
Expand Down
4 changes: 4 additions & 0 deletions lib/sshutils/sftp/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ func (h *httpFS) Getwd() (string, error) {
return "", nil
}

func (h *httpFS) RealPath(path string) (string, error) {
return path, nil
}

type nopWriteCloser struct {
io.Writer
}
Expand Down
8 changes: 8 additions & 0 deletions lib/sshutils/sftp/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,3 +142,11 @@ func (l localFS) Readlink(name string) (string, error) {
func (l localFS) Getwd() (string, error) {
return os.Getwd()
}

func (l localFS) RealPath(path string) (string, error) {
path, err := filepath.Abs(path)
if err != nil {
return "", err
}
return filepath.EvalSymlinks(path)
}
23 changes: 19 additions & 4 deletions lib/sshutils/sftp/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ type FileSystem interface {
Readlink(name string) (string, error)
// Getwd gets the current working directory.
Getwd() (string, error)
// RealPath canonicalizes a path name, including resolving ".." and
// following symlinks.
RealPath(path string) (string, error)
}

// CreateUploadConfig returns a Config ready to upload files over SFTP.
Expand Down Expand Up @@ -548,7 +551,7 @@ func (c *Config) transfer(ctx context.Context) error {
}

if fi.IsDir() {
if err := c.transferDir(ctx, dstPath, matchedPaths[i], fi); err != nil {
if err := c.transferDir(ctx, dstPath, matchedPaths[i], fi, nil); err != nil {
return trace.Wrap(err)
}
} else {
Expand All @@ -562,10 +565,22 @@ func (c *Config) transfer(ctx context.Context) error {
}

// transferDir transfers a directory
func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo) error {
func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo, visited map[string]struct{}) error {
if visited == nil {
visited = make(map[string]struct{})
}
realSrcPath, err := c.srcFS.RealPath(srcPath)
if err != nil {
return trace.Wrap(err)
}
if _, ok := visited[realSrcPath]; ok {
c.Log.DebugContext(ctx, "symlink loop detected, directory will be skipped", "link", srcPath, "target", realSrcPath)
return nil
}
visited[realSrcPath] = struct{}{}
c.Log.DebugContext(ctx, "transferring contents of directory", "source_fs", c.srcFS.Type(), "source_path", srcPath, "dest_fs", c.dstFS.Type(), "dest_path", dstPath)

err := c.dstFS.Mkdir(dstPath)
err = c.dstFS.Mkdir(dstPath)
if err != nil && !errors.Is(err, os.ErrExist) {
return trace.Errorf("error creating %s directory %q: %w", c.dstFS.Type(), dstPath, err)
}
Expand All @@ -583,7 +598,7 @@ func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFi
lSubPath := path.Join(srcPath, info.Name())

if info.IsDir() {
if err := c.transferDir(ctx, dstSubPath, lSubPath, info); err != nil {
if err := c.transferDir(ctx, dstSubPath, lSubPath, info, visited); err != nil {
return trace.Wrap(err)
}
} else {
Expand Down
76 changes: 76 additions & 0 deletions lib/sshutils/sftp/sftp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,82 @@ func TestCopyingSymlinkedFile(t *testing.T) {
checkTransfer(t, false, dstPath, linkPath)
}

type mockFS struct {
localFS
fileAccesses map[string]int
}

func (m *mockFS) Open(path string) (File, error) {
realPath, err := filepath.EvalSymlinks(path)
if err != nil {
return nil, trace.Wrap(err)
}
m.fileAccesses[realPath]++
return m.localFS.Open(path)
}

func TestRecursiveSymlinks(t *testing.T) {
// Create files and symlinks.
root := t.TempDir()
t.Chdir(root)
srcDir := filepath.Join(root, "a")
createDir(t, filepath.Join(srcDir, "b/c"))
fileA := "a/a.txt"
fileB := "a/b/b.txt"
fileC := "a/b/c/c.txt"
for _, file := range []string{fileA, fileB, fileC} {
createFile(t, filepath.Join(root, file))
}
require.NoError(t, os.Symlink(srcDir, filepath.Join(srcDir, "abs_link")))
require.NoError(t, os.Symlink("..", filepath.Join(srcDir, "b/rel_link")))

tests := []struct {
name string
srcDir string
}{
{
name: "absolute",
srcDir: srcDir,
},
{
name: "relative",
srcDir: filepath.Base(srcDir),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
// Perform the transfer.
dstDir := filepath.Join(root, "dst")
t.Cleanup(func() { os.RemoveAll(dstDir) })

cfg, err := CreateDownloadConfig(tc.srcDir, dstDir, Options{Recursive: true})
require.NoError(t, err)
// use all local filesystems to avoid SSH overhead
srcFS := &mockFS{fileAccesses: make(map[string]int)}
cfg.srcFS = srcFS
require.NoError(t, cfg.initFS(nil))
require.NoError(t, cfg.transfer(t.Context()))

// Check results. Don't use checkTransfer() as the directories will not have
// matching sizes (the symlinks that aren't copied over).
for _, file := range []string{fileA, fileB, fileC} {
srcFile, err := filepath.EvalSymlinks(filepath.Join(filepath.Dir(tc.srcDir), file))
require.NoError(t, err)
srcInfo, err := os.Stat(srcFile)
require.NoError(t, err)
dstFile, err := filepath.EvalSymlinks(filepath.Join(dstDir, file))
require.NoError(t, err)
dstInfo, err := os.Stat(dstFile)
require.NoError(t, err)
compareFiles(t, false, dstInfo, srcInfo, dstFile, srcFile)
// Check that the file was only opened once.
accesses := srcFS.fileAccesses[srcFile]
require.Equal(t, 1, accesses, "file %q was opened %d times", srcFile, accesses)
}
})
}
}

func TestHTTPUpload(t *testing.T) {
t.Parallel()

Expand Down
7 changes: 7 additions & 0 deletions tool/teleport/common/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"os"
"os/user"
"path"
"path/filepath"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -231,6 +232,12 @@ func (s *sftpHandler) Filelist(req *sftp.Request) (_ sftp.ListerAt, retErr error
return sftputils.HandleFilelist(req, nil /* local filesystem */)
}

// RealPath canonicalizes a path name, including resolving ".." and
// following symlinks. Required to implement [sftp.RealPathFileLister].
func (s *sftpHandler) RealPath(path string) (string, error) {
return filepath.EvalSymlinks(path)
}

func (s *sftpHandler) sendSFTPEvent(req *sftp.Request, reqErr error) {
wd, err := os.Getwd()
if err != nil {
Expand Down
Loading