From d3ee25a152fdbad6cb145812271b8e8b2fa7d157 Mon Sep 17 00:00:00 2001 From: rosstimothy <39066650+rosstimothy@users.noreply.github.com> Date: Fri, 11 Nov 2022 17:10:43 -0500 Subject: [PATCH] Allow non-moderated sessions during outage (#17309) In the event that a SessionTracker wasn't able to be persisted in the backend the session was aborted and the user received an error about being unable to upsert session trackers. While this is required in the event that the session being established is a moderated session, other sessions should be allowed to be created without a SessionTracker resource existing in the backend. To allow non-moderated sessions to be established in the event there was a failure persisting the resource we can proceed in the same manner as when the recording mode does not match the component. This ensures that a local SessionTracker resource is created, but any operations performed on it are not attempted to be sent to the Auth server. The session `log` was also updated to include a `session_id` field so that it doesn't need to be added to every individual log message. Closes #17024 and #17026 --- lib/srv/sess.go | 97 ++++++++++++++++------- lib/srv/sess_test.go | 161 +++++++++++++++++++++++++++++++++++++- lib/srv/sessiontracker.go | 97 +++++++++++++---------- 3 files changed, 282 insertions(+), 73 deletions(-) diff --git a/lib/srv/sess.go b/lib/srv/sess.go index 8b64088214068..d962b4c17c635 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -347,6 +347,15 @@ func (s *SessionRegistry) broadcastResult(sid rsession.ID, r ExecResult) error { return nil } +// SessionAccessEvaluator is the interface that defines criteria needed to be met +// in order to start and join sessions. +type SessionAccessEvaluator interface { + IsModerated() bool + FulfilledFor(participants []auth.SessionAccessContext) (bool, auth.PolicyOptions, error) + PrettyRequirementsList() string + CanJoin(user auth.SessionAccessContext) []types.SessionParticipantMode +} + // session struct describes an active (in progress) SSH session. These sessions // are managed by 'SessionRegistry' containers which are attached to SSH servers. type session struct { @@ -394,7 +403,7 @@ type session struct { // serverCtx is used to control clean up of internal resources serverCtx context.Context - access auth.SessionAccessEvaluator + access SessionAccessEvaluator tracker *SessionTracker @@ -482,9 +491,11 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se policySets = append(policySets, &policySet) } + access := auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind) sess := &session{ log: log.WithFields(log.Fields{ trace.Component: teleport.Component(teleport.ComponentSession, r.Srv.Component()), + "session_id": id, }), id: id, registry: r, @@ -494,7 +505,7 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se stopC: make(chan struct{}), startTime: startTime, serverCtx: scx.srv.Context(), - access: auth.NewSessionAccessEvaluator(policySets, types.SSHSessionKind), + access: &access, scx: scx, presenceEnabled: scx.Identity.Certificate.Extensions[teleport.CertExtensionMFAVerified] != "", io: NewTermManager(), @@ -521,7 +532,7 @@ func newSession(ctx context.Context, id rsession.ID, r *SessionRegistry, scx *Se } }() - if err = sess.trackSession(scx.Identity.TeleportUser, policySets); err != nil { + if err = sess.trackSession(ctx, scx.Identity.TeleportUser, policySets); err != nil { if trace.IsNotImplemented(err) { return nil, trace.NotImplemented("Attempted to use Moderated Sessions with an Auth Server below the minimum version of 9.0.0.") } @@ -572,7 +583,7 @@ func (s *session) Stop() { } s.BroadcastMessage("Stopping session...") - s.log.Infof("Stopping session %v.", s.id) + s.log.Info("Stopping session") // close io copy loops s.io.Close() @@ -601,7 +612,7 @@ func (s *session) Close() error { s.Stop() s.BroadcastMessage("Closing session...") - s.log.Infof("Closing session %v.", s.id) + s.log.Infof("Closing session") serverSessions.Dec() @@ -825,7 +836,7 @@ func (s *session) setHasEnhancedRecording(val bool) { // launch launches the session. // Must be called under session Lock. func (s *session) launch(ctx *ServerContext) error { - s.log.Debugf("Launching session %v.", s.id) + s.log.Debug("Launching session") s.BroadcastMessage("Connecting to %v over SSH", s.serverMeta.ServerHostname) s.io.On() @@ -923,7 +934,7 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser } if cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext); err != nil { - scx.Errorf("Failed to open enhanced recording (interactive) session: %v: %v.", s.id, err) + s.log.WithError(err).Error("Failed to open enhanced recording (interactive) session") return trace.Wrap(err) } else if cgroupID > 0 { // If a cgroup ID was assigned then enhanced session recording was enabled. @@ -935,17 +946,17 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser scx.srv.GetRestrictedSessionManager().CloseSession(sessionContext, cgroupID) err = scx.srv.GetBPF().CloseSession(sessionContext) if err != nil { - scx.Errorf("Failed to close enhanced recording (interactive) session: %v: %v.", s.id, err) + s.log.WithError(err).Error("Failed to close enhanced recording (interactive) session") } }() } - scx.Debug("Waiting for continue signal") + s.log.Debug("Waiting for continue signal") // Process has been placed in a cgroup, continue execution. s.term.Continue() - scx.Debug("Got continue signal") + s.log.Debug("Got continue signal") // Start a heartbeat that marks this session as active with current members // of party in the backend. @@ -957,7 +968,7 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser go func() { result, err := s.term.Wait() if err != nil { - scx.Errorf("Received error waiting for the interactive session %v to finish: %v.", s.id, err) + s.log.WithError(err).Error("Received error waiting for the interactive session to finish") } // wait for copying from the pty to be complete or a timeout before @@ -965,7 +976,7 @@ func (s *session) startInteractive(ctx context.Context, ch ssh.Channel, scx *Ser // closed already. select { case <-time.After(defaults.WaitCopyTimeout): - s.log.Errorf("Timed out waiting for PTY copy to finish, session data for %v may be missing.", s.id) + s.log.Error("Timed out waiting for PTY copy to finish, session data may be missing.") case <-s.doneCh: } @@ -996,12 +1007,12 @@ func (s *session) startTerminal(ctx context.Context, scx *ServerContext) error { if s.term = scx.GetTerm(); s.term != nil { scx.SetTerm(nil) } else if s.term, err = NewTerminal(scx); err != nil { - scx.Infof("Unable to allocate new terminal: %v", err) + s.log.Infof("Unable to allocate new terminal: %v", err) return trace.Wrap(err) } if err := s.term.Run(ctx); err != nil { - scx.Errorf("Unable to run shell command: %v.", err) + s.log.Errorf("Unable to run shell command: %v.", err) return trace.ConvertSystemError(err) } @@ -1058,7 +1069,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve return trace.Wrap(err) } if result != nil { - scx.Debugf("Exec request (%v) result: %v.", execRequest, result) + s.log.Debugf("Exec request (%v) result: %v.", execRequest, result) scx.SendExecResult(*result) } @@ -1077,7 +1088,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve } cgroupID, err := scx.srv.GetBPF().OpenSession(sessionContext) if err != nil { - scx.Errorf("Failed to open enhanced recording (exec) session: %v: %v.", execRequest.GetCommand(), err) + s.log.WithError(err).Errorf("Failed to open enhanced recording (exec) session: %v", execRequest.GetCommand()) return trace.Wrap(err) } @@ -1107,7 +1118,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve // or running in a recording proxy, this is simply a NOP. err = scx.srv.GetBPF().CloseSession(sessionContext) if err != nil { - scx.Errorf("Failed to close enhanced recording (exec) session: %v: %v.", s.id, err) + s.log.WithError(err).Error("Failed to close enhanced recording (exec) session") } s.emitSessionEndEvent() @@ -1124,7 +1135,7 @@ func (s *session) startExec(ctx context.Context, channel ssh.Channel, scx *Serve func (s *session) newStreamer(ctx *ServerContext) (events.Streamer, error) { mode := ctx.SessionRecordingConfig.GetMode() if services.IsRecordSync(mode) { - s.log.Debugf("Using sync streamer for session %v.", s.id) + s.log.Debug("Using sync streamer for session") return ctx.srv, nil } @@ -1133,7 +1144,7 @@ func (s *session) newStreamer(ctx *ServerContext) (events.Streamer, error) { return events.NewDiscardEmitter(), nil } - s.log.Debugf("Using async streamer for session %v.", s.id) + s.log.Debug("Using async streamer for session") fileStreamer, err := filesessions.NewStreamer(sessionsStreamingUploadDir(ctx)) if err != nil { return nil, trace.Wrap(err) @@ -1169,7 +1180,7 @@ func (s *session) String() string { // if the party is the last in the session or has policies that dictate it to end. // Must be called under session Lock. func (s *session) removePartyUnderLock(p *party) error { - s.log.Infof("Removing party %v from session %v", p, s.id) + s.log.Infof("Removing party %v from session", p) // Remove participant from in-memory map of party members. delete(s.parties, p.id) @@ -1241,11 +1252,11 @@ func (s *session) isStopped() bool { // lingerAndDie will let the party-less session linger for a short // duration, and then die if no parties have joined. func (s *session) lingerAndDie(ctx context.Context, party *party) { - s.log.Debugf("Session %v has no active party members.", s.id) + s.log.Debug("Session has no active party members.") select { case <-s.registry.clock.After(defaults.SessionIdlePeriod): - s.log.Infof("Session %v will be garbage collected.", s.id) + s.log.Info("Session will be garbage collected.") // set closing context to the leaving party to show who ended the session. s.setEndingContext(party.ctx) @@ -1254,7 +1265,7 @@ func (s *session) lingerAndDie(ctx context.Context, party *party) { // complete cleanup and close the session. s.Stop() case <-ctx.Done(): - s.log.Infof("Session %v has become active again.", s.id) + s.log.Info("Session has become active again.") return case <-s.stopC: return @@ -1442,7 +1453,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { s.io.AddWriter(string(p.id), p) s.BroadcastMessage("User %v joined the session.", p.user) - s.log.Infof("New party %v joined session: %v", p.String(), s.id) + s.log.Infof("New party %v joined session", p.String()) if mode == types.SessionPeerMode { s.term.AddParty(1) @@ -1463,7 +1474,7 @@ func (s *session) addParty(p *party, mode types.SessionParticipantMode) error { if canStart { if err := s.launch(s.scx); err != nil { - s.log.Errorf("Failed to launch session %v: %v", s.id, err) + s.log.WithError(err).Error("Failed to launch session") } return nil } @@ -1605,7 +1616,7 @@ func (p *party) closeUnderSessionLock() { // trackSession creates a new session tracker for the ssh session. // While ctx is open, the session tracker's expiration will be extended // on an interval until the session tracker is closed. -func (s *session) trackSession(teleportUser string, policySet []*types.SessionTrackerPolicySet) error { +func (s *session) trackSession(ctx context.Context, teleportUser string, policySet []*types.SessionTrackerPolicySet) error { trackerSpec := types.SessionTrackerSpecV1{ SessionID: s.id.String(), Kind: string(types.SSHSessionKind), @@ -1631,11 +1642,37 @@ func (s *session) trackSession(teleportUser string, policySet []*types.SessionTr } } - s.log.Debug("Creating session tracker") - var err error - s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, s.registry.SessionTrackerService) - if err != nil { + svc := s.registry.SessionTrackerService + // only propagate the session tracker when the recording mode and component are in sync + if (s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) || + (s.registry.Srv.Component() == teleport.ComponentProxy && !services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode())) { + svc = nil + } + + s.log.Debug("Attempting to create session tracker") + tracker, err := NewSessionTracker(ctx, trackerSpec, svc) + switch { + // there was an error creating the tracker for a moderated session - terminate the session + case err != nil && svc != nil && s.access.IsModerated(): + s.log.WithError(err).Warn("Failed to create session tracker, unable to proceed for moderated session") + return trace.Wrap(err) + // there was an error creating the tracker for a non-moderated session - permit the session with a local tracker + case err != nil && svc != nil && !s.access.IsModerated(): + s.log.Warn("Failed to create session tracker, proceeding with local session tracker for non-moderated session") + + localTracker, err := NewSessionTracker(ctx, trackerSpec, nil) + // this error means there are problems with the trackerSpec, we need to return it + if err != nil { + return trace.Wrap(err) + } + + s.tracker = localTracker + // there was an error even though the tracker wasn't being propagated - return it + case err != nil && svc == nil: return trace.Wrap(err) + // the tracker was created successfully + case err == nil: + s.tracker = tracker } go func() { diff --git a/lib/srv/sess_test.go b/lib/srv/sess_test.go index 0632d2641a03b..b24fc57acc603 100644 --- a/lib/srv/sess_test.go +++ b/lib/srv/sess_test.go @@ -19,16 +19,22 @@ package srv import ( "context" "io" + "os/user" + "sync/atomic" "testing" "time" + "golang.org/x/crypto/ssh" + "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/services" - "golang.org/x/crypto/ssh" + rsession "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -425,3 +431,156 @@ func testOpenSession(t *testing.T, reg *SessionRegistry, roleSet services.RoleSe require.NotNil(t, scx.session) return scx.session, sshChanOpen } + +type trackerService struct { + created int32 + createError error + services.SessionTrackerService +} + +func (t *trackerService) CreatedCount() int { + return int(atomic.LoadInt32(&t.created)) +} + +func (t *trackerService) CreateSessionTracker(ctx context.Context, tracker types.SessionTracker) (types.SessionTracker, error) { + atomic.AddInt32(&t.created, 1) + + if t.createError != nil { + return nil, t.createError + } + + return t.SessionTrackerService.CreateSessionTracker(ctx, tracker) +} + +type sessionEvaluator struct { + moderated bool + SessionAccessEvaluator +} + +func (s sessionEvaluator) IsModerated() bool { + return s.moderated +} + +func TestTrackingSession(t *testing.T) { + t.Parallel() + ctx := context.Background() + + me, err := user.Current() + require.NoError(t, err) + + cases := []struct { + name string + component string + recordingMode string + createError error + moderated bool + assertion require.ErrorAssertionFunc + createAssertion func(t *testing.T, count int) + }{ + { + name: "node with proxy recording mode", + component: teleport.ComponentNode, + recordingMode: types.RecordAtProxy, + assertion: require.NoError, + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 0, count) + }, + }, + { + name: "node with node recording mode", + component: teleport.ComponentNode, + recordingMode: types.RecordAtNode, + assertion: require.NoError, + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 1, count) + }, + }, + { + name: "proxy with proxy recording mode", + component: teleport.ComponentProxy, + recordingMode: types.RecordAtProxy, + assertion: require.NoError, + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 1, count) + }, + }, + { + name: "proxy with node recording mode", + component: teleport.ComponentProxy, + recordingMode: types.RecordAtNode, + assertion: require.NoError, + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 0, count) + }, + }, + { + name: "auth outage for non moderated session", + component: teleport.ComponentNode, + recordingMode: types.RecordAtNodeSync, + assertion: require.NoError, + createError: trace.ConnectionProblem(context.DeadlineExceeded, ""), + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 1, count) + }, + }, + { + name: "auth outage for moderated session", + component: teleport.ComponentNode, + recordingMode: types.RecordAtNodeSync, + moderated: true, + assertion: require.Error, + createError: trace.ConnectionProblem(context.DeadlineExceeded, ""), + createAssertion: func(t *testing.T, count int) { + require.Equal(t, 1, count) + }, + }, + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + srv := newMockServer(t) + srv.component = tt.component + + trackingService := &trackerService{ + SessionTrackerService: &mockSessiontrackerService{ + trackers: make(map[string]types.SessionTracker), + }, + createError: tt.createError, + } + + scx := newTestServerContext(t, srv, nil) + scx.SessionRecordingConfig = &types.SessionRecordingConfigV2{ + Kind: types.KindSessionRecordingConfig, + Version: types.V2, + Spec: types.SessionRecordingConfigSpecV2{ + Mode: tt.recordingMode, + }, + } + + sess := &session{ + id: rsession.NewID(), + log: utils.NewLoggerForTests().WithField(trace.Component, "test-session"), + registry: &SessionRegistry{ + SessionRegistryConfig: SessionRegistryConfig{ + Srv: srv, + SessionTrackerService: trackingService, + clock: clockwork.NewFakeClock(), //use a fake clock to prevent the update loop from running + }, + }, + serverMeta: apievents.ServerMetadata{ + ServerHostname: "test", + ServerID: "123", + }, + scx: scx, + serverCtx: ctx, + login: me.Name, + access: sessionEvaluator{moderated: tt.moderated}, + } + + err = sess.trackSession(ctx, me.Name, nil) + tt.assertion(t, err) + tt.createAssertion(t, trackingService.CreatedCount()) + }) + } + +} diff --git a/lib/srv/sessiontracker.go b/lib/srv/sessiontracker.go index 95dbedbb4aefe..531d02d522403 100644 --- a/lib/srv/sessiontracker.go +++ b/lib/srv/sessiontracker.go @@ -21,13 +21,13 @@ import ( "sync" "time" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/gravitational/teleport/api/client/proto" apidefaults "github.com/gravitational/teleport/api/defaults" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/services" - "github.com/jonboulle/clockwork" - - "github.com/gravitational/trace" ) // SessionTracker is a session tracker for a specific session. It tracks @@ -45,17 +45,15 @@ type SessionTracker struct { // NewSessionTracker returns a new SessionTracker for the given types.SessionTracker func NewSessionTracker(ctx context.Context, trackerSpec types.SessionTrackerSpecV1, service services.SessionTrackerService) (*SessionTracker, error) { - if service == nil { - return nil, trace.BadParameter("missing parameter service") - } - t, err := types.NewSessionTracker(trackerSpec) if err != nil { return nil, trace.Wrap(err) } - if t, err = service.CreateSessionTracker(ctx, t); err != nil { - return nil, trace.Wrap(err) + if service != nil { + if t, err = service.CreateSessionTracker(ctx, t); err != nil { + return nil, trace.Wrap(err) + } } return &SessionTracker{ @@ -93,7 +91,7 @@ func (s *SessionTracker) updateExpirationLoop(ctx context.Context, ticker clockw return trace.Wrap(err) } case <-ctx.Done(): - return trace.Wrap(ctx.Err()) + return nil case <-s.closeC: return nil } @@ -106,15 +104,18 @@ func (s *SessionTracker) UpdateExpiration(ctx context.Context, expiry time.Time) s.tracker.SetExpiry(expiry) s.trackerCond.Broadcast() - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_UpdateExpiry{ - UpdateExpiry: &proto.SessionTrackerUpdateExpiry{ - Expires: &expiry, + if s.service != nil { + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_UpdateExpiry{ + UpdateExpiry: &proto.SessionTrackerUpdateExpiry{ + Expires: &expiry, + }, }, - }, - }) - return trace.Wrap(err) + }) + return trace.Wrap(err) + } + return nil } func (s *SessionTracker) AddParticipant(ctx context.Context, p *types.Participant) error { @@ -123,15 +124,19 @@ func (s *SessionTracker) AddParticipant(ctx context.Context, p *types.Participan s.tracker.AddParticipant(*p) s.trackerCond.Broadcast() - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_AddParticipant{ - AddParticipant: &proto.SessionTrackerAddParticipant{ - Participant: p, + if s.service != nil { + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_AddParticipant{ + AddParticipant: &proto.SessionTrackerAddParticipant{ + Participant: p, + }, }, - }, - }) - return trace.Wrap(err) + }) + return trace.Wrap(err) + } + + return nil } func (s *SessionTracker) RemoveParticipant(ctx context.Context, participantID string) error { @@ -140,15 +145,19 @@ func (s *SessionTracker) RemoveParticipant(ctx context.Context, participantID st s.tracker.RemoveParticipant(participantID) s.trackerCond.Broadcast() - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_RemoveParticipant{ - RemoveParticipant: &proto.SessionTrackerRemoveParticipant{ - ParticipantID: participantID, + if s.service != nil { + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_RemoveParticipant{ + RemoveParticipant: &proto.SessionTrackerRemoveParticipant{ + ParticipantID: participantID, + }, }, - }, - }) - return trace.Wrap(err) + }) + return trace.Wrap(err) + } + + return nil } func (s *SessionTracker) UpdateState(ctx context.Context, state types.SessionState) error { @@ -157,15 +166,19 @@ func (s *SessionTracker) UpdateState(ctx context.Context, state types.SessionSta s.tracker.SetState(state) s.trackerCond.Broadcast() - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_UpdateState{ - UpdateState: &proto.SessionTrackerUpdateState{ - State: state, + if s.service != nil { + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_UpdateState{ + UpdateState: &proto.SessionTrackerUpdateState{ + State: state, + }, }, - }, - }) - return trace.Wrap(err) + }) + return trace.Wrap(err) + } + + return nil } // WaitForStateUpdate waits for the tracker's state to be updated and returns the new state.