Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lib/sshutils/sftp/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -404,10 +404,19 @@ func (c *Config) transfer(ctx context.Context) error {
matchedPaths := make([]string, 0, len(c.srcPaths))
fileInfos := make([]os.FileInfo, 0, len(c.srcPaths))
for _, srcPath := range c.srcPaths {
// This source path may or may not contain a glob pattern, but
// try and glob just in case. It is also possible the user
// specified a file path containing glob pattern characters but
// means the literal path without globbing, in which case we'll
// use the raw source path as the sole match below.
matches, err := c.srcFS.Glob(ctx, srcPath)
if err != nil {
return trace.Wrap(err, "error matching glob pattern %q", srcPath)
}
if len(matches) == 0 {
matches = []string{srcPath}
}

// clean match paths to ensure they are separated by backslashes, as
// SFTP requires that
for i := range matches {
Expand Down
46 changes: 35 additions & 11 deletions lib/sshutils/sftp/sftp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package sftp
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"math/rand"
Expand Down Expand Up @@ -54,7 +55,7 @@ func TestUpload(t *testing.T) {
dstPath string
opts Options
files []string
expectedErr string
errCheck require.ErrorAssertionFunc
}{
{
name: "one file",
Expand Down Expand Up @@ -245,7 +246,6 @@ func TestUpload(t *testing.T) {
srcPaths: []string{
"glob*",
"file",
"*stuff",
},
globbedSrcPaths: []string{
"globS",
Expand Down Expand Up @@ -288,7 +288,9 @@ func TestUpload(t *testing.T) {
"tres",
"dst_file",
},
expectedErr: `local file "%s/dst_file" is not a directory, but multiple source files were specified`,
errCheck: func(t require.TestingT, err error, i ...interface{}) {
require.EqualError(t, err, fmt.Sprintf(`local file "%s/dst_file" is not a directory, but multiple source files were specified`, i[0]))
},
},
{
name: "multiple matches from src dst not dir",
Expand All @@ -302,7 +304,9 @@ func TestUpload(t *testing.T) {
"glob3",
"dst_file",
},
expectedErr: `local file "%s/dst_file" is not a directory, but multiple source files were matched by a glob pattern`,
errCheck: func(t require.TestingT, err error, i ...interface{}) {
require.EqualError(t, err, fmt.Sprintf(`local file "%s/dst_file" is not a directory, but multiple source files were matched by a glob pattern`, i[0]))
},
},
{
name: "src dir with recursive not passed",
Expand All @@ -313,7 +317,18 @@ func TestUpload(t *testing.T) {
files: []string{
"src/",
},
expectedErr: `"%s/src" is a directory, but the recursive option was not passed`,
errCheck: func(t require.TestingT, err error, i ...interface{}) {
require.EqualError(t, err, fmt.Sprintf(`"%s/src" is a directory, but the recursive option was not passed`, i[0]))
},
},
{
name: "non-existent src file",
srcPaths: []string{
"idontexist",
},
errCheck: func(t require.TestingT, err error, i ...interface{}) {
require.True(t, errors.Is(err, os.ErrNotExist))
},
},
}

Expand Down Expand Up @@ -346,15 +361,15 @@ func TestUpload(t *testing.T) {

ctx := context.Background()
err = cfg.transfer(ctx)
if tt.expectedErr == "" {
if tt.errCheck == nil {
require.NoError(t, err)
srcPaths := tt.srcPaths
if len(tt.globbedSrcPaths) != 0 {
srcPaths = tt.globbedSrcPaths
}
checkTransfer(t, tt.opts.PreserveAttrs, tt.dstPath, srcPaths...)
} else {
require.EqualError(t, err, fmt.Sprintf(tt.expectedErr, tempDir))
tt.errCheck(t, err, tempDir)
}
})
}
Expand All @@ -370,7 +385,7 @@ func TestDownload(t *testing.T) {
dstPath string
opts Options
files []string
expectedErr string
errCheck require.ErrorAssertionFunc
}{
{
name: "one file",
Expand Down Expand Up @@ -484,7 +499,16 @@ func TestDownload(t *testing.T) {
files: []string{
"src/",
},
expectedErr: `"%s/src" is a directory, but the recursive option was not passed`,
errCheck: func(t require.TestingT, err error, i ...interface{}) {
require.EqualError(t, err, fmt.Sprintf(`"%s/src" is a directory, but the recursive option was not passed`, i[0]))
},
},
{
name: "non-existent src file",
srcPath: "idontexist",
errCheck: func(t require.TestingT, err error, i ...interface{}) {
require.True(t, errors.Is(err, os.ErrNotExist))
},
},
}

Expand Down Expand Up @@ -515,15 +539,15 @@ func TestDownload(t *testing.T) {

ctx := context.Background()
err = cfg.transfer(ctx)
if tt.expectedErr == "" {
if tt.errCheck == nil {
require.NoError(t, err)
srcPaths := []string{tt.srcPath}
if len(tt.globbedSrcPaths) != 0 {
srcPaths = tt.globbedSrcPaths
}
checkTransfer(t, tt.opts.PreserveAttrs, tt.dstPath, srcPaths...)
} else {
require.EqualError(t, err, fmt.Sprintf(tt.expectedErr, tempDir))
tt.errCheck(t, err, tempDir)
}
})
}
Expand Down