Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -1182,6 +1182,12 @@ func (a *Server) Close() error {
errs = append(errs, err)
}

if a.Services.AuditLogSessionStreamer != nil {
if err := a.Services.AuditLogSessionStreamer.Close(); err != nil {
errs = append(errs, err)
}
}

if a.bk != nil {
if err := a.bk.Close(); err != nil {
errs = append(errs, err)
Expand Down
50 changes: 40 additions & 10 deletions lib/backend/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,26 +161,52 @@ func (l *Lock) resetTTL(ctx context.Context, backend Backend) error {
return nil
}

// RunWhileLockedConfig is configuration for RunWhileLocked function.
type RunWhileLockedConfig struct {
// LockConfiguration is configuration for acquire lock.
LockConfiguration

// ReleaseCtxTimeout defines timeout used for calling lock.Release method (optional).
ReleaseCtxTimeout time.Duration
// RefreshLockInterval defines interval at which lock will be refreshed
// if fn is still running (optional).
RefreshLockInterval time.Duration
}

func (c *RunWhileLockedConfig) CheckAndSetDefaults() error {
if err := c.LockConfiguration.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}
if c.ReleaseCtxTimeout <= 0 {
c.ReleaseCtxTimeout = 300 * time.Millisecond
}
if c.RefreshLockInterval <= 0 {
c.RefreshLockInterval = c.LockConfiguration.TTL / 2
}
return nil
}

// RunWhileLocked allows you to run a function while a lock is held.
func RunWhileLocked(ctx context.Context, backend Backend, lockName string, ttl time.Duration, fn func(context.Context) error) error {
lock, err := AcquireLock(ctx, LockConfiguration{
Backend: backend,
LockName: lockName,
TTL: ttl,
})
func RunWhileLocked(ctx context.Context, cfg RunWhileLockedConfig, fn func(context.Context) error) error {
if err := cfg.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}

lock, err := AcquireLock(ctx, cfg.LockConfiguration)
if err != nil {
return trace.Wrap(err)
}

subContext, cancelFunction := context.WithCancel(ctx)
defer cancelFunction()

stopRefresh := make(chan struct{})
go func() {
refreshAfter := ttl / 2
refreshAfter := cfg.RefreshLockInterval
for {
select {
case <-backend.Clock().After(refreshAfter):
if err := lock.resetTTL(ctx, backend); err != nil {
case <-cfg.Backend.Clock().After(refreshAfter):
if err := lock.resetTTL(ctx, cfg.Backend); err != nil {
cancelFunction()
log.Errorf("%v", err)
return
Expand All @@ -194,7 +220,11 @@ func RunWhileLocked(ctx context.Context, backend Backend, lockName string, ttl t
fnErr := fn(subContext)
close(stopRefresh)

if err := lock.Release(ctx, backend); err != nil {
// lock.Release should be called with separate ctx. If someone cancels via ctx
// RunWhileLocked method, we want to at least try releasing lock.
releaseLockCtx, releaseLockCancel := context.WithTimeout(context.Background(), cfg.ReleaseCtxTimeout)
defer releaseLockCancel()
if err := lock.Release(releaseLockCtx, cfg.Backend); err != nil {
return trace.NewAggregate(fnErr, err)
}

Expand Down
60 changes: 60 additions & 0 deletions lib/backend/helpers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,63 @@ func TestLockConfiguration_CheckAndSetDefaults(t *testing.T) {
})
}
}

func TestRunWhileLockedConfigCheckAndSetDefaults(t *testing.T) {
type mockBackend struct {
Backend
}
lockName := "lock"
ttl := 1 * time.Minute
minimumValidConfig := RunWhileLockedConfig{
LockConfiguration: LockConfiguration{
Backend: mockBackend{},
LockName: lockName,
TTL: ttl,
},
}
tests := []struct {
name string
input func() RunWhileLockedConfig
want RunWhileLockedConfig
wantErr string
}{
{
name: "minimum valid config",
input: func() RunWhileLockedConfig {
return minimumValidConfig
},
want: RunWhileLockedConfig{
LockConfiguration: LockConfiguration{
Backend: mockBackend{},
LockName: lockName,
TTL: ttl,
RetryInterval: 250 * time.Millisecond,
},
ReleaseCtxTimeout: 300 * time.Millisecond,
// defaults to halft of TTL.
RefreshLockInterval: 30 * time.Second,
},
},
{
name: "errors from LockConfiguration is passed",
input: func() RunWhileLockedConfig {
cfg := minimumValidConfig
cfg.LockName = ""
return cfg
},
wantErr: "missing LockName",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cfg := tt.input()
err := cfg.CheckAndSetDefaults()
if tt.wantErr == "" {
require.NoError(t, err, "CheckAndSetDefaults return unexpected err")
require.Empty(t, cmp.Diff(tt.want, cfg))
} else {
require.ErrorContains(t, err, tt.wantErr)
}
})
}
}
29 changes: 15 additions & 14 deletions lib/events/athena/athena.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ package athena

import (
"context"
"io"
"net/url"
"regexp"
"strconv"
Expand Down Expand Up @@ -348,9 +349,9 @@ func (cfg *Config) SetFromURL(url *url.URL) error {
// Parquet and send it to S3 for long term storage.
// Athena is used for quering Parquet files on S3.
type Log struct {
publisher *publisher
querier *querier
consumerStop context.CancelFunc
publisher *publisher
querier *querier
consumerCloser io.Closer
}

// New creates an instance of an Athena based audit log.
Expand All @@ -360,14 +361,7 @@ func New(ctx context.Context, cfg Config) (*Log, error) {
return nil, trace.Wrap(err)
}

consumerCtx, consumerCancel := context.WithCancel(ctx)

l := &Log{
publisher: newPublisher(cfg),
consumerStop: consumerCancel,
}

l.querier, err = newQuerier(querierConfig{
querier, err := newQuerier(querierConfig{
tablename: cfg.TableName,
database: cfg.Database,
workgroup: cfg.Workgroup,
Expand All @@ -381,11 +375,19 @@ func New(ctx context.Context, cfg Config) (*Log, error) {
return nil, trace.Wrap(err)
}

consumer, err := newConsumer(cfg)
consumerCtx, consumerCancel := context.WithCancel(ctx)

consumer, err := newConsumer(cfg, consumerCancel)
if err != nil {
return nil, trace.Wrap(err)
}

l := &Log{
publisher: newPublisher(cfg),
querier: querier,
consumerCloser: consumer,
}

go consumer.run(consumerCtx)

return l, nil
Expand All @@ -404,8 +406,7 @@ func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order typ
}

func (l *Log) Close() error {
l.consumerStop()
return nil
return trace.Wrap(l.consumerCloser.Close())
}

var isAlphanumericOrUnderscoreRe = regexp.MustCompile("^[a-zA-Z0-9_]+$")
Expand Down
91 changes: 88 additions & 3 deletions lib/events/athena/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (

"github.com/gravitational/teleport"
apievents "github.com/gravitational/teleport/api/types/events"
"github.com/gravitational/teleport/api/utils/retryutils"
"github.com/gravitational/teleport/lib/backend"
awsutils "github.com/gravitational/teleport/lib/utils/aws"
)
Expand All @@ -60,7 +61,7 @@ const (
// consumer is responsible for receiving messages from SQS, batching them up to
// certain size or interval, and writes to s3 as parquet file.
type consumer struct {
logger *log.Entry
logger log.FieldLogger
backend backend.Backend
storeLocationPrefix string
storeLocationBucket string
Expand All @@ -75,6 +76,13 @@ type consumer struct {

sqsDeleter sqsDeleter
queueURL string

// cancelRun is used to cancel consumer.Run
cancelRun context.CancelFunc

// finished is used to communicate that run (executed in background) has finished.
// It will be closed when run has finished.
finished chan struct{}
}

type sqsReceiver interface {
Expand All @@ -89,7 +97,7 @@ type s3downloader interface {
Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*manager.Downloader)) (n int64, err error)
}

func newConsumer(cfg Config) (*consumer, error) {
func newConsumer(cfg Config, cancelFn context.CancelFunc) (*consumer, error) {
s3client := s3.NewFromConfig(*cfg.AWSConfig)
sqsClient := sqs.NewFromConfig(*cfg.AWSConfig)

Expand All @@ -109,6 +117,10 @@ func newConsumer(cfg Config) (*consumer, error) {
return nil, trace.Wrap(err)
}

if cancelFn == nil {
return nil, trace.BadParameter("cancelFn must be passed to consumer")
}

return &consumer{
logger: cfg.LogEntry,
backend: cfg.Backend,
Expand All @@ -129,18 +141,43 @@ func newConsumer(cfg Config) (*consumer, error) {
}
return fw, nil
},
cancelRun: cancelFn,
finished: make(chan struct{}),
}, nil
}

// run continuously runs batching job. It is blocking operation.
// It is stopped via canceling context.
func (c *consumer) run(ctx context.Context) {
defer func() {
close(c.finished)
c.logger.Debug("Consumer finished")
}()
c.runContinuouslyOnSingleAuth(ctx, c.processEventsContinuously)
}

// Close terminates the goroutine which is running [c.run]
func (c *consumer) Close() error {
c.cancelRun()
select {
case <-c.finished:
return nil
case <-time.After(1 * time.Second):
// ctx is use through all calls within consumer.Run so it should finished
// very fast, within miliseconds.
return errors.New("consumer not finished in time, returning earlier")
}
}

// processEventsContinuously runs processBatchOfEvents continuously in a loop.
// It makes sure that the CPU won't be spammed with too many requests if something goes
// wrong with calls to the AWS API.
func (c *consumer) processEventsContinuously(ctx context.Context) {
processBatchOfEventsWithLogging := func(context.Context) (reachedMaxBatch bool) {
reachedMaxBatch, err := c.processBatchOfEvents(ctx)
if err != nil {
// Ctx.Cancel is used to stop batcher
if ctx.Err() != nil {
c.logger.Debug("Batcher has been stopped")
return false
}
c.logger.Errorf("Batcher single run failed: %v", err)
Expand All @@ -149,6 +186,9 @@ func (c *consumer) run(ctx context.Context) {
return reachedMaxBatch
}

c.logger.Debug("Processing of events started on this instance")
defer c.logger.Debug("Processing of events finished on this instance")

// If batch took 90% of specified interval, we don't want to wait just little bit.
// It's mainly to avoid cases when we will wait like 10ms.
minInterval := time.Duration(float64(c.batchMaxInterval) * 0.9)
Expand All @@ -166,6 +206,51 @@ func (c *consumer) run(ctx context.Context) {
}
}

// runContinuouslyOnSingleAuth runs eventsProcessorFn continuously on single auth instance.
// Backend locking is used to make sure that only single auth is running consumer.
func (c *consumer) runContinuouslyOnSingleAuth(ctx context.Context, eventsProcessorFn func(context.Context)) {
// for 1 minute it will be 5s sleep before retry which seems like reasonable value.
waitTimeAfterLockingError := retryutils.NewSeventhJitter()(c.batchMaxInterval / 12)
for {
select {
case <-ctx.Done():
return
default:
err := backend.RunWhileLocked(ctx, backend.RunWhileLockedConfig{
LockConfiguration: backend.LockConfiguration{
Backend: c.backend,
LockName: "athena_lock",
// TTL is higher then batchMaxInterval because we want to optimize
// for low backend writes.
TTL: 5 * c.batchMaxInterval,
// RetryInterval means how often instance without lock will check
// backend if lock if ready for grab. We are fine with batchMaxInterval.
RetryInterval: c.batchMaxInterval,
},
}, func(ctx context.Context) error {
eventsProcessorFn(ctx)
return nil
})
if err != nil {
if ctx.Err() != nil {
return
}
// Ending up here means something went wrong in the backend while locking/waiting
// for lock. What we can do is log and retry whole operation.
c.logger.WithError(err).Warn("Could not get consumer to run with lock")
select {
// Use wait to make sure we won't spam CPU with a lot requests
// if something goes wrong during acquire lock.
case <-time.After(waitTimeAfterLockingError):
continue
case <-ctx.Done():
return
}
}
}
}
}

// runWithMinInterval runs fn, if fn returns earlier than minInterval
// it waits reamaning time.
// Useful when we don't want to put to many pressure on CPU with constantly running fn.
Expand Down
Loading