Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 77 additions & 4 deletions lib/events/athena/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,22 +72,29 @@ type consumer struct {
perDateFileParquetWriter func(ctx context.Context, date string) (source.ParquetFile, error)

collectConfig sqsCollectConfig

sqsDeleter sqsDeleter
queueURL string
}

type sqsReceiver interface {
ReceiveMessage(ctx context.Context, params *sqs.ReceiveMessageInput, optFns ...func(*sqs.Options)) (*sqs.ReceiveMessageOutput, error)
}

type sqsDeleter interface {
DeleteMessageBatch(ctx context.Context, params *sqs.DeleteMessageBatchInput, optFns ...func(*sqs.Options)) (*sqs.DeleteMessageBatchOutput, error)
}

type s3downloader interface {
Download(ctx context.Context, w io.WriterAt, input *s3.GetObjectInput, options ...func(*manager.Downloader)) (n int64, err error)
}

func newConsumer(cfg Config) (*consumer, error) {
s3client := s3.NewFromConfig(*cfg.AWSConfig)
sqsReceiver := sqs.NewFromConfig(*cfg.AWSConfig)
sqsClient := sqs.NewFromConfig(*cfg.AWSConfig)

collectCfg := sqsCollectConfig{
sqsReceiver: sqsReceiver,
sqsReceiver: sqsClient,
queueURL: cfg.QueueURL,
// TODO(tobiaszheller): use s3 manager from teleport observability.
payloadDownloader: manager.NewDownloader(s3client),
Expand All @@ -110,6 +117,8 @@ func newConsumer(cfg Config) (*consumer, error) {
batchMaxItems: cfg.BatchMaxItems,
batchMaxInterval: cfg.BatchMaxInterval,
collectConfig: collectCfg,
sqsDeleter: sqsClient,
queueURL: cfg.QueueURL,
perDateFileParquetWriter: func(ctx context.Context, date string) (source.ParquetFile, error) {
key := fmt.Sprintf("%s/%s/%s.parquet", cfg.locationS3Prefix, date, uuid.NewString())

Expand Down Expand Up @@ -213,8 +222,7 @@ func (c *consumer) processBatchOfEvents(ctx context.Context) (reachedMaxSize boo
return false, trace.Wrap(err)
}
size = len(toDelete)
return size >= c.batchMaxItems, nil
// TODO(tobiaszheller): delete messages from queue in next PR.
return size >= c.batchMaxItems, trace.Wrap(c.deleteMessagesFromQueue(ctx, toDelete))
}

type sqsCollectConfig struct {
Expand Down Expand Up @@ -637,3 +645,68 @@ func (pw *parquetWriter) Close() error {
}
return trace.Wrap(pw.closer.Close())
}

func (c *consumer) deleteMessagesFromQueue(ctx context.Context, handles []string) error {
if len(handles) == 0 {
return nil
}

const (
// maxDeleteBatchSize defines maximum number of handles passed to deleteMessage endpoint, limited by AWS.
maxDeleteBatchSize = 10
// noOfWorkers defines number of workers which concurrently process delete batch request.
noOfWorkers = 5
)

errorsCh := make(chan error, len(handles))
workerCh := make(chan []string, noOfWorkers)

var wg sync.WaitGroup

// Start the worker goroutines
for i := 0; i < noOfWorkers; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for handles := range workerCh {
entries := make([]sqsTypes.DeleteMessageBatchRequestEntry, 0, len(handles))
for _, h := range handles {
entries = append(entries, sqsTypes.DeleteMessageBatchRequestEntry{
Id: aws.String(uuid.NewString()),
ReceiptHandle: aws.String(h),
})
}
resp, err := c.sqsDeleter.DeleteMessageBatch(ctx, &sqs.DeleteMessageBatchInput{
QueueUrl: aws.String(c.queueURL),
Entries: entries,
})
if err != nil {
errorsCh <- trace.Wrap(err, "error on calling DeleteMessageBatch")
continue
}
for _, entry := range resp.Failed {
// TODO(tobiaszheller): come back at some point and check if there are errors that we should filter.
// Deleting the same handle twice does not result in error.
errorsCh <- trace.Errorf("failed to delete message with ID %s, sender fault %v: %s", aws.ToString(entry.Id), entry.SenderFault, aws.ToString(entry.Message))
}
}
}()
}

// Batch the receipt handles and send them to the worker pool.
for i := 0; i < len(handles); i += maxDeleteBatchSize {
end := i + maxDeleteBatchSize
if end > len(handles) {
end = len(handles)
}
workerCh <- handles[i:end]
}
close(workerCh)

wg.Wait()
// We can close errorsCh when all goroutine has finished, now we will
// be able to collect results.
close(errorsCh)

return trace.Wrap(trace.NewAggregateFromChannel(errorsCh, ctx))
}
115 changes: 115 additions & 0 deletions lib/events/athena/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,3 +573,118 @@ func TestConsumerWriteToS3(t *testing.T) {
})
}
}

func TestDeleteMessagesFromQueue(t *testing.T) {
t.Parallel()
ctx := context.Background()

handlesGen := func(n int) []string {
out := make([]string, 0, n)
for i := 0; i < n; i++ {
out = append(out, fmt.Sprintf("handle-%d", i))
}
return out
}
noOfHandles := 18
handles := handlesGen(noOfHandles)

tests := []struct {
name string
mockRespFn func(ctx context.Context, params *sqs.DeleteMessageBatchInput) (*sqs.DeleteMessageBatchOutput, error)
wantCheck func(t *testing.T, err error, mock *mockSQSDeleter)
}{
{
name: "delete returns no error, expect 2 calls to delete",
mockRespFn: func(ctx context.Context, params *sqs.DeleteMessageBatchInput) (*sqs.DeleteMessageBatchOutput, error) {
if aws.ToString(params.QueueUrl) == "" {
return nil, errors.New("mock called with empty QueueUrl")
}
if noOfEntries := len(params.Entries); noOfEntries > 10 || noOfEntries == 0 {
return nil, fmt.Errorf("mock called with invalid number of entries %d", noOfEntries)
}
return &sqs.DeleteMessageBatchOutput{}, nil
},
wantCheck: func(t *testing.T, err error, mock *mockSQSDeleter) {
require.NoError(t, err)
require.Equal(t, 2, mock.calls)
require.Equal(t, noOfHandles, mock.noOfEntries)
},
},
{
name: "delete returns top level error, make sure it's returned",
mockRespFn: func(ctx context.Context, params *sqs.DeleteMessageBatchInput) (*sqs.DeleteMessageBatchOutput, error) {
if aws.ToString(params.QueueUrl) == "" {
return nil, errors.New("mock called with empty QueueUrl")
}
if noOfEntries := len(params.Entries); noOfEntries > 10 || noOfEntries == 0 {
return nil, fmt.Errorf("mock called with invalid number of entries %d", noOfEntries)
}
return nil, errors.New("AWS API err")
},
wantCheck: func(t *testing.T, err error, _ *mockSQSDeleter) {
require.ErrorContains(t, err, "AWS API err")
},
},
{
name: "half of entries returns error",
mockRespFn: func(ctx context.Context, params *sqs.DeleteMessageBatchInput) (*sqs.DeleteMessageBatchOutput, error) {
success := make([]sqsTypes.DeleteMessageBatchResultEntry, 0)
failed := make([]sqsTypes.BatchResultErrorEntry, 0)
for i, e := range params.Entries {
if i%2 == 0 {
success = append(success, sqsTypes.DeleteMessageBatchResultEntry{
Id: e.Id,
})
} else {
failed = append(failed, sqsTypes.BatchResultErrorEntry{
Id: e.Id,
Message: aws.String("entry failed"),
})
}
}
return &sqs.DeleteMessageBatchOutput{
Failed: failed,
Successful: success,
}, nil
},
wantCheck: func(t *testing.T, err error, mock *mockSQSDeleter) {
require.Error(t, err)
agg, ok := trace.Unwrap(err).(trace.Aggregate)
require.True(t, ok)
for _, errFromAgg := range agg.Errors() {
require.ErrorContains(t, errFromAgg, "entry failed")
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := &mockSQSDeleter{
respFn: tt.mockRespFn,
}
c := consumer{
sqsDeleter: mock,
queueURL: "queue-url",
}
err := c.deleteMessagesFromQueue(ctx, handles)
tt.wantCheck(t, err, mock)
})
}
}

type mockSQSDeleter struct {
respFn func(ctx context.Context, params *sqs.DeleteMessageBatchInput) (*sqs.DeleteMessageBatchOutput, error)

// mu protects fields below
mu sync.Mutex
calls int
noOfEntries int
}

func (m *mockSQSDeleter) DeleteMessageBatch(ctx context.Context, params *sqs.DeleteMessageBatchInput, optFns ...func(*sqs.Options)) (*sqs.DeleteMessageBatchOutput, error) {
m.mu.Lock()
defer m.mu.Unlock()
m.calls++
m.noOfEntries += len(params.Entries)
return m.respFn(ctx, params)
}