diff --git a/go.mod b/go.mod index 9b5c2b8518249..2d614b26fd83c 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.65 github.com/aws/aws-sdk-go-v2/service/athena v1.26.1 github.com/aws/aws-sdk-go-v2/service/ec2 v1.97.0 + github.com/aws/aws-sdk-go-v2/service/glue v1.45.3 github.com/aws/aws-sdk-go-v2/service/rds v1.43.3 github.com/aws/aws-sdk-go-v2/service/s3 v1.33.1 github.com/aws/aws-sdk-go-v2/service/sns v1.20.10 diff --git a/go.sum b/go.sum index 8c68920bafff9..7c392b7d74349 100644 --- a/go.sum +++ b/go.sum @@ -244,6 +244,7 @@ github.com/aws/aws-sdk-go v1.44.244/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8 github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/aws/aws-sdk-go-v2 v1.16.2/go.mod h1:ytwTPBG6fXTZLxxeeCCWj2/EMYp/xDUgX+OET6TLNNU= github.com/aws/aws-sdk-go-v2 v1.17.3/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2 v1.17.8/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2 v1.18.0 h1:882kkTpSFhdgYRKVZ/VCgf7sd0ru57p2JCxz4/oN5RY= github.com/aws/aws-sdk-go-v2 v1.18.0/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.1/go.mod h1:n8Bs1ElDD2wJ9kCRTczA83gYbBmjSwZp3umc6zF4EeM= @@ -266,10 +267,12 @@ github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.65 h1:4irvSxFf0u7pQdtpmUoD github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.65/go.mod h1:BAWKiL53LT19UMewYr9YhZ8xPO69u6NwmGUjSjRwUdM= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.9/go.mod h1:AnVH5pvai0pAF4lXRq0bmhbes1u9R8wTE+g+183bZNM= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27/go.mod h1:a1/UpzeyBBerajpnP5nGZa9mGzsBn5cOKxm6NWQsvoI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.32/go.mod h1:RudqOgadTWdcS3t/erPQo24pcVEoYyqj/kKW5Vya21I= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33 h1:kG5eQilShqmJbv11XL1VpyDbaEJzWxd4zRiCG30GSn4= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.33/go.mod h1:7i0PF1ME/2eUPFcjkVIwq+DOygHEoK92t5cDqNgYbIw= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.3/go.mod h1:ssOhaLpRlh88H3UmEcsBoVKq309quMvm3Ds8e9d4eJM= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21/go.mod h1:+Gxn8jYn5k9ebfHEqlhrMirFjSW0v0C9fI+KN5vk2kE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.26/go.mod h1:vq86l7956VgFr0/FWQ2BWnK07QC3WYsepKzy33qqY5U= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27 h1:vFQlirhuM8lLlpI7imKOMsjdQLuN9CPi+k44F/OFVsk= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.27/go.mod h1:UrHnn3QV/d0pBZ6QBAEQcqFLf8FAzLmoUfPVIueOvoM= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.10/go.mod h1:8DcYQcz0+ZJaSxANlHIsbbi6S+zMwjwdDqwW3r9AzaE= @@ -282,6 +285,8 @@ github.com/aws/aws-sdk-go-v2/service/athena v1.26.1 h1:ztONDoMfRjoIUzp0cmCIzKVzv github.com/aws/aws-sdk-go-v2/service/athena v1.26.1/go.mod h1:97btS9UBEnajlbXXJkaCAFIu1j3vfJKdQCnIhs853xY= github.com/aws/aws-sdk-go-v2/service/ec2 v1.97.0 h1:glGFVlA0MVrOpDF+KsVZZA/QCwykYPanYMW0DoIJN34= github.com/aws/aws-sdk-go-v2/service/ec2 v1.97.0/go.mod h1:L3ZT0N/vBsw77mOAawXmRnREpEjcHd2v5Hzf7AkIH8M= +github.com/aws/aws-sdk-go-v2/service/glue v1.45.3 h1:yWGd1MsH+LtaBuMnSxYB3mXxFIIpKmV9msv8usr0IBs= +github.com/aws/aws-sdk-go-v2/service/glue v1.45.3/go.mod h1:RdegNsxdf+QYMKGrCrEBr8KuojiJeMZ8aGi5iEJqqMI= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1/go.mod h1:GeUru+8VzrTXV/83XyMJ80KpH8xO89VPoUileyNQ+tc= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 h1:y2+VQzC6Zh2ojtV2LoC0MNwHWc6qXv/j2vrQtlftkdA= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11/go.mod h1:iV4q2hsqtNECrfmlXyord9u4zyuFEJX9eLgLpSPzWA8= diff --git a/lib/events/athena/consumer.go b/lib/events/athena/consumer.go index 09622ea101719..b058508c7ae5e 100644 --- a/lib/events/athena/consumer.go +++ b/lib/events/athena/consumer.go @@ -526,11 +526,16 @@ func (s *sqsMessagesCollector) downloadEventFromS3(ctx context.Context, payload s.cfg.logger.Debugf("Downloading %v %v [%v].", s.cfg.payloadBucket, path, versionID) + var versionIDPtr *string + if versionID != "" { + versionIDPtr = aws.String(versionID) + } + buf := manager.NewWriteAtBuffer([]byte{}) written, err := s.cfg.payloadDownloader.Download(ctx, buf, &s3.GetObjectInput{ Bucket: aws.String(s.cfg.payloadBucket), Key: aws.String(path), - VersionId: aws.String(versionID), + VersionId: versionIDPtr, }) if err != nil { return nil, awsutils.ConvertS3Error(err) diff --git a/lib/events/athena/integration_test.go b/lib/events/athena/integration_test.go new file mode 100644 index 0000000000000..44cf2dd0afd72 --- /dev/null +++ b/lib/events/athena/integration_test.go @@ -0,0 +1,455 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package athena + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/athena" + athenaTypes "github.com/aws/aws-sdk-go-v2/service/athena/types" + "github.com/aws/aws-sdk-go-v2/service/glue" + glueTypes "github.com/aws/aws-sdk-go-v2/service/glue/types" + "github.com/aws/aws-sdk-go-v2/service/sns" + "github.com/aws/aws-sdk-go-v2/service/sqs" + sqsTypes "github.com/aws/aws-sdk-go-v2/service/sqs/types" + "github.com/google/go-cmp/cmp" + "github.com/google/uuid" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/api/utils/retryutils" + "github.com/gravitational/teleport/lib/backend/memory" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/test" +) + +type athenaContext struct { + log *Log + clock clockwork.Clock + testID string + database string + tablename string + s3eventsLocation string + s3resultsLocation string + s3largePayloads string + batcherInterval time.Duration +} + +func TestIntegrationAthenaSearchSessionEventsBySessionID(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + ac := setupAthenaContext(t, ctx, athenaContextConfig{}) + auditLogger := &eventuallyConsitentAuditLogger{ + inner: ac.log, + // Additional 5s is used to compensate for uploading parquet on s3. + queryDelay: ac.batcherInterval + 5*time.Second, + } + eventsSuite := test.EventsSuite{ + Log: auditLogger, + Clock: ac.clock, + SearchSessionEvensBySessionIDTimeout: ac.batcherInterval + 10*time.Second, + } + + eventsSuite.SearchSessionEventsBySessionID(t) +} + +func TestIntegrationAthenaSessionEventsCRUD(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + ac := setupAthenaContext(t, ctx, athenaContextConfig{}) + auditLogger := &eventuallyConsitentAuditLogger{ + inner: ac.log, + // Additional 5s is used to compensate for uploading parquet on s3. + queryDelay: ac.batcherInterval + 5*time.Second, + } + eventsSuite := test.EventsSuite{ + Log: auditLogger, + Clock: ac.clock, + } + + eventsSuite.SessionEventsCRUD(t) +} + +func TestIntegrationAthenaEventPagination(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + ac := setupAthenaContext(t, ctx, athenaContextConfig{}) + auditLogger := &eventuallyConsitentAuditLogger{ + inner: ac.log, + // Additional 5s is used to compensate for uploading parquet on s3. + queryDelay: ac.batcherInterval + 5*time.Second, + } + eventsSuite := test.EventsSuite{ + Log: auditLogger, + Clock: ac.clock, + } + + eventsSuite.EventPagination(t) +} + +func TestIntegrationAthenaLargeEvents(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) + defer cancel() + + ac := setupAthenaContext(t, ctx, athenaContextConfig{maxBatchSize: 1}) + in := &apievents.SessionStart{ + Metadata: apievents.Metadata{ + Index: 2, + Type: events.SessionStartEvent, + ID: uuid.NewString(), + Code: strings.Repeat("d", 200000), + Time: ac.clock.Now().UTC(), + }, + } + err := ac.log.EmitAuditEvent(ctx, in) + require.NoError(t, err) + + var history []apievents.AuditEvent + // We have batch time 10s, and 5s for upload and additional buffer for s3 download + err = retryutils.RetryStaticFor(time.Second*20, time.Second*2, func() error { + history, _, err = ac.log.SearchEvents(ac.clock.Now().UTC().Add(-1*time.Minute), ac.clock.Now().UTC(), "", nil, 10, types.EventOrderDescending, "") + if err != nil { + return err + } + if len(history) == 0 { + return errors.New("events not propagated yet") + } + return nil + }) + require.NoError(t, err) + require.Len(t, history, 1) + require.Empty(t, cmp.Diff(in, history[0])) +} + +// athenaContextConfig is optional config to override defaults in athena context. +type athenaContextConfig struct { + maxBatchSize int +} + +func setupAthenaContext(t *testing.T, ctx context.Context, cfg athenaContextConfig) *athenaContext { + testEnabled := os.Getenv(teleport.AWSRunTests) + if ok, _ := strconv.ParseBool(testEnabled); !ok { + t.Skip("Skipping AWS-dependent test suite.") + } + + testID := fmt.Sprintf("auditlogs-integrationtests-%v", uuid.New().String()) + + clock := clockwork.NewRealClock() + + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, backend.Close()) + }) + + ac := &athenaContext{ + clock: clock, + testID: testID, + database: "auditlogs_integrationtests", + s3eventsLocation: fmt.Sprintf("%s/%s/events", "s3://auditlogs-integrationtests", testID), + s3resultsLocation: fmt.Sprintf("%s/%s/results", "s3://auditlogs-integrationtests", testID), + s3largePayloads: fmt.Sprintf("%s/%s/large_payloads", "s3://auditlogs-integrationtests", testID), + tablename: strings.ReplaceAll(testID, "-", "_"), + batcherInterval: 10 * time.Second, + } + infraOut := ac.setupInfraWithCleanup(t, ctx) + + region := infraOut.region + if region == "" { + region = "eu-central-1" + } + + log, err := New(ctx, Config{ + Region: region, + Clock: clock, + Database: ac.database, + TableName: ac.tablename, + TopicARN: infraOut.topicARN, + QueueURL: infraOut.queueURL, + LocationS3: ac.s3eventsLocation, + QueryResultsS3: ac.s3resultsLocation, + LargeEventsS3: ac.s3largePayloads, + BatchMaxInterval: ac.batcherInterval, + BatchMaxItems: cfg.maxBatchSize, + Backend: backend, + Workgroup: "primary", + }) + require.NoError(t, err) + + ac.log = log + t.Cleanup(func() { + ac.Close(t) + }) + + t.Logf("Initialized Athena test suite %q\n", testID) + + return ac +} + +type infraOutputs struct { + topicARN string + queueURL string + region string +} + +func (ac *athenaContext) setupInfraWithCleanup(t *testing.T, ctx context.Context) *infraOutputs { + const timeoutDurationOnCleanup = 1 * time.Minute + + awsCfg, err := awsconfig.LoadDefaultConfig(ctx) + require.NoError(t, err) + + // Create SNS topic and set cleanup fn. + snsClient := sns.NewFromConfig(awsCfg) + topicCreated, err := snsClient.CreateTopic(ctx, &sns.CreateTopicInput{ + Name: aws.String(ac.testID), + }) + require.NoError(t, err) + t.Cleanup(func() { + cleanupCtx, cancel := context.WithTimeout(context.Background(), timeoutDurationOnCleanup) + defer cancel() + _, err = snsClient.DeleteTopic(cleanupCtx, &sns.DeleteTopicInput{ + TopicArn: topicCreated.TopicArn, + }) + assert.NoError(t, err) + }) + + // Create SQS queue and set cleanup fn. + sqsClient := sqs.NewFromConfig(awsCfg) + queueCreated, err := sqsClient.CreateQueue(ctx, &sqs.CreateQueueInput{ + QueueName: aws.String(ac.testID), + }) + require.NoError(t, err) + t.Cleanup(func() { + cleanupCtx, cancel := context.WithTimeout(context.Background(), timeoutDurationOnCleanup) + defer cancel() + _, err := sqsClient.DeleteQueue(cleanupCtx, &sqs.DeleteQueueInput{ + QueueUrl: queueCreated.QueueUrl, + }) + assert.NoError(t, err) + }) + + // Set created queue as subscriber to topic and use valid permissions. + queueAttr, err := sqsClient.GetQueueAttributes(ctx, &sqs.GetQueueAttributesInput{ + QueueUrl: queueCreated.QueueUrl, + AttributeNames: []sqsTypes.QueueAttributeName{sqsTypes.QueueAttributeNameQueueArn}, + }) + require.NoError(t, err) + queueArn := queueAttr.Attributes["QueueArn"] + type StatementEntry struct { + Effect string + Action []string + Resource string + Principal map[string]string + Condition map[string]map[string]string + } + type PolicyDocument struct { + Version string + Statement []StatementEntry + } + sqsAccessPolicy := PolicyDocument{ + Version: "2012-10-17", + Statement: []StatementEntry{ + { + Effect: "Allow", + Action: []string{"SQS:SendMessage"}, + Resource: queueArn, + Principal: map[string]string{ + "AWS": "*", + }, + Condition: map[string]map[string]string{ + "ArnLike": { + "aws:SourceArn": *topicCreated.TopicArn, + }, + }, + }, + }, + } + marshaledPolicy, err := json.Marshal(sqsAccessPolicy) + require.NoError(t, err) + _, err = sqsClient.SetQueueAttributes(ctx, &sqs.SetQueueAttributesInput{ + Attributes: map[string]string{ + "Policy": string(marshaledPolicy), + }, + QueueUrl: queueCreated.QueueUrl, + }) + require.NoError(t, err) + _, err = snsClient.Subscribe(ctx, &sns.SubscribeInput{ + TopicArn: topicCreated.TopicArn, + Protocol: aws.String("sqs"), + Attributes: map[string]string{ + "RawMessageDelivery": "true", + }, + Endpoint: aws.String(queueArn), + }) + require.NoError(t, err) + + // Create glue db if not exists + glueClient := glue.NewFromConfig(awsCfg) + _, err = glueClient.GetDatabase(ctx, &glue.GetDatabaseInput{ + Name: aws.String(ac.database), + }) + if err != nil { + var notFound *glueTypes.EntityNotFoundException + if errors.As(err, ¬Found) { + _, createErr := glueClient.CreateDatabase(ctx, &glue.CreateDatabaseInput{ + DatabaseInput: &glueTypes.DatabaseInput{ + Name: aws.String(ac.database), + }, + }) + require.NoError(t, createErr) + } else { + assert.Fail(t, "unexpected err: %v", err) + } + } + + // Create athena table + athenaClient := athena.NewFromConfig(awsCfg) + startQueryExecResp, err := athenaClient.StartQueryExecution(ctx, &athena.StartQueryExecutionInput{ + QueryString: aws.String(fmt.Sprintf(createTableQuery, ac.tablename, ac.s3eventsLocation, ac.s3eventsLocation)), + ResultConfiguration: &athenaTypes.ResultConfiguration{ + OutputLocation: aws.String(ac.s3resultsLocation), + }, + QueryExecutionContext: &athenaTypes.QueryExecutionContext{ + Database: aws.String(ac.database), + }, + }) + require.NoError(t, err) + // querier is just used here to get helper fn waitForSuccess. + q := querier{ + athenaClient: athenaClient, + querierConfig: querierConfig{ + getQueryResultsInterval: 100 * time.Millisecond, + clock: ac.clock, + }, + } + err = q.waitForSuccess(ctx, aws.ToString(startQueryExecResp.QueryExecutionId)) + require.NoError(t, err) + t.Cleanup(func() { + cleanupCtx, cancel := context.WithTimeout(context.Background(), timeoutDurationOnCleanup) + defer cancel() + _, err = athenaClient.StartQueryExecution(cleanupCtx, &athena.StartQueryExecutionInput{ + QueryString: aws.String(fmt.Sprintf("drop table %s;", ac.tablename)), + ResultConfiguration: &athenaTypes.ResultConfiguration{ + OutputLocation: aws.String(ac.s3resultsLocation), + }, + QueryExecutionContext: &athenaTypes.QueryExecutionContext{ + Database: aws.String(ac.database), + }, + }) + assert.NoError(t, err) + }) + + return &infraOutputs{ + topicARN: aws.ToString(topicCreated.TopicArn), + queueURL: aws.ToString(queueCreated.QueueUrl), + region: awsCfg.Region, + } +} + +// createTableQuery is query used to create athena table using parquet on s3. +// Right now only hardcoded in integration tests, in future it may be moved +// to athena main file if we decide to create table on demand from teleport. +var createTableQuery = ` +CREATE EXTERNAL TABLE %s ( + uid string, + session_id string, + event_type string, + event_time timestamp, + event_data string + ) + PARTITIONED BY ( + event_date DATE + ) + ROW FORMAT SERDE 'org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe' + STORED AS INPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat' + OUTPUTFORMAT 'org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat' + LOCATION "%s/" + TBLPROPERTIES ( + "projection.enabled" = "true", + "projection.event_date.type" = "date", + "projection.event_date.format" = "yyyy-MM-dd", + "projection.event_date.range" = "NOW-4YEARS,NOW", + "projection.event_date.interval" = "1", + "projection.event_date.interval.unit" = "DAYS", + "storage.location.template" = "%s/${event_date}/", + "classification" = "parquet", + "parquet.compression" = "SNAPPY" + ) +` + +func (ac *athenaContext) Close(t *testing.T) { + assert.NoError(t, ac.log.Close()) +} + +// eventuallyConsitentAuditLogger is used to add delay before searching for events +// for eventually consistent audit loggers. +type eventuallyConsitentAuditLogger struct { + inner events.AuditLogger + + // queryDelay specifies how long query should wait after last emit event. + queryDelay time.Duration + + // mu protects field below. + mu sync.Mutex + emitWasAfterLastDelay bool +} + +func (e *eventuallyConsitentAuditLogger) EmitAuditEvent(ctx context.Context, in apievents.AuditEvent) error { + e.mu.Lock() + e.emitWasAfterLastDelay = true + e.mu.Unlock() + return e.inner.EmitAuditEvent(ctx, in) +} + +func (e *eventuallyConsitentAuditLogger) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { + e.mu.Lock() + defer e.mu.Unlock() + if e.emitWasAfterLastDelay { + time.Sleep(e.queryDelay) + // clear emit delay + e.emitWasAfterLastDelay = false + } + return e.inner.SearchEvents(fromUTC, toUTC, namespace, eventTypes, limit, order, startKey) +} + +func (e *eventuallyConsitentAuditLogger) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { + e.mu.Lock() + defer e.mu.Unlock() + if e.emitWasAfterLastDelay { + time.Sleep(e.queryDelay) + // clear emit delay + e.emitWasAfterLastDelay = false + } + return e.inner.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID) +} + +func (e *eventuallyConsitentAuditLogger) Close() error { + return e.inner.Close() +} diff --git a/lib/events/dynamoevents/dynamoevents_test.go b/lib/events/dynamoevents/dynamoevents_test.go index f4ff9849aa228..3dea115ed6e09 100644 --- a/lib/events/dynamoevents/dynamoevents_test.go +++ b/lib/events/dynamoevents/dynamoevents_test.go @@ -109,7 +109,7 @@ func TestSessionEventsCRUD(t *testing.T) { func TestSearchSessionEvensBySessionID(t *testing.T) { tt := setupDynamoContext(t) - tt.suite.SearchSessionEvensBySessionID(t) + tt.suite.SearchSessionEventsBySessionID(t) } func TestSizeBreak(t *testing.T) { diff --git a/lib/events/firestoreevents/firestoreevents_test.go b/lib/events/firestoreevents/firestoreevents_test.go index 1d8b3aa97bd3f..1806d068dd323 100644 --- a/lib/events/firestoreevents/firestoreevents_test.go +++ b/lib/events/firestoreevents/firestoreevents_test.go @@ -121,7 +121,7 @@ func (tt *firestoreContext) testPagination(t *testing.T) { func (tt *firestoreContext) testSearchSessionEvensBySessionID(t *testing.T) { tt.setupTest(t) - tt.suite.SearchSessionEvensBySessionID(t) + tt.suite.SearchSessionEventsBySessionID(t) } func TestFirestoreEvents(t *testing.T) { diff --git a/lib/events/test/suite.go b/lib/events/test/suite.go index 9d1bdb8376a03..d393fa69c1637 100644 --- a/lib/events/test/suite.go +++ b/lib/events/test/suite.go @@ -19,7 +19,6 @@ package test import ( "bytes" "context" - "fmt" "io" "os" "testing" @@ -80,12 +79,16 @@ type EventsSuite struct { Log events.AuditLogger Clock clockwork.Clock QueryDelay time.Duration + + // SearchSessionEvensBySessionIDTimeout is used to specify timeout on query + // in SearchSessionEvensBySessionID test case. + SearchSessionEvensBySessionIDTimeout time.Duration } // EventPagination covers event search pagination. func (s *EventsSuite) EventPagination(t *testing.T) { // This serves no special purpose except to make querying easier. - baseTime := time.Date(2019, time.May, 10, 14, 43, 0, 0, time.UTC) + baseTime := time.Now().UTC() names := []string{"bob", "jack", "daisy", "evan"} @@ -95,6 +98,7 @@ func (s *EventsSuite) EventPagination(t *testing.T) { Status: apievents.Status{Success: true}, UserMetadata: apievents.UserMetadata{User: name}, Metadata: apievents.Metadata{ + ID: uuid.NewString(), Type: events.UserLoginEvent, Time: baseTime.Add(time.Second * time.Duration(i)), }, @@ -174,6 +178,7 @@ func (s *EventsSuite) EventPagination(t *testing.T) { Status: apievents.Status{Success: true}, UserMetadata: apievents.UserMetadata{User: name}, Metadata: apievents.Metadata{ + ID: uuid.NewString(), Type: events.UserLoginEvent, Time: baseTime2, }, @@ -206,14 +211,16 @@ Outer: // SessionEventsCRUD covers session events func (s *EventsSuite) SessionEventsCRUD(t *testing.T) { + loginTime := s.Clock.Now().UTC() // Bob has logged in err := s.Log.EmitAuditEvent(context.Background(), &apievents.UserLogin{ Method: events.LoginMethodSAML, Status: apievents.Status{Success: true}, UserMetadata: apievents.UserMetadata{User: "bob"}, Metadata: apievents.Metadata{ + ID: uuid.NewString(), Type: events.UserLoginEvent, - Time: s.Clock.Now().UTC(), + Time: loginTime, }, }) require.NoError(t, err) @@ -226,7 +233,10 @@ func (s *EventsSuite) SessionEventsCRUD(t *testing.T) { var history []apievents.AuditEvent err = retryutils.RetryStaticFor(time.Minute*5, time.Second*5, func() error { - history, _, err = s.Log.SearchEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(time.Hour), apidefaults.Namespace, nil, 100, types.EventOrderAscending, "") + history, _, err = s.Log.SearchEvents(loginTime.Add(-1*time.Hour), loginTime.Add(time.Hour), apidefaults.Namespace, nil, 100, types.EventOrderAscending, "") + if err != nil { + t.Logf("Retrying searching of events because of: %v", err) + } return err }) require.NoError(t, err) @@ -235,9 +245,13 @@ func (s *EventsSuite) SessionEventsCRUD(t *testing.T) { // start the session and emit data stream to it and wrap it up sessionID := session.NewID() + // sessionStartTime must be greater than loginTime, because in search we assume + // order. + sessionStartTime := loginTime.Add(1 * time.Minute) err = s.Log.EmitAuditEvent(context.Background(), &apievents.SessionStart{ Metadata: apievents.Metadata{ - Time: s.Clock.Now().UTC(), + ID: uuid.NewString(), + Time: sessionStartTime, Index: 0, Type: events.SessionStartEvent, }, @@ -250,9 +264,11 @@ func (s *EventsSuite) SessionEventsCRUD(t *testing.T) { }) require.NoError(t, err) + sessionEndTime := s.Clock.Now().Add(time.Hour).UTC() err = s.Log.EmitAuditEvent(context.Background(), &apievents.SessionEnd{ Metadata: apievents.Metadata{ - Time: s.Clock.Now().Add(time.Hour).UTC(), + ID: uuid.NewString(), + Time: sessionEndTime, Index: 4, Type: events.SessionEndEvent, }, @@ -268,7 +284,10 @@ func (s *EventsSuite) SessionEventsCRUD(t *testing.T) { // search for the session event. err = retryutils.RetryStaticFor(time.Minute*5, time.Second*5, func() error { - history, _, err = s.Log.SearchEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(time.Hour), apidefaults.Namespace, nil, 100, types.EventOrderAscending, "") + history, _, err = s.Log.SearchEvents(s.Clock.Now().UTC().Add(-1*time.Hour), s.Clock.Now().UTC().Add(time.Hour), apidefaults.Namespace, nil, 100, types.EventOrderAscending, "") + if err != nil { + t.Logf("Retrying searching of events because of: %v", err) + } return err }) require.NoError(t, err) @@ -277,7 +296,7 @@ func (s *EventsSuite) SessionEventsCRUD(t *testing.T) { require.Equal(t, history[1].GetType(), events.SessionStartEvent) require.Equal(t, history[2].GetType(), events.SessionEndEvent) - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", nil, "") + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().UTC().Add(-1*time.Hour), s.Clock.Now().UTC().Add(2*time.Hour), 100, types.EventOrderAscending, "", nil, "") require.NoError(t, err) require.Len(t, history, 1) @@ -288,20 +307,20 @@ func (s *EventsSuite) SessionEventsCRUD(t *testing.T) { }} } - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("alice"), "") + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().UTC().Add(-1*time.Hour), s.Clock.Now().UTC().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("alice"), "") require.NoError(t, err) require.Len(t, history, 1) - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("cecile"), "") + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().UTC().Add(-1*time.Hour), s.Clock.Now().UTC().Add(2*time.Hour), 100, types.EventOrderAscending, "", withParticipant("cecile"), "") require.NoError(t, err) require.Len(t, history, 0) - history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().Add(-1*time.Hour), s.Clock.Now().Add(time.Hour-time.Second), 100, types.EventOrderAscending, "", nil, "") + history, _, err = s.Log.SearchSessionEvents(s.Clock.Now().UTC().Add(-1*time.Hour), sessionEndTime.Add(-time.Second), 100, types.EventOrderAscending, "", nil, "") require.NoError(t, err) require.Len(t, history, 0) } -func (s *EventsSuite) SearchSessionEvensBySessionID(t *testing.T) { +func (s *EventsSuite) SearchSessionEventsBySessionID(t *testing.T) { now := time.Now().UTC() firstID := uuid.New().String() secondID := uuid.New().String() @@ -309,7 +328,7 @@ func (s *EventsSuite) SearchSessionEvensBySessionID(t *testing.T) { for i, id := range []string{firstID, secondID, thirdID} { event := &apievents.WindowsDesktopSessionEnd{ Metadata: apievents.Metadata{ - ID: fmt.Sprintf("eventID%d", i), + ID: uuid.NewString(), Type: events.WindowsDesktopSessionEndEvent, Code: events.DesktopSessionEndCode, Time: now.Add(time.Duration(i) * time.Second), @@ -324,6 +343,8 @@ func (s *EventsSuite) SearchSessionEvensBySessionID(t *testing.T) { from := time.Time{} to := now.Add(10 * time.Second) + // TODO(tobiaszheller): drop running SearchSessionEvents in gorouting and using select for cancelation + // when ctx is propagated to search calls. done := make(chan struct{}) go func() { defer close(done) @@ -335,9 +356,14 @@ func (s *EventsSuite) SearchSessionEvensBySessionID(t *testing.T) { require.Equal(t, e.GetSessionID(), secondID) }() + queryTimeout := s.SearchSessionEvensBySessionIDTimeout + if queryTimeout == 0 { + queryTimeout = time.Second * 10 + } + select { - case <-time.After(time.Second * 10): - t.Fatalf("Search event query timeout") + case <-time.After(queryTimeout): + t.Fatalf("Search event query timeout after %s", queryTimeout) case <-done: } }