diff --git a/lib/events/athena/querier.go b/lib/events/athena/querier.go index af44909591edb..fa28875d1c3e3 100644 --- a/lib/events/athena/querier.go +++ b/lib/events/athena/querier.go @@ -400,18 +400,24 @@ func (q *querier) streamEventsFromChunk(ctx context.Context, date, chunk string) reader.Close() } + var prevErr error + return stream.Func(func() (eventParquet, error) { + if prevErr != nil { + return eventParquet{}, prevErr + } // conventional wisdom says that we should use a larger persistent buffer here // but in loadtesting this API was abserved having almost twice the throughput // with a single element local buf variable instead. var buf [1]eventParquet - _, err := reader.Read(buf[:]) - if err != nil { + n, err := reader.Read(buf[:]) + if n == 0 && err != nil { if errors.Is(err, io.EOF) { return eventParquet{}, io.EOF } return eventParquet{}, trace.Wrap(err) } + prevErr = err return buf[0], nil }, closer) } diff --git a/lib/events/athena/querier_test.go b/lib/events/athena/querier_test.go index 551999f7bef77..731f32c6a07b4 100644 --- a/lib/events/athena/querier_test.go +++ b/lib/events/athena/querier_test.go @@ -19,25 +19,33 @@ package athena import ( + "bytes" "context" "errors" "fmt" + "io" "log/slog" + "path/filepath" "strings" + "sync" "testing" "time" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/athena" athenaTypes "github.com/aws/aws-sdk-go-v2/service/athena/types" + "github.com/aws/aws-sdk-go-v2/service/s3" + s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/dustin/go-humanize" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/jonboulle/clockwork" + "github.com/parquet-go/parquet-go" "github.com/stretchr/testify/require" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/internalutils/stream" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/lib/events" @@ -955,3 +963,121 @@ func (f *fakeAthenaResultsGetter) GetQueryResults(ctx context.Context, params *a }, }, nil } + +func Test_querier_streamEventsFromChunk(t *testing.T) { + const ( + tableName = "test_table" + bucketName = "test_bucket" + prefix = "test_prefix" + date = "2025-12-01" + chunkID = "test_chunk" + ) + + event1 := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-1", + }, + } + event2 := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-2", + }, + } + + tests := []struct { + name string + events []apievents.AuditEvent + }{ + { + name: "single-event chunk", + events: []apievents.AuditEvent{event1}, + }, + { + name: "multi-event chunk", + events: []apievents.AuditEvent{event1, event2}, + }, + } + for _, tt := range tests { + // add event parquets (in one .parquet file) to mock S3 getter + payloads, err := auditEventsToParquet(tt.events) + require.NoError(t, err) + + buf := new(bytes.Buffer) + writer := parquet.NewGenericWriter[eventParquet](buf) + _, err = writer.Write(payloads) + require.NoError(t, err) + require.NoError(t, writer.Close()) + + key := fmt.Sprintf("%s/%s/%s.parquet", prefix, date, chunkID) + file := filepath.Join(bucketName, key) + mockS3 := &mockS3Getter{ + files: map[string][]byte{ + file: buf.Bytes(), + }, + } + + q := &querier{ + querierConfig: querierConfig{ + tablename: tableName, + locationS3Bucket: bucketName, + locationS3Prefix: prefix, + logger: slog.Default(), + tracer: tracing.NoopTracer(teleport.ComponentAthena), + }, + s3Getter: mockS3, + } + + eventStream := q.streamEventsFromChunk(t.Context(), date, chunkID) + eventParquets, err := stream.Collect(eventStream) + require.NoError(t, err) + require.Len(t, eventParquets, len(payloads)) + for i, e := range eventParquets { + require.Equal(t, payloads[i].UID, e.UID) + } + } +} + +func auditEventsToParquet(in []apievents.AuditEvent) ([]eventParquet, error) { + out := make([]eventParquet, 0, len(in)) + + for _, e := range in { + p, err := auditEventToParquet(e) + if err != nil { + return nil, err + } + out = append(out, *p) + } + return out, nil +} + +type mockS3Getter struct { + mu sync.Mutex + files map[string][]byte +} + +func (m *mockS3Getter) GetObject(ctx context.Context, params *s3.GetObjectInput, optFns ...func(*s3.Options)) (*s3.GetObjectOutput, error) { + m.mu.Lock() + defer m.mu.Unlock() + + file := filepath.Join(*params.Bucket, *params.Key) + if obj, ok := m.files[file]; ok { + return &s3.GetObjectOutput{ + Body: io.NopCloser(bytes.NewReader(obj)), + }, nil + } + return nil, &s3Types.NoSuchKey{Message: aws.String("key does not exist")} +} + +func (m *mockS3Getter) ListObjectsV2(ctx context.Context, params *s3.ListObjectsV2Input, optFns ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) { + return &s3.ListObjectsV2Output{}, nil +}