diff --git a/go.mod b/go.mod index ac20e0b8bbc9a..c1809ea358a0e 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/aws/aws-sdk-go-v2/credentials v1.13.20 github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.2 github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.62 + github.com/aws/aws-sdk-go-v2/service/athena v1.25.0 github.com/aws/aws-sdk-go-v2/service/ec2 v1.93.2 github.com/aws/aws-sdk-go-v2/service/s3 v1.31.3 github.com/aws/aws-sdk-go-v2/service/sns v1.20.8 diff --git a/go.sum b/go.sum index 821a30bd8dbec..7a0b554cc87d7 100644 --- a/go.sum +++ b/go.sum @@ -163,6 +163,7 @@ github.com/aws/aws-sdk-go v1.44.244 h1:QzBWLD5HjZHdRZyTMTOWtD9Pobzf1n8/CeTJB4giX github.com/aws/aws-sdk-go v1.44.244/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= github.com/aws/aws-sdk-go-v2 v1.17.3/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= +github.com/aws/aws-sdk-go-v2 v1.17.7/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2 v1.17.8 h1:GMupCNNI7FARX27L7GjCJM8NgivWbRgpjNI/hOQjFS8= github.com/aws/aws-sdk-go-v2 v1.17.8/go.mod h1:uzbQtefpm44goOPmdKyAlXSNcwlRgF3ePWVW6EtJvvw= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.4.10 h1:dK82zF6kkPeCo8J1e+tGx4JdvDIQzj7ygIoLg8WMuGs= @@ -179,9 +180,11 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.13.2/go.mod h1:cDh1p6XkSGSwSRIA github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.62 h1:LhVbe/UDWvBT/jp5LYAweFVH8s+DNtT07Qp2riWEovU= github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.62/go.mod h1:4xCuu1TSwhW5UH6WOdtS4/x/9UfMr2XplzKc86Ffj78= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.27/go.mod h1:a1/UpzeyBBerajpnP5nGZa9mGzsBn5cOKxm6NWQsvoI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.31/go.mod h1:QT0BqUvX1Bh2ABdTGnjqEjvjzrCfIniM9Sc8zn9Yndo= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.32 h1:dpbVNUjczQ8Ae3QKHbpHBpfvaVkRdesxpTOe9pTouhU= github.com/aws/aws-sdk-go-v2/internal/configsources v1.1.32/go.mod h1:RudqOgadTWdcS3t/erPQo24pcVEoYyqj/kKW5Vya21I= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.21/go.mod h1:+Gxn8jYn5k9ebfHEqlhrMirFjSW0v0C9fI+KN5vk2kE= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.25/go.mod h1:zBHOPwhBc3FlQjQJE/D3IfPWiWaQmT06Vq9aNukDo0k= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.26 h1:QH2kOS3Ht7x+u0gHCh06CXL/h6G8LQJFpZfFBYBNboo= github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.4.26/go.mod h1:vq86l7956VgFr0/FWQ2BWnK07QC3WYsepKzy33qqY5U= github.com/aws/aws-sdk-go-v2/internal/ini v1.3.28/go.mod h1:yRZVr/iT0AqyHeep00SZ4YfBAKojXz08w3XMBscdi0c= @@ -189,6 +192,8 @@ github.com/aws/aws-sdk-go-v2/internal/ini v1.3.33 h1:HbH1VjUgrCdLJ+4lnnuLI4iVNRv github.com/aws/aws-sdk-go-v2/internal/ini v1.3.33/go.mod h1:zG2FcwjQarWaqXSCGpgcr3RSjZ6dHGguZSppUL0XR7Q= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.24 h1:zsg+5ouVLLbePknVZlUMm1ptwyQLkjjLMWnN+kVs5dA= github.com/aws/aws-sdk-go-v2/internal/v4a v1.0.24/go.mod h1:+fFaIjycTmpV6hjmPTbyU9Kp5MI/lA+bbibcAtmlhYA= +github.com/aws/aws-sdk-go-v2/service/athena v1.25.0 h1:1UrjO+5xowkNmN9YirL+K3u2bVSUe5JYdkIFzdQm8Ps= +github.com/aws/aws-sdk-go-v2/service/athena v1.25.0/go.mod h1:eAiA/Po1i6D8kVj4nLnlfIQxTE1AYn4C0VUvtO+Qflw= github.com/aws/aws-sdk-go-v2/service/ec2 v1.93.2 h1:c6a19AjfhEXKlEX63cnlWtSQ4nzENihHZOG0I3wH6BE= github.com/aws/aws-sdk-go-v2/service/ec2 v1.93.2/go.mod h1:VX22JN3HQXDtQ3uS4h4TtM+K11vydq58tpHTlsm8TL8= github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.11 h1:y2+VQzC6Zh2ojtV2LoC0MNwHWc6qXv/j2vrQtlftkdA= diff --git a/lib/events/athena/athena.go b/lib/events/athena/athena.go index 95bf3d2d91903..4fb85382e01dd 100644 --- a/lib/events/athena/athena.go +++ b/lib/events/athena/athena.go @@ -317,6 +317,7 @@ func (cfg *Config) SetFromURL(url *url.URL) error { // Athena is used for quering Parquet files on S3. type Log struct { publisher *publisher + querier *querier } // New creates an instance of an Athena based audit log. @@ -330,7 +331,20 @@ func New(ctx context.Context, cfg Config) (*Log, error) { } // TODO(tobiaszheller): initialize batcher - // TODO(tobiaszheller): initialize querier + + l.querier, err = newQuerier(querierConfig{ + tablename: cfg.TableName, + database: cfg.Database, + workgroup: cfg.Workgroup, + queryResultsS3: cfg.QueryResultsS3, + getQueryResultsInterval: cfg.GetQueryResultsInterval, + awsCfg: cfg.AWSConfig, + logger: cfg.LogEntry, + clock: cfg.Clock, + }) + if err != nil { + return nil, trace.Wrap(err) + } return l, nil } @@ -340,11 +354,11 @@ func (l *Log) EmitAuditEvent(ctx context.Context, in apievents.AuditEvent) error } func (l *Log) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { - return nil, "", trace.NotImplemented("not implemented") + return l.querier.SearchEvents(fromUTC, toUTC, namespace, eventTypes, limit, order, startKey) } func (l *Log) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { - return nil, "", trace.NotImplemented("not implemented") + return l.querier.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID) } func (l *Log) Close() error { diff --git a/lib/events/athena/querier.go b/lib/events/athena/querier.go new file mode 100644 index 0000000000000..2e0db58b831a8 --- /dev/null +++ b/lib/events/athena/querier.go @@ -0,0 +1,517 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package athena + +import ( + "context" + "encoding/base64" + "encoding/binary" + "fmt" + "strconv" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/athena" + athenaTypes "github.com/aws/aws-sdk-go-v2/service/athena/types" + "github.com/google/uuid" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + log "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/utils" +) + +const ( + athenaTimestampFormat = "2006-01-02 15:04:05.999" + // getQueryResultsInitialDelay defines how long querier will wait before asking + // first time for status of execution query. + getQueryResultsInitialDelay = 600 * time.Millisecond + // getQueryResultsMaxTime defines what's maximum time for running a query. + getQueryResultsMaxTime = 1 * time.Minute +) + +// querier allows searching events on s3 using Athena engine. +// Data on s3 is stored in parquet files and partitioned by date using folders. +type querier struct { + querierConfig + + athenaClient athenaClient +} + +type athenaClient interface { + StartQueryExecution(ctx context.Context, params *athena.StartQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.StartQueryExecutionOutput, error) + GetQueryExecution(ctx context.Context, params *athena.GetQueryExecutionInput, optFns ...func(*athena.Options)) (*athena.GetQueryExecutionOutput, error) + GetQueryResults(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) +} + +type querierConfig struct { + tablename string + database string + workgroup string + queryResultsS3 string + getQueryResultsInterval time.Duration + + clock clockwork.Clock + awsCfg *aws.Config + logger log.FieldLogger +} + +func (cfg *querierConfig) CheckAndSetDefaults() error { + // Proper format of those fields is already validated in athena.Config. + // Here we just check if they were "wired" at all. + switch { + case cfg.tablename == "": + return trace.BadParameter("empty tablename in athena querier") + case cfg.database == "": + return trace.BadParameter("empty database in athena querier") + case cfg.queryResultsS3 == "": + return trace.BadParameter("empty queryResultsS3 in athena querier") + case cfg.getQueryResultsInterval == 0: + return trace.BadParameter("empty getQueryResultsInterval in athena querier") + case cfg.awsCfg == nil: + return trace.BadParameter("empty awsCfg in athena querier") + } + + if cfg.logger == nil { + cfg.logger = log.WithFields(log.Fields{ + trace.Component: teleport.ComponentAthena, + }) + } + if cfg.clock == nil { + cfg.clock = clockwork.NewRealClock() + } + + return nil +} + +func newQuerier(cfg querierConfig) (*querier, error) { + err := cfg.CheckAndSetDefaults() + if err != nil { + return nil, trace.Wrap(err) + } + return &querier{ + athenaClient: athena.NewFromConfig(*cfg.awsCfg), + querierConfig: cfg, + }, nil +} + +func (q *querier) SearchEvents(fromUTC, toUTC time.Time, namespace string, + eventTypes []string, limit int, order types.EventOrder, startKey string, +) ([]apievents.AuditEvent, string, error) { + filter := searchEventsFilter{eventTypes: eventTypes} + return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, "") +} + +func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, + order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string, +) ([]apievents.AuditEvent, string, error) { + // TODO(tobiaszheller): maybe if fromUTC is 0000-00-00, ask first last 30days and fallback to -inf - now-30 + // for sessionID != "". This kind of call is done on RBAC to check if user can access that session. + filter := searchEventsFilter{eventTypes: []string{events.SessionEndEvent, events.WindowsDesktopSessionEndEvent}} + if cond != nil { + condFn, err := utils.ToFieldsCondition(cond) + if err != nil { + return nil, "", trace.Wrap(err) + } + filter.condition = condFn + } + return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, sessionID) +} + +func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, limit int, + order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string, +) ([]apievents.AuditEvent, string, error) { + if limit <= 0 { + limit = defaults.EventsIterationLimit + } + if limit > defaults.EventsMaxIterationLimit { + return nil, "", trace.BadParameter("limit %v exceeds %v", limit, defaults.EventsMaxIterationLimit) + } + + var startKeyset *keyset + if startKey != "" { + var err error + startKeyset, err = fromKey(startKey) + if err != nil { + return nil, "", trace.Wrap(err) + } + } + + query, params := prepareQuery(searchParams{ + fromUTC: fromUTC, + toUTC: toUTC, + order: order, + limit: limit, + startKeyset: startKeyset, + filter: filter, + sessionID: sessionID, + tablename: q.tablename, + }) + + q.logger.WithField("query", query). + WithField("params", params). + WithField("startKey", startKey). + Debug("Executing events query on Athena") + + queryId, err := q.startQueryExecution(ctx, query, params) + if err != nil { + return nil, "", trace.Wrap(err) + } + + if err := q.waitForSuccess(ctx, queryId); err != nil { + return nil, "", trace.Wrap(err) + } + + output, nextKey, err := q.fetchResults(ctx, queryId, limit, filter.condition) + return output, nextKey, trace.Wrap(err) +} + +type searchEventsFilter struct { + eventTypes []string + condition utils.FieldsCondition +} + +type queryBuilder struct { + builder strings.Builder + args []string +} + +// withTicks wraps string with ticks. +// string params in athena need to be wrapped by "ticks". +func withTicks(in string) string { + return fmt.Sprintf("'%s'", in) +} + +func sliceWithTicks(ss []string) []string { + out := make([]string, 0, len(ss)) + for _, s := range ss { + out = append(out, withTicks(s)) + } + return out +} + +func (q *queryBuilder) Append(s string, args ...string) { + q.builder.WriteString(s) + q.args = append(q.args, args...) +} + +func (q *queryBuilder) String() string { + return q.builder.String() +} + +func (q *queryBuilder) Args() []string { + return q.args +} + +type searchParams struct { + fromUTC, toUTC time.Time + limit int + order types.EventOrder + startKeyset *keyset + filter searchEventsFilter + sessionID string + tablename string +} + +// prepareQuery returns query string with parameter placeholders and execution parameters. +// To prevent SQL injection, Athena supports parametrized query. +// As parameter placeholder '?' should be used. +func prepareQuery(params searchParams) (query string, execParams []string) { + qb := &queryBuilder{} + qb.Append(`SELECT DISTINCT uid, event_time, event_data FROM `) + // tablename is validated during config validation. + // It can only contain characters defined by Athena, which are safe from SQL + // Injection. + // Athena does not support passing table name as query parameters. + qb.Append(params.tablename) + qb.Append(` WHERE event_date BETWEEN date(?) AND date(?)`, withTicks(params.fromUTC.Format(time.DateOnly)), withTicks(params.toUTC.Format(time.DateOnly))) + qb.Append(` AND event_time BETWEEN ? and ?`, + fmt.Sprintf("timestamp '%s'", params.fromUTC.Format(athenaTimestampFormat)), fmt.Sprintf("timestamp '%s'", params.toUTC.Format(athenaTimestampFormat))) + + if params.sessionID != "" { + qb.Append(" AND session_id = ?", withTicks(params.sessionID)) + } + + if len(params.filter.eventTypes) > 0 { + // Athena does not support IN with single `?` and multiple parameters. + // Based on number of eventTypes, first query is prepared with defined + // number of placeholders. It's safe because we just taken len of event + // types to query, values of event types are passed as parameters. + eventsTypesInQuery := fmt.Sprintf(" AND event_type IN (%s)", + // Create following part: `?,?,?,?` based on len of eventTypes. + strings.TrimSuffix(strings.Repeat("?,", len(params.filter.eventTypes)), ",")) + qb.Append(eventsTypesInQuery, + sliceWithTicks(params.filter.eventTypes)..., + ) + } + + if params.order == types.EventOrderAscending { + if params.startKeyset != nil { + qb.Append(` AND (event_time, uid) > (?,?)`, + fmt.Sprintf("timestamp '%s'", params.startKeyset.t.Format(athenaTimestampFormat)), fmt.Sprintf("'%s'", params.startKeyset.uid.String())) + } + + qb.Append(` ORDER BY event_time ASC, uid ASC`) + } else { + if params.startKeyset != nil { + qb.Append(` AND (event_time, uid) < (?,?)`, + fmt.Sprintf("timestamp '%s'", params.startKeyset.t.Format(athenaTimestampFormat)), fmt.Sprintf("'%s'", params.startKeyset.uid.String())) + } + qb.Append(` ORDER BY event_time DESC, uid DESC`) + } + + qb.Append(` LIMIT ?`, strconv.Itoa(params.limit)) + + return qb.String(), qb.Args() +} + +func (q *querier) startQueryExecution(ctx context.Context, query string, params []string) (string, error) { + startQueryInput := &athena.StartQueryExecutionInput{ + QueryExecutionContext: &athenaTypes.QueryExecutionContext{ + Database: aws.String(q.database), + }, + ExecutionParameters: params, + QueryString: aws.String(query), + } + if q.workgroup != "" { + startQueryInput.WorkGroup = aws.String(q.workgroup) + } + + if q.queryResultsS3 != "" { + startQueryInput.ResultConfiguration = &athenaTypes.ResultConfiguration{ + OutputLocation: aws.String(q.queryResultsS3), + } + } + + startQueryOut, err := q.athenaClient.StartQueryExecution(ctx, startQueryInput) + if err != nil { + return "", trace.Wrap(err) + } + return aws.ToString(startQueryOut.QueryExecutionId), nil +} + +func (q *querier) waitForSuccess(ctx context.Context, queryId string) error { + ctx, cancel := context.WithTimeout(ctx, getQueryResultsMaxTime) + defer cancel() + + for i := 0; ; i++ { + interval := q.getQueryResultsInterval + if i == 0 { + // we want a longer initial delay because processing execution on athena takes some time + // and that's no real benefit to ask earlier. + interval = getQueryResultsInitialDelay + } + select { + case <-ctx.Done(): + return trace.Wrap(ctx.Err()) + case <-q.clock.After(interval): + // continue below + } + + resp, err := q.athenaClient.GetQueryExecution(ctx, &athena.GetQueryExecutionInput{QueryExecutionId: aws.String(queryId)}) + if err != nil { + return trace.Wrap(err) + } + state := resp.QueryExecution.Status.State + switch state { + case athenaTypes.QueryExecutionStateSucceeded: + return nil + case athenaTypes.QueryExecutionStateCancelled, athenaTypes.QueryExecutionStateFailed: + return trace.Errorf("got unexpected state: %s", state) + case athenaTypes.QueryExecutionStateQueued, athenaTypes.QueryExecutionStateRunning: + continue + default: + return trace.Errorf("got unknown state: %s", state) + } + } +} + +// fetchResults returns query results for given queryID. +// Athena API allows only fetch 1000 results, so if client asks for more, multiple +// calls to GetQueryResults will be necessary. +func (q *querier) fetchResults(ctx context.Context, queryId string, limit int, condition utils.FieldsCondition) ([]apievents.AuditEvent, string, error) { + rb := &responseBuilder{} + // nextToken is used as offset to next calls for GetQueryResults. + var nextToken string + for { + var nextTokenPtr *string + if nextToken != "" { + nextTokenPtr = aws.String(nextToken) + } + resultResp, err := q.athenaClient.GetQueryResults(ctx, &athena.GetQueryResultsInput{ + // AWS SDK allows only 1000 results. + MaxResults: aws.Int32(1000), + QueryExecutionId: aws.String(queryId), + NextToken: nextTokenPtr, + }) + if err != nil { + return nil, "", trace.Wrap(err) + } + + sizeLimit, err := rb.appendUntilSizeLimit(resultResp, condition) + if err != nil { + return nil, "", trace.Wrap(err) + } + + if sizeLimit { + endkeySet, err := rb.endKeyset() + if err != nil { + return nil, "", trace.Wrap(err) + } + return rb.output, endkeySet.ToKey(), nil + } + + // It means that there are no more results to fetch from athena results + // output location. + if resultResp.NextToken == nil { + output := rb.output + // We have the same amount of results as requested, return keyset + // because there could be more results. + if len(output) >= limit { + endkeySet, err := rb.endKeyset() + if err != nil { + return nil, "", trace.Wrap(err) + } + return output, endkeySet.ToKey(), nil + } + // output is smaller then limit, no keyset needed. + return output, "", nil + } + nextToken = *resultResp.NextToken + + } +} + +type responseBuilder struct { + output []apievents.AuditEvent + // totalSize is used to track size of output + totalSize int +} + +func (r *responseBuilder) endKeyset() (*keyset, error) { + if len(r.output) < 1 { + // Search can return 0 events, it means we don't have keyset to return + // but it is also not an error. + return nil, nil + } + lastEvent := r.output[len(r.output)-1] + + endKeyset, err := eventToKeyset(lastEvent) + return endKeyset, trace.Wrap(err) +} + +func eventToKeyset(in apievents.AuditEvent) (*keyset, error) { + var out keyset + var err error + out.t = in.GetTime() + out.uid, err = uuid.Parse(in.GetID()) + if err != nil { + return nil, trace.Wrap(err) + } + return &out, nil +} + +// appendUntilSizeLimit converts events from json blob to apievents.AuditEvent. +// It stops if events.MaxEventBytesInResponse is reached or if there are no more +// events. It returns true if size limit was reached. +func (rb *responseBuilder) appendUntilSizeLimit(resultResp *athena.GetQueryResultsOutput, condition utils.FieldsCondition) (bool, error) { + if resultResp == nil || resultResp.ResultSet == nil { + return false, nil + } + for i, row := range resultResp.ResultSet.Rows { + if len(row.Data) != 3 { + return false, trace.BadParameter("invalid number of row at response, got %d", len(row.Data)) + } + // GetQueryResults returns as first row header from CSV. + // We don't need it, so we will just ignore first row if it contains + // header. + if i == 0 && aws.ToString(row.Data[0].VarCharValue) == "uid" { + continue + } + eventData := aws.ToString(row.Data[2].VarCharValue) + + var fields events.EventFields + if err := utils.FastUnmarshal([]byte(eventData), &fields); err != nil { + return false, trace.Wrap(err, "failed to unmarshal event, %s", eventData) + } + event, err := events.FromEventFields(fields) + if err != nil { + return false, trace.Wrap(err) + } + // TODO(tobiaszheller): encode filter as query params and remove it in next PRs. + if condition != nil && !condition(utils.Fields(fields)) { + continue + } + + if len(eventData)+rb.totalSize > events.MaxEventBytesInResponse { + return true, nil + } + rb.totalSize += len(eventData) + rb.output = append(rb.output, event) + } + return false, nil +} + +// keyset is a point at which the searchEvents pagination ended, and can be +// resumed from. +type keyset struct { + t time.Time + uid uuid.UUID +} + +// keySetLen defines len of keyset. 8 bytes from timestamp + 16 for UUID. +const keySetLen = 24 + +// FromKey attempts to parse a keyset from a string. The string is a URL-safe +// base64 encoding of the time in microseconds as an int64, the event UUID; +// numbers are encoded in little-endian. +func fromKey(key string) (*keyset, error) { + if key == "" { + return nil, trace.BadParameter("missing key") + } + + b, err := base64.URLEncoding.DecodeString(key) + if err != nil { + return nil, trace.Wrap(err) + } + if len(b) != keySetLen { + return nil, trace.BadParameter("malformed pagination key") + } + ks := &keyset{ + t: time.UnixMicro(int64(binary.LittleEndian.Uint64(b[0:8]))).UTC(), + } + ks.uid, err = uuid.FromBytes(b[8:24]) + if err != nil { + return nil, trace.Wrap(err) + } + return ks, nil +} + +// ToKey converts the keyset into a URL-safe string. +func (ks *keyset) ToKey() string { + if ks == nil { + return "" + } + var b [keySetLen]byte + binary.LittleEndian.PutUint64(b[0:8], uint64(ks.t.UnixMicro())) + copy(b[8:24], ks.uid[:]) + return base64.URLEncoding.EncodeToString(b[:]) +} diff --git a/lib/events/athena/querier_test.go b/lib/events/athena/querier_test.go new file mode 100644 index 0000000000000..82e42e9b4ed45 --- /dev/null +++ b/lib/events/athena/querier_test.go @@ -0,0 +1,380 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package athena + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/athena" + athenaTypes "github.com/aws/aws-sdk-go-v2/service/athena/types" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/utils" +) + +func Test_querier_prepareQuery(t *testing.T) { + const ( + tablename = "test_table" + selectFromPrefix = `SELECT DISTINCT uid, event_time, event_data FROM test_table` + whereTimeRange = ` WHERE event_date BETWEEN date(?) AND date(?) AND event_time BETWEEN ? and ?` + ) + fromTimeUTC := time.Date(2023, 2, 1, 0, 0, 0, 0, time.UTC) + toTimeUTC := time.Date(2023, 3, 1, 0, 0, 0, 0, time.UTC) + fromDateParam := "'2023-02-01'" + fromTimestampParam := "timestamp '2023-02-01 00:00:00'" + toDateParam := "'2023-03-01'" + toTimestampParam := "timestamp '2023-03-01 00:00:00'" + timeRangeParams := []string{fromDateParam, toDateParam, fromTimestampParam, toTimestampParam} + + otherTimeUTC := time.Date(2023, 2, 15, 0, 0, 0, 0, time.UTC) + otherTimestampParam := "timestamp '2023-02-15 00:00:00'" + + tests := []struct { + name string + searchParams searchParams + wantQuery string + wantParams []string + }{ + { + name: "query on time range", + searchParams: searchParams{ + fromUTC: fromTimeUTC, + toUTC: toTimeUTC, + limit: 100, + tablename: tablename, + }, + wantQuery: selectFromPrefix + whereTimeRange + + ` ORDER BY event_time ASC, uid ASC LIMIT ?`, + wantParams: append(timeRangeParams, "100"), + }, + { + name: "query on time range order DESC", + searchParams: searchParams{ + fromUTC: fromTimeUTC, + toUTC: toTimeUTC, + limit: 100, + order: types.EventOrderDescending, + tablename: tablename, + }, + wantQuery: selectFromPrefix + whereTimeRange + + ` ORDER BY event_time DESC, uid DESC LIMIT ?`, + wantParams: append(timeRangeParams, "100"), + }, + { + name: "query with event types", + searchParams: searchParams{ + fromUTC: fromTimeUTC, + toUTC: toTimeUTC, + filter: searchEventsFilter{eventTypes: []string{"app.create", "app.delete"}}, + limit: 100, + tablename: tablename, + }, + wantQuery: selectFromPrefix + whereTimeRange + + ` AND event_type IN (?,?) ORDER BY event_time ASC, uid ASC LIMIT ?`, + wantParams: append(timeRangeParams, "'app.create'", "'app.delete'", "100"), + }, + { + name: "session id", + searchParams: searchParams{ + fromUTC: fromTimeUTC, + toUTC: toTimeUTC, + sessionID: "9762a4fe-ac4b-47b5-ba4f-5f70d065849a", + limit: 100, + tablename: tablename, + }, + wantQuery: selectFromPrefix + whereTimeRange + + ` AND session_id = ? ORDER BY event_time ASC, uid ASC LIMIT ?`, + wantParams: append(timeRangeParams, "'9762a4fe-ac4b-47b5-ba4f-5f70d065849a'", "100"), + }, + { + name: "query on time range with keyset", + searchParams: searchParams{ + fromUTC: fromTimeUTC, + toUTC: toTimeUTC, + limit: 100, + startKeyset: &keyset{ + t: otherTimeUTC, + uid: uuid.MustParse("9762a4fe-ac4b-47b5-ba4f-5f70d065849a"), + }, + tablename: tablename, + }, + wantQuery: selectFromPrefix + whereTimeRange + + ` AND (event_time, uid) > (?,?) ORDER BY event_time ASC, uid ASC LIMIT ?`, + wantParams: append(timeRangeParams, otherTimestampParam, "'9762a4fe-ac4b-47b5-ba4f-5f70d065849a'", "100"), + }, + { + name: "query on time range DESC with keyset", + searchParams: searchParams{ + fromUTC: fromTimeUTC, + toUTC: toTimeUTC, + limit: 100, + order: types.EventOrderDescending, + startKeyset: &keyset{ + t: otherTimeUTC, + uid: uuid.MustParse("9762a4fe-ac4b-47b5-ba4f-5f70d065849a"), + }, + tablename: tablename, + }, + wantQuery: selectFromPrefix + whereTimeRange + + ` AND (event_time, uid) < (?,?) ORDER BY event_time DESC, uid DESC LIMIT ?`, + wantParams: append(timeRangeParams, otherTimestampParam, "'9762a4fe-ac4b-47b5-ba4f-5f70d065849a'", "100"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotQuery, gotParams := prepareQuery(tt.searchParams) + require.Empty(t, cmp.Diff(gotQuery, tt.wantQuery), "query") + require.Empty(t, cmp.Diff(gotParams, tt.wantParams), "params") + }) + } +} + +func Test_keyset(t *testing.T) { + // ketset using microseconds precision,that's why truncate is needed. + wantT := clockwork.NewFakeClock().Now().UTC().Truncate(time.Microsecond) + wantUID := uuid.New() + ks := &keyset{ + t: wantT, + uid: wantUID, + } + key := ks.ToKey() + fromKs, err := fromKey(key) + require.NoError(t, err) + require.Equal(t, wantT, fromKs.t) + require.Equal(t, wantUID, fromKs.uid) +} + +func Test_querier_fetchResults(t *testing.T) { + const tableName = "test_table" + event1 := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-1", + }, + } + event2 := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-2", + }, + } + event3 := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-3", + }, + } + event4 := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: "app-4", + }, + } + veryBigEvent := &apievents.AppCreate{ + Metadata: apievents.Metadata{ + ID: uuid.NewString(), + Time: time.Now().UTC(), + Type: events.AppCreateEvent, + }, + AppMetadata: apievents.AppMetadata{ + AppName: strings.Repeat("aaaaa", events.MaxEventBytesInResponse), + }, + } + tests := []struct { + name string + limit int + condition utils.FieldsCondition + // fakeResp defines responses which will be returned based on given + // input token to GetQueryResults. Note that due to limit of GetQueryResults + // we are doing multiple calls, first always with empty token. + fakeResp map[string]eventsWithToken + wantEvents []apievents.AuditEvent + wantKeyset string + }{ + { + name: "no data returned from query, return empty results", + limit: 10, + }, + { + name: "events < then limit, mock returns data in multiple calls", + fakeResp: map[string]eventsWithToken{ + // empty means what is returned in first call. + "": {returnToken: "token1", events: []apievents.AuditEvent{event1}}, + "token1": {returnToken: "", events: []apievents.AuditEvent{event2, event3, event4}}, + }, + limit: 10, + wantEvents: []apievents.AuditEvent{event1, event2, event3, event4}, + }, + { + name: "events with veryBigEvent exceeding > MaxEventBytesInResponse", + fakeResp: map[string]eventsWithToken{ + "": {returnToken: "token1", events: []apievents.AuditEvent{event1}}, + "token1": {returnToken: "", events: []apievents.AuditEvent{event2, event3, veryBigEvent}}, + }, + limit: 10, + // we don't expect veryBigEvent because it should go to next batch + wantEvents: []apievents.AuditEvent{event1, event2, event3}, + wantKeyset: mustEventToKey(t, event3), + }, + { + // TODO(tobiaszheller): right now if we have event that's > 1 MiB, it will be silently ignored (due to gRPC unary limit). + // Come back later when we have decision what to do with it. + name: "only 1 very big event", + fakeResp: map[string]eventsWithToken{ + "": {returnToken: "", events: []apievents.AuditEvent{veryBigEvent}}, + }, + limit: 10, + wantEvents: []apievents.AuditEvent{}, + }, + { + name: "number of events equals limit in req, make sure that pagination keyset is returned", + fakeResp: map[string]eventsWithToken{ + "": {returnToken: "token1", events: []apievents.AuditEvent{event1}}, + "token1": {returnToken: "", events: []apievents.AuditEvent{event2, event3}}, + }, + limit: 3, + wantEvents: []apievents.AuditEvent{event1, event2, event3}, + wantKeyset: mustEventToKey(t, event3), + }, + { + name: "filter events based on condition", + fakeResp: map[string]eventsWithToken{ + "": {returnToken: "", events: []apievents.AuditEvent{event1, event2, event3, event4}}, + }, + condition: func(f utils.Fields) bool { + return f.GetString("app_name") != event3.AppName + }, + limit: 10, + wantEvents: []apievents.AuditEvent{event1, event2, event4}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + q := &querier{ + querierConfig: querierConfig{ + tablename: tableName, + logger: utils.NewLoggerForTests(), + }, + athenaClient: &fakeAthenaResultsGetter{ + resp: tt.fakeResp, + }, + } + gotEvents, gotKeyset, err := q.fetchResults(context.Background(), "queryid", tt.limit, tt.condition) + require.NoError(t, err) + require.Empty(t, cmp.Diff(tt.wantEvents, gotEvents, cmpopts.EquateEmpty())) + require.Equal(t, tt.wantKeyset, gotKeyset) + }) + } +} + +func mustEventToKey(t *testing.T, in apievents.AuditEvent) string { + ks, err := eventToKeyset(in) + if err != nil { + t.Fatal(err) + } + return ks.ToKey() +} + +type fakeAthenaResultsGetter struct { + athenaClient + iteration int + resp map[string]eventsWithToken +} + +type eventsWithToken struct { + events []apievents.AuditEvent + returnToken string +} + +func (f *fakeAthenaResultsGetter) GetQueryResults(ctx context.Context, params *athena.GetQueryResultsInput, optFns ...func(*athena.Options)) (*athena.GetQueryResultsOutput, error) { + if f.resp == nil { + return &athena.GetQueryResultsOutput{}, nil + } + + eventsWithToken, ok := f.resp[aws.ToString(params.NextToken)] + if !ok { + return nil, errors.New("not defined return param in fake") + } + + var rows []athenaTypes.Row + if f.iteration == 0 { + // That's what AWS API does, always adds header on first call. + rows = append(rows, athenaTypes.Row{ + Data: []athenaTypes.Datum{{VarCharValue: aws.String("uid")}, {VarCharValue: aws.String("event_time")}, {VarCharValue: aws.String("event_data")}}, + }) + } + + for _, event := range eventsWithToken.events { + fields, err := events.ToEventFields(event) + if err != nil { + return nil, err + } + marshaled, err := utils.FastMarshal(fields) + if err != nil { + return nil, err + } + rows = append(rows, athenaTypes.Row{ + Data: []athenaTypes.Datum{ + // The first 2 fields are ignored in our code, they are returned only because Athena requires + // to return parameters used in ordering. + {VarCharValue: aws.String("ignored")}, + {VarCharValue: aws.String("ignored")}, + {VarCharValue: aws.String(string(marshaled))}, + }, + }) + } + + f.iteration++ + + var nextToken *string + if eventsWithToken.returnToken != "" { + nextToken = aws.String(eventsWithToken.returnToken) + } + + return &athena.GetQueryResultsOutput{ + NextToken: nextToken, + ResultSet: &athenaTypes.ResultSet{ + Rows: rows, + }, + }, nil +}