diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index d46f1faaacefb..e9d1bf4f80c22 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -1802,10 +1802,12 @@ func TestStreamSessionEventsRBAC(t *testing.T) { clt, err := srv.NewClient(identity) require.NoError(t, err) - _, errC := clt.StreamSessionEvents(context.Background(), "foo", 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, errC := clt.StreamSessionEvents(ctx, "foo", 0) select { case err := <-errC: - require.True(t, trace.IsAccessDenied(err), "expected access denied error, got %v", err) + require.ErrorAs(t, err, new(*trace.AccessDeniedError)) case <-time.After(5 * time.Second): require.FailNow(t, "expected access denied error but stream succeeded") } @@ -1815,7 +1817,8 @@ func TestStreamSessionEventsRBAC(t *testing.T) { func TestStreamSessionEvents_User(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() srv := newTestTLSServer(t) username := "user" @@ -1850,7 +1853,8 @@ func TestStreamSessionEvents_User(t *testing.T) { func TestStreamSessionEvents_Builtin(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() srv := newTestTLSServer(t) identity := TestBuiltin(types.RoleProxy) @@ -7380,7 +7384,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { roles: []string{ "identity-requester", "payments-requester", "identity-resource-requester", "payments-resource-requester", - "any-requester"}, + "any-requester", + }, requestedRoles: []string{"payments-access"}, expectedAnnotations: map[string][]string{ "requesting": {"role"}, @@ -7394,7 +7399,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { roles: []string{ "identity-requester", "payments-requester", "identity-resource-requester", "payments-resource-requester", - "any-requester"}, + "any-requester", + }, requestedRoles: []string{"payments-access"}, requestedResourceIDs: []string{"server-payments"}, expectedAnnotations: map[string][]string{ @@ -7493,10 +7499,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { } else { tc.errfn(t, err) } - }) } - } func mustAccessRequest(t *testing.T, user string, state types.RequestState, created, expires time.Time, roles []string, resourceIDs []types.ResourceID) types.AccessRequest { diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 8af3be7bc3cdd..d3d58c36c0257 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -224,16 +224,14 @@ func (g *GRPCServer) EmitAuditEvent(ctx context.Context, req *apievents.OneOf) ( return &emptypb.Empty{}, nil } -var ( - connectedResourceGauges = map[string]prometheus.Gauge{ - constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode), - constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube), - constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp), - constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase), - constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService), - constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService), - } -) +var connectedResourceGauges = map[string]prometheus.Gauge{ + constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode), + constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube), + constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp), + constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase), + constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService), + constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService), +} // SendKeepAlives allows node to send a stream of keep alive requests func (g *GRPCServer) SendKeepAlives(stream authpb.AuthService_SendKeepAlivesServer) error { diff --git a/lib/client/player.go b/lib/client/player.go index 69c68bcdc9eb7..9032a87fad14e 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -60,7 +60,12 @@ func (p *playFromFileStreamer) StreamSessionEvents( } if i >= startIndex { - evts <- evt + select { + case evts <- evt: + case <-ctx.Done(): + errs <- trace.Wrap(err) + return + } } } }() diff --git a/lib/events/api.go b/lib/events/api.go index 7dd0db719f6c5..f609e4b95f1ca 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -935,9 +935,13 @@ type SessionStreamer interface { // after is used to return events after a specified cursor ID GetSessionEvents(namespace string, sid session.ID, after int) ([]EventFields, error) - // StreamSessionEvents streams all events from a given session recording. An error is returned on the first - // channel if one is encountered. Otherwise the event channel is closed when the stream ends. - // The event channel is not closed on error to prevent race conditions in downstream select statements. + // StreamSessionEvents streams all events from a given session recording. An + // error is returned on the first channel if one is encountered. Otherwise + // the event channel is closed when the stream ends. The event channel is + // not closed on error to prevent race conditions in downstream select + // statements. Both returned channels must be driven until the event channel + // is exhausted or the error channel reports an error, or until the context + // is canceled. StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) } diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 313bc25b6032f..eef6e42cf9e41 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -950,9 +950,7 @@ func (l *AuditLog) SearchSessionEvents(ctx context.Context, req SearchSessionEve return l.localLog.SearchSessionEvents(ctx, req) } -// StreamSessionEvents streams all events from a given session recording. An error is returned on the first -// channel if one is encountered. Otherwise the event channel is closed when the stream ends. -// The event channel is not closed on error to prevent race conditions in downstream select statements. +// StreamSessionEvents implements [SessionStreamer]. func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { l.log.Debugf("StreamSessionEvents(%v)", sessionID) e := make(chan error, 1) diff --git a/lib/events/complete.go b/lib/events/complete.go index b414f508e3969..b335363e667c2 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -326,6 +326,8 @@ func (u *UploadCompleter) ensureSessionEndEvent(ctx context.Context, uploadData // for both Desktop and SSH sessions, where as the GetSessionEvents API relies on downloading // a copy of the session and using the SSH-specific index to iterate through events. var lastEvent events.AuditEvent + ctx, cancel := context.WithCancel(ctx) + defer cancel() evts, errors := u.cfg.AuditLog.StreamSessionEvents(ctx, uploadData.SessionID, 0) loop: diff --git a/tool/tsh/common/recording_export.go b/tool/tsh/common/recording_export.go index 7691d89853b3f..e65a17f5a29ad 100644 --- a/tool/tsh/common/recording_export.go +++ b/tool/tsh/common/recording_export.go @@ -81,6 +81,8 @@ func makeAVIFileName(prefix string, currentFile int) string { // writeMovie writes the events for the specified session into one or more movie files // beginning with the specified prefix. It returns the number of frames that were written and an error. func writeMovie(ctx context.Context, ss events.SessionStreamer, sid session.ID, prefix string, write func(format string, args ...any) (int, error), webProxyAddr string) (frames int, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() var screen *image.NRGBA var movie mjpeg.AviWriter