diff --git a/lib/sshutils/sftp/local.go b/lib/sshutils/sftp/local.go index f2a2ebd098a44..04aaf6d469ea2 100644 --- a/lib/sshutils/sftp/local.go +++ b/lib/sshutils/sftp/local.go @@ -21,6 +21,7 @@ import ( "io" "io/fs" "os" + "path/filepath" "time" "github.com/gravitational/trace" @@ -36,9 +37,22 @@ func (l localFS) Type() string { return "local" } +func (l *localFS) Glob(ctx context.Context, pattern string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, trace.Wrap(err) + } + + matches, err := filepath.Glob(pattern) + if err != nil { + return nil, trace.Wrap(err) + } + + return matches, nil +} + func (l localFS) Stat(ctx context.Context, path string) (os.FileInfo, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, trace.Wrap(err) } fi, err := os.Stat(path) @@ -51,28 +65,36 @@ func (l localFS) Stat(ctx context.Context, path string) (os.FileInfo, error) { func (l localFS) ReadDir(ctx context.Context, path string) ([]os.FileInfo, error) { if err := ctx.Err(); err != nil { - return nil, err - } - - // normally os.ReadDir would be used as it's potentially more efficient, - // but because we want os.FileInfos of every file this is easier - f, err := os.Open(path) - if err != nil { return nil, trace.Wrap(err) } - defer f.Close() - fileInfos, err := f.Readdir(-1) + entries, err := os.ReadDir(path) if err != nil { return nil, trace.Wrap(err) } + fileInfos := make([]fs.FileInfo, len(entries)) + for i, entry := range entries { + info, err := entry.Info() + if err != nil { + return nil, trace.Wrap(err) + } + // if the file is a symlink, return the info of the linked file + if info.Mode().Type()&os.ModeSymlink != 0 { + info, err = os.Stat(filepath.Join(path, info.Name())) + if err != nil { + return nil, trace.Wrap(err) + } + } + + fileInfos[i] = info + } return fileInfos, nil } func (l localFS) Open(ctx context.Context, path string) (fs.File, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, trace.Wrap(err) } f, err := os.Open(path) @@ -85,7 +107,7 @@ func (l localFS) Open(ctx context.Context, path string) (fs.File, error) { func (l localFS) Create(ctx context.Context, path string) (io.WriteCloser, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, trace.Wrap(err) } f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, defaults.FilePermissions) @@ -98,7 +120,7 @@ func (l localFS) Create(ctx context.Context, path string) (io.WriteCloser, error func (l localFS) Mkdir(ctx context.Context, path string) error { if err := ctx.Err(); err != nil { - return err + return trace.Wrap(err) } err := os.MkdirAll(path, defaults.DirectoryPermissions) @@ -111,7 +133,7 @@ func (l localFS) Mkdir(ctx context.Context, path string) error { func (l localFS) Chmod(ctx context.Context, path string, mode os.FileMode) error { if err := ctx.Err(); err != nil { - return err + return trace.Wrap(err) } return trace.Wrap(os.Chmod(path, mode)) @@ -119,7 +141,7 @@ func (l localFS) Chmod(ctx context.Context, path string, mode os.FileMode) error func (l localFS) Chtimes(ctx context.Context, path string, atime, mtime time.Time) error { if err := ctx.Err(); err != nil { - return err + return trace.Wrap(err) } return trace.ConvertSystemError(os.Chtimes(path, atime, mtime)) diff --git a/lib/sshutils/sftp/remote.go b/lib/sshutils/sftp/remote.go index d992a2091b928..3283ada75a31e 100644 --- a/lib/sshutils/sftp/remote.go +++ b/lib/sshutils/sftp/remote.go @@ -21,6 +21,7 @@ import ( "io" "io/fs" "os" + portablepath "path" "time" "github.com/gravitational/trace" @@ -37,9 +38,22 @@ func (r *remoteFS) Type() string { return "remote" } +func (r *remoteFS) Glob(ctx context.Context, pattern string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, trace.Wrap(err) + } + + matches, err := r.c.Glob(pattern) + if err != nil { + return nil, trace.Wrap(err) + } + + return matches, nil +} + func (r *remoteFS) Stat(ctx context.Context, path string) (os.FileInfo, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, trace.Wrap(err) } fi, err := r.c.Stat(path) @@ -52,20 +66,29 @@ func (r *remoteFS) Stat(ctx context.Context, path string) (os.FileInfo, error) { func (r *remoteFS) ReadDir(ctx context.Context, path string) ([]os.FileInfo, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, trace.Wrap(err) } fileInfos, err := r.c.ReadDir(path) if err != nil { return nil, trace.Wrap(err) } + for i := range fileInfos { + // if the file is a symlink, return the info of the linked file + if fileInfos[i].Mode().Type()&os.ModeSymlink != 0 { + fileInfos[i], err = r.c.Stat(portablepath.Join(path, fileInfos[i].Name())) + if err != nil { + return nil, trace.Wrap(err) + } + } + } return fileInfos, nil } func (r *remoteFS) Open(ctx context.Context, path string) (fs.File, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, trace.Wrap(err) } f, err := r.c.Open(path) @@ -78,7 +101,7 @@ func (r *remoteFS) Open(ctx context.Context, path string) (fs.File, error) { func (r *remoteFS) Create(ctx context.Context, path string) (io.WriteCloser, error) { if err := ctx.Err(); err != nil { - return nil, err + return nil, trace.Wrap(err) } f, err := r.c.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC) @@ -91,7 +114,7 @@ func (r *remoteFS) Create(ctx context.Context, path string) (io.WriteCloser, err func (r *remoteFS) Mkdir(ctx context.Context, path string) error { if err := ctx.Err(); err != nil { - return err + return trace.Wrap(err) } err := r.c.MkdirAll(path) @@ -104,7 +127,7 @@ func (r *remoteFS) Mkdir(ctx context.Context, path string) error { func (r *remoteFS) Chmod(ctx context.Context, path string, mode os.FileMode) error { if err := ctx.Err(); err != nil { - return err + return trace.Wrap(err) } return trace.Wrap(r.c.Chmod(path, mode)) @@ -112,7 +135,7 @@ func (r *remoteFS) Chmod(ctx context.Context, path string, mode os.FileMode) err func (r *remoteFS) Chtimes(ctx context.Context, path string, atime, mtime time.Time) error { if err := ctx.Err(); err != nil { - return err + return trace.Wrap(err) } return trace.Wrap(r.c.Chtimes(path, atime, mtime)) diff --git a/lib/sshutils/sftp/sftp.go b/lib/sshutils/sftp/sftp.go index 19dea33589355..4bdcdba098642 100644 --- a/lib/sshutils/sftp/sftp.go +++ b/lib/sshutils/sftp/sftp.go @@ -25,7 +25,8 @@ import ( "io" "io/fs" "os" - "path" // SFTP requires Linux-style path separators + "os/user" + "path" // SFTP requires UNIX-style path separators "runtime" "strings" "time" @@ -75,6 +76,8 @@ type Config struct { type FileSystem interface { // Type returns whether the filesystem is "local" or "remote" Type() string + // Glob returns matching files of a glob pattern + Glob(ctx context.Context, pattern string) ([]string, error) // Stat returns info about a file Stat(ctx context.Context, path string) (os.FileInfo, error) // ReadDir returns information about files contained within a directory @@ -196,7 +199,7 @@ func (c *Config) initFS(sshClient *ssh.Client, client *sftp.Client) error { } if c.getHomeDir == nil { - c.getHomeDir = func() (_ string, err error) { + c.getHomeDir = func() (string, error) { return getRemoteHomeDir(sshClient) } } @@ -205,21 +208,34 @@ func (c *Config) initFS(sshClient *ssh.Client, client *sftp.Client) error { } func (c *Config) expandPaths(srcIsRemote, dstIsRemote bool) (err error) { + srcHomeRetriever := getLocalHomeDir if srcIsRemote { - for i, srcPath := range c.srcPaths { - c.srcPaths[i], err = expandPath(srcPath, c.getHomeDir) - if err != nil { - return trace.Wrap(err) - } + srcHomeRetriever = c.getHomeDir + } + for i, srcPath := range c.srcPaths { + c.srcPaths[i], err = expandPath(srcPath, srcHomeRetriever) + if err != nil { + return trace.Wrap(err) } } + + dstHomeRetriever := getLocalHomeDir if dstIsRemote { - c.dstPath, err = expandPath(c.dstPath, c.getHomeDir) + dstHomeRetriever = c.getHomeDir } + c.dstPath, err = expandPath(c.dstPath, dstHomeRetriever) return trace.Wrap(err) } +func getLocalHomeDir() (string, error) { + u, err := user.Current() + if err != nil { + return "", trace.Wrap(err) + } + return u.HomeDir, nil +} + func expandPath(pathStr string, getHomeDir homeDirRetriever) (string, error) { if !needsExpansion(pathStr) { return pathStr, nil @@ -273,9 +289,41 @@ func getRemoteHomeDir(sshClient *ssh.Client) (string, error) { return homeDirBuf.String(), nil } -// transfer preforms file transfers +// transfer performs file transfers func (c *Config) transfer(ctx context.Context) error { + // get info of source files and ensure appropriate options were passed + matchedPaths := make([]string, 0, len(c.srcPaths)) + fileInfos := make([]os.FileInfo, 0, len(c.srcPaths)) + for _, srcPath := range c.srcPaths { + matches, err := c.srcFS.Glob(ctx, srcPath) + if err != nil { + return trace.Wrap(err, "error matching glob pattern %q", srcPath) + } + // clean match paths to ensure they are separated by backslashes, as + // SFTP requires that + for i := range matches { + matches[i] = path.Clean(matches[i]) + } + matchedPaths = append(matchedPaths, matches...) + + for _, match := range matches { + fi, err := c.srcFS.Stat(ctx, match) + if err != nil { + return trace.Wrap(err, "could not access %s path %q", c.srcFS.Type(), match) + } + if fi.IsDir() && !c.opts.Recursive { + // Note: using any other error constructor than BadParameter + // might lead to relogin attempt and a completely obscure + // error message + return trace.BadParameter("%q is a directory, but the recursive option was not passed", match) + } + fileInfos = append(fileInfos, fi) + } + } + + // validate destination path and create it if necessary var dstIsDir bool + c.dstPath = path.Clean(c.dstPath) dstInfo, err := c.dstFS.Stat(ctx, c.dstPath) if err != nil { if !errors.Is(err, os.ErrNotExist) { @@ -283,7 +331,7 @@ func (c *Config) transfer(ctx context.Context) error { } // if there are multiple source paths and the destination path // doesn't exist, create it as a directory - if len(c.srcPaths) > 1 { + if len(matchedPaths) > 1 { if err := c.dstFS.Mkdir(ctx, c.dstPath); err != nil { return trace.Errorf("error creating %s directory %q: %w", c.dstFS.Type(), c.dstPath, err) } @@ -292,33 +340,24 @@ func (c *Config) transfer(ctx context.Context) error { } dstIsDir = true } - } else if len(c.srcPaths) > 1 && !dstInfo.IsDir() { + } else if len(matchedPaths) > 1 && !dstInfo.IsDir() { // if there are multiple source paths, ensure the destination path // is a directory - return trace.BadParameter("%s file %q is not a directory, but multiple source files were specified", - c.dstFS.Type(), - c.dstPath, - ) + if len(matchedPaths) != len(c.srcPaths) { + return trace.BadParameter("%s file %q is not a directory, but multiple source files were matched by a glob pattern", + c.dstFS.Type(), + c.dstPath, + ) + } else { + return trace.BadParameter("%s file %q is not a directory, but multiple source files were specified", + c.dstFS.Type(), + c.dstPath, + ) + } } else if dstInfo.IsDir() { dstIsDir = true } - // get info of source files and ensure appropriate options were passed - fileInfos := make([]os.FileInfo, len(c.srcPaths)) - for i := range c.srcPaths { - fi, err := c.srcFS.Stat(ctx, c.srcPaths[i]) - if err != nil { - return trace.Errorf("could not access %s path %q: %v", c.srcFS.Type(), c.srcPaths[i], err) - } - if fi.IsDir() && !c.opts.Recursive { - // Note: using any other error constructor (e.g. BadParameter) - // might lead to relogin attempt and a completely obscure - // error message - return trace.BadParameter("%q is a directory, but the recursive option was not passed", c.srcPaths[i]) - } - fileInfos[i] = fi - } - for i, fi := range fileInfos { dstPath := c.dstPath if dstIsDir || fi.IsDir() { @@ -326,11 +365,11 @@ func (c *Config) transfer(ctx context.Context) error { } if fi.IsDir() { - if err := c.transferDir(ctx, dstPath, c.srcPaths[i], fi); err != nil { + if err := c.transferDir(ctx, dstPath, matchedPaths[i], fi); err != nil { return trace.Wrap(err) } } else { - if err := c.transferFile(ctx, dstPath, c.srcPaths[i], fi); err != nil { + if err := c.transferFile(ctx, dstPath, matchedPaths[i], fi); err != nil { return trace.Wrap(err) } } @@ -421,7 +460,14 @@ func (c *Config) transferFile(ctx context.Context, dstPath, srcPath string, srcF ) } if n != srcFileInfo.Size() { - return trace.Errorf("short write: written %v, expected %v", n, srcFileInfo.Size()) + return trace.Errorf("error copying %s file %q to %s file %q: short write: wrote %d bytes, expected to write %d bytes", + c.srcFS.Type(), + srcPath, + c.dstFS.Type(), + dstPath, + n, + srcFileInfo.Size(), + ) } if c.opts.PreserveAttrs { diff --git a/lib/sshutils/sftp/sftp_test.go b/lib/sshutils/sftp/sftp_test.go index baaf2f434c861..64f0d1fcb57c1 100644 --- a/lib/sshutils/sftp/sftp_test.go +++ b/lib/sshutils/sftp/sftp_test.go @@ -44,12 +44,13 @@ func TestUpload(t *testing.T) { t.Parallel() tests := []struct { - name string - srcPaths []string - dstPath string - opts Options - files []string - expectedErr string + name string + srcPaths []string + globbedSrcPaths []string + dstPath string + opts Options + files []string + expectedErr string }{ { name: "one file", @@ -143,6 +144,132 @@ func TestUpload(t *testing.T) { "dst/", }, }, + { + name: "globbed files dst doesn't exist", + srcPaths: []string{ + "glob*", + }, + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + }, + dstPath: "dst/", + files: []string{ + "globS", + "globA", + "globT", + "globB", + }, + }, + { + name: "globbed files dst does exist", + srcPaths: []string{ + "glob*", + }, + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + }, + dstPath: "dst/", + files: []string{ + "globS", + "globA", + "globT", + "globB", + "dst/", + }, + }, + { + name: "multiple glob patterns", + srcPaths: []string{ + "glob*", + "*stuff", + }, + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + "mystuff", + "yourstuff", + }, + dstPath: "dst/", + files: []string{ + "globS", + "globA", + "globT", + "globB", + "mystuff", + "yourstuff", + "dst/", + }, + }, + { + name: "multiple glob patterns with normal path", + srcPaths: []string{ + "glob*", + "file", + "*stuff", + }, + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + "file", + "mystuff", + "yourstuff", + }, + dstPath: "dst/", + files: []string{ + "globS", + "globA", + "globT", + "globB", + "file", + "mystuff", + "yourstuff", + "dst/", + }, + }, + { + name: "recursive glob pattern with normal path", + srcPaths: []string{ + "glob*", + "file", + "*stuff", + }, + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + "globfile", + "file", + }, + dstPath: "dst/", + opts: Options{ + Recursive: true, + PreserveAttrs: true, + }, + files: []string{ + "globS/", + "globS/file", + "globA/", + "globA/file", + "globT/", + "globT/file", + "globB/", + "globB/file", + "globfile", + "file", + "dst/", + }, + }, { name: "multiple src dst not dir", srcPaths: []string{ @@ -159,6 +286,20 @@ func TestUpload(t *testing.T) { }, expectedErr: `local file "%s/dst_file" is not a directory, but multiple source files were specified`, }, + { + name: "multiple matches from src dst not dir", + srcPaths: []string{ + "glob*", + }, + dstPath: "dst_file", + files: []string{ + "glob1", + "glob2", + "glob3", + "dst_file", + }, + expectedErr: `local file "%s/dst_file" is not a directory, but multiple source files were matched by a glob pattern`, + }, { name: "src dir with recursive not passed", srcPaths: []string{ @@ -187,9 +328,11 @@ func TestUpload(t *testing.T) { for i := range tt.srcPaths { tt.srcPaths[i] = filepath.Join(tempDir, tt.srcPaths[i]) } + for i := range tt.globbedSrcPaths { + tt.globbedSrcPaths[i] = filepath.Join(tempDir, tt.globbedSrcPaths[i]) + } tt.dstPath = filepath.Join(tempDir, tt.dstPath) - ctx := context.Background() cfg, err := CreateUploadConfig(tt.srcPaths, tt.dstPath, tt.opts) require.NoError(t, err) // use all local filesystems to avoid SSH overhead @@ -197,10 +340,15 @@ func TestUpload(t *testing.T) { err = cfg.initFS(nil, nil) require.NoError(t, err) + ctx := context.Background() err = cfg.transfer(ctx) if tt.expectedErr == "" { require.NoError(t, err) - checkTransfer(t, tt.opts.PreserveAttrs, tt.dstPath, tt.srcPaths...) + srcPaths := tt.srcPaths + if len(tt.globbedSrcPaths) != 0 { + srcPaths = tt.globbedSrcPaths + } + checkTransfer(t, tt.opts.PreserveAttrs, tt.dstPath, srcPaths...) } else { require.EqualError(t, err, fmt.Sprintf(tt.expectedErr, tempDir)) } @@ -212,12 +360,13 @@ func TestDownload(t *testing.T) { t.Parallel() tests := []struct { - name string - srcPath string - dstPath string - opts Options - files []string - expectedErr string + name string + srcPath string + globbedSrcPaths []string + dstPath string + opts Options + files []string + expectedErr string }{ { name: "one file", @@ -240,6 +389,19 @@ func TestDownload(t *testing.T) { }, files: []string{ "src/", + "dst/", + }, + }, + { + name: "nested dirs", + srcPath: "s", + dstPath: "dst/", + opts: Options{ + PreserveAttrs: true, + Recursive: true, + }, + files: []string{ + "s/", "s/file", "s/r/", "s/r/file", @@ -248,6 +410,69 @@ func TestDownload(t *testing.T) { "dst/", }, }, + { + name: "globbed files dst doesn't exist", + srcPath: "glob*", + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + }, + dstPath: "dst/", + files: []string{ + "globS", + "globA", + "globT", + "globB", + }, + }, + { + name: "globbed files dst does exist", + srcPath: "glob*", + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + }, + dstPath: "dst/", + files: []string{ + "globS", + "globA", + "globT", + "globB", + "dst/", + }, + }, + { + name: "recursive glob pattern", + srcPath: "glob*", + globbedSrcPaths: []string{ + "globS", + "globA", + "globT", + "globB", + "globfile", + }, + dstPath: "dst/", + opts: Options{ + Recursive: true, + PreserveAttrs: true, + }, + files: []string{ + "globS/", + "globS/file", + "globA/", + "globA/file", + "globT/", + "globT/file", + "globB/", + "globB/file", + "globfile", + "dst/", + }, + }, { name: "src dir with recursive not passed", srcPath: "src/", @@ -272,9 +497,11 @@ func TestDownload(t *testing.T) { } } tt.srcPath = filepath.Join(tempDir, tt.srcPath) + for i := range tt.globbedSrcPaths { + tt.globbedSrcPaths[i] = filepath.Join(tempDir, tt.globbedSrcPaths[i]) + } tt.dstPath = filepath.Join(tempDir, tt.dstPath) - ctx := context.Background() cfg, err := CreateDownloadConfig(tt.srcPath, tt.dstPath, tt.opts) require.NoError(t, err) // use all local filesystems to avoid SSH overhead @@ -282,10 +509,15 @@ func TestDownload(t *testing.T) { err = cfg.initFS(nil, nil) require.NoError(t, err) + ctx := context.Background() err = cfg.transfer(ctx) if tt.expectedErr == "" { require.NoError(t, err) - checkTransfer(t, tt.opts.PreserveAttrs, tt.dstPath, tt.srcPath) + srcPaths := []string{tt.srcPath} + if len(tt.globbedSrcPaths) != 0 { + srcPaths = tt.globbedSrcPaths + } + checkTransfer(t, tt.opts.PreserveAttrs, tt.dstPath, srcPaths...) } else { require.EqualError(t, err, fmt.Sprintf(tt.expectedErr, tempDir)) } @@ -329,6 +561,28 @@ func TestHomeDirExpansion(t *testing.T) { } } +func TestCopyingSymlinkedFile(t *testing.T) { + tempDir := t.TempDir() + createFile(t, tempDir, "file") + linkPath := filepath.Join(tempDir, "link") + err := os.Symlink(filepath.Join(tempDir, "file"), linkPath) + require.NoError(t, err) + + dstPath := filepath.Join(tempDir, "dst") + cfg, err := CreateDownloadConfig(linkPath, dstPath, Options{}) + require.NoError(t, err) + // use all local filesystems to avoid SSH overhead + cfg.srcFS = &localFS{} + err = cfg.initFS(nil, nil) + require.NoError(t, err) + + ctx := context.Background() + err = cfg.transfer(ctx) + require.NoError(t, err) + + checkTransfer(t, false, dstPath, linkPath) +} + func createFile(t *testing.T, rootDir, path string) { dir := filepath.Dir(path) if dir != path {