Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 14 additions & 13 deletions integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7968,18 +7968,19 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// A file not in the request shouldn't be allowed
_, err = sftpClient.Open(filepath.Join(tempDir, "bad-file"))
require.ErrorContains(t, err, `method get is not allowed`)
badFile := filepath.Join(tempDir, "bad-file")
_, err = sftpClient.Open(badFile)
require.ErrorContains(t, err, fmt.Sprintf("operations are only allowed on %s, not %s", reqFile, badFile))
// Since this is a download no files should be allowed to be written to
_, err = sftpClient.OpenFile(filepath.Join(tempDir, reqFile), os.O_WRONLY)
require.ErrorContains(t, err, `method put is not allowed`)
_, err = sftpClient.OpenFile(reqFile, os.O_WRONLY)
require.ErrorContains(t, err, `writing is not allowed`)
// Only stats and reads should be allowed
err = sftpClient.Mkdir(filepath.Join(tempDir, "new-dir"))
require.ErrorContains(t, err, `method mkdir is not allowed`)
err = sftpClient.Mkdir(reqFile)
require.ErrorContains(t, err, `method mkdir is not allowed on `+reqFile)
// Since this is a download no files should be allowed to have
// their permissions changed
err = sftpClient.Chmod(reqFile, 0o777)
require.ErrorContains(t, err, `method setstat is not allowed`)
require.ErrorContains(t, err, `writing is not allowed`)

// Only necessary operations should be allowed
_, err = sftpClient.Stat(reqFile)
Expand Down Expand Up @@ -8032,14 +8033,14 @@ func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) {
require.NoError(t, err)

// A file not in the request shouldn't be allowed
_, err = sftpClient.Open(filepath.Join(tempDir, "bad-file"))
require.ErrorContains(t, err, `method get is not allowed`)
_, err = sftpClient.Open(badFile)
require.ErrorContains(t, err, fmt.Sprintf("operations are only allowed on %s, not %s", reqFile, badFile))
// Since this is an upload no files should be allowed to be read from
_, err = sftpClient.OpenFile(filepath.Join(tempDir, reqFile), os.O_RDONLY)
require.ErrorContains(t, err, `method get is not allowed`)
_, err = sftpClient.OpenFile(reqFile, os.O_RDONLY)
require.ErrorContains(t, err, `reading is not allowed`)
// Only stats, writes, and chmods should be allowed
err = sftpClient.Mkdir(filepath.Join(tempDir, "new-dir"))
require.ErrorContains(t, err, `method mkdir is not allowed`)
err = sftpClient.Mkdir(reqFile)
require.ErrorContains(t, err, `method mkdir is not allowed on `+reqFile)

// Only necessary operations should be allowed
_, err = sftpClient.Stat(reqFile)
Expand Down
69 changes: 8 additions & 61 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,6 @@ import (
"fmt"
"io"
"log/slog"
"os"
"os/user"
"path/filepath"
"slices"
"strconv"
"sync"
Expand Down Expand Up @@ -1900,59 +1897,6 @@ func (s *session) checkIfFileTransferApproved(req *FileTransferRequest) (bool, e
return isApproved, nil
}

// newFileTransferRequest takes FileTransferParams and creates a new fileTransferRequest struct
func (s *session) newFileTransferRequest(params *rsession.FileTransferRequestParams) (*FileTransferRequest, error) {
location, err := s.expandFileTransferRequestPath(params.Location)
if err != nil {
return nil, trace.Wrap(err)
}

req := FileTransferRequest{
ID: uuid.New().String(),
Requester: params.Requester,
Location: location,
Filename: params.Filename,
Download: params.Download,
approvers: make(map[string]*party),
}

return &req, nil
}

func (s *session) expandFileTransferRequestPath(p string) (string, error) {
expanded := filepath.Clean(p)
dir := filepath.Dir(expanded)

tildePrefixed := dir == "~"
noBaseDir := dir == "."
if tildePrefixed || noBaseDir {
localUser, err := user.Lookup(s.login)
if err != nil {
return "", trace.Wrap(err)
}

exists, err := CheckHomeDir(localUser)
if err != nil {
return "", trace.Wrap(err)
}
homeDir := localUser.HomeDir
if !exists {
homeDir = string(os.PathSeparator)
}

if tildePrefixed {
// expand home dir to make an absolute path
expanded = filepath.Join(homeDir, expanded[2:])
} else {
// if no directories are specified SFTP will assume the file
// to be in the user's home dir
expanded = filepath.Join(homeDir, expanded)
}
}

return expanded, nil
}
Comment thread
atburke marked this conversation as resolved.

// addFileTransferRequest will create a new file transfer request and add it to the current session's fileTransferRequests map
// and broadcast the appropriate string to the session.
func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestParams, scx *ServerContext) error {
Expand All @@ -1966,18 +1910,21 @@ func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestPar
return trace.BadParameter("no source file is set for the upload")
}

req, err := s.newFileTransferRequest(params)
if err != nil {
return trace.Wrap(err)
s.fileTransferReq = &FileTransferRequest{
ID: uuid.New().String(),
Requester: params.Requester,
Location: params.Location,
Filename: params.Filename,
Download: params.Download,
approvers: make(map[string]*party),
}
s.fileTransferReq = req

if params.Download {
s.BroadcastMessage("User %s would like to download: %s", params.Requester, params.Location)
} else {
s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location)
}
err = s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx)
err := s.registry.notifyFileTransferRequestUnderLock(s.fileTransferReq, FileTransferUpdate, scx)

return trace.Wrap(err)
}
Expand Down
25 changes: 13 additions & 12 deletions lib/sshutils/sftp/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ func TransferFiles(ctx context.Context, req *FileTransferRequest) error {
return trace.Wrap(err)
}
for i, srcPath := range req.Sources.Paths {
expandedPath, err := expandPath(srcPath)
expandedPath, err := ExpandHomeDir(srcPath)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -338,7 +338,7 @@ func TransferFiles(ctx context.Context, req *FileTransferRequest) error {
if err != nil {
return trace.Wrap(err)
}
expandedPath, err := expandPath(req.Destination.Path)
expandedPath, err := ExpandHomeDir(req.Destination.Path)
if err != nil {
return trace.Wrap(err)
}
Expand All @@ -361,27 +361,28 @@ func (p PathExpansionError) Error() string {
return fmt.Sprintf("expanding remote ~user paths is not supported, specify an absolute path instead of %q", p.path)
}

func expandPath(pathStr string) (string, error) {
// ExpandHomeDir evaluates the home directory ('~') in a path.
func ExpandHomeDir(pathStr string) (string, error) {
pfxLen, ok := homeDirPrefixLen(pathStr)
if !ok {
return pathStr, nil
}

// Removing the home dir prefix would mean returning an empty string,
// which is supported by SFTP but won't be as clear in logs or audit
// events. Since the SFTP server will be rooted at the user's home
// directory, "." and "" are equivalent in this context.
if pathStr == "~" {
return ".", nil
}
if pfxLen == 1 && len(pathStr) > 1 {
return "", trace.Wrap(PathExpansionError{path: pathStr})
}

// if an SFTP path is not absolute, it is assumed to start at the user's
// home directory so just strip the prefix and let the SFTP server
// figure out the correct remote path
return pathStr[pfxLen:], nil
// figure out the correct remote path.
trimmedPath := pathStr[pfxLen:]
// Returning an empty string is supported by SFTP but won't be as clear in
// logs or audit events. Since the SFTP server will be rooted at the user's
// home directory, "." and "" are equivalent in this context.
if trimmedPath == "" {
return ".", nil
}
return trimmedPath, nil
}

// homeDirPrefixLen returns the length of a set of characters that
Expand Down
8 changes: 6 additions & 2 deletions lib/sshutils/sftp/sftp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,11 @@ func TestHomeDirExpansion(t *testing.T) {
path: "~",
expandedPath: ".",
},

{
name: "tilde slash",
path: "~/",
expandedPath: ".",
},
{
name: "~user path",
path: "~user/foo",
Expand All @@ -502,7 +506,7 @@ func TestHomeDirExpansion(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
expanded, err := expandPath(tt.path)
expanded, err := ExpandHomeDir(tt.path)
if tt.errCheck == nil {
require.NoError(t, err)
require.Equal(t, tt.expandedPath, expanded)
Expand Down
33 changes: 22 additions & 11 deletions tool/teleport/common/sftp.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"os"
Expand Down Expand Up @@ -89,7 +88,22 @@ func newSFTPHandler(logger *slog.Logger, req *srv.FileTransferRequest, events ch
}
// TODO(capnspacehook): reject relative paths and symlinks
// make filepaths consistent by ensuring all separators use backslashes
allowed.path = path.Clean(req.Location)
allowedPath, err := sftputils.ExpandHomeDir(req.Location)
if err != nil {
return nil, trace.Wrap(err)
}
if !path.IsAbs(allowedPath) {
currentUser, err := user.Current()
if err != nil {
return nil, trace.Wrap(err)
}
if currentUser.HomeDir != "" {
allowedPath = path.Join(currentUser.HomeDir, allowedPath)
} else {
allowedPath = path.Join(string(os.PathSeparator), allowedPath)
}
}
allowed.path = path.Clean(allowedPath)
}

return &sftpHandler{
Expand All @@ -99,10 +113,6 @@ func newSFTPHandler(logger *slog.Logger, req *srv.FileTransferRequest, events ch
}, nil
}

func newDisallowedErr(req *sftp.Request) error {
return fmt.Errorf("method %s is not allowed on %s", strings.ToLower(req.Method), req.Filepath)
}

// ensureReqIsAllowed returns an error if the SFTP request isn't
// allowed based on the approved file transfer request for this session.
func (s *sftpHandler) ensureReqIsAllowed(req *sftp.Request) error {
Expand All @@ -111,8 +121,9 @@ func (s *sftpHandler) ensureReqIsAllowed(req *sftp.Request) error {
return nil
}

if s.allowed.path != path.Clean(req.Filepath) {
return newDisallowedErr(req)
cleaned := path.Clean(req.Filepath)
if s.allowed.path != cleaned {
return trace.Errorf("operations are only allowed on %s, not %s", s.allowed.path, cleaned)
}

switch req.Method {
Expand All @@ -121,15 +132,15 @@ func (s *sftpHandler) ensureReqIsAllowed(req *sftp.Request) error {
case sftputils.MethodGet:
// only allow reads for downloads
if s.allowed.write {
return newDisallowedErr(req)
return trace.Errorf("reading is not allowed for this request")
}
case sftputils.MethodPut, sftputils.MethodSetStat:
// only allow writes and chmods for uploads
if !s.allowed.write {
return newDisallowedErr(req)
return trace.Errorf("writing is not allowed for this request")
}
default:
return newDisallowedErr(req)
return trace.Errorf("method %s is not allowed on %s", strings.ToLower(req.Method), req.Filepath)
}

return nil
Expand Down
Loading