diff --git a/integration/integration_test.go b/integration/integration_test.go index d3cdbf66d6159..818a8e2b3c05c 100644 --- a/integration/integration_test.go +++ b/integration/integration_test.go @@ -97,6 +97,7 @@ import ( rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/srv/alpnproxy/common" "github.com/gravitational/teleport/lib/sshutils" + telesftp "github.com/gravitational/teleport/lib/sshutils/sftp" "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" @@ -190,6 +191,7 @@ func TestIntegrations(t *testing.T) { t.Run("DifferentPinnedIP", suite.bind(testDifferentPinnedIP)) t.Run("JoinOverReverseTunnelOnly", suite.bind(testJoinOverReverseTunnelOnly)) t.Run("SFTP", suite.bind(testSFTP)) + t.Run("ModeratedSFTP", suite.bind(testModeratedSFTP)) t.Run("EscapeSequenceTriggers", suite.bind(testEscapeSequenceTriggers)) t.Run("AuthLocalNodeControlStream", suite.bind(testAuthLocalNodeControlStream)) t.Run("AgentlessConnection", suite.bind(testAgentlessConnection)) @@ -7934,6 +7936,332 @@ func getRemoteAddrString(sshClientString string) string { return fmt.Sprintf("%s:%s", parts[0], parts[1]) } +func isNilOrEOFErr(t *testing.T, err error) { + t.Helper() + + if err != nil { + require.ErrorIs(t, err, io.EOF) + } +} + +func testModeratedSFTP(t *testing.T, suite *integrationTestSuite) { + modules.SetTestModules(t, &modules.TestModules{ + TestBuildType: modules.BuildEnterprise, + }) + + // Create Teleport instance + instance := suite.newTeleport(t, nil, true) + t.Cleanup(func() { + instance.StopAll() + }) + + ctx := context.Background() + authServer := instance.Process.GetAuthServer() + + // Create peer and moderator users and roles + username := suite.Me.Username + peerUsername := username + "-peer" + sshAccessRole, err := types.NewRole("ssh-access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + Logins: []string{username}, + NodeLabels: types.Labels{ + types.Wildcard: []string{types.Wildcard}, + }, + }, + }) + require.NoError(t, err) + err = authServer.CreateRole(ctx, sshAccessRole) + require.NoError(t, err) + + peerRole, err := types.NewRole("peer", types.RoleSpecV6{ + Allow: types.RoleConditions{ + RequireSessionJoin: []*types.SessionRequirePolicy{ + { + Name: "Requires oversight", + Filter: `equals("true", "true")`, + Kinds: []string{ + string(types.SSHSessionKind), + }, + Count: 1, + Modes: []string{ + string(types.SessionModeratorMode), + }, + OnLeave: string(types.OnSessionLeaveTerminate), + }, + }, + }, + }) + require.NoError(t, err) + err = authServer.CreateRole(ctx, peerRole) + require.NoError(t, err) + + peerUser, err := types.NewUser(peerUsername) + require.NoError(t, err) + peerUser.SetLogins([]string{username}) + peerUser.SetRoles([]string{sshAccessRole.GetName(), peerRole.GetName()}) + err = authServer.CreateUser(ctx, peerUser) + require.NoError(t, err) + + modUsername := username + "-moderator" + modRole, err := types.NewRole("moderator", types.RoleSpecV6{ + Allow: types.RoleConditions{ + JoinSessions: []*types.SessionJoinPolicy{{ + Name: "Session moderator", + Roles: []string{peerRole.GetName()}, + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionModeratorMode), string(types.SessionObserverMode)}, + }}, + }, + }) + require.NoError(t, err) + err = authServer.CreateRole(ctx, modRole) + require.NoError(t, err) + + moderatorUser, err := types.NewUser(modUsername) + require.NoError(t, err) + moderatorUser.SetLogins([]string{username}) + moderatorUser.SetRoles([]string{sshAccessRole.GetName(), modRole.GetName()}) + err = authServer.CreateUser(ctx, moderatorUser) + require.NoError(t, err) + + waitForNodesToRegister(t, instance, helpers.Site) + + // Start a shell so a moderated session is created + peerClient, err := instance.NewClient(helpers.ClientConfig{ + TeleportUser: peerUsername, + Login: username, + Cluster: helpers.Site, + Host: Host, + }) + require.NoError(t, err) + + peerClusterClient, err := peerClient.ConnectToCluster(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, peerClusterClient.Close()) + }) + + nodeDetails := client.NodeDetails{ + Addr: instance.Config.SSH.Addr.Addr, + Namespace: peerClient.Namespace, + Cluster: helpers.Site, + } + peerNodeClient, err := peerClient.ConnectToNode( + ctx, + peerClusterClient, + nodeDetails, + username, + ) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, peerNodeClient.Close()) + }) + + peerSSH := peerNodeClient.Client + peerSess, err := peerSSH.NewSession(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, peerSess.Close()) + }) + + peerTerm := NewTerminal(250) + peerSess.Stdin = peerTerm + peerSess.Stdout = peerTerm + peerSess.Stderr = peerTerm + err = peerSess.Shell(ctx) + require.NoError(t, err) + + var sessTracker types.SessionTracker + require.EventuallyWithT(t, func(t *assert.CollectT) { + trackers, err := peerClusterClient.AuthClient.GetActiveSessionTrackers(ctx) + assert.NoError(t, err) + if assert.Len(t, trackers, 1) { + sessTracker = trackers[0] + } + }, 5*time.Second, 100*time.Millisecond) + + // Join the waiting session so it is approved + modTC, err := instance.NewClient(helpers.ClientConfig{ + TeleportUser: modUsername, + Login: username, + Cluster: helpers.Site, + Host: Host, + }) + require.NoError(t, err) + + modClusterClient, err := modTC.ConnectToCluster(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, modClusterClient.Close()) + }) + + conn, details, err := modClusterClient.ProxyClient.DialHost(ctx, nodeDetails.Addr, nodeDetails.Cluster, modTC.LocalAgent().ExtendedAgent) + require.NoError(t, err) + sshConfig := modClusterClient.ProxyClient.SSHConfig(username) + modSSHConn, modSSHChans, modSSHReqs, err := tracessh.NewClientConn(ctx, conn, nodeDetails.ProxyFormat(), sshConfig) + require.NoError(t, err) + + // We pass an empty channel which we close right away to ssh.NewClient + // because the client need to handle requests itself. + emptyCh := make(chan *ssh.Request) + close(emptyCh) + modNodeCli := client.NodeClient{ + Client: tracessh.NewClient(modSSHConn, modSSHChans, emptyCh), + Namespace: nodeDetails.Namespace, + TC: modTC, + Tracer: modTC.Tracer, + FIPSEnabled: details.FIPS, + ProxyPublicAddr: modTC.WebProxyAddr, + } + + modSess, err := modNodeCli.Client.NewSession(ctx) + require.NoError(t, err) + err = modSess.Setenv(ctx, sshutils.SessionEnvVar, sessTracker.GetSessionID()) + require.NoError(t, err) + err = modSess.Setenv(ctx, teleport.EnvSSHJoinMode, string(types.SessionModeratorMode)) + require.NoError(t, err) + + modTerm := NewTerminal(250) + modSess.Stdin = modTerm + modSess.Stdout = modTerm + modSess.Stderr = modTerm + err = modSess.Shell(ctx) + require.NoError(t, err) + + // Create and approve a file download request + tempDir := t.TempDir() + reqFile := filepath.Join(tempDir, "req-file") + err = os.WriteFile(reqFile, []byte("contents"), 0o666) + require.NoError(t, err) + + err = peerSess.RequestFileTransfer(ctx, tracessh.FileTransferReq{ + Download: true, + Location: reqFile, + }) + require.NoError(t, err) + + sshReq := <-modSSHReqs + var joinEvent apievents.SessionJoin + err = json.Unmarshal(sshReq.Payload, &joinEvent) + require.NoError(t, err) + + sshReq = <-modSSHReqs + var fileReq apievents.FileTransferRequestEvent + err = json.Unmarshal(sshReq.Payload, &fileReq) + require.NoError(t, err) + + err = modSess.ApproveFileTransferRequest(ctx, fileReq.RequestID) + require.NoError(t, err) + + // Ignore file transfer request approve event + <-modSSHReqs + + // Test that only operations needed to complete the download + // are allowed + transferSess, err := peerSSH.NewSession(ctx) + require.NoError(t, err) + t.Cleanup(func() { + isNilOrEOFErr(t, transferSess.Close()) + }) + + err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + require.NoError(t, err) + + err = transferSess.RequestSubsystem(ctx, teleport.SFTPSubsystem) + require.NoError(t, err) + w, err := transferSess.StdinPipe() + require.NoError(t, err) + r, err := transferSess.StdoutPipe() + require.NoError(t, err) + sftpClient, err := sftp.NewClientPipe(r, w) + require.NoError(t, err) + + // A file not in the request shouldn't be allowed + _, err = sftpClient.Open(filepath.Join(tempDir, "bad-file")) + require.ErrorContains(t, err, `method get is not allowed`) + // Since this is a download no files should be allowed to be written to + _, err = sftpClient.OpenFile(filepath.Join(tempDir, reqFile), os.O_WRONLY) + require.ErrorContains(t, err, `method put is not allowed`) + // Only stats and reads should be allowed + err = sftpClient.Mkdir(filepath.Join(tempDir, "new-dir")) + require.ErrorContains(t, err, `method mkdir is not allowed`) + // Since this is a download no files should be allowed to have + // their permissions changed + err = sftpClient.Chmod(reqFile, 0o777) + require.ErrorContains(t, err, `method setstat is not allowed`) + + // Only necessary operations should be allowed + _, err = sftpClient.Stat(reqFile) + require.NoError(t, err) + _, err = sftpClient.Lstat(reqFile) + require.NoError(t, err) + rf, err := sftpClient.Open(reqFile) + require.NoError(t, err) + require.NoError(t, rf.Close()) + + require.NoError(t, sftpClient.Close()) + + // Create and approve a file upload request + err = peerSess.RequestFileTransfer(ctx, tracessh.FileTransferReq{ + Download: false, + Filename: "upload-file", + Location: reqFile, + }) + require.NoError(t, err) + + sshReq = <-modSSHReqs + err = json.Unmarshal(sshReq.Payload, &fileReq) + require.NoError(t, err) + + err = modSess.ApproveFileTransferRequest(ctx, fileReq.RequestID) + require.NoError(t, err) + + // Ignore file transfer request approve event + <-modSSHReqs + + isNilOrEOFErr(t, transferSess.Close()) + transferSess, err = peerSSH.NewSession(ctx) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, transferSess.Close()) + }) + + err = transferSess.Setenv(ctx, string(telesftp.ModeratedSessionID), sessTracker.GetSessionID()) + require.NoError(t, err) + + // Test that only operations needed to complete the download + // are allowed + err = transferSess.RequestSubsystem(ctx, teleport.SFTPSubsystem) + require.NoError(t, err) + w, err = transferSess.StdinPipe() + require.NoError(t, err) + r, err = transferSess.StdoutPipe() + require.NoError(t, err) + sftpClient, err = sftp.NewClientPipe(r, w) + require.NoError(t, err) + + // A file not in the request shouldn't be allowed + _, err = sftpClient.Open(filepath.Join(tempDir, "bad-file")) + require.ErrorContains(t, err, `method get is not allowed`) + // Since this is an upload no files should be allowed to be read from + _, err = sftpClient.OpenFile(filepath.Join(tempDir, reqFile), os.O_RDONLY) + require.ErrorContains(t, err, `method get is not allowed`) + // Only stats, writes, and chmods should be allowed + err = sftpClient.Mkdir(filepath.Join(tempDir, "new-dir")) + require.ErrorContains(t, err, `method mkdir is not allowed`) + + // Only necessary operations should be allowed + _, err = sftpClient.Stat(reqFile) + require.NoError(t, err) + _, err = sftpClient.Lstat(reqFile) + require.NoError(t, err) + err = sftpClient.Chmod(reqFile, 0o777) + require.NoError(t, err) + wf, err := sftpClient.OpenFile(reqFile, os.O_WRONLY) + require.NoError(t, err) + require.NoError(t, wf.Close()) +} + func testSFTP(t *testing.T, suite *integrationTestSuite) { // Create Teleport instance. teleport := suite.newTeleport(t, nil, true) @@ -7971,6 +8299,9 @@ func testSFTP(t *testing.T, suite *integrationTestSuite) { suite.Me.Username, ) require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, nodeClient.Close()) + }) sftpClient, err := sftp.NewClient(nodeClient.Client.Client) require.NoError(t, err) diff --git a/lib/srv/ctx.go b/lib/srv/ctx.go index 2790b346142f4..30d387da645b9 100644 --- a/lib/srv/ctx.go +++ b/lib/srv/ctx.go @@ -454,6 +454,10 @@ type ServerContext struct { // UserCreatedByTeleport is true when the system user was created by Teleport user auto-provision. UserCreatedByTeleport bool + + // approvedFileReq is an approved file transfer request that will only be + // set when the session's pending file transfer request is approved. + approvedFileReq *FileTransferRequest } // NewServerContext creates a new *ServerContext which is used to pass and @@ -1371,3 +1375,23 @@ func (c *ServerContext) GetExecRequest() (Exec, error) { } return c.execRequest, nil } + +func (c *ServerContext) setApprovedFileTransferRequest(req *FileTransferRequest) { + c.mu.Lock() + c.approvedFileReq = req + c.mu.Unlock() +} + +// ConsumeApprovedFileTransferRequest will return the approved file transfer +// request for this session if there is one present. Note that if an +// approved request is returned future calls to this method will return +// nil to prevent an approved request getting reused incorrectly. +func (c *ServerContext) ConsumeApprovedFileTransferRequest() *FileTransferRequest { + c.mu.Lock() + defer c.mu.Unlock() + + req := c.approvedFileReq + c.approvedFileReq = nil + + return req +} diff --git a/lib/srv/regular/sftp.go b/lib/srv/regular/sftp.go index df61a6f705833..a3d691c9f3362 100644 --- a/lib/srv/regular/sftp.go +++ b/lib/srv/regular/sftp.go @@ -19,6 +19,7 @@ package regular import ( "bufio" "context" + "encoding/json" "errors" "io" "os" @@ -40,18 +41,20 @@ import ( const copyingGoroutines = 2 type sftpSubsys struct { - sftpCmd *exec.Cmd - serverCtx *srv.ServerContext - errCh chan error - log *logrus.Entry + log *logrus.Entry + + fileTransferReq *srv.FileTransferRequest + sftpCmd *exec.Cmd + serverCtx *srv.ServerContext + errCh chan error } -func newSFTPSubsys() (*sftpSubsys, error) { - // TODO: add prometheus collectors? +func newSFTPSubsys(fileTransferReq *srv.FileTransferRequest) (*sftpSubsys, error) { return &sftpSubsys{ log: logrus.WithFields(logrus.Fields{ trace.Component: teleport.ComponentSubsystemSFTP, }), + fileTransferReq: fileTransferReq, }, nil } @@ -113,9 +116,25 @@ func (s *sftpSubsys) Start(ctx context.Context, if err != nil { return trace.Wrap(err) } - // TODO: put in cgroup? execRequest.Continue() + // Send the file transfer request if applicable. The SFTP process + // expects the file transfer request data will end with a null byte, + // so if there is no request to send just send a null byte so the + // SFTP process can detect that no request was sent. + encodedReq := []byte{0x0} + if s.fileTransferReq != nil { + encodedReq, err = json.Marshal(s.fileTransferReq) + if err != nil { + return trace.Wrap(err) + } + encodedReq = append(encodedReq, 0x0) + } + _, err = chReadPipeIn.Write(encodedReq) + if err != nil { + return trace.Wrap(err) + } + // Copy the SSH channel to and from the anonymous pipes s.errCh = make(chan error, copyingGoroutines) go func() { diff --git a/lib/srv/regular/sshserver.go b/lib/srv/regular/sshserver.go index a0decc619e461..d1b1757ac96c0 100644 --- a/lib/srv/regular/sshserver.go +++ b/lib/srv/regular/sshserver.go @@ -2216,7 +2216,7 @@ func (s *Server) parseSubsystemRequest(req *ssh.Request, ctx *srv.ServerContext) return nil, trace.Wrap(err) } - return newSFTPSubsys() + return newSFTPSubsys(ctx.ConsumeApprovedFileTransferRequest()) default: return nil, trace.BadParameter("unrecognized subsystem: %v", r.Name) } diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 0aba96e6fe3be..6a6d9c6a180a9 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -22,6 +22,9 @@ import ( "errors" "fmt" "io" + "os" + "os/user" + "path" "path/filepath" "sync" "sync/atomic" @@ -436,48 +439,54 @@ func (s *SessionRegistry) GetTerminalSize(sessionID string) (*term.Winsize, erro } func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, error) { - s.sessionsMux.Lock() - defer s.sessionsMux.Unlock() - - // get the requested location from env vars - location, _ := scx.GetEnv(sftp.FileTransferDstPath) - if location == "" { - return false, nil - } - // if a sessID and requestID environment variables were not set, return not approved and no error. - // This means the file transfer came from a non-moderated session. sessionID will be passed after a - // moderated session approval process has completed. + // If the TELEPORT_MODERATED_SESSION_ID environment variable was not + // set, return not approved and no error. This means the file + // transfer came from a non-moderated session. sessionID will be + // passed after a moderated session approval process has completed. sessID, _ := scx.GetEnv(string(sftp.ModeratedSessionID)) if sessID == "" { return false, nil } + // fetch session from registry with sessionID + s.sessionsMux.Lock() sess := s.sessions[rsession.ID(sessID)] + s.sessionsMux.Unlock() if sess == nil { // If they sent a sessionID and it wasn't found, send an actual error return false, trace.NotFound("Session not found") } - requestID, _ := scx.GetEnv(string(sftp.FileTransferRequestID)) - if requestID == "" { - return false, nil - } - // find file transfer request in the session by requestID - req := sess.fileTransferRequests[requestID] - if req == nil { - // If they sent a fileTransferRequestID and it wasn't found, send an actual error - return false, trace.NotFound("File transfer request not found") - } + // acquire the session mutex lock so sess.fileTransferReq doesn't get + // written while we're reading it + sess.mu.Lock() + defer sess.mu.Unlock() - if req.location != location { - return false, trace.AccessDenied("requested destination path does not match the current request") + if sess.fileTransferReq == nil { + return false, trace.NotFound("Session does not have a pending file transfer request") } + if sess.fileTransferReq.Requester != scx.Identity.TeleportUser { + // to be safe deny and remove the pending request if the user + // doesn't match what we expect + req := sess.fileTransferReq + sess.fileTransferReq = nil + + sess.BroadcastMessage("file transfer request %s denied due to %s attempting to transfer files", req.ID, scx.Identity.TeleportUser) + _ = s.NotifyFileTransferRequest(req, FileTransferDenied, scx) - if req.requester != scx.Identity.TeleportUser { return false, trace.AccessDenied("Teleport user does not match original requester") } - return sess.checkIfFileTransferApproved(req) + approved, err := sess.checkIfFileTransferApproved(sess.fileTransferReq) + if err != nil { + return false, trace.Wrap(err) + } + if approved { + scx.setApprovedFileTransferRequest(sess.fileTransferReq) + sess.fileTransferReq = nil + } + + return approved, nil } // FileTransferRequestEvent is an event used to Notify party members during File Transfer Request approval process @@ -498,7 +507,7 @@ const ( // NotifyFileTransferRequest is called to notify all members of a party that a file transfer request has been created/approved/denied. // The notification is a global ssh request and requires the client to update its UI state accordingly. -func (s *SessionRegistry) NotifyFileTransferRequest(req *fileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { +func (s *SessionRegistry) NotifyFileTransferRequest(req *FileTransferRequest, res FileTransferRequestEvent, scx *ServerContext) error { session := scx.getSession() if session == nil { s.log.Debugf("Unable to notify %s, no session found in context.", res) @@ -514,11 +523,11 @@ func (s *SessionRegistry) NotifyFileTransferRequest(req *fileTransferRequest, re SessionMetadata: apievents.SessionMetadata{ SessionID: string(sid), }, - RequestID: req.id, - Requester: req.requester, - Location: req.location, - Filename: req.filename, - Download: req.download, + RequestID: req.ID, + Requester: req.Requester, + Location: req.Location, + Filename: req.Filename, + Download: req.Download, Approvers: make([]string, 0), } @@ -661,10 +670,10 @@ type session struct { // participants at the end of a session. participants map[rsession.ID]*party - // fileTransferRequests is a set of fileTransferRequests that are currently in the approval - // process, or already approved and not yet executed during a moderated session. If a request is - // denied or, once it's been executed, it should be removed from this map. - fileTransferRequests map[string]*fileTransferRequest + // fileTransferReq a pending file transfer request for this session. + // If the request is denied or approved it should be set to nil to + // prevent its reuse. + fileTransferReq *FileTransferRequest io *TermManager inWriter io.WriteCloser @@ -761,7 +770,6 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se id: id, registry: r, parties: make(map[rsession.ID]*party), - fileTransferRequests: make(map[string]*fileTransferRequest), participants: make(map[rsession.ID]*party), login: scx.Identity.Login, stopC: make(chan struct{}), @@ -1676,22 +1684,24 @@ func (s *session) checkPresence(ctx context.Context) error { return nil } -// fileTransferRequest is a request to upload or download a file from the node. -type fileTransferRequest struct { - id string - // requester is the Teleport User that requested the file transfer - requester string - // download is true if the request is a download, false if its an upload - download bool - // filename the name of the file to upload. - filename string - // location of the requested download or where a file will be uploaded - location string +// FileTransferRequest is a request to upload or download a file from a node. +type FileTransferRequest struct { + // ID is a UUID that uniquely identifies a file transfer request + // and is unlikely to collide with another file transfer request + ID string + // Requester is the Teleport User that requested the file transfer + Requester string + // Download is true if the request is a download, false if its an upload + Download bool + // Filename is the name of the file to upload. + Filename string + // Location of the requested download or where a file will be uploaded + Location string // approvers is a list of participants of moderator or peer type that have approved the request approvers map[string]*party } -func (s *session) checkIfFileTransferApproved(req *fileTransferRequest) (bool, error) { +func (s *session) checkIfFileTransferApproved(req *FileTransferRequest) (bool, error) { var participants []auth.SessionAccessContext for _, party := range req.approvers { @@ -1715,44 +1725,104 @@ func (s *session) checkIfFileTransferApproved(req *fileTransferRequest) (bool, e } // newFileTransferRequest takes FileTransferParams and creates a new fileTransferRequest struct -func (s *session) newFileTransferRequest(params *rsession.FileTransferRequestParams) *fileTransferRequest { - return &fileTransferRequest{ - id: uuid.New().String(), - requester: params.Requester, - location: params.Location, - filename: params.Filename, - download: params.Download, +func (s *session) newFileTransferRequest(params *rsession.FileTransferRequestParams) (*FileTransferRequest, error) { + location, err := s.expandFileTransferRequestPath(params.Location) + if err != nil { + return nil, trace.Wrap(err) + } + + req := FileTransferRequest{ + ID: uuid.New().String(), + Requester: params.Requester, + Location: location, + Filename: params.Filename, + Download: params.Download, approvers: make(map[string]*party), } + + return &req, nil +} + +func (s *session) expandFileTransferRequestPath(p string) (string, error) { + expanded := path.Clean(p) + dir := path.Dir(expanded) + + var tildePrefixed bool + var noBaseDir bool + if dir == "~" { + tildePrefixed = true + } else if dir == "." { + noBaseDir = true + } + + if tildePrefixed || noBaseDir { + localUser, err := user.Lookup(s.login) + if err != nil { + return "", trace.Wrap(err) + } + + exists, err := CheckHomeDir(localUser) + if err != nil { + return "", trace.Wrap(err) + } + homeDir := localUser.HomeDir + if !exists { + homeDir = string(os.PathSeparator) + } + + if tildePrefixed { + // expand home dir to make an absolute path + expanded = path.Join(homeDir, expanded[2:]) + } else { + // if no directories are specified SFTP will assume the file + // to be in the user's home dir + expanded = path.Join(homeDir, expanded) + } + } + + return expanded, nil } // addFileTransferRequest will create a new file transfer request and add it to the current session's fileTransferRequests map // and broadcast the appropriate string to the session. -func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestParams, scx *ServerContext) *fileTransferRequest { +func (s *session) addFileTransferRequest(params *rsession.FileTransferRequestParams, scx *ServerContext) error { s.mu.Lock() defer s.mu.Unlock() - req := s.newFileTransferRequest(params) - s.fileTransferRequests[req.id] = req + if s.fileTransferReq != nil { + return trace.AlreadyExists("a file transfer request already exists for this session") + } + if !params.Download && params.Filename == "" { + return trace.BadParameter("no source file is set for the upload") + } + + req, err := s.newFileTransferRequest(params) + if err != nil { + return trace.Wrap(err) + } + s.fileTransferReq = req + if params.Download { s.BroadcastMessage("User %s would like to download: %s", params.Requester, params.Location) } else { s.BroadcastMessage("User %s would like to upload %s to: %s", params.Requester, params.Filename, params.Location) } + err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, FileTransferUpdate, scx) - s.registry.NotifyFileTransferRequest(req, FileTransferUpdate, scx) - return req + return trace.Wrap(err) } // approveFileTransferRequest will add the approver to the approvers map of a file transfer request and notify the members // of the session if the updated approvers map would fulfill the moderated policy. -func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisionParams, scx *ServerContext) (*fileTransferRequest, error) { +func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisionParams, scx *ServerContext) error { s.mu.Lock() defer s.mu.Unlock() - fileTransferReq := s.fileTransferRequests[params.RequestID] - if fileTransferReq == nil { - return nil, trace.NotFound("File Transfer Request %s not found", params.RequestID) + if s.fileTransferReq == nil { + return trace.NotFound("File Transfer Request %s not found", params.RequestID) + } + if s.fileTransferReq.ID != params.RequestID { + return trace.BadParameter("current file transfer request is not %s", params.RequestID) } var approver *party @@ -1762,16 +1832,16 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi } } if approver == nil { - return nil, trace.AccessDenied("cannot approve file transfer requests if not in the current moderated session") + return trace.AccessDenied("cannot approve file transfer requests if not in the current moderated session") } - fileTransferReq.approvers[approver.user] = approver - s.BroadcastMessage("%s approved file transfer request %s", scx.Identity.TeleportUser, fileTransferReq.id) + s.fileTransferReq.approvers[approver.user] = approver + s.BroadcastMessage("%s approved file transfer request %s", scx.Identity.TeleportUser, s.fileTransferReq.ID) // check if policy is fulfilled - approved, err := s.checkIfFileTransferApproved(fileTransferReq) + approved, err := s.checkIfFileTransferApproved(s.fileTransferReq) if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } var eventType FileTransferRequestEvent @@ -1780,22 +1850,25 @@ func (s *session) approveFileTransferRequest(params *rsession.FileTransferDecisi } else { eventType = FileTransferUpdate } + err = s.registry.NotifyFileTransferRequest(s.fileTransferReq, eventType, scx) - s.registry.NotifyFileTransferRequest(fileTransferReq, eventType, scx) - - return fileTransferReq, nil + return trace.Wrap(err) } // denyFileTransferRequest will deny a file transfer request and remove it from the current session's file transfer requests map. // A file transfer request does not persist after deny, so there is no "denied" state. Deny in this case is synonymous with delete // with the addition of checking for a valid denier. -func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionParams, scx *ServerContext) (*fileTransferRequest, error) { +func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionParams, scx *ServerContext) error { s.mu.Lock() defer s.mu.Unlock() - fileTransferReq := s.fileTransferRequests[params.RequestID] - if fileTransferReq == nil { - return nil, trace.NotFound("file transfer request %s not found", params.RequestID) + + if s.fileTransferReq == nil { + return trace.NotFound("file transfer request %s not found", params.RequestID) + } + if s.fileTransferReq.ID != params.RequestID { + return trace.BadParameter("current file transfer request is not %s", params.RequestID) } + var denier *party for _, p := range s.parties { if p.ctx.ID() == scx.ID() { @@ -1803,15 +1876,16 @@ func (s *session) denyFileTransferRequest(params *rsession.FileTransferDecisionP } } if denier == nil { - return nil, trace.AccessDenied("cannot deny file transfer requests if not in the current moderated session") + return trace.AccessDenied("cannot deny file transfer requests if not in the current moderated session") } - delete(s.fileTransferRequests, fileTransferReq.id) + req := s.fileTransferReq + s.fileTransferReq = nil - s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, fileTransferReq.id) - s.registry.NotifyFileTransferRequest(fileTransferReq, FileTransferDenied, scx) + s.BroadcastMessage("%s denied file transfer request %s", scx.Identity.TeleportUser, req.ID) + err := s.registry.NotifyFileTransferRequest(req, FileTransferDenied, scx) - return fileTransferReq, nil + return trace.Wrap(err) } func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) { diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index 627abdfd07a52..c7c4bf01dec4f 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -139,32 +139,26 @@ func TestIsApprovedFileTransfer(t *testing.T) { name string expectedResult bool expectedError string - req *fileTransferRequest + req *FileTransferRequest reqID string location string }{ { - name: "no file request found with supplied ID", + name: "no pending file request", expectedResult: false, - expectedError: "", + expectedError: "Session does not have a pending file transfer request", reqID: "", req: nil, }, - { - name: "no file request found with supplied ID", - expectedResult: false, - expectedError: "File transfer request not found", - reqID: "111", - req: nil, - }, { name: "current requester does not match original requester", expectedResult: false, expectedError: "Teleport user does not match original requester", reqID: "123", - req: &fileTransferRequest{ - requester: "michael", + req: &FileTransferRequest{ + ID: "123", + Requester: "michael", approvers: make(map[string]*party), }, }, @@ -174,10 +168,11 @@ func TestIsApprovedFileTransfer(t *testing.T) { expectedError: "requested destination path does not match the current request", reqID: "123", location: "~/Downloads", - req: &fileTransferRequest{ - requester: "michael", + req: &FileTransferRequest{ + ID: "123", + Requester: "teleportUser", approvers: make(map[string]*party), - location: "~/badlocation", + Location: "~/badlocation", }, }, { @@ -186,10 +181,11 @@ func TestIsApprovedFileTransfer(t *testing.T) { expectedError: "", reqID: "123", location: "~/Downloads", - req: &fileTransferRequest{ - requester: "teleportUser", + req: &FileTransferRequest{ + ID: "123", + Requester: "teleportUser", approvers: approvers, - location: "~/Downloads", + Location: "~/Downloads", }, }, } @@ -199,16 +195,12 @@ func TestIsApprovedFileTransfer(t *testing.T) { // create and add a session to the registry sess, _ := testOpenSession(t, reg, accessRoleSet) - // create a fileTransferRequest. can be nil - sess.fileTransferRequests = map[string]*fileTransferRequest{ - "123": tt.req, - } + // create a FileTransferRequest. can be nil + sess.fileTransferReq = tt.req // new exec request context scx := newTestServerContext(t, reg.Srv, accessRoleSet) scx.SetEnv(string(sftp.ModeratedSessionID), sess.ID()) - scx.SetEnv(string(sftp.FileTransferRequestID), tt.reqID) - scx.SetEnv(sftp.FileTransferDstPath, tt.location) result, err := reg.isApprovedFileTransfer(scx) if err != nil { require.Equal(t, tt.expectedError, err.Error()) diff --git a/lib/srv/termhandlers.go b/lib/srv/termhandlers.go index 597eebafd82b2..e051d8171b1da 100644 --- a/lib/srv/termhandlers.go +++ b/lib/srv/termhandlers.go @@ -146,12 +146,10 @@ func (t *TermHandlers) HandleFileTransferDecision(ctx context.Context, ch ssh.Ch } if params.Approved { - _, err := session.approveFileTransferRequest(params, scx) - return trace.Wrap(err) + return trace.Wrap(session.approveFileTransferRequest(params, scx)) } - _, err = session.denyFileTransferRequest(params, scx) - return trace.Wrap(err) + return trace.Wrap(session.denyFileTransferRequest(params, scx)) } // HandleFileTransferRequest handles requests of type "file-transfer-request" which will @@ -170,8 +168,7 @@ func (t *TermHandlers) HandleFileTransferRequest(ctx context.Context, ch ssh.Cha return nil } - session.addFileTransferRequest(params, scx) - return nil + return trace.Wrap(session.addFileTransferRequest(params, scx)) } // HandleWinChange handles requests of type "window-change" which update the diff --git a/lib/sshutils/sftp/http.go b/lib/sshutils/sftp/http.go index 11181c727e66c..3f1f71753de34 100644 --- a/lib/sshutils/sftp/http.go +++ b/lib/sshutils/sftp/http.go @@ -36,14 +36,6 @@ import ( type contextKey string const ( - // FileTransferDstPath is the dstPath (location) for the requested file transfer. This would be equal - // to the file to be downloaded, or location for a file to be uploaded. - FileTransferDstPath string = "TELEPORT_FILE_TRANSFER_DST_PATH" - // FileTransferRequestID is an optional parameter id of an file transfer request that has gone through - // an approval process during a moderated session to allow a file transfer scp command to be executed - // used as a value in the file transfer context and env var for exec session - FileTransferRequestID contextKey = "TELEPORT_FILE_TRANSFER_REQUEST_ID" - // ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session // to check for valid FileTransferRequests // used as a value in the file transfer context and env var for exec session diff --git a/lib/sshutils/sftp/sftp.go b/lib/sshutils/sftp/sftp.go index dc8975f20f758..a43b12e6433e2 100644 --- a/lib/sshutils/sftp/sftp.go +++ b/lib/sshutils/sftp/sftp.go @@ -237,17 +237,11 @@ func (c *Config) TransferFiles(ctx context.Context, sshClient *ssh.Client) error } defer s.Close() - // File transfers in a moderated session require these two variables - // to check for approval on the ssh server. If they exist in the - // context, set them in our env vars + // File transfers in a moderated session require this variable + // to check for approval on the ssh server if moderatedSessionID, ok := ctx.Value(ModeratedSessionID).(string); ok { s.Setenv(string(ModeratedSessionID), moderatedSessionID) } - if fileTransferRequestID, ok := ctx.Value(FileTransferRequestID).(string); ok { - s.Setenv(string(FileTransferRequestID), fileTransferRequestID) - } - // set dstPath in env var to check against file transfer request location - s.Setenv(FileTransferDstPath, c.dstPath) pe, err := s.StderrPipe() if err != nil { diff --git a/lib/web/files.go b/lib/web/files.go index d2ffc4d92bff3..da43dd0bc5b0e 100644 --- a/lib/web/files.go +++ b/lib/web/files.go @@ -139,8 +139,6 @@ func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprou ctx := r.Context() if req.fileTransferRequestID != "" { - // These values should never exist independently of each other so we can set them at the same time - ctx = context.WithValue(ctx, sftp.FileTransferRequestID, req.fileTransferRequestID) ctx = context.WithValue(ctx, sftp.ModeratedSessionID, req.moderatedSessionID) } diff --git a/tool/teleport/common/sftp.go b/tool/teleport/common/sftp.go index c14e8dffd6516..fcb09af394a5d 100644 --- a/tool/teleport/common/sftp.go +++ b/tool/teleport/common/sftp.go @@ -17,13 +17,17 @@ limitations under the License. package common import ( + "bufio" "bytes" + "encoding/json" "errors" "fmt" "io" "io/fs" "os" "os/user" + "path" + "strings" "time" "github.com/gogo/protobuf/jsonpb" @@ -36,6 +40,7 @@ import ( apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/srv" "github.com/gravitational/teleport/lib/utils" ) @@ -73,17 +78,70 @@ func (c compositeCh) Close() error { return trace.NewAggregate(c.r.Close(), c.w.Close()) } +type allowedOps struct { + write bool + path string +} + // sftpHandler provides handlers for a SFTP server. type sftpHandler struct { - logger *log.Entry - events chan<- *apievents.SFTP + logger *log.Entry + allowed *allowedOps + events chan<- *apievents.SFTP } -func newSFTPHandler(logger *log.Entry, events chan<- *apievents.SFTP) *sftpHandler { +func newSFTPHandler(logger *log.Entry, req *srv.FileTransferRequest, events chan<- *apievents.SFTP) (*sftpHandler, error) { + var allowed *allowedOps + if req != nil { + allowed = &allowedOps{ + write: !req.Download, + } + // TODO(capnspacehook): reject relative paths and symlinks + // make filepaths consistent by ensuring all separators use backslashes + allowed.path = path.Clean(req.Location) + } + return &sftpHandler{ - logger: logger, - events: events, + logger: logger, + allowed: allowed, + events: events, + }, nil +} + +func newDisallowedErr(req *sftp.Request) error { + return fmt.Errorf("method %s is not allowed on %s", strings.ToLower(req.Method), req.Filepath) +} + +// ensureReqIsAllowed returns an error if the SFTP request isn't +// allowed based on the approved file transfer request for this session. +func (s *sftpHandler) ensureReqIsAllowed(req *sftp.Request) error { + // no specifically allowed operations, all requests are allowed + if s.allowed == nil { + return nil } + + if s.allowed.path != path.Clean(req.Filepath) { + return newDisallowedErr(req) + } + + switch req.Method { + case methodLstat, methodStat: + // these methods are allowed + case methodGet: + // only allow reads for downloads + if s.allowed.write { + return newDisallowedErr(req) + } + case methodPut, methodSetStat: + // only allow writes and chmods for uploads + if !s.allowed.write { + return newDisallowedErr(req) + } + default: + return newDisallowedErr(req) + } + + return nil } // OpenFile handles 'open' requests when opening a file for reading @@ -129,6 +187,10 @@ func (s *sftpHandler) Filewrite(req *sftp.Request) (_ io.WriterAt, retErr error) } func (s *sftpHandler) openFile(req *sftp.Request) (*os.File, error) { + if err := s.ensureReqIsAllowed(req); err != nil { + return nil, err + } + var flags int pflags := req.Pflags() if pflags.Append { @@ -172,6 +234,9 @@ func (s *sftpHandler) Filecmd(req *sftp.Request) (retErr error) { if req.Filepath == "" { return os.ErrInvalid } + if err := s.ensureReqIsAllowed(req); err != nil { + return err + } switch req.Method { case methodSetStat: @@ -306,6 +371,9 @@ func (s *sftpHandler) Filelist(req *sftp.Request) (_ sftp.ListerAt, retErr error if req.Filepath == "" { return nil, os.ErrInvalid } + if err := s.ensureReqIsAllowed(req); err != nil { + return nil, err + } switch req.Method { case methodList: @@ -344,6 +412,9 @@ func (s *sftpHandler) Lstat(req *sftp.Request) (sftp.ListerAt, error) { if req.Filepath == "" { return nil, os.ErrInvalid } + if err := s.ensureReqIsAllowed(req); err != nil { + return nil, err + } fi, err := os.Lstat(req.Filepath) if err != nil { @@ -492,7 +563,6 @@ func onSFTP() error { return trace.Wrap(err) } defer chw.Close() - ch := compositeCh{chr, chw} auditFile, err := openFD(5, "audit") if err != nil { return trace.Wrap(err) @@ -501,7 +571,6 @@ func onSFTP() error { // Ensure the parent process will receive log messages from us l := utils.NewLogger() - l.SetOutput(os.Stderr) logger := l.WithField(trace.Component, teleport.ComponentSubsystemSFTP) currentUser, err := user.Current() @@ -513,8 +582,34 @@ func onSFTP() error { return trace.Wrap(err) } + // Read the file transfer request for this session if one exists + bufferedReader := bufio.NewReader(chr) + var encodedReq []byte + var fileTransferReq *srv.FileTransferRequest + for { + b, err := bufferedReader.ReadByte() + if err != nil { + return trace.Wrap(err) + } + // the encoded request will end with a null byte + if b == 0x0 { + break + } + encodedReq = append(encodedReq, b) + } + if len(encodedReq) != 0 { + fileTransferReq = new(srv.FileTransferRequest) + if err := json.Unmarshal(encodedReq, fileTransferReq); err != nil { + return trace.Wrap(err) + } + } + ch := compositeCh{io.NopCloser(bufferedReader), chw} + sftpEvents := make(chan *apievents.SFTP, 1) - h := newSFTPHandler(logger, sftpEvents) + h, err := newSFTPHandler(logger, fileTransferReq, sftpEvents) + if err != nil { + return trace.Wrap(err) + } handler := sftp.Handlers{ FileGet: h, FilePut: h,