diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 43f398d08d528..eb727185f303f 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -1182,6 +1182,12 @@ func (a *Server) Close() error { errs = append(errs, err) } + if a.Services.AuditLogSessionStreamer != nil { + if err := a.Services.AuditLogSessionStreamer.Close(); err != nil { + errs = append(errs, err) + } + } + if a.bk != nil { if err := a.bk.Close(); err != nil { errs = append(errs, err) diff --git a/lib/backend/helpers.go b/lib/backend/helpers.go index 10d1e21c2c0bd..892b6a27fe395 100644 --- a/lib/backend/helpers.go +++ b/lib/backend/helpers.go @@ -161,26 +161,52 @@ func (l *Lock) resetTTL(ctx context.Context, backend Backend) error { return nil } +// RunWhileLockedConfig is configuration for RunWhileLocked function. +type RunWhileLockedConfig struct { + // LockConfiguration is configuration for acquire lock. + LockConfiguration + + // ReleaseCtxTimeout defines timeout used for calling lock.Release method (optional). + ReleaseCtxTimeout time.Duration + // RefreshLockInterval defines interval at which lock will be refreshed + // if fn is still running (optional). + RefreshLockInterval time.Duration +} + +func (c *RunWhileLockedConfig) CheckAndSetDefaults() error { + if err := c.LockConfiguration.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + if c.ReleaseCtxTimeout <= 0 { + c.ReleaseCtxTimeout = 300 * time.Millisecond + } + if c.RefreshLockInterval <= 0 { + c.RefreshLockInterval = c.LockConfiguration.TTL / 2 + } + return nil +} + // RunWhileLocked allows you to run a function while a lock is held. -func RunWhileLocked(ctx context.Context, backend Backend, lockName string, ttl time.Duration, fn func(context.Context) error) error { - lock, err := AcquireLock(ctx, LockConfiguration{ - Backend: backend, - LockName: lockName, - TTL: ttl, - }) +func RunWhileLocked(ctx context.Context, cfg RunWhileLockedConfig, fn func(context.Context) error) error { + if err := cfg.CheckAndSetDefaults(); err != nil { + return trace.Wrap(err) + } + + lock, err := AcquireLock(ctx, cfg.LockConfiguration) if err != nil { return trace.Wrap(err) } subContext, cancelFunction := context.WithCancel(ctx) + defer cancelFunction() stopRefresh := make(chan struct{}) go func() { - refreshAfter := ttl / 2 + refreshAfter := cfg.RefreshLockInterval for { select { - case <-backend.Clock().After(refreshAfter): - if err := lock.resetTTL(ctx, backend); err != nil { + case <-cfg.Backend.Clock().After(refreshAfter): + if err := lock.resetTTL(ctx, cfg.Backend); err != nil { cancelFunction() log.Errorf("%v", err) return @@ -194,7 +220,11 @@ func RunWhileLocked(ctx context.Context, backend Backend, lockName string, ttl t fnErr := fn(subContext) close(stopRefresh) - if err := lock.Release(ctx, backend); err != nil { + // lock.Release should be called with separate ctx. If someone cancels via ctx + // RunWhileLocked method, we want to at least try releasing lock. + releaseLockCtx, releaseLockCancel := context.WithTimeout(context.Background(), cfg.ReleaseCtxTimeout) + defer releaseLockCancel() + if err := lock.Release(releaseLockCtx, cfg.Backend); err != nil { return trace.NewAggregate(fnErr, err) } diff --git a/lib/backend/helpers_test.go b/lib/backend/helpers_test.go index 72f6f00eb94ae..1f509a8dbd829 100644 --- a/lib/backend/helpers_test.go +++ b/lib/backend/helpers_test.go @@ -100,3 +100,63 @@ func TestLockConfiguration_CheckAndSetDefaults(t *testing.T) { }) } } + +func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) { + type mockBackend struct { + Backend + } + lockName := "lock" + ttl := 1 * time.Minute + minimumValidConfig := RunWhileLockedConfig{ + LockConfiguration: LockConfiguration{ + Backend: mockBackend{}, + LockName: lockName, + TTL: ttl, + }, + } + tests := []struct { + name string + input func() RunWhileLockedConfig + want RunWhileLockedConfig + wantErr string + }{ + { + name: "minimum valid config", + input: func() RunWhileLockedConfig { + return minimumValidConfig + }, + want: RunWhileLockedConfig{ + LockConfiguration: LockConfiguration{ + Backend: mockBackend{}, + LockName: lockName, + TTL: ttl, + RetryInterval: 250 * time.Millisecond, + }, + ReleaseCtxTimeout: 300 * time.Millisecond, + // defaults to halft of TTL. + RefreshLockInterval: 30 * time.Second, + }, + }, + { + name: "errors from LockConfiguration is passed", + input: func() RunWhileLockedConfig { + cfg := minimumValidConfig + cfg.LockName = "" + return cfg + }, + wantErr: "missing LockName", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := tt.input() + err := cfg.CheckAndSetDefaults() + if tt.wantErr == "" { + require.NoError(t, err, "CheckAndSetDefaults return unexpected err") + require.Empty(t, cmp.Diff(tt.want, cfg)) + } else { + require.ErrorContains(t, err, tt.wantErr) + } + }) + } +} diff --git a/lib/events/athena/athena.go b/lib/events/athena/athena.go index 01076863a8acf..b403dc5d73853 100644 --- a/lib/events/athena/athena.go +++ b/lib/events/athena/athena.go @@ -16,6 +16,7 @@ package athena import ( "context" + "io" "net/url" "regexp" "strconv" @@ -348,9 +349,9 @@ func (cfg *Config) SetFromURL(url *url.URL) error { // Parquet and send it to S3 for long term storage. // Athena is used for quering Parquet files on S3. type Log struct { - publisher *publisher - querier *querier - consumerStop context.CancelFunc + publisher *publisher + querier *querier + consumerCloser io.Closer } // New creates an instance of an Athena based audit log. @@ -360,14 +361,7 @@ func New(ctx context.Context, cfg Config) (*Log, error) { return nil, trace.Wrap(err) } - consumerCtx, consumerCancel := context.WithCancel(ctx) - - l := &Log{ - publisher: newPublisher(cfg), - consumerStop: consumerCancel, - } - - l.querier, err = newQuerier(querierConfig{ + querier, err := newQuerier(querierConfig{ tablename: cfg.TableName, database: cfg.Database, workgroup: cfg.Workgroup, @@ -381,11 +375,19 @@ func New(ctx context.Context, cfg Config) (*Log, error) { return nil, trace.Wrap(err) } - consumer, err := newConsumer(cfg) + consumerCtx, consumerCancel := context.WithCancel(ctx) + + consumer, err := newConsumer(cfg, consumerCancel) if err != nil { return nil, trace.Wrap(err) } + l := &Log{ + publisher: newPublisher(cfg), + querier: querier, + consumerCloser: consumer, + } + go consumer.run(consumerCtx) return l, nil @@ -404,8 +406,7 @@ func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order typ } func (l *Log) Close() error { - l.consumerStop() - return nil + return trace.Wrap(l.consumerCloser.Close()) } var isAlphanumericOrUnderscoreRe = regexp.MustCompile("^[a-zA-Z0-9_]+$") diff --git a/lib/events/athena/consumer.go b/lib/events/athena/consumer.go index 09622ea101719..d5179d9e9a5cd 100644 --- a/lib/events/athena/consumer.go +++ b/lib/events/athena/consumer.go @@ -39,6 +39,7 @@ import ( "github.com/gravitational/teleport" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/retryutils" "github.com/gravitational/teleport/lib/backend" awsutils "github.com/gravitational/teleport/lib/utils/aws" ) @@ -60,7 +61,7 @@ const ( // consumer is responsible for receiving messages from SQS, batching them up to // certain size or interval, and writes to s3 as parquet file. type consumer struct { - logger *log.Entry + logger log.FieldLogger backend backend.Backend storeLocationPrefix string storeLocationBucket string @@ -75,6 +76,13 @@ type consumer struct { sqsDeleter sqsDeleter queueURL string + + // cancelRun is used to cancel consumer.Run + cancelRun context.CancelFunc + + // finished is used to communicate that run (executed in background) has finished. + // It will be closed when run has finished. + finished chan struct{} } type sqsReceiver interface { @@ -89,7 +97,7 @@ type s3downloader interface { Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*manager.Downloader)) (n int64, err error) } -func newConsumer(cfg Config) (*consumer, error) { +func newConsumer(cfg Config, cancelFn context.CancelFunc) (*consumer, error) { s3client := s3.NewFromConfig(*cfg.AWSConfig) sqsClient := sqs.NewFromConfig(*cfg.AWSConfig) @@ -109,6 +117,10 @@ func newConsumer(cfg Config) (*consumer, error) { return nil, trace.Wrap(err) } + if cancelFn == nil { + return nil, trace.BadParameter("cancelFn must be passed to consumer") + } + return &consumer{ logger: cfg.LogEntry, backend: cfg.Backend, @@ -129,18 +141,43 @@ func newConsumer(cfg Config) (*consumer, error) { } return fw, nil }, + cancelRun: cancelFn, + finished: make(chan struct{}), }, nil } // run continuously runs batching job. It is blocking operation. // It is stopped via canceling context. func (c *consumer) run(ctx context.Context) { + defer func() { + close(c.finished) + c.logger.Debug("Consumer finished") + }() + c.runContinuouslyOnSingleAuth(ctx, c.processEventsContinuously) +} + +// Close terminates the goroutine which is running [c.run] +func (c *consumer) Close() error { + c.cancelRun() + select { + case <-c.finished: + return nil + case <-time.After(1 * time.Second): + // ctx is use through all calls within consumer.Run so it should finished + // very fast, within miliseconds. + return errors.New("consumer not finished in time, returning earlier") + } +} + +// processEventsContinuously runs processBatchOfEvents continuously in a loop. +// It makes sure that the CPU won't be spammed with too many requests if something goes +// wrong with calls to the AWS API. +func (c *consumer) processEventsContinuously(ctx context.Context) { processBatchOfEventsWithLogging := func(context.Context) (reachedMaxBatch bool) { reachedMaxBatch, err := c.processBatchOfEvents(ctx) if err != nil { // Ctx.Cancel is used to stop batcher if ctx.Err() != nil { - c.logger.Debug("Batcher has been stopped") return false } c.logger.Errorf("Batcher single run failed: %v", err) @@ -149,6 +186,9 @@ func (c *consumer) run(ctx context.Context) { return reachedMaxBatch } + c.logger.Debug("Processing of events started on this instance") + defer c.logger.Debug("Processing of events finished on this instance") + // If batch took 90% of specified interval, we don't want to wait just little bit. // It's mainly to avoid cases when we will wait like 10ms. minInterval := time.Duration(float64(c.batchMaxInterval) * 0.9) @@ -166,6 +206,51 @@ func (c *consumer) run(ctx context.Context) { } } +// runContinuouslyOnSingleAuth runs eventsProcessorFn continuously on single auth instance. +// Backend locking is used to make sure that only single auth is running consumer. +func (c *consumer) runContinuouslyOnSingleAuth(ctx context.Context, eventsProcessorFn func(context.Context)) { + // for 1 minute it will be 5s sleep before retry which seems like reasonable value. + waitTimeAfterLockingError := retryutils.NewSeventhJitter()(c.batchMaxInterval / 12) + for { + select { + case <-ctx.Done(): + return + default: + err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ + LockConfiguration: backend.LockConfiguration{ + Backend: c.backend, + LockName: "athena_lock", + // TTL is higher then batchMaxInterval because we want to optimize + // for low backend writes. + TTL: 5 * c.batchMaxInterval, + // RetryInterval means how often instance without lock will check + // backend if lock if ready for grab. We are fine with batchMaxInterval. + RetryInterval: c.batchMaxInterval, + }, + }, func(ctx context.Context) error { + eventsProcessorFn(ctx) + return nil + }) + if err != nil { + if ctx.Err() != nil { + return + } + // Ending up here means something went wrong in the backend while locking/waiting + // for lock. What we can do is log and retry whole operation. + c.logger.WithError(err).Warn("Could not get consumer to run with lock") + select { + // Use wait to make sure we won't spam CPU with a lot requests + // if something goes wrong during acquire lock. + case <-time.After(waitTimeAfterLockingError): + continue + case <-ctx.Done(): + return + } + } + } + } +} + // runWithMinInterval runs fn, if fn returns earlier than minInterval // it waits reamaning time. // Useful when we don't want to put to many pressure on CPU with constantly running fn. diff --git a/lib/events/athena/consumer_test.go b/lib/events/athena/consumer_test.go index ea1d24f94f7af..69246ee3a11f5 100644 --- a/lib/events/athena/consumer_test.go +++ b/lib/events/athena/consumer_test.go @@ -42,6 +42,7 @@ import ( "github.com/xitongsys/parquet-go/source" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/utils" ) @@ -358,6 +359,89 @@ func (m *mockReceiver) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMe return m.receiveMessageRespFn() } +func TestConsumerRunContinuouslyOnSingleAuth(t *testing.T) { + log := utils.NewLoggerForTests() + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + defer backend.Close() + + batchInterval := 20 * time.Millisecond + + c1 := consumer{ + logger: log, + backend: backend, + batchMaxInterval: batchInterval, + } + c2 := consumer{ + logger: log, + backend: backend, + batchMaxInterval: batchInterval, + } + m1 := mockEventsProcessor{interval: batchInterval} + m2 := mockEventsProcessor{interval: batchInterval} + + ctx1, cancel1 := context.WithCancel(context.Background()) + defer cancel1() + ctx2, cancel2 := context.WithCancel(context.Background()) + defer cancel2() + + // start two consumer with different mocks in background. + go c1.runContinuouslyOnSingleAuth(ctx1, m1.Run) + go c2.runContinuouslyOnSingleAuth(ctx2, m2.Run) + + // We want wait till we processing of events starts. + // Check if there only single consumer is processing is below. + require.Eventually(t, func() bool { + // let's wait for at least 2 iteration. + return m1.getCount() >= 2 || m2.getCount() >= 2 + }, 5*batchInterval, batchInterval/2, "events were never processed by mock") + + m1Processing := m1.getCount() >= 2 + if m1Processing { + require.Zero(t, m2.getCount(), "expected 0 events by mock2") + } else { + require.Zero(t, m1.getCount(), "expected 0 events by mock1") + } + + // let's cancel ctx of single mock and verify if 2nd take over. + if m1Processing { + cancel1() + require.Eventually(t, func() bool { + return m2.getCount() >= 1 + }, 5*batchInterval, batchInterval/2, "mock2 hasn't started processing") + } else { + cancel2() + require.Eventually(t, func() bool { + return m1.getCount() >= 1 + }, 5*batchInterval, batchInterval/2, "mock1 hasn't started processing") + } +} + +type mockEventsProcessor struct { + mu sync.Mutex + count int + interval time.Duration +} + +func (m *mockEventsProcessor) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case <-time.After(m.interval): + m.mu.Lock() + m.count++ + m.mu.Unlock() + } + } +} + +func (m *mockEventsProcessor) getCount() int { + m.mu.Lock() + defer m.mu.Unlock() + return m.count +} + func TestRunWithMinInterval(t *testing.T) { ctx := context.Background() t.Run("function returns earlier than minInterval, wait should happen", func(t *testing.T) { diff --git a/lib/services/local/access.go b/lib/services/local/access.go index 3a0fc1c144fdc..8ada6d97fbb97 100644 --- a/lib/services/local/access.go +++ b/lib/services/local/access.go @@ -238,7 +238,13 @@ func (s *AccessService) DeleteAllLocks(ctx context.Context) error { // ReplaceRemoteLocks replaces the set of locks associated with a remote cluster. func (s *AccessService) ReplaceRemoteLocks(ctx context.Context, clusterName string, newRemoteLocks []types.Lock) error { - return backend.RunWhileLocked(ctx, s.Backend, "ReplaceRemoteLocks/"+clusterName, time.Minute, func(ctx context.Context) error { + return backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{ + LockConfiguration: backend.LockConfiguration{ + Backend: s.Backend, + LockName: "ReplaceRemoteLocks/" + clusterName, + TTL: time.Minute, + }, + }, func(ctx context.Context) error { remoteLocksKey := backend.Key(locksPrefix, clusterName) origRemoteLocks, err := s.GetRange(ctx, remoteLocksKey, backend.RangeEnd(remoteLocksKey), backend.NoLimit) if err != nil { diff --git a/lib/services/local/generic/generic.go b/lib/services/local/generic/generic.go index 72d6d9266d285..4905fa8e15787 100644 --- a/lib/services/local/generic/generic.go +++ b/lib/services/local/generic/generic.go @@ -286,7 +286,14 @@ func (s *Service[T]) MakeKey(name string) []byte { // RunWhileLocked will run the given function in a backend lock. This is a wrapper around the backend.RunWhileLocked function. func (s *Service[T]) RunWhileLocked(ctx context.Context, lockName string, ttl time.Duration, fn func(context.Context, backend.Backend) error) error { - return trace.Wrap(backend.RunWhileLocked(ctx, s.backend, lockName, ttl, func(ctx context.Context) error { - return fn(ctx, s.backend) - })) + return trace.Wrap(backend.RunWhileLocked(ctx, + backend.RunWhileLockedConfig{ + LockConfiguration: backend.LockConfiguration{ + Backend: s.backend, + LockName: lockName, + TTL: ttl, + }, + }, func(ctx context.Context) error { + return fn(ctx, s.backend) + })) }