Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,15 @@ func (c *NodeClient) ExecuteSCP(ctx context.Context, cmd scp.Command) error {
}
defer s.Close()

// File transfers in a moderated session require these two variablesto check for
// approval on the ssh server. If they exist in the context, set them in our env vars
if moderatedSessionID, ok := ctx.Value(scp.ModeratedSessionID).(string); ok {
s.Setenv(ctx, string(scp.ModeratedSessionID), moderatedSessionID)
}
if fileTransferRequestID, ok := ctx.Value(scp.FileTransferRequestID).(string); ok {
s.Setenv(ctx, string(scp.FileTransferRequestID), fileTransferRequestID)
}

stdin, err := s.StdinPipe()
if err != nil {
return trace.Wrap(err)
Expand Down
88 changes: 87 additions & 1 deletion lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/sshutils/scp"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -287,12 +288,19 @@ func (s *SessionRegistry) OpenExecSession(ctx context.Context, channel ssh.Chann
}
scx.Infof("Creating (exec) session %v.", sessionID)

approved, err := s.isApprovedFileTransfer(scx)
if err != nil {
return trace.Wrap(err)
}

canStart, _, err := sess.checkIfStart()
if err != nil {
return trace.Wrap(err)
}

if !canStart {
// canStart will be true for non-moderated sessions. If canStart is false, check to
// see if the request has been approved through a moderated session next.
if !canStart && !approved {
return errCannotStartUnattendedSession
}

Expand Down Expand Up @@ -338,6 +346,47 @@ func (s *SessionRegistry) GetTerminalSize(sessionID string) (*term.Winsize, erro
return sess.term.GetWinSize()
}

func (s *SessionRegistry) isApprovedFileTransfer(scx *ServerContext) (bool, error) {
s.sessionsMux.Lock()
defer s.sessionsMux.Unlock()

// if a sessID and requestID environment variables were not set, return not approved and no error.
// This means the file transfer came from a non-moderated session. sessionID will be passed after a
// moderated session approval process has completed.
sessID, _ := scx.GetEnv(string(scp.ModeratedSessionID))
if sessID == "" {
return false, nil
}
// fetch session from registry with sessionID
sess := s.sessions[rsession.ID(sessID)]
if sess == nil {
// If they sent a sessionID and it wasn't found, send an actual error
return false, trace.NotFound("Session not found")
}

requestID, _ := scx.GetEnv(string(scp.FileTransferRequestID))
if requestID == "" {
return false, nil
}
// find file transfer request in the session by requestID
req := sess.fileTransferRequests[requestID]
if req == nil {
// If they sent a fileTransferRequestID and it wasn't found, send an actual error
return false, trace.NotFound("File transfer request not found")
}

if req.requester != scx.Identity.TeleportUser {
return false, trace.AccessDenied("Teleport user does not match original requester")
}

incomingShellCmd := string(scx.sshRequest.Payload)
if incomingShellCmd != req.shellCmd {
return false, trace.AccessDenied("Incoming request does not match the approved request")
}

return sess.checkIfFileTransferApproved(req)
}

// NotifyWinChange is called to notify all members in the party that the PTY
// size has changed. The notification is sent as a global SSH request and it
// is the responsibility of the client to update it's window size upon receipt.
Expand Down Expand Up @@ -454,6 +503,11 @@ type session struct {
// participants at the end of a session.
participants map[rsession.ID]*party

// fileTransferRequests is a set of fileTransferRequests that are currently in the approval
// process, or already approved and not yet executed during a moderated session. If a request is
// denied or, once it's been executed, it should be removed from this map.
fileTransferRequests map[string]*fileTransferRequest

io *TermManager
inWriter io.Writer

Expand Down Expand Up @@ -1405,6 +1459,38 @@ func (s *session) checkPresence() error {
return nil
}

type fileTransferRequest struct {
// requester is the Teleport User that requested the file transfer
requester string
// shellCmd is the requested scp command to run
shellCmd string
// approvers is a list of participants of moderator or peer type that have approved the request
approvers map[string]*party
}

func (s *session) checkIfFileTransferApproved(req *fileTransferRequest) (bool, error) {
var participants []auth.SessionAccessContext

for _, party := range req.approvers {
if party.ctx.Identity.TeleportUser == s.initiator {
continue
}

participants = append(participants, auth.SessionAccessContext{
Username: party.ctx.Identity.TeleportUser,
Roles: party.ctx.Identity.AccessChecker.Roles(),
Mode: party.mode,
})
}

isApproved, _, err := s.access.FulfilledFor(participants)
if err != nil {
return false, trace.Wrap(err)
}

return isApproved, nil
}

func (s *session) checkIfStart() (bool, auth.PolicyOptions, error) {
var participants []auth.SessionAccessContext

Expand Down
133 changes: 133 additions & 0 deletions lib/srv/sess_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import (
"github.com/gravitational/teleport/lib/modules"
"github.com/gravitational/teleport/lib/services"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils/scp"
"github.com/gravitational/teleport/lib/utils"
)

Expand Down Expand Up @@ -87,6 +88,138 @@ func TestParseAccessRequestIDs(t *testing.T) {
}
}

func TestIsApprovedFileTransfer(t *testing.T) {
// set enterprise for tests
modules.SetTestModules(t, &modules.TestModules{TestBuildType: modules.BuildEnterprise})
srv := newMockServer(t)
srv.component = teleport.ComponentNode

// init a session registry
reg, _ := NewSessionRegistry(SessionRegistryConfig{
Srv: srv,
SessionTrackerService: srv.auth,
})
t.Cleanup(func() { reg.Close() })

// Create the auditorRole and moderator Party
auditorRole, _ := types.NewRole("auditor", types.RoleSpecV6{
Allow: types.RoleConditions{
JoinSessions: []*types.SessionJoinPolicy{{
Name: "foo",
Roles: []string{"access"},
Kinds: []string{string(types.SSHSessionKind)},
Modes: []string{string(types.SessionModeratorMode)},
}},
},
})
auditorRoleSet := services.NewRoleSet(auditorRole)
auditScx := newTestServerContext(t, reg.Srv, auditorRoleSet)
// change the teleport user so we dont match the user in the test cases
auditScx.Identity.TeleportUser = "mod"
auditSess, _ := testOpenSession(t, reg, auditorRoleSet)
approvers := make(map[string]*party)
auditChan := newMockSSHChannel()
approvers["mod"] = newParty(auditSess, types.SessionModeratorMode, auditChan, auditScx)

// create the accessRole to be used for the requester
accessRole, _ := types.NewRole("access", types.RoleSpecV6{
Allow: types.RoleConditions{
RequireSessionJoin: []*types.SessionRequirePolicy{{
Name: "foo",
Filter: "contains(user.roles, \"auditor\")", // escape to avoid illegal rune
Kinds: []string{string(types.SSHSessionKind)},
Modes: []string{string(types.SessionModeratorMode)},
Count: 1,
}},
},
})
accessRoleSet := services.NewRoleSet(accessRole)

cases := []struct {
name string
expectedResult bool
expectedError string
req *fileTransferRequest
reqID string
}{

{
name: "no file request found with supplied ID",
expectedResult: false,
expectedError: "",
reqID: "",
req: nil,
},
{
name: "no file request found with supplied ID",
expectedResult: false,
expectedError: "File transfer request not found",
reqID: "111",
req: nil,
},
{
name: "current requester does not match original requester",
expectedResult: false,
expectedError: "Teleport user does not match original requester",
reqID: "123",
req: &fileTransferRequest{
requester: "michael",
shellCmd: "/usr/bin/scp -f ~/logs.txt",
approvers: make(map[string]*party),
},
},
{
name: "current payload does not match original payload",
expectedResult: false,
expectedError: "Incoming request does not match the approved request",
reqID: "123",
req: &fileTransferRequest{
requester: "teleportUser",
shellCmd: "badcommand",
approvers: make(map[string]*party),
},
},
{
name: "approved request",
expectedResult: true,
expectedError: "",
reqID: "123",
req: &fileTransferRequest{
requester: "teleportUser",
shellCmd: "/usr/bin/scp -f ~/logs.txt",
approvers: approvers,
},
},
}

for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
// create and add a session to the registry
sess, _ := testOpenSession(t, reg, accessRoleSet)

// create a fileTransferRequest. can be nil
sess.fileTransferRequests = map[string]*fileTransferRequest{
"123": tt.req,
}

// new exec request context
scx := newTestServerContext(t, reg.Srv, accessRoleSet)
scx.sshRequest = &ssh.Request{
Payload: []byte("/usr/bin/scp -f ~/logs.txt"),
}

scx.SetEnv(string(scp.ModeratedSessionID), sess.ID())
scx.SetEnv(string(scp.FileTransferRequestID), tt.reqID)
result, err := reg.isApprovedFileTransfer(scx)
if err != nil {
require.Equal(t, tt.expectedError, err.Error())
}

require.Equal(t, tt.expectedResult, result)
})
}
}

func TestSession_newRecorder(t *testing.T) {
t.Parallel()

Expand Down
19 changes: 19 additions & 0 deletions lib/sshutils/scp/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ type HTTPTransferRequest struct {
User string
// AuditLog is AuditLog log
AuditLog events.AuditLogSessionStreamer
// FileTransferRequestID is used to find a FileTransferRequest on a session
FileTransferRequestID string
// ModeratedSessonID is an ID of a moderated session that has completed a
// file transfer request approval process
ModeratedSessionID string
}

func (r *HTTPTransferRequest) parseRemoteLocation() (string, string, error) {
Expand Down Expand Up @@ -269,3 +274,17 @@ type nopWriteCloser struct {
func (wr *nopWriteCloser) Close() error {
return nil
}

const (
// FileTransferRequestID is an optional parameter id of an file transfer request that has gone through
// an approval process during a moderated session to allow a file transfer scp command to be executed
// used as a value in the file transfer context and env var for exec session
FileTransferRequestID ContextKey = "FILE_TRANSFER_REQUEST_ID"

// ModeratedSessionID is an optional parameter sent during SCP requests to specify which moderated session
// to check for valid FileTransferRequests
// used as a value in the file transfer context and env var for exec session
ModeratedSessionID ContextKey = "MODERATED_SESSION_ID"
)

type ContextKey string
Loading