diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index fa05685fcf789..36ec4240d2464 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -1781,10 +1781,12 @@ func TestStreamSessionEventsRBAC(t *testing.T) { clt, err := srv.NewClient(identity) require.NoError(t, err) - _, errC := clt.StreamSessionEvents(context.Background(), "foo", 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, errC := clt.StreamSessionEvents(ctx, "foo", 0) select { case err := <-errC: - require.True(t, trace.IsAccessDenied(err), "expected access denied error, got %v", err) + require.ErrorAs(t, err, new(*trace.AccessDeniedError)) case <-time.After(5 * time.Second): require.FailNow(t, "expected access denied error but stream succeeded") } @@ -1794,7 +1796,8 @@ func TestStreamSessionEventsRBAC(t *testing.T) { func TestStreamSessionEvents_User(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() srv := newTestTLSServer(t) username := "user" @@ -1829,7 +1832,8 @@ func TestStreamSessionEvents_User(t *testing.T) { func TestStreamSessionEvents_Builtin(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() srv := newTestTLSServer(t) identity := TestBuiltin(types.RoleProxy) @@ -7275,7 +7279,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { roles: []string{ "identity-requester", "payments-requester", "identity-resource-requester", "payments-resource-requester", - "any-requester"}, + "any-requester", + }, requestedRoles: []string{"payments-access"}, expectedAnnotations: map[string][]string{ "requesting": {"role"}, @@ -7289,7 +7294,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { roles: []string{ "identity-requester", "payments-requester", "identity-resource-requester", "payments-resource-requester", - "any-requester"}, + "any-requester", + }, requestedRoles: []string{"payments-access"}, requestedResourceIDs: []string{"server-payments"}, expectedAnnotations: map[string][]string{ @@ -7388,10 +7394,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { } else { tc.errfn(t, err) } - }) } - } func mustAccessRequest(t *testing.T, user string, state types.RequestState, created, expires time.Time, roles []string, resourceIDs []types.ResourceID) types.AccessRequest { diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index f3b8db589f2d7..2b9298d9c41b3 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -178,16 +178,14 @@ func (g *GRPCServer) EmitAuditEvent(ctx context.Context, req *apievents.OneOf) ( return &emptypb.Empty{}, nil } -var ( - connectedResourceGauges = map[string]prometheus.Gauge{ - constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode), - constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube), - constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp), - constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase), - constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService), - constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService), - } -) +var connectedResourceGauges = map[string]prometheus.Gauge{ + constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode), + constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube), + constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp), + constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase), + constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService), + constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService), +} // SendKeepAlives allows node to send a stream of keep alive requests func (g *GRPCServer) SendKeepAlives(stream authpb.AuthService_SendKeepAlivesServer) error { diff --git a/lib/client/player.go b/lib/client/player.go index 3ea4030e1f07c..ddef660436f2d 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -225,7 +225,6 @@ func (p *sessionPlayer) playRange(from, to int) { var i int defer func() { - p.Lock() endRequested := p.state == stateEnding p.setState(stateStopped) diff --git a/lib/events/api.go b/lib/events/api.go index a12b70ca47203..44e9f71a16a7c 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -889,9 +889,13 @@ type SessionStreamer interface { // after is used to return events after a specified cursor ID GetSessionEvents(namespace string, sid session.ID, after int) ([]EventFields, error) - // StreamSessionEvents streams all events from a given session recording. An error is returned on the first - // channel if one is encountered. Otherwise the event channel is closed when the stream ends. - // The event channel is not closed on error to prevent race conditions in downstream select statements. + // StreamSessionEvents streams all events from a given session recording. An + // error is returned on the first channel if one is encountered. Otherwise + // the event channel is closed when the stream ends. The event channel is + // not closed on error to prevent race conditions in downstream select + // statements. Both returned channels must be driven until the event channel + // is exhausted or the error channel reports an error, or until the context + // is canceled. StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) } diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 56f7dbde8d909..292202465114e 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -950,9 +950,7 @@ func (l *AuditLog) SearchSessionEvents(ctx context.Context, req SearchSessionEve return l.localLog.SearchSessionEvents(ctx, req) } -// StreamSessionEvents streams all events from a given session recording. An error is returned on the first -// channel if one is encountered. Otherwise the event channel is closed when the stream ends. -// The event channel is not closed on error to prevent race conditions in downstream select statements. +// StreamSessionEvents implements [SessionStreamer]. func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { l.log.Debugf("StreamSessionEvents(%v)", sessionID) e := make(chan error, 1) diff --git a/lib/events/complete.go b/lib/events/complete.go index a68d5332a47be..f3f8e40f57a8c 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -324,6 +324,8 @@ func (u *UploadCompleter) ensureSessionEndEvent(ctx context.Context, uploadData // for both Desktop and SSH sessions, where as the GetSessionEvents API relies on downloading // a copy of the session and using the SSH-specific index to iterate through events. var lastEvent events.AuditEvent + ctx, cancel := context.WithCancel(ctx) + defer cancel() evts, errors := u.cfg.AuditLog.StreamSessionEvents(ctx, uploadData.SessionID, 0) loop: diff --git a/tool/tsh/common/recording_export.go b/tool/tsh/common/recording_export.go index a78cc82079bd5..c477eab2f49d1 100644 --- a/tool/tsh/common/recording_export.go +++ b/tool/tsh/common/recording_export.go @@ -78,7 +78,10 @@ func makeAVIFileName(prefix string, currentFile int) string { // writeMovie writes the events for the specified session into one or more movie files // beginning with the specified prefix. It returns the number of frames that were written and an error. func writeMovie(ctx context.Context, ss events.SessionStreamer, sid session.ID, prefix string, - write func(format string, args ...any) (int, error)) (frames int, err error) { + write func(format string, args ...any) (int, error), +) (frames int, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() var screen *image.NRGBA var movie mjpeg.AviWriter