diff --git a/lib/client/api.go b/lib/client/api.go index b45dd46e19811..5dec11f87b662 100644 --- a/lib/client/api.go +++ b/lib/client/api.go @@ -76,7 +76,6 @@ import ( "github.com/gravitational/teleport/lib/shell" alpncommon "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/sshutils" - "github.com/gravitational/teleport/lib/sshutils/scp" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -1854,53 +1853,6 @@ func PlayFile(ctx context.Context, tarFile io.Reader, sid string) error { return playSession(sessionEvents, stream) } -// ExecuteSCP executes SCP command. It executes scp.Command using -// lower-level API integrations that mimic SCP CLI command behavior -func (tc *TeleportClient) ExecuteSCP(ctx context.Context, serverAddr string, cmd scp.Command) error { - ctx, span := tc.Tracer.Start( - ctx, - "teleportClient/ExecuteSCP", - oteltrace.WithSpanKind(oteltrace.SpanKindClient), - ) - defer span.End() - - // connect to proxy first: - if !tc.Config.ProxySpecified() { - return trace.BadParameter("proxy server is not specified") - } - - proxyClient, err := tc.ConnectToProxy(ctx) - if err != nil { - return trace.Wrap(err) - } - defer proxyClient.Close() - - nodeClient, err := tc.ConnectToNode( - ctx, - proxyClient, - // We append the ":0" to tell the server to figure out the port for us. - NodeDetails{Addr: serverAddr + ":0", Namespace: tc.Namespace, Cluster: tc.SiteName}, - tc.Config.HostLogin, - ) - if err != nil { - tc.ExitStatus = 1 - return trace.Wrap(err) - } - - err = nodeClient.ExecuteSCP(ctx, cmd) - if err != nil { - // converts SSH error code to tc.ExitStatus - exitError, _ := trace.Unwrap(err).(*ssh.ExitError) - if exitError != nil { - tc.ExitStatus = exitError.ExitStatus() - } - return err - - } - - return nil -} - // SFTP securely copies files between Nodes or SSH servers using SFTP func (tc *TeleportClient) SFTP(ctx context.Context, args []string, port int, opts sftp.Options, quiet bool) (err error) { ctx, span := tc.Tracer.Start( @@ -1958,7 +1910,7 @@ func (tc *TeleportClient) uploadConfig(args []string, port int, opts sftp.Option // copy everything except the last arg (the destination) dstPath := args[len(args)-1] - dst, addr, err := getSCPDestination(dstPath, port) + dst, addr, err := getSFTPDestination(dstPath, port) if err != nil { return nil, trace.Wrap(err) } @@ -1980,7 +1932,7 @@ func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Opti } // args are guaranteed to have len(args) > 1 - src, addr, err := getSCPDestination(args[0], port) + src, addr, err := getSFTPDestination(args[0], port) if err != nil { return nil, trace.Wrap(err) } @@ -1996,8 +1948,8 @@ func (tc *TeleportClient) downloadConfig(args []string, port int, opts sftp.Opti }, nil } -func getSCPDestination(target string, port int) (dest *scp.Destination, addr string, err error) { - dest, err = scp.ParseSCPDestination(target) +func getSFTPDestination(target string, port int) (dest *sftp.Destination, addr string, err error) { + dest, err = sftp.ParseDestination(target) if err != nil { return nil, "", trace.Wrap(err) } @@ -2034,7 +1986,11 @@ func (tc *TeleportClient) TransferFiles(ctx context.Context, hostLogin, nodeAddr client, err := tc.ConnectToNode( ctx, proxyClient, - NodeDetails{Addr: nodeAddr, Namespace: tc.Namespace, Cluster: tc.SiteName}, + NodeDetails{ + Addr: nodeAddr, + Namespace: tc.Namespace, + Cluster: tc.SiteName, + }, hostLogin, ) if err != nil { diff --git a/lib/client/client.go b/lib/client/client.go index ff2110e6fcc8d..b471c227d2c8a 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -21,11 +21,9 @@ import ( "context" "crypto/tls" "encoding/json" - "errors" "fmt" "io" "net" - "os" "strconv" "strings" "time" @@ -52,7 +50,6 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/sshutils" - "github.com/gravitational/teleport/lib/sshutils/scp" "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/socks" @@ -1711,97 +1708,6 @@ func (proxy *ProxyClient) Close() error { return trace.NewAggregate(proxy.Client.Close(), proxy.currentCluster.Close()) } -// ExecuteSCP runs remote scp command(shellCmd) on the remote server and -// runs local scp handler using SCP Command -func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error { - ctx, span := c.Tracer.Start( - ctx, - "nodeClient/ExecuteSCP", - oteltrace.WithSpanKind(oteltrace.SpanKindClient), - ) - defer span.End() - - shellCmd, err := cmd.GetRemoteShellCmd() - if err != nil { - return trace.Wrap(err) - } - - s, err := c.Client.NewSession(ctx) - if err != nil { - return trace.Wrap(err) - } - defer s.Close() - - stdin, err := s.StdinPipe() - if err != nil { - return trace.Wrap(err) - } - - stdout, err := s.StdoutPipe() - if err != nil { - return trace.Wrap(err) - } - - // Stream scp's stderr so tsh gets the verbose remote error - // if the command fails - stderr, err := s.StderrPipe() - if err != nil { - return trace.Wrap(err) - } - go io.Copy(os.Stderr, stderr) - - ch := utils.NewPipeNetConn( - stdout, - stdin, - utils.MultiCloser(), - &net.IPAddr{}, - &net.IPAddr{}, - ) - - execC := make(chan error, 1) - go func() { - err := cmd.Execute(ch) - if err != nil && !trace.IsEOF(err) { - log.WithError(err).Warn("Failed to execute SCP command.") - } - stdin.Close() - execC <- err - }() - - runC := make(chan error, 1) - go func() { - err := s.Run(ctx, shellCmd) - if err != nil && errors.Is(err, &ssh.ExitMissingError{}) { - // TODO(dmitri): currently, if the session is aborted with (*session).Close, - // the remote side cannot send exit-status and this error results. - // To abort the session properly, Teleport needs to support `signal` request - err = nil - } - runC <- err - }() - - var runErr error - select { - case <-ctx.Done(): - if err := s.Close(); err != nil { - log.WithError(err).Debug("Failed to close the SSH session.") - } - err, runErr = <-execC, <-runC - case err = <-execC: - runErr = <-runC - case runErr = <-runC: - err = <-execC - } - - if runErr != nil && (err == nil || trace.IsEOF(err)) { - err = runErr - } - if trace.IsEOF(err) { - err = nil - } - return trace.Wrap(err) -} - // TransferFiles transfers files over SFTP. func (c *NodeClient) TransferFiles(ctx context.Context, cfg *sftp.Config) error { ctx, span := c.Tracer.Start( diff --git a/lib/sshutils/scp/http.go b/lib/sshutils/scp/http.go deleted file mode 100644 index 1418659d0ae27..0000000000000 --- a/lib/sshutils/scp/http.go +++ /dev/null @@ -1,271 +0,0 @@ -/* -Copyright 2018 Gravitational, Inc. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package scp - -import ( - "fmt" - "io" - "net/http" - "net/url" - "os" - "path/filepath" - "strconv" - "time" - - "github.com/gravitational/trace" - - "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/httplib" -) - -const ( - // 644 means that files are readable and writeable by the owner of - // the file and readable by users in the group owner of that file - // and readable by everyone else. - httpUploadFileMode = 0o644 -) - -// HTTPTransferRequest describes HTTP file transfer request -type HTTPTransferRequest struct { - // RemoteLocation is a destination location of the file - RemoteLocation string - // FileName is a file name - FileName string - // HTTPRequest is HTTP request - HTTPRequest *http.Request - // HTTPRequest is HTTP request - HTTPResponse http.ResponseWriter - // Progress is a writer for printing the progress - Progress io.Writer - // User is a username - User string - // AuditLog is AuditLog log - AuditLog events.AuditLogSessionStreamer -} - -func (r *HTTPTransferRequest) parseRemoteLocation() (string, string, error) { - dir, filename := filepath.Split(r.RemoteLocation) - if filename == "" { - return "", "", trace.BadParameter("failed to parse file remote location: %q", r.RemoteLocation) - } - - return dir, filename, nil -} - -// CreateHTTPUpload creates an HTTP upload command -func CreateHTTPUpload(req HTTPTransferRequest) (Command, error) { - if req.HTTPRequest == nil { - return nil, trace.BadParameter("missing parameter HTTPRequest") - } - - if req.FileName == "" { - return nil, trace.BadParameter("missing file name") - } - - if req.RemoteLocation == "" { - return nil, trace.BadParameter("missing remote location") - } - - contentLength := req.HTTPRequest.Header.Get("Content-Length") - fileSize, err := strconv.ParseInt(contentLength, 10, 0) - if err != nil { - return nil, trace.BadParameter("failed to parse Content-Length header: %q", contentLength) - } - - fs := &httpFileSystem{ - reader: req.HTTPRequest.Body, - fileName: req.FileName, - fileSize: fileSize, - } - - flags := Flags{ - // scp treats it as a list of files to upload - Target: []string{req.FileName}, - } - - cfg := Config{ - Flags: flags, - FileSystem: fs, - User: req.User, - ProgressWriter: req.Progress, - RemoteLocation: req.RemoteLocation, - AuditLog: req.AuditLog, - } - - cmd, err := CreateUploadCommand(cfg) - if err != nil { - return nil, trace.Wrap(err) - } - - return cmd, nil -} - -// CreateHTTPDownload creates an HTTP download command -func CreateHTTPDownload(req HTTPTransferRequest) (Command, error) { - _, filename, err := req.parseRemoteLocation() - if err != nil { - return nil, trace.Wrap(err) - } - - flags := Flags{ - Target: []string{filename}, - } - - cfg := Config{ - Flags: flags, - User: req.User, - ProgressWriter: req.Progress, - RemoteLocation: req.RemoteLocation, - FileSystem: &httpFileSystem{ - writer: req.HTTPResponse, - }, - } - - cmd, err := CreateDownloadCommand(cfg) - if err != nil { - return nil, trace.Wrap(err) - } - - return cmd, nil -} - -// httpFileSystem simulates file system calls while using HTTP response/request streams. -type httpFileSystem struct { - writer http.ResponseWriter - reader io.ReadCloser - fileName string - fileSize int64 -} - -// Chmod sets file permissions. It does nothing as there are no permissions -// while processing HTTP downloads -func (l *httpFileSystem) Chmod(path string, mode int) error { - return nil -} - -// Chtimes sets file access and modification time. -// It is a no-op for the HTTP file system implementation -func (l *httpFileSystem) Chtimes(path string, atime, mtime time.Time) error { - return nil -} - -// MkDir creates a directory. This method is not implemented as creating directories -// is not supported during HTTP downloads. -func (l *httpFileSystem) MkDir(path string, mode int) error { - return trace.BadParameter("directories are not supported in http file transfer") -} - -// IsDir tells if this file is a directory. It always returns false as -// directories are not supported in HTTP file transfer -func (l *httpFileSystem) IsDir(path string) bool { - return false -} - -// OpenFile returns file reader -func (l *httpFileSystem) OpenFile(filePath string) (io.ReadCloser, error) { - if l.reader == nil { - return nil, trace.BadParameter("missing reader") - } - - return l.reader, nil -} - -// CreateFile sets proper HTTP headers and returns HTTP writer to stream incoming -// file content -func (l *httpFileSystem) CreateFile(filePath string, length uint64) (io.WriteCloser, error) { - _, filename := filepath.Split(filePath) - contentLength := strconv.FormatUint(length, 10) - header := l.writer.Header() - - httplib.SetNoCacheHeaders(header) - httplib.SetDefaultSecurityHeaders(header) - header.Set("Content-Length", contentLength) - header.Set("Content-Type", "application/octet-stream") - filename = url.QueryEscape(filename) - header.Set("Content-Disposition", fmt.Sprintf(`attachment;filename="%v"`, filename)) - - return &nopWriteCloser{Writer: l.writer}, nil -} - -// GetFileInfo returns file information -func (l *httpFileSystem) GetFileInfo(filePath string) (FileInfo, error) { - return &httpFileInfo{ - name: l.fileName, - path: l.fileName, - size: l.fileSize, - }, nil -} - -// httpFileInfo is implementation of FileInfo interface used during HTTP -// file transfer -type httpFileInfo struct { - path string - name string - size int64 -} - -// IsDir tells if this file in a directory -func (l *httpFileInfo) IsDir() bool { - return false -} - -// GetName returns file name -func (l *httpFileInfo) GetName() string { - return l.name -} - -// GetPath returns file path -func (l *httpFileInfo) GetPath() string { - return l.path -} - -// GetSize returns file size -func (l *httpFileInfo) GetSize() int64 { - return l.size -} - -// ReadDir returns an slice of files in the directory. -// This method is not supported in HTTP file transfer -func (l *httpFileInfo) ReadDir() ([]FileInfo, error) { - return nil, trace.BadParameter("directories are not supported in http file transfer") -} - -// GetModePerm returns file permissions that will be set on the -// file created on the remote host during HTTP upload. -func (l *httpFileInfo) GetModePerm() os.FileMode { - return httpUploadFileMode -} - -// GetModTime returns file modification time. -// It is a no-op for HTTP file information -func (l *httpFileInfo) GetModTime() time.Time { - return time.Time{} -} - -// GetAccessTime returns file last access time. -// It is a no-op for HTTP file information -func (l *httpFileInfo) GetAccessTime() time.Time { - return time.Time{} -} - -type nopWriteCloser struct { - io.Writer -} - -func (wr *nopWriteCloser) Close() error { - return nil -} diff --git a/lib/sshutils/scp/scp.go b/lib/sshutils/scp/scp.go index 14170f1796fe6..42ef2998e1722 100644 --- a/lib/sshutils/scp/scp.go +++ b/lib/sshutils/scp/scp.go @@ -28,7 +28,6 @@ import ( "io" "os" "path/filepath" - "regexp" "strconv" "strings" "time" @@ -779,58 +778,3 @@ func (r *reader) read() error { } return trace.BadParameter("unrecognized command: %v", r.b) } - -var reSCP = regexp.MustCompile( - // optional username, note that outside group - // is a non-capturing as it includes @ signs we don't want - `(?:(?P.+)@)?` + - // either some stuff in brackets - [ipv6] - // or some stuff without brackets and colons - `(?P` + - // this says: [stuff in brackets that is not brackets] - loose definition of the IP address - `(?:\[[^@\[\]]+\])` + - // or - `|` + - // some stuff without brackets or colons to make sure the OR condition - // is not ambiguous - `(?:[^@\[\:\]]+)` + - `)` + - // after colon, there is a path that could consist technically of - // any char including empty which stands for the implicit home directory - `:(?P.*)`, -) - -// Destination is SCP destination to copy to or from -type Destination struct { - // Login is an optional login username - Login string - // Host is a host to copy to/from - Host utils.NetAddr - // Path is a path to copy to/from. - // An empty path name is valid, and it refers to the user's default directory (usually - // the user's home directory). - // See https://tools.ietf.org/html/draft-ietf-secsh-filexfer-09#page-14, 'File Names' - Path string -} - -// ParseSCPDestination takes a string representing a remote resource for SCP -// to download/upload, like "user@host:/path/to/resource.txt" and parses it into -// a structured form. -// -// See https://tools.ietf.org/html/draft-ietf-secsh-filexfer-09#page-14, 'File Names' -// section about details on file names. -func ParseSCPDestination(s string) (*Destination, error) { - out := reSCP.FindStringSubmatch(s) - if len(out) < 4 { - return nil, trace.BadParameter("failed to parse %q, try form user@host:/path", s) - } - addr, err := utils.ParseAddr(out[2]) - if err != nil { - return nil, trace.Wrap(err) - } - path := out[3] - if path == "" { - path = "." - } - return &Destination{Login: out[1], Host: *addr, Path: path}, nil -} diff --git a/lib/sshutils/scp/scp_test.go b/lib/sshutils/scp/scp_test.go index 4ee27fd7acc4e..93034eb91acd7 100644 --- a/lib/sshutils/scp/scp_test.go +++ b/lib/sshutils/scp/scp_test.go @@ -19,12 +19,9 @@ import ( "bytes" "fmt" "io" - "net/http" - "net/http/httptest" "os" "os/exec" "path/filepath" - "strconv" "testing" "time" @@ -41,63 +38,6 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func TestHTTPSendFile(t *testing.T) { - outDir := t.TempDir() - - expectedBytes := []byte("hello") - buf := bytes.NewReader(expectedBytes) - req, err := http.NewRequest("POST", "/", buf) - require.NoError(t, err) - - req.Header.Set("Content-Length", strconv.Itoa(len(expectedBytes))) - - stdOut := bytes.NewBufferString("") - cmd, err := CreateHTTPUpload( - HTTPTransferRequest{ - FileName: "filename", - RemoteLocation: outDir, - HTTPRequest: req, - Progress: stdOut, - User: "test-user", - }) - require.NoError(t, err) - err = runSCP(cmd, "-v", "-t", outDir) - require.NoError(t, err) - bytesReceived, err := os.ReadFile(filepath.Join(outDir, "filename")) - require.NoError(t, err) - require.Empty(t, cmp.Diff(string(bytesReceived), string(expectedBytes))) -} - -func TestHTTPReceiveFile(t *testing.T) { - source := filepath.Join(t.TempDir(), "target") - - contents := []byte("hello, file contents!") - err := os.WriteFile(source, contents, 0666) - require.NoError(t, err) - - w := httptest.NewRecorder() - stdOut := bytes.NewBufferString("") - cmd, err := CreateHTTPDownload( - HTTPTransferRequest{ - RemoteLocation: "/home/robots.txt", - HTTPResponse: w, - User: "test-user", - Progress: stdOut, - }) - require.NoError(t, err) - - err = runSCP(cmd, "-v", "-f", source) - require.NoError(t, err) - - data, err := io.ReadAll(w.Body) - contentLengthStr := strconv.Itoa(len(data)) - require.NoError(t, err) - require.Empty(t, cmp.Diff(string(data), string(contents))) - require.Empty(t, cmp.Diff(contentLengthStr, w.Header().Get("Content-Length"))) - require.Empty(t, cmp.Diff("application/octet-stream", w.Header().Get("Content-Type"))) - require.Empty(t, cmp.Diff(`attachment;filename="robots.txt"`, w.Header().Get("Content-Disposition"))) -} - func TestSend(t *testing.T) { t.Parallel() modtime := testNow @@ -105,7 +45,7 @@ func TestSend(t *testing.T) { dirModtime := testNow.Add(2 * time.Second) dirAtime := testNow.Add(3 * time.Second) logger := logrus.WithField(trace.Component, "t:send") - var testCases = []struct { + testCases := []struct { desc string config Config fs *testFS @@ -169,7 +109,7 @@ func TestReceive(t *testing.T) { dirModtime := testNow.Add(2 * time.Second) dirAtime := testNow.Add(3 * time.Second) logger := logrus.WithField(trace.Component, "t:recv") - var testCases = []struct { + testCases := []struct { desc string config Config source string @@ -450,7 +390,7 @@ func TestVerifyDirectoryModeFailsWithFile(t *testing.T) { // Create temporary directory with a file "target" in it. dir := t.TempDir() target := filepath.Join(dir, "target") - err := os.WriteFile(target, []byte{}, 0666) + err := os.WriteFile(target, []byte{}, 0o666) require.NoError(t, err) cmd, err := CreateCommand( @@ -476,7 +416,7 @@ func TestVerifyDirectoryModeIsRequiredForDirectory(t *testing.T) { // Create temporary directory with a file "target" in it. dir := t.TempDir() target := filepath.Join(dir, "target") - err := os.WriteFile(target, []byte{}, 0666) + err := os.WriteFile(target, []byte{}, 0o666) require.NoError(t, err) cmd, err := CreateCommand( @@ -496,66 +436,6 @@ func TestVerifyDirectoryModeIsRequiredForDirectory(t *testing.T) { require.Regexp(t, fmt.Sprintf("%s is a directory, use -r flag to copy recursively", filepath.Base(dir)), err) } -func TestSCPParsing(t *testing.T) { - t.Parallel() - - var testCases = []struct { - comment string - in string - dest Destination - err error - }{ - { - comment: "full spec of the remote destination", - in: "root@remote.host:/etc/nginx.conf", - dest: Destination{Login: "root", Host: utils.NetAddr{Addr: "remote.host", AddrNetwork: "tcp"}, Path: "/etc/nginx.conf"}, - }, - { - comment: "spec with just the remote host", - in: "remote.host:/etc/nginx.co:nf", - dest: Destination{Host: utils.NetAddr{Addr: "remote.host", AddrNetwork: "tcp"}, Path: "/etc/nginx.co:nf"}, - }, - { - comment: "ipv6 remote destination address", - in: "[::1]:/etc/nginx.co:nf", - dest: Destination{Host: utils.NetAddr{Addr: "[::1]", AddrNetwork: "tcp"}, Path: "/etc/nginx.co:nf"}, - }, - { - comment: "full spec of the remote destination using ipv4 address", - in: "root@123.123.123.123:/var/www/html/", - dest: Destination{Login: "root", Host: utils.NetAddr{Addr: "123.123.123.123", AddrNetwork: "tcp"}, Path: "/var/www/html/"}, - }, - { - comment: "target location using wildcard", - in: "myusername@myremotehost.com:/home/hope/*", - dest: Destination{Login: "myusername", Host: utils.NetAddr{Addr: "myremotehost.com", AddrNetwork: "tcp"}, Path: "/home/hope/*"}, - }, - { - comment: "complex login", - in: "complex@example.com@remote.com:/anything.txt", - dest: Destination{Login: "complex@example.com", Host: utils.NetAddr{Addr: "remote.com", AddrNetwork: "tcp"}, Path: "/anything.txt"}, - }, - { - comment: "implicit user's home directory", - in: "root@remote.host:", - dest: Destination{Login: "root", Host: utils.NetAddr{Addr: "remote.host", AddrNetwork: "tcp"}, Path: "."}, - }, - } - for _, tt := range testCases { - tt := tt - t.Run(tt.comment, func(t *testing.T) { - resp, err := ParseSCPDestination(tt.in) - if tt.err != nil { - require.IsType(t, err, tt.err) - return - } - require.NoError(t, err) - require.Empty(t, cmp.Diff(resp, &tt.dest)) - }) - - } -} - func runSCP(cmd Command, args ...string) error { scp, stdin, stdout, _ := newCmd("scp", args...) rw := &readWriter{r: stdout, w: stdin} @@ -811,7 +691,7 @@ func (r *testFS) CreateFile(path string, length uint64) (io.WriteCloser, error) fi := &testFileInfo{ path: path, size: int64(length), - perms: 0666, + perms: 0o666, contents: new(bytes.Buffer), } r.fs[path] = fi @@ -924,14 +804,14 @@ func newDir(name string, ents ...*testFileInfo) *testFileInfo { path: name, ents: ents, dir: true, - perms: 0755, + perms: 0o755, } } func newFile(name string, contents string) *testFileInfo { return &testFileInfo{ path: name, - perms: 0666, + perms: 0o666, size: int64(len(contents)), contents: bytes.NewBufferString(contents), } @@ -944,7 +824,7 @@ func newDirTimes(name string, modtime, atime time.Time, ents ...*testFileInfo) * modtime: modtime, atime: atime, dir: true, - perms: 0755, + perms: 0o755, } } @@ -953,7 +833,7 @@ func newFileTimes(name string, modtime, atime time.Time, contents string) *testF path: name, modtime: modtime, atime: atime, - perms: 0666, + perms: 0o666, size: int64(len(contents)), contents: bytes.NewBufferString(contents), } diff --git a/lib/sshutils/sftp/http.go b/lib/sshutils/sftp/http.go new file mode 100644 index 0000000000000..ae2c37904f88a --- /dev/null +++ b/lib/sshutils/sftp/http.go @@ -0,0 +1,181 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sftp + +import ( + "context" + "fmt" + "io" + "io/fs" + "net/http" + "net/url" + "os" + "path" + "strconv" + "time" + + "github.com/gravitational/teleport/lib/httplib" + "github.com/gravitational/trace" +) + +var errDirsNotSupported = trace.BadParameter("directories are not supported when transferring files over HTTP") + +// httpFS provides API for accessing the a file over HTTP. +type httpFS struct { + reader io.ReadCloser + writer http.ResponseWriter + + fileName string + fileSize int64 +} + +func (h *httpFS) Type() string { + return "local" +} + +func (h *httpFS) Glob(ctx context.Context, _ string) ([]string, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + return []string{h.fileName}, nil +} + +func (h *httpFS) Stat(ctx context.Context, _ string) (fs.FileInfo, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + return httpFileInfo{ + name: path.Base(h.fileName), + size: h.fileSize, + }, nil +} + +func (h *httpFS) ReadDir(_ context.Context, _ string) ([]fs.FileInfo, error) { + return nil, errDirsNotSupported +} + +func (h *httpFS) Open(ctx context.Context, path string) (fs.File, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + if h.reader == nil { + return nil, trace.BadParameter("missing reader") + } + + return &httpFile{ + reader: h.reader, + fileInfo: httpFileInfo{ + name: h.fileName, + size: h.fileSize, + }, + }, nil +} + +func (h *httpFS) Create(ctx context.Context, p string, size int64) (io.WriteCloser, error) { + if err := ctx.Err(); err != nil { + return nil, err + } + + filename := path.Base(p) + contentLength := strconv.FormatInt(size, 10) + header := h.writer.Header() + + httplib.SetNoCacheHeaders(header) + httplib.SetDefaultSecurityHeaders(header) + header.Set("Content-Length", contentLength) + header.Set("Content-Type", "application/octet-stream") + filename = url.QueryEscape(filename) + header.Set("Content-Disposition", fmt.Sprintf(`attachment;filename="%s"`, filename)) + + return &nopWriteCloser{ + Writer: h.writer, + }, nil +} + +func (h *httpFS) Mkdir(_ context.Context, _ string) error { + return errDirsNotSupported +} + +func (h *httpFS) Chmod(_ context.Context, _ string, _ os.FileMode) error { + return nil +} + +func (h *httpFS) Chtimes(_ context.Context, _ string, _, _ time.Time) error { + return nil +} + +type nopWriteCloser struct { + io.Writer +} + +func (wr *nopWriteCloser) Close() error { + return nil +} + +// httpFile implements [fs.File]. +type httpFile struct { + reader io.ReadCloser + fileInfo httpFileInfo +} + +func (h *httpFile) Read(p []byte) (int, error) { + return h.reader.Read(p) +} + +func (h *httpFile) Stat() (fs.FileInfo, error) { + return h.fileInfo, nil +} + +func (h *httpFile) Close() error { + return h.reader.Close() +} + +// httpFileInfo is a simple implementation of [fs.FileMode] that only +// knows its file's name and size. +type httpFileInfo struct { + name string + size int64 +} + +func (h httpFileInfo) Name() string { + return h.name +} + +func (h httpFileInfo) Size() int64 { + return h.size +} + +func (h httpFileInfo) Mode() fs.FileMode { + // return sensible default file permissions so when uploading files + // the created destination file will have sensible permissions set + return 0o644 +} + +func (h httpFileInfo) ModTime() time.Time { + return time.Time{} +} + +func (h httpFileInfo) IsDir() bool { + return false +} + +func (h httpFileInfo) Sys() any { + return nil +} diff --git a/lib/sshutils/sftp/local.go b/lib/sshutils/sftp/local.go index 2f660cf50b8bf..6d142551fe030 100644 --- a/lib/sshutils/sftp/local.go +++ b/lib/sshutils/sftp/local.go @@ -29,9 +29,16 @@ import ( // localFS provides API for accessing the files on // the local file system -type localFS struct{} +type localFS struct { + partOfHTTPTransfer bool +} func (l localFS) Type() string { + // if this localFS is being used in a HTTP file transfer, from the + // user's perspective this filesystem is remote + if l.partOfHTTPTransfer { + return "remote" + } return "local" } @@ -103,7 +110,7 @@ func (l localFS) Open(ctx context.Context, p string) (fs.File, error) { return &fileWrapper{file: f}, nil } -func (l localFS) Create(ctx context.Context, p string) (io.WriteCloser, error) { +func (l localFS) Create(ctx context.Context, p string, _ int64) (io.WriteCloser, error) { if err := ctx.Err(); err != nil { return nil, err } diff --git a/lib/sshutils/sftp/parse.go b/lib/sshutils/sftp/parse.go new file mode 100644 index 0000000000000..fadfc267c677f --- /dev/null +++ b/lib/sshutils/sftp/parse.go @@ -0,0 +1,79 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sftp + +import ( + "regexp" + + "github.com/gravitational/teleport/lib/utils" + "github.com/gravitational/trace" +) + +var reSFTP = regexp.MustCompile( + // optional username, note that outside group + // is a non-capturing as it includes @ signs we don't want + `(?:(?P.+)@)?` + + // either some stuff in brackets - [ipv6] + // or some stuff without brackets and colons + `(?P` + + // this says: [stuff in brackets that is not brackets] - loose definition of the IP address + `(?:\[[^@\[\]]+\])` + + // or + `|` + + // some stuff without brackets or colons to make sure the OR condition + // is not ambiguous + `(?:[^@\[\:\]]+)` + + `)` + + // after colon, there is a path that could consist technically of + // any char including empty which stands for the implicit home directory + `:(?P.*)`, +) + +// Destination is SCP destination to copy to or from +type Destination struct { + // Login is an optional login username + Login string + // Host is a host to copy to/from + Host utils.NetAddr + // Path is a path to copy to/from. + // An empty path name is valid, and it refers to the user's default directory (usually + // the user's home directory). + // See https://tools.ietf.org/html/draft-ietf-secsh-filexfer-09#page-14, 'File Names' + Path string +} + +// ParseSCPDestination takes a string representing a remote resource for SFTP +// to download/upload, like "user@host:/path/to/resource.txt" and parses it into +// a structured form. +// +// See https://tools.ietf.org/html/draft-ietf-secsh-filexfer-09#page-14, 'File Names' +// section about details on file names. +func ParseDestination(s string) (*Destination, error) { + out := reSFTP.FindStringSubmatch(s) + if len(out) < 4 { + return nil, trace.BadParameter("failed to parse %q, try form user@host:/path", s) + } + addr, err := utils.ParseAddr(out[2]) + if err != nil { + return nil, trace.Wrap(err) + } + path := out[3] + if path == "" { + path = "." + } + return &Destination{Login: out[1], Host: *addr, Path: path}, nil +} diff --git a/lib/sshutils/sftp/parse_test.go b/lib/sshutils/sftp/parse_test.go new file mode 100644 index 0000000000000..585612f171389 --- /dev/null +++ b/lib/sshutils/sftp/parse_test.go @@ -0,0 +1,85 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package sftp + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/lib/utils" +) + +func TestDestinationParsing(t *testing.T) { + t.Parallel() + + testCases := []struct { + comment string + in string + dest Destination + err error + }{ + { + comment: "full spec of the remote destination", + in: "root@remote.host:/etc/nginx.conf", + dest: Destination{Login: "root", Host: utils.NetAddr{Addr: "remote.host", AddrNetwork: "tcp"}, Path: "/etc/nginx.conf"}, + }, + { + comment: "spec with just the remote host", + in: "remote.host:/etc/nginx.co:nf", + dest: Destination{Host: utils.NetAddr{Addr: "remote.host", AddrNetwork: "tcp"}, Path: "/etc/nginx.co:nf"}, + }, + { + comment: "ipv6 remote destination address", + in: "[::1]:/etc/nginx.co:nf", + dest: Destination{Host: utils.NetAddr{Addr: "[::1]", AddrNetwork: "tcp"}, Path: "/etc/nginx.co:nf"}, + }, + { + comment: "full spec of the remote destination using ipv4 address", + in: "root@123.123.123.123:/var/www/html/", + dest: Destination{Login: "root", Host: utils.NetAddr{Addr: "123.123.123.123", AddrNetwork: "tcp"}, Path: "/var/www/html/"}, + }, + { + comment: "target location using wildcard", + in: "myusername@myremotehost.com:/home/hope/*", + dest: Destination{Login: "myusername", Host: utils.NetAddr{Addr: "myremotehost.com", AddrNetwork: "tcp"}, Path: "/home/hope/*"}, + }, + { + comment: "complex login", + in: "complex@example.com@remote.com:/anything.txt", + dest: Destination{Login: "complex@example.com", Host: utils.NetAddr{Addr: "remote.com", AddrNetwork: "tcp"}, Path: "/anything.txt"}, + }, + { + comment: "implicit user's home directory", + in: "root@remote.host:", + dest: Destination{Login: "root", Host: utils.NetAddr{Addr: "remote.host", AddrNetwork: "tcp"}, Path: "."}, + }, + } + for _, tt := range testCases { + tt := tt + t.Run(tt.comment, func(t *testing.T) { + resp, err := ParseDestination(tt.in) + if tt.err != nil { + require.IsType(t, err, tt.err) + return + } + require.NoError(t, err) + require.Empty(t, cmp.Diff(resp, &tt.dest)) + }) + } +} diff --git a/lib/sshutils/sftp/remote.go b/lib/sshutils/sftp/remote.go index 5287c22dc6e94..1fe9207282674 100644 --- a/lib/sshutils/sftp/remote.go +++ b/lib/sshutils/sftp/remote.go @@ -99,7 +99,7 @@ func (r *remoteFS) Open(ctx context.Context, p string) (fs.File, error) { return f, nil } -func (r *remoteFS) Create(ctx context.Context, p string) (io.WriteCloser, error) { +func (r *remoteFS) Create(ctx context.Context, p string, _ int64) (io.WriteCloser, error) { if err := ctx.Err(); err != nil { return nil, err } diff --git a/lib/sshutils/sftp/sftp.go b/lib/sshutils/sftp/sftp.go index d7a5da3414594..e16e942d35d2e 100644 --- a/lib/sshutils/sftp/sftp.go +++ b/lib/sshutils/sftp/sftp.go @@ -14,7 +14,7 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package sftp handles file transfers client-side via SFTP +// Package sftp handles file transfers client-side via SFTP or HTTP. package sftp import ( @@ -24,13 +24,16 @@ import ( "fmt" "io" "io/fs" + "net/http" "os" "os/user" "path" // SFTP requires UNIX-style path separators "runtime" + "strconv" "strings" "time" + "github.com/gravitational/teleport/lib/sshutils/scp" "github.com/gravitational/trace" "github.com/pkg/sftp" "github.com/schollz/progressbar/v3" @@ -38,7 +41,6 @@ import ( "golang.org/x/crypto/ssh" "github.com/gravitational/teleport" - "github.com/gravitational/teleport/lib/sshutils/scp" ) // Options control aspects of a file transfer @@ -63,6 +65,9 @@ type Config struct { // getHomeDir returns the home directory of the remote user of the // SSH session getHomeDir homeDirRetriever + // isHttpTransfer denotes whether this is a file transfer over + // HTTP instead of SFTP + isHttpTransfer bool // ProgressStream is a callback to return a read/writer for printing the progress // (used only on the client) @@ -84,7 +89,7 @@ type FileSystem interface { // Open opens a file Open(ctx context.Context, path string) (fs.File, error) // Create creates a new file - Create(ctx context.Context, path string) (io.WriteCloser, error) + Create(ctx context.Context, path string, size int64) (io.WriteCloser, error) // Mkdir creates a directory Mkdir(ctx context.Context, path string) error // Chmod sets file permissions @@ -93,7 +98,7 @@ type FileSystem interface { Chtimes(ctx context.Context, path string, atime, mtime time.Time) error } -// CreateUploadConfig returns a Config ready to upload files +// CreateUploadConfig returns a Config ready to upload files over SFTP. func CreateUploadConfig(src []string, dst string, opts Options) (*Config, error) { for _, srcPath := range src { if srcPath == "" { @@ -116,7 +121,7 @@ func CreateUploadConfig(src []string, dst string, opts Options) (*Config, error) return c, nil } -// CreateDownloadConfig returns a Config ready to download files +// CreateDownloadConfig returns a Config ready to download files over SFTP. func CreateDownloadConfig(src, dst string, opts Options) (*Config, error) { if src == "" { return nil, trace.BadParameter("source path is empty") @@ -137,6 +142,87 @@ func CreateDownloadConfig(src, dst string, opts Options) (*Config, error) { return c, nil } +// HTTPTransferRequest describes file transfer request over HTTP. +type HTTPTransferRequest struct { + // Dst is the source file name + Src string + // Dst is the destination file name + Dst string + // HTTPRequest is where the source file will be read from for + // file upload transfers + HTTPRequest *http.Request + // HTTPResponse is where the destination file will be written to for + // file download transfers + HTTPResponse http.ResponseWriter +} + +// CreateHTTPUploadConfig returns a Config ready to upload a file over +// HTTP. +func CreateHTTPUploadConfig(req HTTPTransferRequest) (*Config, error) { + if req.Src == "" { + return nil, trace.BadParameter("source path is empty") + } + if req.Dst == "" { + return nil, trace.BadParameter("destination path is empty") + } + if req.HTTPRequest == nil { + return nil, trace.BadParameter("HTTP request is empty") + } + + contentLength := req.HTTPRequest.Header.Get("Content-Length") + fileSize, err := strconv.ParseInt(contentLength, 10, 0) + if err != nil { + return nil, trace.Errorf("failed to parse Content-Length header: %w", err) + } + + c := &Config{ + srcPaths: []string{req.Src}, + dstPath: req.Dst, + srcFS: &httpFS{ + reader: req.HTTPRequest.Body, + fileName: req.Src, + fileSize: fileSize, + }, + dstFS: &localFS{ + partOfHTTPTransfer: true, + }, + isHttpTransfer: true, + } + c.setDefaults() + + return c, nil +} + +// CreateHTTPDownloadConfig returns a Config ready to download a file +// over HTTP. +func CreateHTTPDownloadConfig(req HTTPTransferRequest) (*Config, error) { + if req.Src == "" { + return nil, trace.BadParameter("source path is empty") + } + if req.Dst == "" { + return nil, trace.BadParameter("destination path is empty") + } + if req.HTTPResponse == nil { + return nil, trace.BadParameter("HTTP response is empty") + } + + c := &Config{ + srcPaths: []string{req.Src}, + dstPath: req.Dst, + srcFS: &localFS{ + partOfHTTPTransfer: true, + }, + dstFS: &httpFS{ + writer: req.HTTPResponse, + fileName: req.Dst, + }, + isHttpTransfer: true, + } + c.setDefaults() + + return c, nil +} + // setDefaults sets default values func (c *Config) setDefaults() { logger := c.Log @@ -155,8 +241,17 @@ func (c *Config) setDefaults() { } // TransferFiles transfers files from the configured source paths to the -// configured destination path over SFTP +// configured destination path over SFTP or HTTP depending on the Config. func (c *Config) TransferFiles(ctx context.Context, sshClient *ssh.Client) error { + // if this is a transfer over HTTP, no SFTP client is needed + if c.isHttpTransfer { + err := c.expandPaths(false, false) + if err != nil { + return trace.Wrap(err) + } + return trace.Wrap(c.transfer(ctx)) + } + sftpClient, err := sftp.NewClient(sshClient, // Use concurrent stream to speed up transfer on slow networks as described in // https://github.com/gravitational/teleport/issues/20579 @@ -181,7 +276,6 @@ func (c *Config) TransferFiles(ctx context.Context, sshClient *ssh.Client) error // initFS ensures the source and destination filesystems are ready to transfer func (c *Config) initFS(sshClient *ssh.Client, client *sftp.Client) error { var haveRemoteFS bool - srcFS, srcOK := c.srcFS.(*remoteFS) if srcOK { srcFS.c = client @@ -427,7 +521,7 @@ func (c *Config) transferFile(ctx context.Context, dstPath, srcPath string, srcF } defer srcFile.Close() - dstFile, err := c.dstFS.Create(ctx, dstPath) + dstFile, err := c.dstFS.Create(ctx, dstPath, srcFileInfo.Size()) if err != nil { return trace.Errorf("error creating %s file %q: %w", c.dstFS.Type(), dstPath, err) } @@ -443,9 +537,11 @@ func (c *Config) transferFile(ctx context.Context, dstPath, srcPath string, srcF } reader, writer := prepareStreams(ctx, srcFile, dstFile, progressBar) - - if err := assertStreamsType(reader, writer); err != nil { - return trace.Wrap(err) + // skip this SFTP-specific check if we're transferring over HTTP + if !c.isHttpTransfer { + if err := assertStreamsType(reader, writer); err != nil { + return trace.Wrap(err) + } } n, err := io.Copy(writer, reader) diff --git a/lib/sshutils/sftp/sftp_test.go b/lib/sshutils/sftp/sftp_test.go index 64f0d1fcb57c1..e907c4dfb44d4 100644 --- a/lib/sshutils/sftp/sftp_test.go +++ b/lib/sshutils/sftp/sftp_test.go @@ -22,12 +22,16 @@ import ( "fmt" "io" "math/rand" + "net/http" + "net/http/httptest" "os" "path/filepath" + "strconv" "strings" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib/utils" @@ -583,6 +587,83 @@ func TestCopyingSymlinkedFile(t *testing.T) { checkTransfer(t, false, dstPath, linkPath) } +func TestHTTPUpload(t *testing.T) { + tempDir := t.TempDir() + src := "source" + dst := filepath.Join(tempDir, "destination") + + createFile(t, tempDir, src) + src = filepath.Join(tempDir, src) + f, err := os.Open(src) + require.NoError(t, err) + t.Cleanup(func() { + f.Close() + }) + + req, err := http.NewRequest("POST", "/", f) + require.NoError(t, err) + + fi, err := f.Stat() + require.NoError(t, err) + req.Header.Set("Content-Length", strconv.FormatInt(fi.Size(), 10)) + + cfg, err := CreateHTTPUploadConfig( + HTTPTransferRequest{ + Src: "source", + Dst: dst, + HTTPRequest: req, + }, + ) + require.NoError(t, err) + + err = cfg.transfer(req.Context()) + require.NoError(t, err) + + srcContents, err := os.ReadFile(src) + require.NoError(t, err) + dstContents, err := os.ReadFile(dst) + require.NoError(t, err) + require.Empty(t, cmp.Diff(string(srcContents), string(dstContents))) +} + +func TestHTTPDownload(t *testing.T) { + tempDir := t.TempDir() + src := "source" + + createFile(t, tempDir, src) + src = filepath.Join(tempDir, src) + f, err := os.Open(src) + require.NoError(t, err) + t.Cleanup(func() { + f.Close() + }) + + contents, err := os.ReadFile(src) + require.NoError(t, err) + + w := httptest.NewRecorder() + cfg, err := CreateHTTPDownloadConfig( + HTTPTransferRequest{ + Src: src, + Dst: "/home/robots.txt", + HTTPResponse: w, + }, + ) + require.NoError(t, err) + + err = cfg.transfer(context.Background()) + require.NoError(t, err) + + data, err := io.ReadAll(w.Body) + require.NoError(t, err) + contentLengthStr := strconv.Itoa(len(data)) + + require.Empty(t, cmp.Diff(string(contents), string(data))) + require.Empty(t, cmp.Diff(contentLengthStr, w.Header().Get("Content-Length"))) + require.Empty(t, cmp.Diff("application/octet-stream", w.Header().Get("Content-Type"))) + require.Empty(t, cmp.Diff(`attachment;filename="robots.txt"`, w.Header().Get("Content-Disposition"))) +} + func createFile(t *testing.T, rootDir, path string) { dir := filepath.Dir(path) if dir != path { @@ -677,7 +758,7 @@ func compareFiles(t *testing.T, preserveAttrs bool, dstInfo, srcInfo os.FileInfo require.NoError(t, err) srcBytes, err := os.ReadFile(src) require.NoError(t, err) - require.True(t, bytes.Equal(dstBytes, srcBytes), "%q and %q contents not equal", dst, src[0]) + require.True(t, bytes.Equal(dstBytes, srcBytes), "%q and %q contents not equal", dst, src) } func compareFileInfos(t *testing.T, preserveAttrs bool, dstInfo, srcInfo os.FileInfo, dst, src string) { diff --git a/lib/web/files.go b/lib/web/files.go index 44b38110a5a16..e0f057e514745 100644 --- a/lib/web/files.go +++ b/lib/web/files.go @@ -32,7 +32,7 @@ import ( wanlib "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/client" "github.com/gravitational/teleport/lib/reversetunnel" - "github.com/gravitational/teleport/lib/sshutils/scp" + "github.com/gravitational/teleport/lib/sshutils/sftp" ) // fileTransferRequest describes HTTP file transfer request @@ -92,88 +92,52 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou return nil, trace.AccessDenied("MFA required for file transfer") } + var cfg *sftp.Config isUpload := r.Method == http.MethodPost if isUpload { - err = ft.upload(req, r) + cfg, err = sftp.CreateHTTPUploadConfig(sftp.HTTPTransferRequest{ + Src: req.filename, + Dst: req.remoteLocation, + HTTPRequest: r, + }) } else { - err = ft.download(req, r, w) + cfg, err = sftp.CreateHTTPDownloadConfig(sftp.HTTPTransferRequest{ + Src: req.remoteLocation, + Dst: req.filename, + HTTPResponse: w, + }) } - if err != nil { return nil, trace.Wrap(err) } - // We must return nil so that we don't write anything to - // the response, which would corrupt the downloaded file. - return nil, nil -} - -type fileTransfer struct { - // ctx is a web session context for the currently logged in user. - ctx *SessionContext - authClient auth.ClientI - proxyHostPort string -} - -func (f *fileTransfer) download(req fileTransferRequest, httpReq *http.Request, w http.ResponseWriter) error { - cmd, err := scp.CreateHTTPDownload(scp.HTTPTransferRequest{ - RemoteLocation: req.remoteLocation, - HTTPResponse: w, - User: f.ctx.GetUser(), - }) - if err != nil { - return trace.Wrap(err) - } - - tc, err := f.createClient(req, httpReq) + tc, err := ft.createClient(req, r) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } if req.webauthn != "" { - err = f.issueSingleUseCert(req.webauthn, httpReq, tc) + err = ft.issueSingleUseCert(req.webauthn, r, tc) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } } - err = tc.ExecuteSCP(httpReq.Context(), req.serverID, cmd) + err = tc.TransferFiles(r.Context(), req.login, req.serverID+":0", cfg) if err != nil { - return trace.Wrap(err) + return nil, trace.Wrap(err) } - return nil + // We must return nil so that we don't write anything to + // the response, which would corrupt the downloaded file. + return nil, nil } -func (f *fileTransfer) upload(req fileTransferRequest, httpReq *http.Request) error { - cmd, err := scp.CreateHTTPUpload(scp.HTTPTransferRequest{ - RemoteLocation: req.remoteLocation, - FileName: req.filename, - HTTPRequest: httpReq, - User: f.ctx.GetUser(), - }) - if err != nil { - return trace.Wrap(err) - } - - tc, err := f.createClient(req, httpReq) - if err != nil { - return trace.Wrap(err) - } - - if req.webauthn != "" { - err = f.issueSingleUseCert(req.webauthn, httpReq, tc) - if err != nil { - return trace.Wrap(err) - } - } - - err = tc.ExecuteSCP(httpReq.Context(), req.serverID, cmd) - if err != nil { - return trace.Wrap(err) - } - - return nil +type fileTransfer struct { + // ctx is a web session context for the currently logged in user. + ctx *SessionContext + authClient auth.ClientI + proxyHostPort string } func (f *fileTransfer) createClient(req fileTransferRequest, httpReq *http.Request) (*client.TeleportClient, error) {