diff --git a/lib/srv/sess.go b/lib/srv/sess.go index cceaec943503d..4ff86c7b1436d 100644 --- a/lib/srv/sess.go +++ b/lib/srv/sess.go @@ -31,11 +31,13 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/events/filesessions" "github.com/gravitational/teleport/lib/observability/metrics" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" rsession "github.com/gravitational/teleport/lib/session" "github.com/gravitational/teleport/lib/sshutils" "github.com/gravitational/teleport/lib/utils" @@ -1604,16 +1606,36 @@ func (s *session) trackSession(teleportUser string, policySet []*types.SessionTr } s.log.Debug("Creating session tracker") - var err error // if doing proxy recording, don't propagate tracker to the cluster level if s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode()) { - s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, nil) + bk, err := memory.New(memory.Config{}) + if err != nil { + return trace.Wrap(err) + } + trackerRemote, err := local.NewSessionTrackerService(bk) + if err != nil { + return trace.Wrap(err) + } + s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, trackerRemote) + if err != nil { + return trace.Wrap(err) + } + + go func() { + err := s.tracker.WaitOnState(s.serverCtx, types.SessionState_SessionStateTerminated) + if err != nil { + s.log.WithError(err).Error("Failed to wait on session state.") + } + + bk.Close() + }() } else { + var err error s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, s.registry.SessionTrackerService) - } - if err != nil { - return trace.Wrap(err) + if err != nil { + return trace.Wrap(err) + } } go func() { diff --git a/lib/srv/sessiontracker.go b/lib/srv/sessiontracker.go index 872b0f9ae1137..b8ce54bc33aa6 100644 --- a/lib/srv/sessiontracker.go +++ b/lib/srv/sessiontracker.go @@ -45,15 +45,17 @@ type SessionTracker struct { // NewSessionTracker returns a new SessionTracker for the given types.SessionTracker func NewSessionTracker(ctx context.Context, trackerSpec types.SessionTrackerSpecV1, service services.SessionTrackerService) (*SessionTracker, error) { + if service == nil { + return nil, trace.BadParameter("missing parameter service") + } + t, err := types.NewSessionTracker(trackerSpec) if err != nil { return nil, trace.Wrap(err) } - if service != nil { - if t, err = service.CreateSessionTracker(ctx, t); err != nil { - return nil, trace.Wrap(err) - } + if t, err = service.CreateSessionTracker(ctx, t); err != nil { + return nil, trace.Wrap(err) } return &SessionTracker{ @@ -104,20 +106,16 @@ func (s *SessionTracker) UpdateExpiration(ctx context.Context, expiry time.Time) s.tracker.SetExpiry(expiry) s.trackerCond.Broadcast() - if s.service != nil { - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_UpdateExpiry{ - UpdateExpiry: &proto.SessionTrackerUpdateExpiry{ - Expires: &expiry, - }, + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_UpdateExpiry{ + UpdateExpiry: &proto.SessionTrackerUpdateExpiry{ + Expires: &expiry, }, - }) + }, + }) - return trace.Wrap(err) - } - - return nil + return trace.Wrap(err) } func (s *SessionTracker) AddParticipant(ctx context.Context, p *types.Participant) error { @@ -126,20 +124,16 @@ func (s *SessionTracker) AddParticipant(ctx context.Context, p *types.Participan s.tracker.AddParticipant(*p) s.trackerCond.Broadcast() - if s.service != nil { - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_AddParticipant{ - AddParticipant: &proto.SessionTrackerAddParticipant{ - Participant: p, - }, + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_AddParticipant{ + AddParticipant: &proto.SessionTrackerAddParticipant{ + Participant: p, }, - }) - - return trace.Wrap(err) - } + }, + }) - return nil + return trace.Wrap(err) } func (s *SessionTracker) RemoveParticipant(ctx context.Context, participantID string) error { @@ -148,20 +142,16 @@ func (s *SessionTracker) RemoveParticipant(ctx context.Context, participantID st s.tracker.RemoveParticipant(participantID) s.trackerCond.Broadcast() - if s.service != nil { - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_RemoveParticipant{ - RemoveParticipant: &proto.SessionTrackerRemoveParticipant{ - ParticipantID: participantID, - }, + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_RemoveParticipant{ + RemoveParticipant: &proto.SessionTrackerRemoveParticipant{ + ParticipantID: participantID, }, - }) - - return trace.Wrap(err) - } + }, + }) - return nil + return trace.Wrap(err) } func (s *SessionTracker) UpdateState(ctx context.Context, state types.SessionState) error { @@ -170,20 +160,16 @@ func (s *SessionTracker) UpdateState(ctx context.Context, state types.SessionSta s.tracker.SetState(state) s.trackerCond.Broadcast() - if s.service != nil { - err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ - SessionID: s.tracker.GetSessionID(), - Update: &proto.UpdateSessionTrackerRequest_UpdateState{ - UpdateState: &proto.SessionTrackerUpdateState{ - State: state, - }, + err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{ + SessionID: s.tracker.GetSessionID(), + Update: &proto.UpdateSessionTrackerRequest_UpdateState{ + UpdateState: &proto.SessionTrackerUpdateState{ + State: state, }, - }) - - return trace.Wrap(err) - } + }, + }) - return nil + return trace.Wrap(err) } // WaitForStateUpdate waits for the tracker's state to be updated and returns the new state.