diff --git a/lib/client/client.go b/lib/client/client.go index 12f3c9c9f81cc..0923d4338b82f 100644 --- a/lib/client/client.go +++ b/lib/client/client.go @@ -1806,6 +1806,15 @@ func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error { } defer s.Close() + // File transfers in a moderated session require these two variablesto check for + // approval on the ssh server. If they exist in the context, set them in our env vars + if moderatedSessionID, ok := ctx.Value(scp.ModeratedSessionID).(string); ok { + s.Setenv(ctx, string(scp.ModeratedSessionID), moderatedSessionID) + } + if fileTransferRequestID, ok := ctx.Value(scp.FileTransferRequestID).(string); ok { + s.Setenv(ctx, string(scp.FileTransferRequestID), fileTransferRequestID) + } + stdin, err := s.StdinPipe() if err != nil { return trace.Wrap(err) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 2f692b3e0edaf..a64f2b26b07e1 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -51,6 +51,7 @@ import ( "github.com/gravitational/teleport/lib/services" rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" + "github.com/gravitational/teleport/lib/sshutils/scp" "github.com/gravitational/teleport/lib/utils" ) @@ -287,12 +288,19 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann } scx.Infof("Creating (exec) session %v.", sessionID) + approved, err := s.isApprovedFileTransfer(scx) + if err != nil { + return trace.Wrap(err) + } + canStart, _, err := sess.checkIfStart() if err != nil { return trace.Wrap(err) } - if !canStart { + // canStart will be true for non-moderated sessions. If canStart is false, check to + // see if the request has been approved through a moderated session next. + if !canStart && !approved { return errCannotStartUnattendedSession } @@ -338,6 +346,47 @@ func (s *SessionRegistry) GetTerminalSize(sessionID string) (*term.Winsize, erro return sess.term.GetWinSize() } +func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, error) { + s.sessionsMux.Lock() + defer s.sessionsMux.Unlock() + + // if a sessID and requestID environment variables were not set, return not approved and no error. + // This means the file transfer came from a non-moderated session. sessionID will be passed after a + // moderated session approval process has completed. + sessID, _ := scx.GetEnv(string(scp.ModeratedSessionID)) + if sessID == "" { + return false, nil + } + // fetch session from registry with sessionID + sess := s.sessions[rsession.ID(sessID)] + if sess == nil { + // If they sent a sessionID and it wasn't found, send an actual error + return false, trace.NotFound("Session not found") + } + + requestID, _ := scx.GetEnv(string(scp.FileTransferRequestID)) + if requestID == "" { + return false, nil + } + // find file transfer request in the session by requestID + req := sess.fileTransferRequests[requestID] + if req == nil { + // If they sent a fileTransferRequestID and it wasn't found, send an actual error + return false, trace.NotFound("File transfer request not found") + } + + if req.requester != scx.Identity.TeleportUser { + return false, trace.AccessDenied("Teleport user does not match original requester") + } + + incomingShellCmd := string(scx.sshRequest.Payload) + if incomingShellCmd != req.shellCmd { + return false, trace.AccessDenied("Incoming request does not match the approved request") + } + + return sess.checkIfFileTransferApproved(req) +} + // NotifyWinChange is called to notify all members in the party that the PTY // size has changed. The notification is sent as a global SSH request and it // is the responsibility of the client to update it's window size upon receipt. @@ -454,6 +503,11 @@ type session struct { // participants at the end of a session. participants map[rsession.ID]*party + // fileTransferRequests is a set of fileTransferRequests that are currently in the approval + // process, or already approved and not yet executed during a moderated session. If a request is + // denied or, once it's been executed, it should be removed from this map. + fileTransferRequests map[string]*fileTransferRequest + io *TermManager inWriter io.Writer @@ -1405,6 +1459,38 @@ func (s *session) checkPresence() error { return nil } +type fileTransferRequest struct { + // requester is the Teleport User that requested the file transfer + requester string + // shellCmd is the requested scp command to run + shellCmd string + // approvers is a list of participants of moderator or peer type that have approved the request + approvers map[string]*party +} + +func (s *session) checkIfFileTransferApproved(req *fileTransferRequest) (bool, error) { + var participants []auth.SessionAccessContext + + for _, party := range req.approvers { + if party.ctx.Identity.TeleportUser == s.initiator { + continue + } + + participants = append(participants, auth.SessionAccessContext{ + Username: party.ctx.Identity.TeleportUser, + Roles: party.ctx.Identity.AccessChecker.Roles(), + Mode: party.mode, + }) + } + + isApproved, _, err := s.access.FulfilledFor(participants) + if err != nil { + return false, trace.Wrap(err) + } + + return isApproved, nil +} + func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) { var participants []auth.SessionAccessContext diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index ba1c0161c983a..1fb3f7e825c3e 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -41,6 +41,7 @@ import ( "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" rsession "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/sshutils/scp" "github.com/gravitational/teleport/lib/utils" ) @@ -87,6 +88,138 @@ func TestParseAccessRequestIDs(t *testing.T) { } } +func TestIsApprovedFileTransfer(t *testing.T) { + // set enterprise for tests + modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise}) + srv := newMockServer(t) + srv.component = teleport.ComponentNode + + // init a session registry + reg, _ := NewSessionRegistry(SessionRegistryConfig{ + Srv: srv, + SessionTrackerService: srv.auth, + }) + t.Cleanup(func() { reg.Close() }) + + // Create the auditorRole and moderator Party + auditorRole, _ := types.NewRole("auditor", types.RoleSpecV6{ + Allow: types.RoleConditions{ + JoinSessions: []*types.SessionJoinPolicy{{ + Name: "foo", + Roles: []string{"access"}, + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionModeratorMode)}, + }}, + }, + }) + auditorRoleSet := services.NewRoleSet(auditorRole) + auditScx := newTestServerContext(t, reg.Srv, auditorRoleSet) + // change the teleport user so we dont match the user in the test cases + auditScx.Identity.TeleportUser = "mod" + auditSess, _ := testOpenSession(t, reg, auditorRoleSet) + approvers := make(map[string]*party) + auditChan := newMockSSHChannel() + approvers["mod"] = newParty(auditSess, types.SessionModeratorMode, auditChan, auditScx) + + // create the accessRole to be used for the requester + accessRole, _ := types.NewRole("access", types.RoleSpecV6{ + Allow: types.RoleConditions{ + RequireSessionJoin: []*types.SessionRequirePolicy{{ + Name: "foo", + Filter: "contains(user.roles, \"auditor\")", // escape to avoid illegal rune + Kinds: []string{string(types.SSHSessionKind)}, + Modes: []string{string(types.SessionModeratorMode)}, + Count: 1, + }}, + }, + }) + accessRoleSet := services.NewRoleSet(accessRole) + + cases := []struct { + name string + expectedResult bool + expectedError string + req *fileTransferRequest + reqID string + }{ + + { + name: "no file request found with supplied ID", + expectedResult: false, + expectedError: "", + reqID: "", + req: nil, + }, + { + name: "no file request found with supplied ID", + expectedResult: false, + expectedError: "File transfer request not found", + reqID: "111", + req: nil, + }, + { + name: "current requester does not match original requester", + expectedResult: false, + expectedError: "Teleport user does not match original requester", + reqID: "123", + req: &fileTransferRequest{ + requester: "michael", + shellCmd: "/usr/bin/scp -f ~/logs.txt", + approvers: make(map[string]*party), + }, + }, + { + name: "current payload does not match original payload", + expectedResult: false, + expectedError: "Incoming request does not match the approved request", + reqID: "123", + req: &fileTransferRequest{ + requester: "teleportUser", + shellCmd: "badcommand", + approvers: make(map[string]*party), + }, + }, + { + name: "approved request", + expectedResult: true, + expectedError: "", + reqID: "123", + req: &fileTransferRequest{ + requester: "teleportUser", + shellCmd: "/usr/bin/scp -f ~/logs.txt", + approvers: approvers, + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + // create and add a session to the registry + sess, _ := testOpenSession(t, reg, accessRoleSet) + + // create a fileTransferRequest. can be nil + sess.fileTransferRequests = map[string]*fileTransferRequest{ + "123": tt.req, + } + + // new exec request context + scx := newTestServerContext(t, reg.Srv, accessRoleSet) + scx.sshRequest = &ssh.Request{ + Payload: []byte("/usr/bin/scp -f ~/logs.txt"), + } + + scx.SetEnv(string(scp.ModeratedSessionID), sess.ID()) + scx.SetEnv(string(scp.FileTransferRequestID), tt.reqID) + result, err := reg.isApprovedFileTransfer(scx) + if err != nil { + require.Equal(t, tt.expectedError, err.Error()) + } + + require.Equal(t, tt.expectedResult, result) + }) + } +} + func TestSession_newRecorder(t *testing.T) { t.Parallel() diff --git a/lib/sshutils/scp/http.go b/lib/sshutils/scp/http.go index 1418659d0ae27..c6563b4455360 100644 --- a/lib/sshutils/scp/http.go +++ b/lib/sshutils/scp/http.go @@ -55,6 +55,11 @@ type HTTPTransferRequest struct { User string // AuditLog is AuditLog log AuditLog events.AuditLogSessionStreamer + // FileTransferRequestID is used to find a FileTransferRequest on a session + FileTransferRequestID string + // ModeratedSessonID is an ID of a moderated session that has completed a + // file transfer request approval process + ModeratedSessionID string } func (r *HTTPTransferRequest) parseRemoteLocation() (string, string, error) { @@ -269,3 +274,17 @@ type nopWriteCloser struct { func (wr *nopWriteCloser) Close() error { return nil } + +const ( + // FileTransferRequestID is an optional parameter id of an file transfer request that has gone through + // an approval process during a moderated session to allow a file transfer scp command to be executed + // used as a value in the file transfer context and env var for exec session + FileTransferRequestID ContextKey = "FILE_TRANSFER_REQUEST_ID" + + // ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session + // to check for valid FileTransferRequests + // used as a value in the file transfer context and env var for exec session + ModeratedSessionID ContextKey = "MODERATED_SESSION_ID" +) + +type ContextKey string diff --git a/lib/web/files.go b/lib/web/files.go index 44b38110a5a16..1ff96aa1ead79 100644 --- a/lib/web/files.go +++ b/lib/web/files.go @@ -17,6 +17,7 @@ limitations under the License. package web import ( + "context" "encoding/json" "net/http" "time" @@ -51,18 +52,30 @@ type fileTransferRequest struct { filename string // webauthn is an optional parameter that contains a webauthn response string used to issue single use certs webauthn string + // fileTransferRequestID is used to find a FileTransferRequest on a session + fileTransferRequestID string + // moderatedSessonID is an ID of a moderated session that has completed a + // file transfer request approval process + moderatedSessionID string } func (h *Handler) transferFile(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnel.RemoteSite) (interface{}, error) { query := r.URL.Query() req := fileTransferRequest{ - cluster: site.GetName(), - login: p.ByName("login"), - serverID: p.ByName("server"), - remoteLocation: query.Get("location"), - filename: query.Get("filename"), - namespace: defaults.Namespace, - webauthn: query.Get("webauthn"), + cluster: site.GetName(), + login: p.ByName("login"), + serverID: p.ByName("server"), + remoteLocation: query.Get("location"), + filename: query.Get("filename"), + namespace: defaults.Namespace, + webauthn: query.Get("webauthn"), + fileTransferRequestID: query.Get("file_transfer_request_id"), + moderatedSessionID: query.Get("moderated_session_id"), + } + + // Send an error if only one of these params has been sent. Both should exist or not exist together + if (req.fileTransferRequestID != "") != (req.moderatedSessionID != "") { + return nil, trace.BadParameter("file_transfer_request_id and moderated_session_id must both be included in the same request.") } clt, err := sctx.GetUserClient(r.Context(), site) @@ -116,6 +129,7 @@ type fileTransfer struct { } func (f *fileTransfer) download(req fileTransferRequest, httpReq *http.Request, w http.ResponseWriter) error { + ctx := httpReq.Context() cmd, err := scp.CreateHTTPDownload(scp.HTTPTransferRequest{ RemoteLocation: req.remoteLocation, HTTPResponse: w, @@ -137,7 +151,13 @@ func (f *fileTransfer) download(req fileTransferRequest, httpReq *http.Request, } } - err = tc.ExecuteSCP(httpReq.Context(), req.serverID, cmd) + if req.fileTransferRequestID != "" { + // These values should never exist independently of each other so we can set them at the same time + ctx = context.WithValue(ctx, scp.FileTransferRequestID, req.fileTransferRequestID) + ctx = context.WithValue(ctx, scp.ModeratedSessionID, req.moderatedSessionID) + } + + err = tc.ExecuteSCP(ctx, req.serverID, cmd) if err != nil { return trace.Wrap(err) } @@ -146,6 +166,7 @@ func (f *fileTransfer) download(req fileTransferRequest, httpReq *http.Request, } func (f *fileTransfer) upload(req fileTransferRequest, httpReq *http.Request) error { + ctx := httpReq.Context() cmd, err := scp.CreateHTTPUpload(scp.HTTPTransferRequest{ RemoteLocation: req.remoteLocation, FileName: req.filename, @@ -168,7 +189,13 @@ func (f *fileTransfer) upload(req fileTransferRequest, httpReq *http.Request) er } } - err = tc.ExecuteSCP(httpReq.Context(), req.serverID, cmd) + if req.fileTransferRequestID != "" { + // These values should never exist independently of each other so we can set them at the same time + ctx = context.WithValue(ctx, scp.FileTransferRequestID, req.fileTransferRequestID) + ctx = context.WithValue(ctx, scp.ModeratedSessionID, req.moderatedSessionID) + } + + err = tc.ExecuteSCP(ctx, req.serverID, cmd) if err != nil { return trace.Wrap(err) }