diff --git a/integrations/event-handler/session_events_job.go b/integrations/event-handler/session_events_job.go index 3abca19f68d2a..4fb2dca00a296 100644 --- a/integrations/event-handler/session_events_job.go +++ b/integrations/event-handler/session_events_job.go @@ -16,7 +16,6 @@ package main import ( "context" - "log/slog" "sync/atomic" "time" @@ -45,6 +44,10 @@ type session struct { UploadTime time.Time } +// processSessionFunc mimics the signature of +// [SessionEventsJob.processSession]. +type processSessionFunc func(ctx context.Context, s session, processingAttempt int) error + // SessionEventsJob incapsulates session events consumption logic type SessionEventsJob struct { lib.ServiceJob @@ -54,6 +57,7 @@ type SessionEventsJob struct { logLimiter *rate.Limiter backpressureLogLimiter *rate.Limiter sessionsProcessed atomic.Uint64 + processSessionFunc processSessionFunc } // NewSessionEventsJob creates new EventsJob structure @@ -66,6 +70,7 @@ func NewSessionEventsJob(app *App) *SessionEventsJob { backpressureLogLimiter: rate.NewLimiter(rate.Every(time.Minute), 1), } + j.processSessionFunc = j.processSession j.ServiceJob = lib.NewServiceJob(j.run) return j @@ -108,33 +113,9 @@ func (j *SessionEventsJob) run(ctx context.Context) error { for { select { case s := <-j.sessions: - log := j.app.log.With( - "id", s.ID, - "index", s.Index, - ) - - if j.logLimiter.Allow() { - log.DebugContext(ctx, "Starting session ingest") - } - - select { - case j.semaphore <- struct{}{}: - case <-ctx.Done(): - log.ErrorContext(ctx, "Failed to acquire semaphore", "error", ctx.Err()) - return nil + if err := j.ingestSession(ctx, s, 0, nil); err != nil { + j.app.log.WarnContext(ctx, "Unable to ingest session event", "error", err) } - - func(s session, log *slog.Logger) { - j.app.SpawnCritical(func(ctx context.Context) error { - defer func() { <-j.semaphore }() - - if err := j.processSession(ctx, s, 0); err != nil { - return trace.Wrap(err) - } - - return nil - }) - }(s, log) case <-ctx.Done(): if lib.IsCanceled(ctx.Err()) { return nil @@ -144,6 +125,35 @@ func (j *SessionEventsJob) run(ctx context.Context) error { } } +func (j *SessionEventsJob) ingestSession(ctx context.Context, s session, attempt int, semaphore chan struct{}) error { + log := j.app.log.With( + "id", s.ID, + "index", s.Index, + ) + if j.logLimiter.Allow() { + log.DebugContext(ctx, "Starting session ingest") + } + + if semaphore == nil { + semaphore = j.semaphore + } + select { + case semaphore <- struct{}{}: + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + } + + go func() { + defer func() { <-semaphore }() + + if err := j.processSessionFunc(ctx, s, attempt); err != nil { + log.ErrorContext(ctx, "Failed processing session recording", "error", err) + } + }() + + return nil +} + func (j *SessionEventsJob) processSession(ctx context.Context, s session, processingAttempt int) error { const ( // maxNumberOfProcessingAttempts is the number of times a non-existent @@ -237,9 +247,8 @@ func (j *SessionEventsJob) processSession(ctx context.Context, s session, proces // from session recordings that were previously not found. func (j *SessionEventsJob) processMissingRecordings(ctx context.Context) error { const ( - initialProcessingDelay = time.Minute - processingInterval = 3 * time.Minute - maxNumberOfInflightSessions = 10 + initialProcessingDelay = time.Minute + processingInterval = 3 * time.Minute ) ctx, cancel := context.WithCancel(ctx) @@ -252,7 +261,6 @@ func (j *SessionEventsJob) processMissingRecordings(ctx context.Context) error { timer := time.NewTimer(jitter(initialProcessingDelay)) defer timer.Stop() - semaphore := make(chan struct{}, maxNumberOfInflightSessions) for { select { case <-timer.C: @@ -261,21 +269,9 @@ func (j *SessionEventsJob) processMissingRecordings(ctx context.Context) error { } err := j.app.State.IterateMissingRecordings(func(sess session, attempts int) error { - select { - case semaphore <- struct{}{}: - case <-ctx.Done(): - return trace.Wrap(ctx.Err()) - } - - go func() { - defer func() { <-semaphore }() + semaphore := make(chan struct{}, j.app.Config.Concurrency*2) - if err := j.processSession(ctx, sess, attempts); err != nil { - j.app.log.DebugContext(ctx, "Failed processing session recording", "error", err) - } - }() - - return nil + return j.ingestSession(ctx, sess, attempts, semaphore) }) if err != nil && !lib.IsCanceled(err) { j.app.log.WarnContext(ctx, "Unable to load previously failed sessions for processing", "error", err) diff --git a/integrations/event-handler/session_events_job_test.go b/integrations/event-handler/session_events_job_test.go index 79ec5bd497dec..25722bbe7bcde 100644 --- a/integrations/event-handler/session_events_job_test.go +++ b/integrations/event-handler/session_events_job_test.go @@ -15,10 +15,14 @@ package main import ( + "bytes" "context" "log/slog" "testing" + "testing/synctest" + "time" + "github.com/gravitational/trace" "github.com/peterbourgon/diskv/v3" "github.com/stretchr/testify/require" @@ -40,10 +44,54 @@ func TestConsumeSessionNoEventsFound(t *testing.T) { client: &mockClient{}, log: slog.Default(), }) - _, err := j.consumeSession(context.Background(), session{ID: sessionID}) + _, err := j.consumeSession(t.Context(), session{ID: sessionID}) require.NoError(t, err) } +// TestIngestSession tests that the ingestSession method returns without error if a malformed +// session event is processed. +func TestIngestSession(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + startTime := time.Now().Add(-time.Minute) + out := &bytes.Buffer{} + log := slog.New(slog.NewTextHandler(out, &slog.HandlerOptions{Level: slog.LevelError})) + + j := NewSessionEventsJob(&App{ + Config: &StartCmdConfig{ + IngestConfig: IngestConfig{ + StorageDir: t.TempDir(), + Timeout: time.Second, + BatchSize: 100, + Concurrency: 5, + StartTime: &startTime, + SkipSessionTypes: map[string]struct{}{"print": {}, "desktop.recording": {}}, + WindowSize: time.Hour * 24, + DryRun: true, + }, + }, + State: &State{ + dv: diskv.New(diskv.Options{ + BasePath: t.TempDir(), + }), + }, + client: &mockClient{}, + log: log, + }) + + j.processSessionFunc = func(ctx context.Context, s session, processingAttempt int) error { + return trace.LimitExceeded("Session ingestion exceeded attempt limit") + } + + err := j.ingestSession(t.Context(), session{ID: "test"}, 0, nil) + require.NoError(t, err) + + synctest.Wait() + + require.Contains(t, out.String(), "Failed processing session recording") + require.Contains(t, out.String(), "Session ingestion exceeded attempt limit") + }) +} + type mockClient struct { client.Client }