diff --git a/go.mod b/go.mod index 81640f6a61f89..14964cebefa33 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/rds v1.43.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.32.0 github.com/aws/aws-sdk-go-v2/service/sns v1.20.8 + github.com/aws/aws-sdk-go-v2/service/sqs v1.20.6 github.com/aws/aws-sdk-go-v2/service/sts v1.18.9 github.com/aws/aws-sigv4-auth-cassandra-gocql-driver-plugin v0.0.0-20220331165046-e4d000c0d6a6 github.com/beevik/etree v1.1.0 diff --git a/go.sum b/go.sum index d752982f426e0..e7b3db0d656d9 100644 --- a/go.sum +++ b/go.sum @@ -163,6 +163,7 @@ github.com/aws/aws-sdk-go v1.44.244 h1:QzBWLD5HjZHdRZyTMTOWtD9Pobzf1n8/CeTJB4giX github.com/aws/aws-sdk-go v1.44.244/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/aws/aws-sdk-go-v2 v1.17.3/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2 v1.17.7/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2 v1.17.8 h1:GMupCNNI7FARX27L7GjCJM8NgivWbRgpjNI/hOQjFS8= github.com/aws/aws-sdk-go-v2 v1.17.8/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= @@ -179,9 +180,11 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.2/go.mod h1:cDh1p6XkSGSwSRIA github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.62 h1:LhVbe/UDWvBT/jp5LYAweFVH8s+DNtT07Qp2riWEovU= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.62/go.mod h1:4xCuu1TSwhW5UH6WOdtS4/x/9UfMr2XplzKc86Ffj78= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27/go.mod h1:a1/UpzeyBBerajpnP5nGZa9mGzsBn5cOKxm6NWQsvoI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31/go.mod h1:QT0BqUvX1Bh2ABdTGnjqEjvjzrCfIniM9Sc8zn9Yndo= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.32 h1:dpbVNUjczQ8Ae3QKHbpHBpfvaVkRdesxpTOe9pTouhU= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.32/go.mod h1:RudqOgadTWdcS3t/erPQo24pcVEoYyqj/kKW5Vya21I= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21/go.mod h1:+Gxn8jYn5k9ebfHEqlhrMirFjSW0v0C9fI+KN5vk2kE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25/go.mod h1:zBHOPwhBc3FlQjQJE/D3IfPWiWaQmT06Vq9aNukDo0k= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.26 h1:QH2kOS3Ht7x+u0gHCh06CXL/h6G8LQJFpZfFBYBNboo= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.26/go.mod h1:vq86l7956VgFr0/FWQ2BWnK07QC3WYsepKzy33qqY5U= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.28/go.mod h1:yRZVr/iT0AqyHeep00SZ4YfBAKojXz08w3XMBscdi0c= @@ -209,6 +212,8 @@ github.com/aws/aws-sdk-go-v2/service/s3 v1.32.0 h1:NAc8WQsVQ3+kz3rU619mlz8NcbpZI github.com/aws/aws-sdk-go-v2/service/s3 v1.32.0/go.mod h1:aSl9/LJltSz1cVusiR/Mu8tvI4Sv/5w/WWrJmmkNii0= github.com/aws/aws-sdk-go-v2/service/sns v1.20.8 h1:wy1jYAot40/Odzpzeq9S3OfSddJJ5RmpaKujvj5Hz7k= github.com/aws/aws-sdk-go-v2/service/sns v1.20.8/go.mod h1:HmCFGnmh0Tx4Onh9xUklrVhNcCsBTeDx4n53WGhp+oY= +github.com/aws/aws-sdk-go-v2/service/sqs v1.20.6 h1:4P/vyx7zCI5yBhlDZ2kwhoLjMJi0X7iR3cxqjNfbego= +github.com/aws/aws-sdk-go-v2/service/sqs v1.20.6/go.mod h1:HQHh1eChX10zDnGmD53WLYk8nPhUKO/JkAUUzDZ530Y= github.com/aws/aws-sdk-go-v2/service/sso v1.12.0/go.mod h1:wo/B7uUm/7zw/dWhBJ4FXuw1sySU5lyIhVg1Bu2yL9A= github.com/aws/aws-sdk-go-v2/service/sso v1.12.8 h1:5cb3D6xb006bPTqEfCNaEA6PPEfBXxxy4NNeX/44kGk= github.com/aws/aws-sdk-go-v2/service/sso v1.12.8/go.mod h1:GNIveDnP+aE3jujyUSH5aZ/rktsTM5EvtKnCqBZawdw= diff --git a/lib/events/athena/athena.go b/lib/events/athena/athena.go index 57419882e4a7f..124db8a5afbbe 100644 --- a/lib/events/athena/athena.go +++ b/lib/events/athena/athena.go @@ -32,14 +32,16 @@ import ( "github.com/gravitational/teleport" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/backend" "github.com/gravitational/teleport/lib/utils" ) const ( - // TODO(tobiaszheller): move to batcher.go in other PR. - // maxWaitTimeOnReceiveMessageFromSQS defines how long single - // receiveFromQueue will wait if there is no max events (10). - maxWaitTimeOnReceiveMessageFromSQS = 5 * time.Second + // defaultBatchItems defines default value for batch items count. + // 20000 items, per average 500KB event size = 10MB + defaultBatchItems = 20000 + // defaultBatchInterval defines default batch interval. + defaultBatchInterval = 1 * time.Minute ) // Config structure represents Athena configuration. @@ -67,7 +69,10 @@ type Config struct { TableName string // LocationS3 is location on S3 where Parquet files partitioned by date are // stored (required). - LocationS3 string + LocationS3 string + locationS3Bucket string + locationS3Prefix string + // QueryResultsS3 is location on S3 where Athena stored query results (optional). // Default results path can be defined by in workgroup settings. QueryResultsS3 string @@ -102,6 +107,8 @@ type Config struct { // using aws-sdk-go-v2. AWSConfig *aws.Config + Backend backend.Backend + // TODO(tobiaszheller): add FIPS config in later phase. } @@ -139,9 +146,15 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error { if cfg.LocationS3 == "" { return trace.BadParameter("LocationS3 is not specified") } - if scheme, ok := isValidUrlWithScheme(cfg.LocationS3); !ok || scheme != "s3" { - return trace.BadParameter("LocationS3 must be valid url and start with s3") + locationS3URL, err := url.Parse(cfg.LocationS3) + if err != nil { + return trace.BadParameter("LocationS3 must be valid url") + } + if locationS3URL.Scheme != "s3" { + return trace.BadParameter("LocationS3 must starts with s3://") } + cfg.locationS3Bucket = locationS3URL.Host + cfg.locationS3Prefix = strings.TrimSuffix(strings.TrimPrefix(locationS3URL.Path, "/"), "/") if cfg.LargeEventsS3 == "" { return trace.BadParameter("LargeEventsS3 is not specified") @@ -169,12 +182,11 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error { } if cfg.BatchMaxItems == 0 { - // 20000 items, per average 500KB event size = 10MB - cfg.BatchMaxItems = 20000 + cfg.BatchMaxItems = defaultBatchItems } if cfg.BatchMaxInterval == 0 { - cfg.BatchMaxInterval = 1 * time.Minute + cfg.BatchMaxInterval = defaultBatchInterval } if cfg.BatchMaxInterval < maxWaitTimeOnReceiveMessageFromSQS { @@ -227,6 +239,10 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error { cfg.AWSConfig = &awsCfg } + if cfg.Backend == nil { + return trace.BadParameter("Backend cannot be nil") + } + return nil } @@ -316,8 +332,9 @@ func (cfg *Config) SetFromURL(url *url.URL) error { // Parquet and send it to S3 for long term storage. // Athena is used for quering Parquet files on S3. type Log struct { - publisher *publisher - querier *querier + publisher *publisher + querier *querier + consumerStop context.CancelFunc } // New creates an instance of an Athena based audit log. @@ -326,12 +343,14 @@ func New(ctx context.Context, cfg Config) (*Log, error) { if err != nil { return nil, trace.Wrap(err) } + + consumerCtx, consumerCancel := context.WithCancel(ctx) + l := &Log{ - publisher: newPublisher(cfg), + publisher: newPublisher(cfg), + consumerStop: consumerCancel, } - // TODO(tobiaszheller): initialize batcher - l.querier, err = newQuerier(querierConfig{ tablename: cfg.TableName, database: cfg.Database, @@ -346,6 +365,13 @@ func New(ctx context.Context, cfg Config) (*Log, error) { return nil, trace.Wrap(err) } + consumer, err := newConsumer(cfg) + if err != nil { + return nil, trace.Wrap(err) + } + + go consumer.run(consumerCtx) + return l, nil } @@ -362,6 +388,7 @@ func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order typ } func (l *Log) Close() error { + l.consumerStop() return nil } diff --git a/lib/events/athena/athena_test.go b/lib/events/athena/athena_test.go index 0ccf6d01b7cb2..0bdf25444a2ee 100644 --- a/lib/events/athena/athena_test.go +++ b/lib/events/athena/athena_test.go @@ -16,16 +16,34 @@ package athena import ( "context" + "errors" + "io" "net/url" + "os" + "sort" + "strings" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" "github.com/stretchr/testify/require" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/utils" ) +func TestMain(m *testing.M) { + utils.InitLoggerForTests() + os.Exit(m.Run()) +} + func TestConfig_SetFromURL(t *testing.T) { tests := []struct { name string @@ -104,6 +122,10 @@ func TestConfig_SetFromURL(t *testing.T) { } func TestConfig_CheckAndSetDefaults(t *testing.T) { + type mockBackend struct { + backend.Backend + } + validConfig := Config{ Database: "db", TableName: "tbl", @@ -112,6 +134,7 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) { LocationS3: "s3://events-bucket", QueueURL: "https://queue-url", AWSConfig: &aws.Config{}, + Backend: mockBackend{}, } tests := []struct { name string @@ -131,11 +154,13 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) { LargeEventsS3: "s3://large-payloads-bucket", largeEventsBucket: "large-payloads-bucket", LocationS3: "s3://events-bucket", + locationS3Bucket: "events-bucket", QueueURL: "https://queue-url", GetQueryResultsInterval: 100 * time.Millisecond, BatchMaxItems: 20000, BatchMaxInterval: 1 * time.Minute, AWSConfig: &aws.Config{}, + Backend: mockBackend{}, }, }, { @@ -181,7 +206,7 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) { cfg.LocationS3 = "https://abc" return cfg }, - wantErr: "LocationS3 must be valid url and start with s3", + wantErr: "LocationS3 must starts with s3://", }, { name: "missing QueueURL", @@ -235,3 +260,111 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) { }) } } + +func TestPublisherConsumer(t *testing.T) { + fS3 := newFakeS3manager() + fq := newFakeQueue() + p := &publisher{ + snsPublisher: fq, + uploader: fS3, + } + + smallEvent := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-small", + }, + } + + largeEvent := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + Code: strings.Repeat("d", 2*maxSNSMessageSize), + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-large", + }, + } + + cfg := validCollectCfgForTests(t) + cfg.sqsReceiver = fq + cfg.payloadDownloader = fS3 + cfg.batchMaxItems = 2 + require.NoError(t, cfg.CheckAndSetDefaults()) + c := newSqsMessagesCollector(cfg) + + eventsChan := c.getEventsChan() + + ctx := context.Background() + readSQSCtx, readCancel := context.WithCancel(ctx) + defer readCancel() + + go c.fromSQS(readSQSCtx) + + // receiver is used to read messages from eventsChan. + r := &receiver{} + go r.Do(eventsChan) + + err := p.EmitAuditEvent(ctx, smallEvent) + require.NoError(t, err) + err = p.EmitAuditEvent(ctx, largeEvent) + require.NoError(t, err) + require.Eventually(t, func() bool { + return len(r.GetMsgs()) == 2 + }, 200*time.Millisecond, 1*time.Millisecond, "missing events, got %d", len(r.GetMsgs())) + + requireEventsEqualInAnyOrder(t, []apievents.AuditEvent{smallEvent, largeEvent}, eventAndAckIDToAuditEvents(r.GetMsgs())) + // S3 for uplodad should be called only once. + require.Equal(t, 1, fS3.uploadCount) +} + +// requireEventsEqualInAnyOrder compares slices of auditevents ignoring order. +// It's useful in tests because consumer does not guarantee order. +func requireEventsEqualInAnyOrder(t *testing.T, want, got []apievents.AuditEvent) { + sort.Slice(want, func(i, j int) bool { + return want[i].GetID() < want[j].GetID() + }) + sort.Slice(got, func(i, j int) bool { + return got[i].GetID() < got[j].GetID() + }) + require.Empty(t, cmp.Diff(want, got)) +} + +type fakeS3manager struct { + objects map[string][]byte + uploadCount int +} + +func newFakeS3manager() *fakeS3manager { + return &fakeS3manager{ + objects: map[string][]byte{}, + } +} + +func (f *fakeS3manager) Upload(ctx context.Context, input *s3.PutObjectInput, opts ...func(*manager.Uploader)) (*manager.UploadOutput, error) { + data, err := io.ReadAll(input.Body) + if err != nil { + return nil, err + } + f.objects[*input.Key] = data + f.uploadCount++ + return &manager.UploadOutput{Key: input.Key}, nil +} + +func (f *fakeS3manager) Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*manager.Downloader)) (int64, error) { + data, ok := f.objects[*input.Key] + if !ok { + return 0, errors.New("object not found") + } + n, err := w.WriteAt(data, 0) + if err != nil { + return 0, err + } + return int64(n), nil +} diff --git a/lib/events/athena/consumer.go b/lib/events/athena/consumer.go new file mode 100644 index 0000000000000..083ca9552e03b --- /dev/null +++ b/lib/events/athena/consumer.go @@ -0,0 +1,532 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package athena + +import ( + "context" + "encoding/base64" + "errors" + "io" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/feature/s3/manager" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/sqs" + sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" + + "github.com/gravitational/teleport" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/backend" + awsutils "github.com/gravitational/teleport/lib/utils/aws" +) + +const ( + // maxWaitTimeOnReceiveMessageFromSQS defines how long single + // receiveFromQueue will wait if there is no max events (10). + maxWaitTimeOnReceiveMessageFromSQS = 5 * time.Second + // maxNumberOfMessagesFromReceive defines how many messages single receive + // call can return. Maximum value is 10. + // https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_ReceiveMessage.html + maxNumberOfMessagesFromReceive = 10 + + // maxErrorCountForLogsOnSQSReceive defines maximum number of error log messages + // printed on receiving error from SQS receiver loop. + maxErrorCountForLogsOnSQSReceive = 10 +) + +// consumer is responsible for receiving messages from SQS, batching them up to +// certain size or interval, and writes to s3 as parquet file. +type consumer struct { + logger *log.Entry + backend backend.Backend + storeLocationPrefix string + storeLocationBucket string + batchMaxItems int + batchMaxInterval time.Duration + + collectConfig sqsCollectConfig +} + +type sqsReceiver interface { + ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) +} + +type s3downloader interface { + Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*manager.Downloader)) (n int64, err error) +} + +func newConsumer(cfg Config) (*consumer, error) { + s3client := s3.NewFromConfig(*cfg.AWSConfig) + sqsReceiver := sqs.NewFromConfig(*cfg.AWSConfig) + + collectCfg := sqsCollectConfig{ + sqsReceiver: sqsReceiver, + queueURL: cfg.QueueURL, + // TODO(tobiaszheller): use s3 manager from teleport observability. + payloadDownloader: manager.NewDownloader(s3client), + payloadBucket: cfg.largeEventsBucket, + visibilityTimeout: int32(cfg.BatchMaxInterval.Seconds()), + batchMaxItems: cfg.BatchMaxItems, + errHandlingFn: errHandlingFnFromSQS(cfg.LogEntry), + logger: cfg.LogEntry, + } + err := collectCfg.CheckAndSetDefaults() + if err != nil { + return nil, trace.Wrap(err) + } + + return &consumer{ + logger: cfg.LogEntry, + backend: cfg.Backend, + storeLocationPrefix: cfg.locationS3Prefix, + storeLocationBucket: cfg.locationS3Bucket, + batchMaxItems: cfg.BatchMaxItems, + batchMaxInterval: cfg.BatchMaxInterval, + collectConfig: collectCfg, + }, nil +} + +// run continuously runs batching job. It is blocking operation. +// It is stopped via canceling context. +func (c *consumer) run(ctx context.Context) { + processBatchOfEventsWithLogging := func(context.Context) (reachedMaxBatch bool) { + reachedMaxBatch, err := c.processBatchOfEvents(ctx) + if err != nil { + // Ctx.Cancel is used to stop batcher + if ctx.Err() != nil { + c.logger.Debug("Batcher has been stopped") + return false + } + c.logger.Errorf("Batcher single run failed: %v", err) + return false + } + return reachedMaxBatch + } + + // If batch took 90% of specified interval, we don't want to wait just little bit. + // It's mainly to avoid cases when we will wait like 10ms. + minInterval := time.Duration(float64(c.batchMaxInterval) * 0.9) + + var stop bool + for { + // We use helper fn [runWithMinInterval] to guarantee that we won't spam + // CPU if processBatchOfEvents will return immediately without processing + // any data. runWithMinInterval guarantees that if fn finished earlier, + // it will wait reaming time of interval before proceeding. + stop = runWithMinInterval(ctx, processBatchOfEventsWithLogging, minInterval) + if stop { + return + } + } +} + +// runWithMinInterval runs fn, if fn returns earlier than minInterval +// it waits reamaning time. +// Useful when we don't want to put to many pressure on CPU with constantly running fn. +func runWithMinInterval(ctx context.Context, fn func(context.Context) bool, minInterval time.Duration) (stop bool) { + start := time.Now() + reachedMaxBatch := fn(ctx) + if ctx.Err() != nil { + // stopping + return true + } + if reachedMaxBatch { + // reachedMaxBatch means that fn reached maxBatchSize. We don't want + // to wait in that case. + return false + } + elapsed := time.Since(start) + if elapsed > minInterval { + return false + } + select { + case <-ctx.Done(): + return true + case <-time.After(minInterval - elapsed): + return false + } +} + +// processBatchOfEvents creates single batch of events. It waits either up to BatchMaxInterval +// or BatchMaxItems while reading events from queue. Batch is sent to s3 as +// parquet file and at the end events are deleted from queue. +func (c *consumer) processBatchOfEvents(ctx context.Context) (reachedMaxSize bool, e error) { + start := time.Now() + var size int + // TODO(tobiaszheller): we need some metrics to track it. + // And that log message should be deleted. + defer func() { + if size > 0 { + c.logger.Debugf("Batch of %d messages processed in %s", size, time.Since(start)) + } + }() + + msgsCollector := newSqsMessagesCollector(c.collectConfig) + + readSQSCtx, readCancel := context.WithTimeout(ctx, c.batchMaxInterval) + defer readCancel() + + // msgsCollector and writeToS3 runs concurrently, and use events channel + // to send messages from collector to writeToS3. + go func() { + msgsCollector.fromSQS(readSQSCtx) + }() + var err error + size, err = c.writeToS3(ctx, msgsCollector.getEventsChan()) + if err != nil { + return false, trace.Wrap(err) + } + return size >= c.batchMaxItems, nil + // TODO(tobiaszheller): delete messages from queue in next PR. +} + +type sqsCollectConfig struct { + sqsReceiver sqsReceiver + queueURL string + payloadBucket string + payloadDownloader s3downloader + // visibilityTimeout defines how long message won't be available for other + // receiveMessage calls. If timeout happens, and message was not deleted + // it will return to the queue. + visibilityTimeout int32 + // waitOnReceiveDuration defines how long single + // receiveFromQueue will wait if there is no max events (10). + waitOnReceiveDuration time.Duration + // waitOnReceiveTimeout is int32 representation of waitOnReceiveDuration + // required by AWS API. + waitOnReceiveTimeout int32 + + // waitOnReceiveError defines interval used to wait before + // retrying receive message from SQS after getting error. + waitOnReceiveError time.Duration + + batchMaxItems int + + // noOfWorkers defines how many workers are processing messages from queue. + noOfWorkers int + + logger log.FieldLogger + errHandlingFn func(ctx context.Context, errC chan error) +} + +func (cfg *sqsCollectConfig) CheckAndSetDefaults() error { + if cfg.sqsReceiver == nil { + return trace.BadParameter("sqsReceiver is not specified") + } + if cfg.queueURL == "" { + return trace.BadParameter("queueURL is not specified") + } + if cfg.payloadBucket == "" { + return trace.BadParameter("payloadBucket is not specified") + } + if cfg.payloadDownloader == nil { + return trace.BadParameter("payloadDownloader is not specified") + } + if cfg.visibilityTimeout == 0 { + // visibilityTimeout is timeout in seconds, so 1 minute. + cfg.visibilityTimeout = int32(defaultBatchInterval.Seconds()) + } + if cfg.waitOnReceiveDuration == 0 { + cfg.waitOnReceiveDuration = maxWaitTimeOnReceiveMessageFromSQS + } + if cfg.waitOnReceiveTimeout != 0 { + return trace.BadParameter("waitOnReceiveTimeout is calculated internally and should not be set") + } + cfg.waitOnReceiveTimeout = int32(cfg.waitOnReceiveDuration.Seconds()) + + if cfg.waitOnReceiveError == 0 { + cfg.waitOnReceiveError = 1 * time.Second + } + if cfg.batchMaxItems == 0 { + cfg.batchMaxItems = defaultBatchItems + } + if cfg.noOfWorkers == 0 { + cfg.noOfWorkers = 5 + } + if cfg.logger == nil { + cfg.logger = log.WithFields(log.Fields{ + trace.Component: teleport.ComponentAthena, + }) + } + if cfg.errHandlingFn == nil { + return trace.BadParameter("errHandlingFn is not specified") + } + return nil +} + +// sqsMessagesCollector is responsible for collecting messages from SQS and +// writing to on channel. +type sqsMessagesCollector struct { + cfg sqsCollectConfig + eventsChan chan eventAndAckID +} + +// newSqsMessagesCollector returns message collector. +// Collector sends collected messages from SQS on events channel. +func newSqsMessagesCollector(cfg sqsCollectConfig) *sqsMessagesCollector { + return &sqsMessagesCollector{ + cfg: cfg, + eventsChan: make(chan eventAndAckID, cfg.batchMaxItems), + } +} + +// getEventsChan returns channel which can be used to read messages from SQS. +// When collector finishes, channel will be closed. +func (s *sqsMessagesCollector) getEventsChan() <-chan eventAndAckID { + return s.eventsChan +} + +// fromSQS receives messages from SQS and sends it on eventsC channel. +// It runs until context is canceled (via timeout) or when maxItems is reached. +// MaxItems is soft limit and can happen that it will return more items then MaxItems. +func (s *sqsMessagesCollector) fromSQS(ctx context.Context) { + // Errors should be immediately process by error handling loop, so 10 size + // should be enough to not cause blocking. + errorsC := make(chan error, 10) + defer close(errorsC) + + // errhandle loop for receiving single event errors. + go func() { + s.cfg.errHandlingFn(ctx, errorsC) + }() + eventsC := s.eventsChan + + // wokerCtx is mechanism to stop other workers when maxItems is reached. + wokerCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var ( + count int + countMu sync.Mutex + wg sync.WaitGroup + ) + + wg.Add(s.cfg.noOfWorkers) + for i := 0; i < s.cfg.noOfWorkers; i++ { + go func(i int) { + defer wg.Done() + for { + if wokerCtx.Err() != nil { + return + } + // If there is not enough time to process receiveMessage call + // we can return immediately. It's added because if + // receiveMessages is canceled message is marked as not + // processed after VisibilitTimeout (equal to BatchInterval). + if deadline, ok := wokerCtx.Deadline(); ok && time.Until(deadline) <= s.cfg.waitOnReceiveDuration { + return + } + noOfReceived := s.receiveMessagesAndSendOnChan(wokerCtx, eventsC, errorsC) + if noOfReceived == 0 { + // no point of locking and checking for size if nothing was returned. + continue + } + countMu.Lock() + count += noOfReceived + if count >= s.cfg.batchMaxItems { + countMu.Unlock() + cancel() + return + } + countMu.Unlock() + } + }(i) + } + wg.Wait() + close(eventsC) +} + +func (s *sqsMessagesCollector) receiveMessagesAndSendOnChan(ctx context.Context, eventsC chan<- eventAndAckID, errorsC chan<- error) (size int) { + sqsOut, err := s.cfg.sqsReceiver.ReceiveMessage(ctx, &sqs.ReceiveMessageInput{ + QueueUrl: aws.String(s.cfg.queueURL), + MaxNumberOfMessages: maxNumberOfMessagesFromReceive, + WaitTimeSeconds: s.cfg.waitOnReceiveTimeout, + VisibilityTimeout: s.cfg.visibilityTimeout, + MessageAttributeNames: []string{payloadTypeAttr}, + }) + if err != nil { + // We don't need handle canceled errors anyhow. + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return 0 + } + errorsC <- trace.Wrap(err) + + // We don't want to retry receiving message immediately to prevent huge load + // on CPU if calls are contantly failing. + select { + case <-ctx.Done(): + return 0 + case <-time.After(s.cfg.waitOnReceiveError): + return 0 + } + } + if len(sqsOut.Messages) == 0 { + return 0 + } + var noOfValidMessages int + for _, msg := range sqsOut.Messages { + event, err := s.auditEventFromSQSorS3(ctx, msg) + if err != nil { + errorsC <- trace.Wrap(err) + continue + } + eventsC <- eventAndAckID{ + event: event, + receiptHandle: aws.ToString(msg.ReceiptHandle), + } + noOfValidMessages++ + } + return noOfValidMessages +} + +// auditEventFromSQSorS3 returns events either directly from SQS message payload +// or from s3, if event was very large. +func (s *sqsMessagesCollector) auditEventFromSQSorS3(ctx context.Context, msg sqsTypes.Message) (apievents.AuditEvent, error) { + payloadType, err := validateSQSMessage(msg) + if err != nil { + return nil, trace.Wrap(err) + } + + var protoMarshaledOneOf []byte + switch payloadType { + // default case is hanlded in validateSQSMessage. + case payloadTypeS3Based: + protoMarshaledOneOf, err = s.downloadEventFromS3(ctx, *msg.Body) + if err != nil { + return nil, trace.Wrap(err) + } + case payloadTypeRawProtoEvent: + protoMarshaledOneOf, err = base64.StdEncoding.DecodeString(*msg.Body) + if err != nil { + return nil, trace.Wrap(err) + } + } + + var oneOf apievents.OneOf + if err := oneOf.Unmarshal(protoMarshaledOneOf); err != nil { + return nil, trace.Wrap(err) + } + event, err := apievents.FromOneOf(oneOf) + return event, trace.Wrap(err) +} + +func validateSQSMessage(msg sqsTypes.Message) (string, error) { + if msg.Body == nil || msg.MessageAttributes == nil { + // This should not happen. If it happen though, it will be retried + // and go to dead-letter queue after max attempts. + return "", trace.BadParameter("missing Body or MessageAttributes of msg: %v", msg) + } + if msg.ReceiptHandle == nil { + return "", trace.BadParameter("missing ReceiptHandle") + } + v := msg.MessageAttributes[payloadTypeAttr] + if v.StringValue == nil { + // This should not happen. If it happen though, it will be retried + // and go to dead-letter queue after max attempts. + return "", trace.BadParameter("message without %q attribute", payloadTypeAttr) + } + payloadType := *v.StringValue + if !slices.Contains([]string{payloadTypeRawProtoEvent, payloadTypeS3Based}, payloadType) { + return "", trace.BadParameter("unsupported payload type %s", payloadType) + } + return payloadType, nil +} + +type eventAndAckID struct { + event apievents.AuditEvent + receiptHandle string +} + +func errHandlingFnFromSQS(logger log.FieldLogger) func(ctx context.Context, errC chan error) { + return func(ctx context.Context, errC chan error) { + var errorsCount int + + defer func() { + if errorsCount > maxErrorCountForLogsOnSQSReceive { + logger.Errorf("Got %d errors from SQS collector, printed only first %d", errorsCount, maxErrorCountForLogsOnSQSReceive) + } + }() + + for { + select { + case <-ctx.Done(): + // if errorsCount > maxErrorCountForLogs, log will be printed via defer. + return + case err, ok := <-errC: + if !ok { + return + } + errorsCount++ + if errorsCount <= maxErrorCountForLogsOnSQSReceive { + logger.WithError(err).Error("Failure processing SQS messages") + } + } + } + } +} + +func (s *sqsMessagesCollector) downloadEventFromS3(ctx context.Context, payload string) ([]byte, error) { + decoded, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return nil, trace.Wrap(err) + } + + s3Payload := &apievents.AthenaS3EventPayload{} + if err := s3Payload.Unmarshal(decoded); err != nil { + return nil, trace.Wrap(err) + } + + path := s3Payload.GetPath() + versionID := s3Payload.GetVersionId() + + s.cfg.logger.Debugf("Downloading %v %v [%v].", s.cfg.payloadBucket, path, versionID) + + buf := manager.NewWriteAtBuffer([]byte{}) + written, err := s.cfg.payloadDownloader.Download(ctx, buf, &s3.GetObjectInput{ + Bucket: aws.String(s.cfg.payloadBucket), + Key: aws.String(path), + VersionId: aws.String(versionID), + }) + if err != nil { + return nil, awsutils.ConvertS3Error(err) + } + if written == 0 { + return nil, trace.NotFound("payload for %v is not found", path) + } + return buf.Bytes(), nil +} + +// writeToS3 is not doing anything then just receiving from channel and printing +// for now. It will be changed in next PRs to actually write to S3 via parquet writer. +func (c *consumer) writeToS3(ctx context.Context, eventsChan <-chan eventAndAckID) (int, error) { + var size int + for { + select { + case <-ctx.Done(): + return size, trace.Wrap(ctx.Err()) + case eventAndAckID, ok := <-eventsChan: + if !ok { + return size, nil + } + size++ + c.logger.Debugf("Received event: %s %s", eventAndAckID.event.GetID(), eventAndAckID.event.GetType()) + } + } +} diff --git a/lib/events/athena/consumer_test.go b/lib/events/athena/consumer_test.go new file mode 100644 index 0000000000000..a6fdabf3d696a --- /dev/null +++ b/lib/events/athena/consumer_test.go @@ -0,0 +1,502 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package athena + +import ( + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "math/big" + "strings" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sqs" + sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/utils" +) + +func Test_consumer_sqsMessagesCollector(t *testing.T) { + // channelClosedCondition returns function that can be used to check if eventually + // channel was closed. + channelClosedCondition := func(t *testing.T, ch <-chan eventAndAckID) func() bool { + return func() bool { + select { + case _, ok := <-ch: + if ok { + t.Log("Received unexpected message") + t.Fail() + return false + } else { + // channel is closed, that's what we are waiting for. + return true + } + default: + // retry + return false + } + } + } + + maxWaitTimeOnReceiveMessagesInFake := 5 * time.Millisecond + maxWaitOnResults := 200 * time.Millisecond + + t.Run("verify if events are sent over channel", func(t *testing.T) { + // Given SqsMessagesCollector reading from fake sqs with random wait time on receiveMessage call + // When 3 messages are published + // Then 3 messages can be received from eventsChan. + + // Given + fclock := clockwork.NewFakeClock() + fq := &fakeSQS{ + clock: fclock, + maxWaitTime: maxWaitTimeOnReceiveMessagesInFake, + } + cfg := validCollectCfgForTests(t) + cfg.sqsReceiver = fq + require.NoError(t, cfg.CheckAndSetDefaults()) + c := newSqsMessagesCollector(cfg) + eventsChan := c.getEventsChan() + + readSQSCtx, readCancel := context.WithCancel(context.Background()) + defer readCancel() + go c.fromSQS(readSQSCtx) + + // receiver is used to read messages from eventsChan. + r := &receiver{} + go r.Do(eventsChan) + + // When + wantEvents := []apievents.AuditEvent{ + &apievents.AppCreate{Metadata: apievents.Metadata{Type: events.AppCreateEvent}, AppMetadata: apievents.AppMetadata{AppName: "app1"}}, + &apievents.AppCreate{Metadata: apievents.Metadata{Type: events.AppCreateEvent}, AppMetadata: apievents.AppMetadata{AppName: "app2"}}, + &apievents.AppCreate{Metadata: apievents.Metadata{Type: events.AppCreateEvent}, AppMetadata: apievents.AppMetadata{AppName: "app3"}}, + } + fq.addEvents(wantEvents...) + // Advance clock to simulate random wait time on receive messages endpoint. + fclock.BlockUntil(cfg.noOfWorkers) + fclock.Advance(maxWaitTimeOnReceiveMessagesInFake) + + // Then + require.Eventually(t, func() bool { + return len(r.GetMsgs()) == 3 + }, maxWaitOnResults, 1*time.Millisecond) + requireEventsEqualInAnyOrder(t, wantEvents, eventAndAckIDToAuditEvents(r.GetMsgs())) + }) + + t.Run("verify if collector finishes execution (via closing channel) upon ctx.Cancel", func(t *testing.T) { + // Given SqsMessagesCollector reading from fake sqs with random wait time on receiveMessage call + // When ctx is canceled + // Then reading chan is closed. + + // Given + fclock := clockwork.NewFakeClock() + fq := &fakeSQS{ + clock: fclock, + maxWaitTime: maxWaitTimeOnReceiveMessagesInFake, + } + cfg := validCollectCfgForTests(t) + cfg.sqsReceiver = fq + require.NoError(t, cfg.CheckAndSetDefaults()) + c := newSqsMessagesCollector(cfg) + eventsChan := c.getEventsChan() + + readSQSCtx, readCancel := context.WithCancel(context.Background()) + go c.fromSQS(readSQSCtx) + + // When + readCancel() + + // Then + // Make sure that channel is closed. + require.Eventually(t, channelClosedCondition(t, eventsChan), maxWaitOnResults, 1*time.Millisecond) + }) + + t.Run("verify if collector finishes execution (via closing channel) upon reaching batchMaxItems", func(t *testing.T) { + // Given SqsMessagesCollector reading from fake sqs with random wait time on receiveMessage call + // When batchMaxItems is reached. + // Then reading chan is closed. + + // Given + fclock := clockwork.NewFakeClock() + fq := &fakeSQS{ + clock: fclock, + maxWaitTime: maxWaitTimeOnReceiveMessagesInFake, + } + cfg := validCollectCfgForTests(t) + cfg.sqsReceiver = fq + cfg.batchMaxItems = 3 + require.NoError(t, cfg.CheckAndSetDefaults()) + c := newSqsMessagesCollector(cfg) + + eventsChan := c.getEventsChan() + + readSQSCtx, readCancel := context.WithCancel(context.Background()) + defer readCancel() + + go c.fromSQS(readSQSCtx) + + // receiver is used to read messages from eventsChan. + r := &receiver{} + go r.Do(eventsChan) + + // When + wantEvents := []apievents.AuditEvent{ + &apievents.AppCreate{Metadata: apievents.Metadata{Type: events.AppCreateEvent}, AppMetadata: apievents.AppMetadata{AppName: "app1"}}, + &apievents.AppCreate{Metadata: apievents.Metadata{Type: events.AppCreateEvent}, AppMetadata: apievents.AppMetadata{AppName: "app2"}}, + &apievents.AppCreate{Metadata: apievents.Metadata{Type: events.AppCreateEvent}, AppMetadata: apievents.AppMetadata{AppName: "app3"}}, + } + fq.addEvents(wantEvents...) + fclock.BlockUntil(cfg.noOfWorkers) + fclock.Advance(maxWaitTimeOnReceiveMessagesInFake) + require.Eventually(t, func() bool { + return len(r.GetMsgs()) == 3 + }, maxWaitOnResults, 1*time.Millisecond) + + // Then + // Make sure that channel is closed. + require.Eventually(t, channelClosedCondition(t, eventsChan), maxWaitOnResults, 1*time.Millisecond) + requireEventsEqualInAnyOrder(t, wantEvents, eventAndAckIDToAuditEvents(r.GetMsgs())) + }) +} + +func validCollectCfgForTests(t *testing.T) sqsCollectConfig { + return sqsCollectConfig{ + sqsReceiver: &mockReceiver{}, + queueURL: "test-queue", + payloadBucket: "bucket", + payloadDownloader: &fakeS3manager{}, + logger: utils.NewLoggerForTests(), + errHandlingFn: func(ctx context.Context, errC chan error) { + err, ok := <-errC + if ok && err != nil { + // we don't expect error in that test case. + t.Log("Unexpected error", err) + t.Fail() + } + }, + } +} + +type fakeSQS struct { + mu sync.Mutex + msgs []sqsTypes.Message + clock clockwork.Clock + maxWaitTime time.Duration +} + +func (f *fakeSQS) addEvents(events ...apievents.AuditEvent) { + f.mu.Lock() + defer f.mu.Unlock() + for _, e := range events { + f.msgs = append(f.msgs, rawProtoMessage(e)) + } +} + +func (f *fakeSQS) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + // Let's use random sleep duration. That's how sqs works, you could wait up until max wait time but + // it can return earlier. + + randInt, err := rand.Int(rand.Reader, big.NewInt(f.maxWaitTime.Nanoseconds())) + if err != nil { + return nil, trace.Wrap(err) + } + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-f.clock.After(time.Duration(randInt.Int64())): + // continue below + } + f.mu.Lock() + defer f.mu.Unlock() + if len(f.msgs) > 0 { + out := &sqs.ReceiveMessageOutput{ + Messages: f.msgs, + } + f.msgs = nil + return out, nil + } + return &sqs.ReceiveMessageOutput{}, nil +} + +type receiver struct { + mu sync.Mutex + msgs []eventAndAckID +} + +func (f *receiver) Do(eventsChan <-chan eventAndAckID) { + for e := range eventsChan { + f.mu.Lock() + f.msgs = append(f.msgs, e) + f.mu.Unlock() + } +} + +func (f *receiver) GetMsgs() []eventAndAckID { + f.mu.Lock() + defer f.mu.Unlock() + return f.msgs +} + +func eventAndAckIDToAuditEvents(in []eventAndAckID) []apievents.AuditEvent { + var out []apievents.AuditEvent + for _, eventAndAckID := range in { + out = append(out, eventAndAckID.event) + } + return out +} + +func rawProtoMessage(in apievents.AuditEvent) sqsTypes.Message { + oneOf := apievents.MustToOneOf(in) + bb, err := oneOf.Marshal() + if err != nil { + panic(err) + } + return sqsTypes.Message{ + Body: aws.String(base64.StdEncoding.EncodeToString(bb)), + MessageAttributes: map[string]sqsTypes.MessageAttributeValue{ + payloadTypeAttr: {StringValue: aws.String(payloadTypeRawProtoEvent)}, + }, + ReceiptHandle: aws.String(uuid.NewString()), + } +} + +// TestSQSMessagesCollectorErrorsOnReceive verifies that workers fetching events +// from ReceiveMessage endpoint, will wait specified interval before retrying +// after receiving error from API call. +func TestSQSMessagesCollectorErrorsOnReceive(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + + mockReceiver := &mockReceiver{ + receiveMessageRespFn: func() (*sqs.ReceiveMessageOutput, error) { + return nil, errors.New("aws error") + }, + } + + errHandlingFn := func(ctx context.Context, errC chan error) { + require.ErrorContains(t, trace.NewAggregateFromChannel(errC, ctx), "aws error") + } + waitIntervalOnReceiveError := 5 * time.Millisecond + noOfWorker := 2 + iterationsToWait := 4 + expectedNoOfCalls := noOfWorker * iterationsToWait + + cfg := validCollectCfgForTests(t) + cfg.sqsReceiver = mockReceiver + cfg.noOfWorkers = noOfWorker + cfg.waitOnReceiveError = waitIntervalOnReceiveError + cfg.errHandlingFn = errHandlingFn + require.NoError(t, cfg.CheckAndSetDefaults()) + c := newSqsMessagesCollector(cfg) + + eventsChan := c.getEventsChan() + sqsCtx, sqsCancel := context.WithCancel(ctx) + go c.fromSQS(sqsCtx) + + <-time.After(time.Duration(iterationsToWait) * waitIntervalOnReceiveError) + sqsCancel() + select { + case <-ctx.Done(): + t.Fatal("Collector never finished") + case _, ok := <-eventsChan: + require.False(t, ok, "No data should be sent on events channel") + } + + gotNoOfCalls := mockReceiver.getNoOfCalls() + // We can't be sure that there will be equaly noOfCalls as expected, + // because they are process in async way, that's why margin in EquateApprox is used. + require.Empty(t, cmp.Diff(float32(gotNoOfCalls), float32(expectedNoOfCalls), cmpopts.EquateApprox(0, 4))) +} + +type mockReceiver struct { + receiveMessageRespFn func() (*sqs.ReceiveMessageOutput, error) + receiveMessageCountMu sync.Mutex + receiveMessageCount int +} + +func (m *mockReceiver) getNoOfCalls() int { + m.receiveMessageCountMu.Lock() + defer m.receiveMessageCountMu.Unlock() + return m.receiveMessageCount +} + +func (m *mockReceiver) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + m.receiveMessageCountMu.Lock() + m.receiveMessageCount++ + m.receiveMessageCountMu.Unlock() + return m.receiveMessageRespFn() +} + +func TestRunWithMinInterval(t *testing.T) { + ctx := context.Background() + t.Run("function returns earlier than minInterval, wait should happen", func(t *testing.T) { + fn := func(ctx context.Context) bool { + // did not reached max size + return false + } + minInterval := 5 * time.Millisecond + start := time.Now() + stop := runWithMinInterval(ctx, fn, minInterval) + elapsed := time.Since(start) + require.False(t, stop) + require.GreaterOrEqual(t, elapsed, minInterval) + }) + + t.Run("function takes longer than minInterval, noting more should happen", func(t *testing.T) { + minInterval := 5 * time.Millisecond + fn := func(ctx context.Context) bool { + // did not reached max size + select { + case <-time.After(2 * minInterval): + return false + case <-ctx.Done(): + return false + } + } + start := time.Now() + stop := runWithMinInterval(ctx, fn, minInterval) + elapsed := time.Since(start) + require.False(t, stop) + require.GreaterOrEqual(t, elapsed, 2*minInterval) + }) + + t.Run("reached maxBatchSize, wait should not happen", func(t *testing.T) { + fn := func(ctx context.Context) bool { + return true + } + minInterval := 5 * time.Millisecond + start := time.Now() + stop := runWithMinInterval(ctx, fn, minInterval) + elapsed := time.Since(start) + require.False(t, stop) + require.Less(t, elapsed, minInterval) + }) + + t.Run("context is canceled, make sure that stop is returned.", func(t *testing.T) { + minInterval := 5 * time.Millisecond + fn := func(ctx context.Context) bool { + // did not reached max size + select { + case <-time.After(minInterval): + return false + case <-ctx.Done(): + return false + } + } + ctx, cancel := context.WithCancel(ctx) + cancel() + stop := runWithMinInterval(ctx, fn, minInterval) + require.True(t, stop) + }) +} + +func TestErrHandlingFnFromSQS(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Minute) + defer cancel() + + log := utils.NewLoggerForTests() + // buf is used as output of logs, that we will use for assertions. + var buf bytes.Buffer + log.SetOutput(&buf) + + t.Run("a lot of errors, make sure only up to maxErrorCountForLogsOnSQSReceive are printed and total count", func(t *testing.T) { + buf.Reset() + noOfErrors := maxErrorCountForLogsOnSQSReceive + 1 + errorC := make(chan error, noOfErrors) + go func() { + for i := 0; i < noOfErrors; i++ { + errorC <- errors.New("some error") + } + close(errorC) + }() + errHandlingFnFromSQS(log)(ctx, errorC) + require.Equal(t, maxErrorCountForLogsOnSQSReceive, strings.Count(buf.String(), "some error"), "number of error log messages does not match") + require.Contains(t, buf.String(), fmt.Sprintf("Got %d errors from SQS collector, printed only first", noOfErrors)) + }) + + t.Run("few errors, no total count should be printed", func(t *testing.T) { + buf.Reset() + noOfErrors := 5 + errorC := make(chan error, noOfErrors) + go func() { + for i := 0; i < noOfErrors; i++ { + errorC <- errors.New("some error") + } + close(errorC) + }() + errHandlingFnFromSQS(log)(ctx, errorC) + require.Equal(t, noOfErrors, strings.Count(buf.String(), "some error"), "number of error log messages does not match") + require.NotContains(t, buf.String(), "printed only first") + }) + t.Run("no errors at all", func(t *testing.T) { + buf.Reset() + errorC := make(chan error, 10) + go func() { + // close without any errors sent means receiving loop finished without any err + close(errorC) + }() + errHandlingFnFromSQS(log)(ctx, errorC) + require.Empty(t, buf.String()) + }) + t.Run("no errors at all - stopped via ctx cancel", func(t *testing.T) { + buf.Reset() + errorC := make(chan error, 10) + defer close(errorC) + + ctx, inCancel := context.WithCancel(ctx) + inCancel() + + errHandlingFnFromSQS(log)(ctx, errorC) + require.Empty(t, buf.String()) + }) + + t.Run("there were a lot of errors, stopped via ctx cancel", func(t *testing.T) { + buf.Reset() + // unbuffered channel and a more messages, + // just make sure that errors are processed + // before cancel happen, used to avoid sleeping. + noOfErrors := maxErrorCountForLogsOnSQSReceive + 10 + + errorC := make(chan error) + defer close(errorC) + + ctx, inCancel := context.WithCancel(ctx) + go func() { + for i := 0; i < noOfErrors; i++ { + errorC <- errors.New("some error") + } + inCancel() + }() + + errHandlingFnFromSQS(log)(ctx, errorC) + require.Equal(t, maxErrorCountForLogsOnSQSReceive, strings.Count(buf.String(), "some error"), "number of error log messages does not match") + require.Contains(t, buf.String(), "printed only first") + }) +} diff --git a/lib/events/athena/fakequeue_test.go b/lib/events/athena/fakequeue_test.go index 24bc57d905ce8..f594162071015 100644 --- a/lib/events/athena/fakequeue_test.go +++ b/lib/events/athena/fakequeue_test.go @@ -18,8 +18,12 @@ import ( "context" "sync" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/sns" snsTypes "github.com/aws/aws-sdk-go-v2/service/sns/types" + "github.com/aws/aws-sdk-go-v2/service/sqs" + sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/google/uuid" ) // fakeQueue is used to fake SNS+SQS combination on AWS. @@ -57,11 +61,43 @@ func (f *fakeQueue) Publish(ctx context.Context, params *sns.PublishInput, optFn return nil, nil } -func (f *fakeQueue) getMessages() []fakeQueueMessage { +func (f *fakeQueue) ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error) { + msgs := f.dequeue() + if len(msgs) == 0 { + return &sqs.ReceiveMessageOutput{}, nil + } + out := make([]sqsTypes.Message, 0, 10) + for _, msg := range msgs { + out = append(out, sqsTypes.Message{ + Body: aws.String(msg.payload), + MessageAttributes: snsToSqsAttributes(msg.attributes), + ReceiptHandle: aws.String(uuid.NewString()), + }) + } + return &sqs.ReceiveMessageOutput{ + Messages: out, + }, nil +} + +func snsToSqsAttributes(in map[string]snsTypes.MessageAttributeValue) map[string]sqsTypes.MessageAttributeValue { + if in == nil { + return nil + } + out := map[string]sqsTypes.MessageAttributeValue{} + for k, v := range in { + out[k] = sqsTypes.MessageAttributeValue{ + DataType: v.DataType, + StringValue: v.StringValue, + } + } + return out +} + +func (f *fakeQueue) dequeue() []fakeQueueMessage { f.mu.Lock() defer f.mu.Unlock() batchSize := 10 - if len(f.msgs) < 1 { + if len(f.msgs) == 0 { return nil } if len(f.msgs) < batchSize { diff --git a/lib/events/athena/publisher.go b/lib/events/athena/publisher.go index 7840657860e0b..2906dd4351d9c 100644 --- a/lib/events/athena/publisher.go +++ b/lib/events/athena/publisher.go @@ -142,7 +142,7 @@ func (p *publisher) emitViaS3(ctx context.Context, uid string, marshaledEvent [] _, err = p.snsPublisher.Publish(ctx, &sns.PublishInput{ TopicArn: aws.String(p.topicARN), - Message: aws.String(string(buf)), + Message: aws.String(base64.StdEncoding.EncodeToString(buf)), MessageAttributes: map[string]snsTypes.MessageAttributeValue{ payloadTypeAttr: {DataType: aws.String("String"), StringValue: aws.String(payloadTypeS3Based)}, }, diff --git a/lib/events/athena/publisher_test.go b/lib/events/athena/publisher_test.go index 9f19399f0469f..122b36fd09d63 100644 --- a/lib/events/athena/publisher_test.go +++ b/lib/events/athena/publisher_test.go @@ -93,7 +93,7 @@ func Test_EmitAuditEvent(t *testing.T) { } err := p.EmitAuditEvent(context.Background(), tt.in) require.NoError(t, err) - out := fq.getMessages() + out := fq.dequeue() tt.wantCheck(t, out) }) } diff --git a/lib/events/s3sessions/s3handler.go b/lib/events/s3sessions/s3handler.go index f6e8f81ea388a..4f10b2b9e2271 100644 --- a/lib/events/s3sessions/s3handler.go +++ b/lib/events/s3sessions/s3handler.go @@ -28,7 +28,6 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/aws/credentials" awssession "github.com/aws/aws-sdk-go/aws/session" "github.com/aws/aws-sdk-go/service/s3" @@ -43,6 +42,7 @@ import ( "github.com/gravitational/teleport/lib/events" s3metrics "github.com/gravitational/teleport/lib/observability/metrics/s3" "github.com/gravitational/teleport/lib/session" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) // s3AllowedACL is the set of canned ACLs that S3 accepts @@ -90,7 +90,7 @@ type Config struct { // Insecure is an optional switch to opt out of https connections Insecure bool - //DisableServerSideEncryption is an optional switch to opt out of SSE in case the provider does not support it + // DisableServerSideEncryption is an optional switch to opt out of SSE in case the provider does not support it DisableServerSideEncryption bool } @@ -263,7 +263,7 @@ func (h *Handler) Upload(ctx context.Context, sessionID session.ID, reader io.Re } _, err = h.uploader.UploadWithContext(ctx, uploadInput) if err != nil { - return "", ConvertS3Error(err) + return "", awsutils.ConvertS3Error(err) } return fmt.Sprintf("%v://%v/%v", teleport.SchemeS3, h.Bucket, path), nil } @@ -286,9 +286,8 @@ func (h *Handler) Download(ctx context.Context, sessionID session.ID, writer io. Key: aws.String(h.path(sessionID)), VersionId: aws.String(versionID), }) - if err != nil { - return ConvertS3Error(err) + return awsutils.ConvertS3Error(err) } if written == 0 { return trace.NotFound("recording for %v is not found", sessionID) @@ -325,7 +324,7 @@ func (h *Handler) getOldestVersion(ctx context.Context, bucket string, prefix st return !lastPage }) if err != nil { - return "", ConvertS3Error(err) + return "", awsutils.ConvertS3Error(err) } if len(versions) == 0 { return "", trace.NotFound("%v/%v not found", bucket, prefix) @@ -345,7 +344,7 @@ func (h *Handler) deleteBucket(ctx context.Context) error { Bucket: aws.String(h.Bucket), }) if err != nil { - return ConvertS3Error(err) + return awsutils.ConvertS3Error(err) } for _, ver := range out.Versions { _, err := h.client.DeleteObjectWithContext(ctx, &s3.DeleteObjectInput{ @@ -354,13 +353,13 @@ func (h *Handler) deleteBucket(ctx context.Context) error { VersionId: ver.VersionId, }) if err != nil { - return ConvertS3Error(err) + return awsutils.ConvertS3Error(err) } } _, err = h.client.DeleteBucketWithContext(ctx, &s3.DeleteBucketInput{ Bucket: aws.String(h.Bucket), }) - return ConvertS3Error(err) + return awsutils.ConvertS3Error(err) } func (h *Handler) path(sessionID session.ID) string { @@ -379,7 +378,7 @@ func (h *Handler) ensureBucket(ctx context.Context) error { _, err := h.client.HeadBucketWithContext(ctx, &s3.HeadBucketInput{ Bucket: aws.String(h.Bucket), }) - err = ConvertS3Error(err) + err = awsutils.ConvertS3Error(err) // assumes that bucket is administered by other entity if err == nil { return nil @@ -393,7 +392,7 @@ func (h *Handler) ensureBucket(ctx context.Context) error { ACL: aws.String("private"), } _, err = h.client.CreateBucketWithContext(ctx, input) - err = ConvertS3Error(err, fmt.Sprintf("bucket %v already exists", aws.String(h.Bucket))) + err = awsutils.ConvertS3Error(err, fmt.Sprintf("bucket %v already exists", aws.String(h.Bucket))) if err != nil { if !trace.IsAlreadyExists(err) { return trace.Wrap(err) @@ -410,7 +409,7 @@ func (h *Handler) ensureBucket(ctx context.Context) error { }, } _, err = h.client.PutBucketVersioningWithContext(ctx, ver) - err = ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) + err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) if err != nil { return trace.Wrap(err) } @@ -427,28 +426,10 @@ func (h *Handler) ensureBucket(ctx context.Context) error { }}, }, }) - err = ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) + err = awsutils.ConvertS3Error(err, fmt.Sprintf("failed to set versioning state for bucket %q", h.Bucket)) if err != nil { return trace.Wrap(err) } } return nil } - -// ConvertS3Error wraps S3 error and returns trace equivalent -func ConvertS3Error(err error, args ...interface{}) error { - if err == nil { - return nil - } - if aerr, ok := err.(awserr.Error); ok { - switch aerr.Code() { - case s3.ErrCodeNoSuchKey, s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchUpload, "NotFound": - return trace.NotFound(aerr.Error(), args...) - case s3.ErrCodeBucketAlreadyExists, s3.ErrCodeBucketAlreadyOwnedByYou: - return trace.AlreadyExists(aerr.Error(), args...) - default: - return trace.BadParameter(aerr.Error(), args...) - } - } - return err -} diff --git a/lib/events/s3sessions/s3stream.go b/lib/events/s3sessions/s3stream.go index 6595d4fff6d43..87e74a05d01d2 100644 --- a/lib/events/s3sessions/s3stream.go +++ b/lib/events/s3sessions/s3stream.go @@ -33,6 +33,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/session" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) // CreateUpload creates a multipart upload @@ -57,7 +58,7 @@ func (h *Handler) CreateUpload(ctx context.Context, sessionID session.ID) (*even resp, err := h.client.CreateMultipartUploadWithContext(ctx, input) if err != nil { - return nil, ConvertS3Error(err) + return nil, awsutils.ConvertS3Error(err) } return &events.StreamUpload{SessionID: sessionID, ID: *resp.UploadId}, nil @@ -84,7 +85,7 @@ func (h *Handler) UploadPart(ctx context.Context, upload events.StreamUpload, pa resp, err := h.client.UploadPartWithContext(ctx, params) if err != nil { - return nil, ConvertS3Error(err) + return nil, awsutils.ConvertS3Error(err) } return &events.StreamPart{ETag: *resp.ETag, Number: partNumber}, nil @@ -98,7 +99,7 @@ func (h *Handler) abortUpload(ctx context.Context, upload events.StreamUpload) e } _, err := h.client.AbortMultipartUploadWithContext(ctx, req) if err != nil { - return ConvertS3Error(err) + return awsutils.ConvertS3Error(err) } return nil } @@ -133,7 +134,7 @@ func (h *Handler) CompleteUpload(ctx context.Context, upload events.StreamUpload } _, err := h.client.CompleteMultipartUploadWithContext(ctx, params) if err != nil { - return ConvertS3Error(err) + return awsutils.ConvertS3Error(err) } return nil } @@ -150,7 +151,7 @@ func (h *Handler) ListParts(ctx context.Context, upload events.StreamUpload) ([] PartNumberMarker: partNumberMarker, }) if err != nil { - return nil, ConvertS3Error(err) + return nil, awsutils.ConvertS3Error(err) } for _, part := range re.Parts { parts = append(parts, events.StreamPart{ @@ -189,7 +190,7 @@ func (h *Handler) ListUploads(ctx context.Context) ([]events.StreamUpload, error } re, err := h.client.ListMultipartUploadsWithContext(ctx, input) if err != nil { - return nil, ConvertS3Error(err) + return nil, awsutils.ConvertS3Error(err) } for _, upload := range re.Uploads { uploads = append(uploads, events.StreamUpload{ diff --git a/lib/service/service.go b/lib/service/service.go index 1c249250d1e21..117847984e7fc 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1407,7 +1407,8 @@ func initAuthExternalAuditLog(ctx context.Context, auditConfig types.ClusterAudi case teleport.ComponentAthena: hasNonFileLog = true cfg := athena.Config{ - Region: auditConfig.Region(), + Region: auditConfig.Region(), + Backend: backend, } err = cfg.SetFromURL(uri) if err != nil { diff --git a/lib/utils/aws/s3.go b/lib/utils/aws/s3.go new file mode 100644 index 0000000000000..cb63d379127ea --- /dev/null +++ b/lib/utils/aws/s3.go @@ -0,0 +1,64 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package aws + +import ( + "errors" + + s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types" + "github.com/aws/aws-sdk-go/aws/awserr" + "github.com/aws/aws-sdk-go/service/s3" + "github.com/gravitational/trace" +) + +// ConvertS3Error wraps S3 error and returns trace equivalent +// It works on both sdk v1 and v2. +func ConvertS3Error(err error, args ...interface{}) error { + if err == nil { + return nil + } + if aerr, ok := err.(awserr.Error); ok { + switch aerr.Code() { + case s3.ErrCodeNoSuchKey, s3.ErrCodeNoSuchBucket, s3.ErrCodeNoSuchUpload, "NotFound": + return trace.NotFound(aerr.Error(), args...) + case s3.ErrCodeBucketAlreadyExists, s3.ErrCodeBucketAlreadyOwnedByYou: + return trace.AlreadyExists(aerr.Error(), args...) + default: + return trace.BadParameter(aerr.Error(), args...) + } + } + + var noSuchKey *s3Types.NoSuchKey + if errors.As(err, &noSuchKey) { + return trace.NotFound(noSuchKey.Error(), args...) + } + var noSuchBucket *s3Types.NoSuchBucket + if errors.As(err, &noSuchBucket) { + return trace.NotFound(noSuchBucket.Error(), args...) + } + var noSuchUpload *s3Types.NoSuchUpload + if errors.As(err, &noSuchUpload) { + return trace.NotFound(noSuchUpload.Error(), args...) + } + var bucketAlreadyExists *s3Types.BucketAlreadyExists + if errors.As(err, &bucketAlreadyExists) { + return trace.AlreadyExists(bucketAlreadyExists.Error(), args...) + } + var bucketAlreadyOwned *s3Types.BucketAlreadyOwnedByYou + if errors.As(err, &bucketAlreadyOwned) { + return trace.AlreadyExists(bucketAlreadyOwned.Error(), args...) + } + return err +}