From 6f4d7ad65e109a1ea38f605bf8318bb65518d4c9 Mon Sep 17 00:00:00 2001 From: Edoardo Spadolini Date: Mon, 27 May 2024 20:36:37 +0200 Subject: [PATCH] Fix goroutine leak in (*UploadCompleter).ensureSessionEndEvent --- lib/auth/auth_with_roles_test.go | 20 ++++++++++++-------- lib/auth/grpcserver.go | 18 ++++++++---------- lib/client/player.go | 7 ++++++- lib/events/api.go | 10 +++++++--- lib/events/auditlog.go | 4 +--- lib/events/complete.go | 2 ++ tool/tsh/common/recording_export.go | 2 ++ 7 files changed, 38 insertions(+), 25 deletions(-) diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 4ec75ed87ffe7..b0e8e2b4f0f78 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -1803,10 +1803,12 @@ func TestStreamSessionEventsRBAC(t *testing.T) { clt, err := srv.NewClient(identity) require.NoError(t, err) - _, errC := clt.StreamSessionEvents(context.Background(), "foo", 0) + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, errC := clt.StreamSessionEvents(ctx, "foo", 0) select { case err := <-errC: - require.True(t, trace.IsAccessDenied(err), "expected access denied error, got %v", err) + require.ErrorAs(t, err, new(*trace.AccessDeniedError)) case <-time.After(5 * time.Second): require.FailNow(t, "expected access denied error but stream succeeded") } @@ -1816,7 +1818,8 @@ func TestStreamSessionEventsRBAC(t *testing.T) { func TestStreamSessionEvents_User(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() srv := newTestTLSServer(t) username := "user" @@ -1851,7 +1854,8 @@ func TestStreamSessionEvents_User(t *testing.T) { func TestStreamSessionEvents_Builtin(t *testing.T) { t.Parallel() - ctx := context.Background() + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() srv := newTestTLSServer(t) identity := TestBuiltin(types.RoleProxy) @@ -7554,7 +7558,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { roles: []string{ "identity-requester", "payments-requester", "identity-resource-requester", "payments-resource-requester", - "any-requester"}, + "any-requester", + }, requestedRoles: []string{"payments-access"}, expectedAnnotations: map[string][]string{ "requesting": {"role"}, @@ -7568,7 +7573,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { roles: []string{ "identity-requester", "payments-requester", "identity-resource-requester", "payments-resource-requester", - "any-requester"}, + "any-requester", + }, requestedRoles: []string{"payments-access"}, requestedResourceIDs: []string{"server-payments"}, expectedAnnotations: map[string][]string{ @@ -7667,10 +7673,8 @@ func TestAccessRequestNonGreedyAnnotations(t *testing.T) { } else { tc.errfn(t, err) } - }) } - } func mustAccessRequest(t *testing.T, user string, state types.RequestState, created, expires time.Time, roles []string, resourceIDs []types.ResourceID) types.AccessRequest { diff --git a/lib/auth/grpcserver.go b/lib/auth/grpcserver.go index 927c658f18820..98375037b6d4c 100644 --- a/lib/auth/grpcserver.go +++ b/lib/auth/grpcserver.go @@ -233,16 +233,14 @@ func (g *GRPCServer) EmitAuditEvent(ctx context.Context, req *apievents.OneOf) ( return &emptypb.Empty{}, nil } -var ( - connectedResourceGauges = map[string]prometheus.Gauge{ - constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode), - constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube), - constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp), - constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase), - constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService), - constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService), - } -) +var connectedResourceGauges = map[string]prometheus.Gauge{ + constants.KeepAliveNode: connectedResources.WithLabelValues(constants.KeepAliveNode), + constants.KeepAliveKube: connectedResources.WithLabelValues(constants.KeepAliveKube), + constants.KeepAliveApp: connectedResources.WithLabelValues(constants.KeepAliveApp), + constants.KeepAliveDatabase: connectedResources.WithLabelValues(constants.KeepAliveDatabase), + constants.KeepAliveDatabaseService: connectedResources.WithLabelValues(constants.KeepAliveDatabaseService), + constants.KeepAliveWindowsDesktopService: connectedResources.WithLabelValues(constants.KeepAliveWindowsDesktopService), +} // SendKeepAlives allows node to send a stream of keep alive requests func (g *GRPCServer) SendKeepAlives(stream authpb.AuthService_SendKeepAlivesServer) error { diff --git a/lib/client/player.go b/lib/client/player.go index 69c68bcdc9eb7..9032a87fad14e 100644 --- a/lib/client/player.go +++ b/lib/client/player.go @@ -60,7 +60,12 @@ func (p *playFromFileStreamer) StreamSessionEvents( } if i >= startIndex { - evts <- evt + select { + case evts <- evt: + case <-ctx.Done(): + errs <- trace.Wrap(err) + return + } } } }() diff --git a/lib/events/api.go b/lib/events/api.go index 5ad6226f9327b..a8529a4047561 100644 --- a/lib/events/api.go +++ b/lib/events/api.go @@ -957,9 +957,13 @@ type SessionStreamer interface { // after is used to return events after a specified cursor ID GetSessionEvents(namespace string, sid session.ID, after int) ([]EventFields, error) - // StreamSessionEvents streams all events from a given session recording. An error is returned on the first - // channel if one is encountered. Otherwise the event channel is closed when the stream ends. - // The event channel is not closed on error to prevent race conditions in downstream select statements. + // StreamSessionEvents streams all events from a given session recording. An + // error is returned on the first channel if one is encountered. Otherwise + // the event channel is closed when the stream ends. The event channel is + // not closed on error to prevent race conditions in downstream select + // statements. Both returned channels must be driven until the event channel + // is exhausted or the error channel reports an error, or until the context + // is canceled. StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) } diff --git a/lib/events/auditlog.go b/lib/events/auditlog.go index 2c3a34312e15e..d19e59e3ac5b3 100644 --- a/lib/events/auditlog.go +++ b/lib/events/auditlog.go @@ -950,9 +950,7 @@ func (l *AuditLog) SearchSessionEvents(ctx context.Context, req SearchSessionEve return l.localLog.SearchSessionEvents(ctx, req) } -// StreamSessionEvents streams all events from a given session recording. An error is returned on the first -// channel if one is encountered. Otherwise the event channel is closed when the stream ends. -// The event channel is not closed on error to prevent race conditions in downstream select statements. +// StreamSessionEvents implements [SessionStreamer]. func (l *AuditLog) StreamSessionEvents(ctx context.Context, sessionID session.ID, startIndex int64) (chan apievents.AuditEvent, chan error) { l.log.Debugf("StreamSessionEvents(%v)", sessionID) e := make(chan error, 1) diff --git a/lib/events/complete.go b/lib/events/complete.go index b414f508e3969..b335363e667c2 100644 --- a/lib/events/complete.go +++ b/lib/events/complete.go @@ -326,6 +326,8 @@ func (u *UploadCompleter) ensureSessionEndEvent(ctx context.Context, uploadData // for both Desktop and SSH sessions, where as the GetSessionEvents API relies on downloading // a copy of the session and using the SSH-specific index to iterate through events. var lastEvent events.AuditEvent + ctx, cancel := context.WithCancel(ctx) + defer cancel() evts, errors := u.cfg.AuditLog.StreamSessionEvents(ctx, uploadData.SessionID, 0) loop: diff --git a/tool/tsh/common/recording_export.go b/tool/tsh/common/recording_export.go index c0a46bba767ec..6d868557fca30 100644 --- a/tool/tsh/common/recording_export.go +++ b/tool/tsh/common/recording_export.go @@ -78,6 +78,8 @@ func makeAVIFileName(prefix string, currentFile int) string { // writeMovie writes the events for the specified session into one or more movie files // beginning with the specified prefix. It returns the number of frames that were written and an error. func writeMovie(ctx context.Context, ss events.SessionStreamer, sid session.ID, prefix string, write func(format string, args ...any) (int, error), webProxyAddr string) (frames int, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() var screen *image.NRGBA var movie mjpeg.AviWriter