Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 75 additions & 39 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ import (
"sync"
"time"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/moby/term"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"

"github.com/gravitational/teleport"
"github.com/gravitational/teleport/api/constants"
"github.com/gravitational/teleport/api/types"
Expand All @@ -39,14 +47,6 @@ import (
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"

"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/moby/term"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)

const sessionRecorderID = "session-recorder"
Expand Down Expand Up @@ -414,6 +414,15 @@ func (s *SessionRegistry) broadcastResult(sid rsession.ID, r ExecResult) error {
return nil
}

// SessionAccessEvaluator is the interface that defines criteria needed to be met
// in order to start and join sessions.
type SessionAccessEvaluator interface {
IsModerated() bool
FulfilledFor(participants []auth.SessionAccessContext) (bool, auth.PolicyOptions, error)
PrettyRequirementsList() string
CanJoin(user auth.SessionAccessContext) []types.SessionParticipantMode
}

// session struct describes an active (in progress) SSH session. These sessions
// are managed by 'SessionRegistry' containers which are attached to SSH servers.
type session struct {
Expand Down Expand Up @@ -462,7 +471,7 @@ type session struct {
// serverCtx is used to control clean up of internal resources
serverCtx context.Context

access auth.SessionAccessEvaluator
access SessionAccessEvaluator

tracker *SessionTracker

Expand Down Expand Up @@ -545,10 +554,11 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se
}

policySets := scx.Identity.AccessChecker.SessionPolicySets()

access := auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind, scx.Identity.TeleportUser)
sess := &session{
log: log.WithFields(log.Fields{
trace.Component: teleport.Component(teleport.ComponentSession, r.Srv.Component()),
"session_id": id,
}),
id: id,
registry: r,
Expand All @@ -558,7 +568,7 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se
stopC: make(chan struct{}),
startTime: startTime,
serverCtx: scx.srv.Context(),
access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind, scx.Identity.TeleportUser),
access: &access,
scx: scx,
presenceEnabled: scx.Identity.Certificate.Extensions[teleport.CertExtensionMFAVerified] != "",
io: NewTermManager(),
Expand All @@ -579,7 +589,7 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se
}
}()

if err = sess.trackSession(scx.Identity.TeleportUser, policySets); err != nil {
if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets); err != nil {
if trace.IsNotImplemented(err) {
return nil, trace.NotImplemented("Attempted to use Moderated Sessions with an Auth Server below the minimum version of 9.0.0.")
}
Expand Down Expand Up @@ -636,7 +646,7 @@ func (s *session) Stop() {
}

s.BroadcastMessage("Stopping session...")
s.log.Infof("Stopping session %v.", s.id)
s.log.Info("Stopping session")

// close io copy loops
s.io.Close()
Expand Down Expand Up @@ -665,7 +675,7 @@ func (s *session) Close() error {
s.Stop()

s.BroadcastMessage("Closing session...")
s.log.Infof("Closing session %v.", s.id)
s.log.Infof("Closing session")

serverSessions.Dec()

Expand Down Expand Up @@ -894,7 +904,7 @@ func (s *session) setHasEnhancedRecording(val bool) {
// launch launches the session.
// Must be called under session Lock.
func (s *session) launch(ctx *ServerContext) error {
s.log.Debugf("Launching session %v.", s.id)
s.log.Debug("Launching session")
s.BroadcastMessage("Connecting to %v over SSH", s.serverMeta.ServerHostname)

s.io.On()
Expand Down Expand Up @@ -992,7 +1002,7 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser
}

if cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext); err != nil {
scx.Errorf("Failed to open enhanced recording (interactive) session: %v: %v.", s.id, err)
s.log.WithError(err).Error("Failed to open enhanced recording (interactive) session")
return trace.Wrap(err)
} else if cgroupID > 0 {
// If a cgroup ID was assigned then enhanced session recording was enabled.
Expand All @@ -1004,17 +1014,17 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser
scx.srv.GetRestrictedSessionManager().CloseSession(sessionContext, cgroupID)
err = scx.srv.GetBPF().CloseSession(sessionContext)
if err != nil {
scx.Errorf("Failed to close enhanced recording (interactive) session: %v: %v.", s.id, err)
s.log.WithError(err).Error("Failed to close enhanced recording (interactive) session")
}
}()
}

scx.Debug("Waiting for continue signal")
s.log.Debug("Waiting for continue signal")

// Process has been placed in a cgroup, continue execution.
s.term.Continue()

scx.Debug("Got continue signal")
s.log.Debug("Got continue signal")

// Start a heartbeat that marks this session as active with current members
// of party in the backend.
Expand All @@ -1026,15 +1036,15 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser
go func() {
result, err := s.term.Wait()
if err != nil {
scx.Errorf("Received error waiting for the interactive session %v to finish: %v.", s.id, err)
s.log.WithError(err).Error("Received error waiting for the interactive session to finish")
}

// wait for copying from the pty to be complete or a timeout before
// broadcasting the result (which will close the pty) if it has not been
// closed already.
select {
case <-time.After(defaults.WaitCopyTimeout):
s.log.Errorf("Timed out waiting for PTY copy to finish, session data for %v may be missing.", s.id)
s.log.Error("Timed out waiting for PTY copy to finish, session data may be missing.")
case <-s.doneCh:
}

Expand Down Expand Up @@ -1065,12 +1075,12 @@ func (s *session) startTerminal(ctx context.Context, scx *ServerContext) error {
if s.term = scx.GetTerm(); s.term != nil {
scx.SetTerm(nil)
} else if s.term, err = NewTerminal(scx); err != nil {
scx.Infof("Unable to allocate new terminal: %v", err)
s.log.Infof("Unable to allocate new terminal: %v", err)
return trace.Wrap(err)
}

if err := s.term.Run(ctx); err != nil {
scx.Errorf("Unable to run shell command: %v.", err)
s.log.Errorf("Unable to run shell command: %v.", err)
return trace.ConvertSystemError(err)
}

Expand Down Expand Up @@ -1166,7 +1176,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
return trace.Wrap(err)
}
if result != nil {
scx.Debugf("Exec request (%v) result: %v.", execRequest, result)
s.log.Debugf("Exec request (%v) result: %v.", execRequest, result)
scx.SendExecResult(*result)
}

Expand All @@ -1185,7 +1195,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
}
cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext)
if err != nil {
scx.Errorf("Failed to open enhanced recording (exec) session: %v: %v.", execRequest.GetCommand(), err)
s.log.WithError(err).Errorf("Failed to open enhanced recording (exec) session: %v", execRequest.GetCommand())
return trace.Wrap(err)
}

Expand Down Expand Up @@ -1215,7 +1225,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
// or running in a recording proxy, this is simply a NOP.
err = scx.srv.GetBPF().CloseSession(sessionContext)
if err != nil {
scx.Errorf("Failed to close enhanced recording (exec) session: %v: %v.", s.id, err)
s.log.WithError(err).Error("Failed to close enhanced recording (exec) session")
}

s.emitSessionEndEvent()
Expand All @@ -1232,7 +1242,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve
func (s *session) newStreamer(ctx *ServerContext) (events.Streamer, error) {
mode := ctx.SessionRecordingConfig.GetMode()
if services.IsRecordSync(mode) {
s.log.Debugf("Using sync streamer for session %v.", s.id)
s.log.Debug("Using sync streamer for session")
return ctx.srv, nil
}

Expand All @@ -1241,7 +1251,7 @@ func (s *session) newStreamer(ctx *ServerContext) (events.Streamer, error) {
return events.NewTeeStreamer(events.NewDiscardEmitter(), ctx.srv), nil
}

s.log.Debugf("Using async streamer for session %v.", s.id)
s.log.Debug("Using async streamer for session")
fileStreamer, err := filesessions.NewStreamer(sessionsStreamingUploadDir(ctx))
if err != nil {
return nil, trace.Wrap(err)
Expand Down Expand Up @@ -1277,7 +1287,7 @@ func (s *session) String() string {
// if the party is the last in the session or has policies that dictate it to end.
// Must be called under session Lock.
func (s *session) removePartyUnderLock(p *party) error {
s.log.Infof("Removing party %v from session %v", p, s.id)
s.log.Infof("Removing party %v from session", p)

// Remove participant from in-memory map of party members.
delete(s.parties, p.id)
Expand Down Expand Up @@ -1349,11 +1359,11 @@ func (s *session) isStopped() bool {
// lingerAndDie will let the party-less session linger for a short
// duration, and then die if no parties have joined.
func (s *session) lingerAndDie(ctx context.Context, party *party) {
s.log.Debugf("Session %v has no active party members.", s.id)
s.log.Debug("Session has no active party members.")

select {
case <-s.registry.clock.After(defaults.SessionIdlePeriod):
s.log.Infof("Session %v will be garbage collected.", s.id)
s.log.Info("Session will be garbage collected.")

// set closing context to the leaving party to show who ended the session.
s.setEndingContext(party.ctx)
Expand All @@ -1362,7 +1372,7 @@ func (s *session) lingerAndDie(ctx context.Context, party *party) {
// complete cleanup and close the session.
s.Stop()
case <-ctx.Done():
s.log.Infof("Session %v has become active again.", s.id)
s.log.Info("Session has become active again.")
return
case <-s.stopC:
return
Expand Down Expand Up @@ -1544,7 +1554,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {
s.io.AddWriter(string(p.id), p)

s.BroadcastMessage("User %v joined the session.", p.user)
s.log.Infof("New party %v joined session: %v", p.String(), s.id)
s.log.Infof("New party %v joined session", p.String())

if mode == types.SessionPeerMode {
s.term.AddParty(1)
Expand All @@ -1565,7 +1575,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error {

if canStart {
if err := s.launch(s.scx); err != nil {
s.log.Errorf("Failed to launch session %v: %v", s.id, err)
s.log.WithError(err).Error("Failed to launch session")
}
return nil
}
Expand Down Expand Up @@ -1707,7 +1717,7 @@ func (p *party) closeUnderSessionLock() {
// trackSession creates a new session tracker for the ssh session.
// While ctx is open, the session tracker's expiration will be extended
// on an interval until the session tracker is closed.
func (s *session) trackSession(teleportUser string, policySet []*types.SessionTrackerPolicySet) error {
func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet) error {
trackerSpec := types.SessionTrackerSpecV1{
SessionID: s.id.String(),
Kind: string(types.SSHSessionKind),
Expand All @@ -1728,11 +1738,37 @@ func (s *session) trackSession(teleportUser string, policySet []*types.SessionTr
}
}

s.log.Debug("Creating session tracker")
var err error
s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, s.registry.SessionTrackerService)
if err != nil {
svc := s.registry.SessionTrackerService
// only propagate the session tracker when the recording mode and component are in sync
if (s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) ||
(s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) {
svc = nil
}

s.log.Debug("Attempting to create session tracker")
tracker, err := NewSessionTracker(ctx, trackerSpec, svc)
switch {
// there was an error creating the tracker for a moderated session - terminate the session
case err != nil && svc != nil && s.access.IsModerated():
s.log.WithError(err).Warn("Failed to create session tracker, unable to proceed for moderated session")
return trace.Wrap(err)
// there was an error creating the tracker for a non-moderated session - permit the session with a local tracker
case err != nil && svc != nil && !s.access.IsModerated():
s.log.Warn("Failed to create session tracker, proceeding with local session tracker for non-moderated session")

localTracker, err := NewSessionTracker(ctx, trackerSpec, nil)
// this error means there are problems with the trackerSpec, we need to return it
if err != nil {
return trace.Wrap(err)
}

s.tracker = localTracker
// there was an error even though the tracker wasn't being propagated - return it
case err != nil && svc == nil:
return trace.Wrap(err)
// the tracker was created successfully
case err == nil:
s.tracker = tracker
}

go func() {
Expand Down
Loading