Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 41 additions & 45 deletions integrations/event-handler/session_events_job.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package main

import (
"context"
"log/slog"
"sync/atomic"
"time"

Expand Down Expand Up @@ -45,6 +44,10 @@ type session struct {
UploadTime time.Time
}

// processSessionFunc mimics the signature of
// [SessionEventsJob.processSession].
type processSessionFunc func(ctx context.Context, s session, processingAttempt int) error

// SessionEventsJob incapsulates session events consumption logic
type SessionEventsJob struct {
lib.ServiceJob
Expand All @@ -54,6 +57,7 @@ type SessionEventsJob struct {
logLimiter *rate.Limiter
backpressureLogLimiter *rate.Limiter
sessionsProcessed atomic.Uint64
processSessionFunc processSessionFunc
}

// NewSessionEventsJob creates new EventsJob structure
Expand All @@ -66,6 +70,7 @@ func NewSessionEventsJob(app *App) *SessionEventsJob {
backpressureLogLimiter: rate.NewLimiter(rate.Every(time.Minute), 1),
}

j.processSessionFunc = j.processSession
j.ServiceJob = lib.NewServiceJob(j.run)

return j
Expand Down Expand Up @@ -108,33 +113,9 @@ func (j *SessionEventsJob) run(ctx context.Context) error {
for {
select {
case s := <-j.sessions:
log := j.app.log.With(
"id", s.ID,
"index", s.Index,
)

if j.logLimiter.Allow() {
log.DebugContext(ctx, "Starting session ingest")
}

select {
case j.semaphore <- struct{}{}:
case <-ctx.Done():
log.ErrorContext(ctx, "Failed to acquire semaphore", "error", ctx.Err())
return nil
if err := j.ingestSession(ctx, s, 0, nil); err != nil {
j.app.log.WarnContext(ctx, "Unable to ingest session event", "error", err)
}

func(s session, log *slog.Logger) {
j.app.SpawnCritical(func(ctx context.Context) error {
defer func() { <-j.semaphore }()

if err := j.processSession(ctx, s, 0); err != nil {
return trace.Wrap(err)
}

return nil
})
}(s, log)
case <-ctx.Done():
if lib.IsCanceled(ctx.Err()) {
return nil
Expand All @@ -144,6 +125,35 @@ func (j *SessionEventsJob) run(ctx context.Context) error {
}
}

func (j *SessionEventsJob) ingestSession(ctx context.Context, s session, attempt int, semaphore chan struct{}) error {
log := j.app.log.With(
"id", s.ID,
"index", s.Index,
)
if j.logLimiter.Allow() {
log.DebugContext(ctx, "Starting session ingest")
}

if semaphore == nil {
semaphore = j.semaphore
}
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}

go func() {
defer func() { <-semaphore }()

if err := j.processSessionFunc(ctx, s, attempt); err != nil {
log.ErrorContext(ctx, "Failed processing session recording", "error", err)
}
}()

return nil
}

func (j *SessionEventsJob) processSession(ctx context.Context, s session, processingAttempt int) error {
const (
// maxNumberOfProcessingAttempts is the number of times a non-existent
Expand Down Expand Up @@ -237,9 +247,8 @@ func (j *SessionEventsJob) processSession(ctx context.Context, s session, proces
// from session recordings that were previously not found.
func (j *SessionEventsJob) processMissingRecordings(ctx context.Context) error {
const (
initialProcessingDelay = time.Minute
processingInterval = 3 * time.Minute
maxNumberOfInflightSessions = 10
initialProcessingDelay = time.Minute
processingInterval = 3 * time.Minute
)

ctx, cancel := context.WithCancel(ctx)
Expand All @@ -252,7 +261,6 @@ func (j *SessionEventsJob) processMissingRecordings(ctx context.Context) error {
timer := time.NewTimer(jitter(initialProcessingDelay))
defer timer.Stop()

semaphore := make(chan struct{}, maxNumberOfInflightSessions)
for {
select {
case <-timer.C:
Expand All @@ -261,21 +269,9 @@ func (j *SessionEventsJob) processMissingRecordings(ctx context.Context) error {
}

err := j.app.State.IterateMissingRecordings(func(sess session, attempts int) error {
select {
case semaphore <- struct{}{}:
case <-ctx.Done():
return trace.Wrap(ctx.Err())
}

go func() {
defer func() { <-semaphore }()
semaphore := make(chan struct{}, j.app.Config.Concurrency*2)

if err := j.processSession(ctx, sess, attempts); err != nil {
j.app.log.DebugContext(ctx, "Failed processing session recording", "error", err)
}
}()

return nil
return j.ingestSession(ctx, sess, attempts, semaphore)
})
if err != nil && !lib.IsCanceled(err) {
j.app.log.WarnContext(ctx, "Unable to load previously failed sessions for processing", "error", err)
Expand Down
50 changes: 49 additions & 1 deletion integrations/event-handler/session_events_job_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,14 @@
package main

import (
"bytes"
"context"
"log/slog"
"testing"
"testing/synctest"
"time"

"github.com/gravitational/trace"
"github.com/peterbourgon/diskv/v3"
"github.com/stretchr/testify/require"

Expand All @@ -40,10 +44,54 @@ func TestConsumeSessionNoEventsFound(t *testing.T) {
client: &mockClient{},
log: slog.Default(),
})
_, err := j.consumeSession(context.Background(), session{ID: sessionID})
_, err := j.consumeSession(t.Context(), session{ID: sessionID})
require.NoError(t, err)
}

// TestIngestSession tests that the ingestSession method returns without error if a malformed
// session event is processed.
func TestIngestSession(t *testing.T) {
synctest.Test(t, func(t *testing.T) {
startTime := time.Now().Add(-time.Minute)
out := &bytes.Buffer{}
log := slog.New(slog.NewTextHandler(out, &slog.HandlerOptions{Level: slog.LevelError}))

j := NewSessionEventsJob(&App{
Config: &StartCmdConfig{
IngestConfig: IngestConfig{
StorageDir: t.TempDir(),
Timeout: time.Second,
BatchSize: 100,
Concurrency: 5,
StartTime: &startTime,
SkipSessionTypes: map[string]struct{}{"print": {}, "desktop.recording": {}},
WindowSize: time.Hour * 24,
DryRun: true,
},
},
State: &State{
dv: diskv.New(diskv.Options{
BasePath: t.TempDir(),
}),
},
client: &mockClient{},
log: log,
})

j.processSessionFunc = func(ctx context.Context, s session, processingAttempt int) error {
return trace.LimitExceeded("Session ingestion exceeded attempt limit")
}

err := j.ingestSession(t.Context(), session{ID: "test"}, 0, nil)
require.NoError(t, err)

synctest.Wait()

require.Contains(t, out.String(), "Failed processing session recording")
require.Contains(t, out.String(), "Session ingestion exceeded attempt limit")
})
}

type mockClient struct {
client.Client
}
Expand Down
Loading