Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions lib/srv/sess.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ import (
"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/lib/auth"
"github.com/gravitational/teleport/lib/backend/memory"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/events"
"github.com/gravitational/teleport/lib/events/filesessions"
"github.com/gravitational/teleport/lib/observability/metrics"
"github.com/gravitational/teleport/lib/services"
"github.com/gravitational/teleport/lib/services/local"
rsession "github.com/gravitational/teleport/lib/session"
"github.com/gravitational/teleport/lib/sshutils"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -1604,16 +1606,36 @@ func (s *session) trackSession(teleportUser string, policySet []*types.SessionTr
}

s.log.Debug("Creating session tracker")
var err error

// if doing proxy recording, don't propagate tracker to the cluster level
if s.registry.Srv.Component() == teleport.ComponentNode && services.IsRecordAtProxy(s.scx.SessionRecordingConfig.GetMode()) {
s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, nil)
bk, err := memory.New(memory.Config{})
if err != nil {
return trace.Wrap(err)
}
trackerRemote, err := local.NewSessionTrackerService(bk)
if err != nil {
return trace.Wrap(err)
}
s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, trackerRemote)
if err != nil {
return trace.Wrap(err)
}

go func() {
err := s.tracker.WaitOnState(s.serverCtx, types.SessionState_SessionStateTerminated)
if err != nil {
s.log.WithError(err).Error("Failed to wait on session state.")
}

bk.Close()
}()
} else {
var err error
s.tracker, err = NewSessionTracker(s.serverCtx, trackerSpec, s.registry.SessionTrackerService)
}
if err != nil {
return trace.Wrap(err)
if err != nil {
return trace.Wrap(err)
}
}

go func() {
Expand Down
90 changes: 38 additions & 52 deletions lib/srv/sessiontracker.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,17 @@ type SessionTracker struct {

// NewSessionTracker returns a new SessionTracker for the given types.SessionTracker
func NewSessionTracker(ctx context.Context, trackerSpec types.SessionTrackerSpecV1, service services.SessionTrackerService) (*SessionTracker, error) {
if service == nil {
return nil, trace.BadParameter("missing parameter service")
}

t, err := types.NewSessionTracker(trackerSpec)
if err != nil {
return nil, trace.Wrap(err)
}

if service != nil {
if t, err = service.CreateSessionTracker(ctx, t); err != nil {
return nil, trace.Wrap(err)
}
if t, err = service.CreateSessionTracker(ctx, t); err != nil {
return nil, trace.Wrap(err)
}

return &SessionTracker{
Expand Down Expand Up @@ -104,20 +106,16 @@ func (s *SessionTracker) UpdateExpiration(ctx context.Context, expiry time.Time)
s.tracker.SetExpiry(expiry)
s.trackerCond.Broadcast()

if s.service != nil {
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_UpdateExpiry{
UpdateExpiry: &proto.SessionTrackerUpdateExpiry{
Expires: &expiry,
},
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_UpdateExpiry{
UpdateExpiry: &proto.SessionTrackerUpdateExpiry{
Expires: &expiry,
},
})
},
})

return trace.Wrap(err)
}

return nil
return trace.Wrap(err)
}

func (s *SessionTracker) AddParticipant(ctx context.Context, p *types.Participant) error {
Expand All @@ -126,20 +124,16 @@ func (s *SessionTracker) AddParticipant(ctx context.Context, p *types.Participan
s.tracker.AddParticipant(*p)
s.trackerCond.Broadcast()

if s.service != nil {
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_AddParticipant{
AddParticipant: &proto.SessionTrackerAddParticipant{
Participant: p,
},
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_AddParticipant{
AddParticipant: &proto.SessionTrackerAddParticipant{
Participant: p,
},
})

return trace.Wrap(err)
}
},
})

return nil
return trace.Wrap(err)
}

func (s *SessionTracker) RemoveParticipant(ctx context.Context, participantID string) error {
Expand All @@ -148,20 +142,16 @@ func (s *SessionTracker) RemoveParticipant(ctx context.Context, participantID st
s.tracker.RemoveParticipant(participantID)
s.trackerCond.Broadcast()

if s.service != nil {
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_RemoveParticipant{
RemoveParticipant: &proto.SessionTrackerRemoveParticipant{
ParticipantID: participantID,
},
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_RemoveParticipant{
RemoveParticipant: &proto.SessionTrackerRemoveParticipant{
ParticipantID: participantID,
},
})

return trace.Wrap(err)
}
},
})

return nil
return trace.Wrap(err)
}

func (s *SessionTracker) UpdateState(ctx context.Context, state types.SessionState) error {
Expand All @@ -170,20 +160,16 @@ func (s *SessionTracker) UpdateState(ctx context.Context, state types.SessionSta
s.tracker.SetState(state)
s.trackerCond.Broadcast()

if s.service != nil {
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_UpdateState{
UpdateState: &proto.SessionTrackerUpdateState{
State: state,
},
err := s.service.UpdateSessionTracker(ctx, &proto.UpdateSessionTrackerRequest{
SessionID: s.tracker.GetSessionID(),
Update: &proto.UpdateSessionTrackerRequest_UpdateState{
UpdateState: &proto.SessionTrackerUpdateState{
State: state,
},
})

return trace.Wrap(err)
}
},
})

return nil
return trace.Wrap(err)
}

// WaitForStateUpdate waits for the tracker's state to be updated and returns the new state.
Expand Down