diff --git a/lib/srv/forward/sftp.go b/lib/srv/forward/sftp.go index 9aa93da548491..f158ee9f1f853 100644 --- a/lib/srv/forward/sftp.go +++ b/lib/srv/forward/sftp.go @@ -211,6 +211,12 @@ func (h *proxyHandlers) Filelist(req *sftp.Request) (_ sftp.ListerAt, err error) return lister, nil } +// RealPath canonicalizes a path name, including resolving ".." and +// following symlinks. Required to implement [sftp.RealPathFileLister]. +func (h *proxyHandlers) RealPath(path string) (string, error) { + return h.remoteFS.RealPath(path) +} + func (h *proxyHandlers) sendSFTPEvent(req *sftp.Request, reqErr error) { wd, err := h.remoteFS.Getwd() if err != nil { diff --git a/lib/sshutils/sftp/http.go b/lib/sshutils/sftp/http.go index 33f5bc60c0b9c..d657039d1d0ef 100644 --- a/lib/sshutils/sftp/http.go +++ b/lib/sshutils/sftp/http.go @@ -173,6 +173,10 @@ func (h *httpFS) Getwd() (string, error) { return "", nil } +func (h *httpFS) RealPath(path string) (string, error) { + return path, nil +} + type nopWriteCloser struct { io.Writer } diff --git a/lib/sshutils/sftp/local.go b/lib/sshutils/sftp/local.go index 8496e87d95abc..d76c779edb1b8 100644 --- a/lib/sshutils/sftp/local.go +++ b/lib/sshutils/sftp/local.go @@ -142,3 +142,11 @@ func (l localFS) Readlink(name string) (string, error) { func (l localFS) Getwd() (string, error) { return os.Getwd() } + +func (l localFS) RealPath(path string) (string, error) { + path, err := filepath.Abs(path) + if err != nil { + return "", err + } + return filepath.EvalSymlinks(path) +} diff --git a/lib/sshutils/sftp/sftp.go b/lib/sshutils/sftp/sftp.go index 69b1146ba23e8..823dfce15b5ab 100644 --- a/lib/sshutils/sftp/sftp.go +++ b/lib/sshutils/sftp/sftp.go @@ -160,6 +160,9 @@ type FileSystem interface { Readlink(name string) (string, error) // Getwd gets the current working directory. Getwd() (string, error) + // RealPath canonicalizes a path name, including resolving ".." and + // following symlinks. + RealPath(path string) (string, error) } // CreateUploadConfig returns a Config ready to upload files over SFTP. @@ -548,7 +551,7 @@ func (c *Config) transfer(ctx context.Context) error { } if fi.IsDir() { - if err := c.transferDir(ctx, dstPath, matchedPaths[i], fi); err != nil { + if err := c.transferDir(ctx, dstPath, matchedPaths[i], fi, nil); err != nil { return trace.Wrap(err) } } else { @@ -562,10 +565,22 @@ func (c *Config) transfer(ctx context.Context) error { } // transferDir transfers a directory -func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo) error { +func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFileInfo os.FileInfo, visited map[string]struct{}) error { + if visited == nil { + visited = make(map[string]struct{}) + } + realSrcPath, err := c.srcFS.RealPath(srcPath) + if err != nil { + return trace.Wrap(err) + } + if _, ok := visited[realSrcPath]; ok { + c.Log.DebugContext(ctx, "symlink loop detected, directory will be skipped", "link", srcPath, "target", realSrcPath) + return nil + } + visited[realSrcPath] = struct{}{} c.Log.DebugContext(ctx, "transferring contents of directory", "source_fs", c.srcFS.Type(), "source_path", srcPath, "dest_fs", c.dstFS.Type(), "dest_path", dstPath) - err := c.dstFS.Mkdir(dstPath) + err = c.dstFS.Mkdir(dstPath) if err != nil && !errors.Is(err, os.ErrExist) { return trace.Errorf("error creating %s directory %q: %w", c.dstFS.Type(), dstPath, err) } @@ -583,7 +598,7 @@ func (c *Config) transferDir(ctx context.Context, dstPath, srcPath string, srcFi lSubPath := path.Join(srcPath, info.Name()) if info.IsDir() { - if err := c.transferDir(ctx, dstSubPath, lSubPath, info); err != nil { + if err := c.transferDir(ctx, dstSubPath, lSubPath, info, visited); err != nil { return trace.Wrap(err) } } else { diff --git a/lib/sshutils/sftp/sftp_test.go b/lib/sshutils/sftp/sftp_test.go index 9df26071143c4..642497a34ba32 100644 --- a/lib/sshutils/sftp/sftp_test.go +++ b/lib/sshutils/sftp/sftp_test.go @@ -633,6 +633,82 @@ func TestCopyingSymlinkedFile(t *testing.T) { checkTransfer(t, false, dstPath, linkPath) } +type mockFS struct { + localFS + fileAccesses map[string]int +} + +func (m *mockFS) Open(path string) (File, error) { + realPath, err := filepath.EvalSymlinks(path) + if err != nil { + return nil, trace.Wrap(err) + } + m.fileAccesses[realPath]++ + return m.localFS.Open(path) +} + +func TestRecursiveSymlinks(t *testing.T) { + // Create files and symlinks. + root := t.TempDir() + t.Chdir(root) + srcDir := filepath.Join(root, "a") + createDir(t, filepath.Join(srcDir, "b/c")) + fileA := "a/a.txt" + fileB := "a/b/b.txt" + fileC := "a/b/c/c.txt" + for _, file := range []string{fileA, fileB, fileC} { + createFile(t, filepath.Join(root, file)) + } + require.NoError(t, os.Symlink(srcDir, filepath.Join(srcDir, "abs_link"))) + require.NoError(t, os.Symlink("..", filepath.Join(srcDir, "b/rel_link"))) + + tests := []struct { + name string + srcDir string + }{ + { + name: "absolute", + srcDir: srcDir, + }, + { + name: "relative", + srcDir: filepath.Base(srcDir), + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + // Perform the transfer. + dstDir := filepath.Join(root, "dst") + t.Cleanup(func() { os.RemoveAll(dstDir) }) + + cfg, err := CreateDownloadConfig(tc.srcDir, dstDir, Options{Recursive: true}) + require.NoError(t, err) + // use all local filesystems to avoid SSH overhead + srcFS := &mockFS{fileAccesses: make(map[string]int)} + cfg.srcFS = srcFS + require.NoError(t, cfg.initFS(nil)) + require.NoError(t, cfg.transfer(t.Context())) + + // Check results. Don't use checkTransfer() as the directories will not have + // matching sizes (the symlinks that aren't copied over). + for _, file := range []string{fileA, fileB, fileC} { + srcFile, err := filepath.EvalSymlinks(filepath.Join(filepath.Dir(tc.srcDir), file)) + require.NoError(t, err) + srcInfo, err := os.Stat(srcFile) + require.NoError(t, err) + dstFile, err := filepath.EvalSymlinks(filepath.Join(dstDir, file)) + require.NoError(t, err) + dstInfo, err := os.Stat(dstFile) + require.NoError(t, err) + compareFiles(t, false, dstInfo, srcInfo, dstFile, srcFile) + // Check that the file was only opened once. + accesses := srcFS.fileAccesses[srcFile] + require.Equal(t, 1, accesses, "file %q was opened %d times", srcFile, accesses) + } + }) + } +} + func TestHTTPUpload(t *testing.T) { t.Parallel() diff --git a/tool/teleport/common/sftp.go b/tool/teleport/common/sftp.go index db73eb17e27fc..39de1e3432b13 100644 --- a/tool/teleport/common/sftp.go +++ b/tool/teleport/common/sftp.go @@ -30,6 +30,7 @@ import ( "os" "os/user" "path" + "path/filepath" "strings" "sync" "time" @@ -231,6 +232,12 @@ func (s *sftpHandler) Filelist(req *sftp.Request) (_ sftp.ListerAt, retErr error return sftputils.HandleFilelist(req, nil /* local filesystem */) } +// RealPath canonicalizes a path name, including resolving ".." and +// following symlinks. Required to implement [sftp.RealPathFileLister]. +func (s *sftpHandler) RealPath(path string) (string, error) { + return filepath.EvalSymlinks(path) +} + func (s *sftpHandler) sendSFTPEvent(req *sftp.Request, reqErr error) { wd, err := os.Getwd() if err != nil {