diff --git a/lib/events/athena/athena.go b/lib/events/athena/athena.go index 124db8a5afbbe..01076863a8acf 100644 --- a/lib/events/athena/athena.go +++ b/lib/events/athena/athena.go @@ -16,7 +16,6 @@ package athena import ( "context" - "math" "net/url" "regexp" "strconv" @@ -81,9 +80,14 @@ type Config struct { // GetQueryResultsInterval is used to define how long query will wait before // checking again for results status if previous status was not ready (optional). GetQueryResultsInterval time.Duration - // LimiterRate defines rate at which search_event rate limiter is filled (optional). - LimiterRate float64 - // LimiterBurst defines rate limit bucket capacity (optional). + + // LimiterRefillTime determines the duration of time between the addition of tokens to the bucket (optional). + LimiterRefillTime time.Duration + // LimiterRefillAmount is the number of tokens that are added to the bucket during interval + // specified by LimiterRefillTime (optional). + LimiterRefillAmount int + // Burst defines number of available tokens. It's initially full and refilled + // based on LimiterRefillAmount and LimiterRefillTime (optional). LimiterBurst int // Batcher settings. @@ -198,19 +202,23 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error { return trace.BadParameter("BatchMaxInterval too short, must be greater than 5s") } - if cfg.LimiterRate < 0 { - return trace.BadParameter("LimiterRate cannot be negative") + if cfg.LimiterRefillAmount < 0 { + return trace.BadParameter("LimiterRefillAmount cannot be nagative") } if cfg.LimiterBurst < 0 { return trace.BadParameter("LimiterBurst cannot be negative") } - if cfg.LimiterRate > 0 && cfg.LimiterBurst == 0 { - return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRate is used") + if cfg.LimiterRefillAmount > 0 && cfg.LimiterBurst == 0 { + return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRefillAmount is used") + } + + if cfg.LimiterBurst > 0 && cfg.LimiterRefillAmount == 0 { + return trace.BadParameter("LimiterRefillAmount must be greater than 0 if LimiterBurst is used") } - if cfg.LimiterBurst > 0 && math.Abs(cfg.LimiterRate) < 1e-9 { - return trace.BadParameter("LimiterRate must be greater than 0 if LimiterBurst is used") + if cfg.LimiterRefillAmount > 0 && cfg.LimiterRefillTime == 0 { + cfg.LimiterRefillTime = time.Second } if cfg.Clock == nil { @@ -283,13 +291,21 @@ func (cfg *Config) SetFromURL(url *url.URL) error { } cfg.GetQueryResultsInterval = dur } - rateInString := url.Query().Get("limiterRate") - if rateInString != "" { - rate, err := strconv.ParseFloat(rateInString, 32) + refillAmountInString := url.Query().Get("limiterRefillAmount") + if refillAmountInString != "" { + refillAmount, err := strconv.Atoi(refillAmountInString) + if err != nil { + return trace.BadParameter("invalid limiterRefillAmount value (it must be int): %v", err) + } + cfg.LimiterRefillAmount = refillAmount + } + refillTimeInString := url.Query().Get("limiterRefillTime") + if refillTimeInString != "" { + dur, err := time.ParseDuration(refillTimeInString) if err != nil { - return trace.BadParameter("invalid limiterRate value (it must be float32): %v", err) + return trace.BadParameter("invalid limiterRefillTime value: %v", err) } - cfg.LimiterRate = rate + cfg.LimiterRefillTime = dur } burstInString := url.Query().Get("limiterBurst") if burstInString != "" { diff --git a/lib/events/athena/athena_test.go b/lib/events/athena/athena_test.go index 0bdf25444a2ee..ad242626a55bc 100644 --- a/lib/events/athena/athena_test.go +++ b/lib/events/athena/athena_test.go @@ -74,12 +74,13 @@ func TestConfig_SetFromURL(t *testing.T) { }, { name: "params to querier - part 2", - url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRate=0.642&limiterBurst=3", + url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRefillAmount=2&&limiterRefillTime=2s&limiterBurst=3", want: Config{ TableName: "tbl", Database: "db", GetQueryResultsInterval: 200 * time.Millisecond, - LimiterRate: 0.642, + LimiterRefillAmount: 2, + LimiterRefillTime: 2 * time.Second, LimiterBurst: 3, }, }, @@ -100,9 +101,9 @@ func TestConfig_SetFromURL(t *testing.T) { wantErr: "invalid athena address, supported format is 'athena://database.table'", }, { - name: "invalid limiterRate format", - url: "athena://db.tbl/?limiterRate=abc", - wantErr: "invalid limiterRate value (it must be float32)", + name: "invalid limiterRefillAmount format", + url: "athena://db.tbl/?limiterRefillAmount=abc", + wantErr: "invalid limiterRefillAmount value (it must be int)", }, } for _, tt := range tests { @@ -163,6 +164,33 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) { Backend: mockBackend{}, }, }, + { + name: "valid config with limiter, check defaults refillTime", + input: func() Config { + cfg := validConfig + cfg.LimiterBurst = 10 + cfg.LimiterRefillAmount = 5 + return cfg + }, + want: Config{ + Database: "db", + TableName: "tbl", + TopicARN: "arn:topic", + LargeEventsS3: "s3://large-payloads-bucket", + largeEventsBucket: "large-payloads-bucket", + LocationS3: "s3://events-bucket", + locationS3Bucket: "events-bucket", + QueueURL: "https://queue-url", + GetQueryResultsInterval: 100 * time.Millisecond, + BatchMaxItems: 20000, + BatchMaxInterval: 1 * time.Minute, + AWSConfig: &aws.Config{}, + Backend: mockBackend{}, + LimiterRefillTime: 1 * time.Second, + LimiterBurst: 10, + LimiterRefillAmount: 5, + }, + }, { name: "missing table name", input: func() Config { @@ -227,24 +255,24 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) { wantErr: "QueueURL must be valid url and start with https", }, { - name: "invalid LimiterBurst and LimiterRate combination", + name: "invalid LimiterBurst and LimiterRefillAmount combination", input: func() Config { cfg := validConfig cfg.LimiterBurst = 0 - cfg.LimiterRate = 2.5 + cfg.LimiterRefillAmount = 2 return cfg }, - wantErr: "LimiterBurst must be greater than 0 if LimiterRate is used", + wantErr: "LimiterBurst must be greater than 0 if LimiterRefillAmount is used", }, { - name: "invalid LimiterRate and LimiterBurst combination", + name: "invalid LimiterRefillAmount and LimiterBurst combination", input: func() Config { cfg := validConfig cfg.LimiterBurst = 3 - cfg.LimiterRate = 0 + cfg.LimiterRefillAmount = 0 return cfg }, - wantErr: "LimiterRate must be greater than 0 if LimiterBurst is used", + wantErr: "LimiterRefillAmount must be greater than 0 if LimiterBurst is used", }, } for _, tt := range tests { diff --git a/lib/events/athena/querier.go b/lib/events/athena/querier.go index 2e0db58b831a8..fc134cf58d51f 100644 --- a/lib/events/athena/querier.go +++ b/lib/events/athena/querier.go @@ -117,7 +117,16 @@ func (q *querier) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string, ) ([]apievents.AuditEvent, string, error) { filter := searchEventsFilter{eventTypes: eventTypes} - return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, "") + events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{ + fromUTC: fromUTC, + toUTC: toUTC, + limit: limit, + order: order, + startKey: startKey, + filter: filter, + sessionID: "", + }) + return events, keyset, trace.Wrap(err) } func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, @@ -133,12 +142,29 @@ func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, } filter.condition = condFn } - return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, sessionID) + events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{ + fromUTC: fromUTC, + toUTC: toUTC, + limit: limit, + order: order, + startKey: startKey, + filter: filter, + sessionID: sessionID, + }) + return events, keyset, trace.Wrap(err) } -func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, limit int, - order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string, -) ([]apievents.AuditEvent, string, error) { +type searchEventsRequest struct { + fromUTC, toUTC time.Time + limit int + order types.EventOrder + startKey string + filter searchEventsFilter + sessionID string +} + +func (q *querier) searchEvents(ctx context.Context, req searchEventsRequest) ([]apievents.AuditEvent, string, error) { + limit := req.limit if limit <= 0 { limit = defaults.EventsIterationLimit } @@ -147,28 +173,28 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li } var startKeyset *keyset - if startKey != "" { + if req.startKey != "" { var err error - startKeyset, err = fromKey(startKey) + startKeyset, err = fromKey(req.startKey) if err != nil { return nil, "", trace.Wrap(err) } } query, params := prepareQuery(searchParams{ - fromUTC: fromUTC, - toUTC: toUTC, - order: order, + fromUTC: req.fromUTC, + toUTC: req.toUTC, + order: req.order, limit: limit, startKeyset: startKeyset, - filter: filter, - sessionID: sessionID, + filter: req.filter, + sessionID: req.sessionID, tablename: q.tablename, }) q.logger.WithField("query", query). WithField("params", params). - WithField("startKey", startKey). + WithField("startKey", req.startKey). Debug("Executing events query on Athena") queryId, err := q.startQueryExecution(ctx, query, params) @@ -180,7 +206,7 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li return nil, "", trace.Wrap(err) } - output, nextKey, err := q.fetchResults(ctx, queryId, limit, filter.condition) + output, nextKey, err := q.fetchResults(ctx, queryId, limit, req.filter.condition) return output, nextKey, trace.Wrap(err) } diff --git a/lib/events/search_limiter.go b/lib/events/search_limiter.go new file mode 100644 index 0000000000000..644e9f69f5724 --- /dev/null +++ b/lib/events/search_limiter.go @@ -0,0 +1,95 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package events + +import ( + "time" + + "github.com/gravitational/trace" + "golang.org/x/time/rate" + + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" +) + +// SearchEventsLimiter allows to wrap any AuditLogger with rate limit on +// search events endpoints. +// Note it share limiter for both SearchEvents and SearchSessionEvents. +type SearchEventsLimiter struct { + limiter *rate.Limiter + AuditLogger +} + +// SearchEventsLimiterConfig is configuration for SearchEventsLimiter. +type SearchEventsLimiterConfig struct { + // RefillTime determines the duration of time between the addition of tokens to the bucket. + RefillTime time.Duration + // RefillAmount is the number of tokens that are added to the bucket during interval + // specified by RefillTime. + RefillAmount int + // Burst defines number of available tokens. It's initially full and refilled + // based on RefillAmount and RefillTime. + Burst int + // AuditLogger is auditLogger that will be wrapped with limiter on search endpoints. + AuditLogger AuditLogger +} + +func (cfg *SearchEventsLimiterConfig) CheckAndSetDefaults() error { + if cfg.AuditLogger == nil { + return trace.BadParameter("empty auditLogger") + } + if cfg.Burst <= 0 { + return trace.BadParameter("Burst cannot be less or equal to 0") + } + if cfg.RefillAmount <= 0 { + return trace.BadParameter("RefillAmount cannot be less or equal to 0") + } + if cfg.RefillTime == 0 { + // Default to seconds so it can be just used as rate. + cfg.RefillTime = time.Second + } + return nil +} + +// NewSearchEventLimiter returns instance of new SearchEventsLimiter. +func NewSearchEventLimiter(cfg SearchEventsLimiterConfig) (*SearchEventsLimiter, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + return &SearchEventsLimiter{ + limiter: rate.NewLimiter(rate.Every(cfg.RefillTime/time.Duration(cfg.RefillAmount)), cfg.Burst), + AuditLogger: cfg.AuditLogger, + }, nil +} + +func (s *SearchEventsLimiter) SearchEvents(fromUTC, toUTC time.Time, namespace string, + eventTypes []string, limit int, order types.EventOrder, startKey string, +) ([]apievents.AuditEvent, string, error) { + if !s.limiter.Allow() { + return nil, "", trace.LimitExceeded("rate limit exceeded for searching events") + } + out, keyset, err := s.AuditLogger.SearchEvents(fromUTC, toUTC, namespace, eventTypes, limit, order, startKey) + return out, keyset, trace.Wrap(err) +} + +func (s *SearchEventsLimiter) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, + order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string, +) ([]apievents.AuditEvent, string, error) { + if !s.limiter.Allow() { + return nil, "", trace.LimitExceeded("rate limit exceeded for searching events") + } + out, keyset, err := s.AuditLogger.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID) + return out, keyset, trace.Wrap(err) +} diff --git a/lib/events/search_limiter_test.go b/lib/events/search_limiter_test.go new file mode 100644 index 0000000000000..829ce13c01f43 --- /dev/null +++ b/lib/events/search_limiter_test.go @@ -0,0 +1,168 @@ +// Copyright 2023 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package events_test + +import ( + "context" + "testing" + "time" + + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/types" + apievents "github.com/gravitational/teleport/api/types/events" + "github.com/gravitational/teleport/lib/events" +) + +func TestSearchEventsLimiter(t *testing.T) { + t.Parallel() + t.Run("emitting events happen without any limiting", func(t *testing.T) { + s, err := events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{ + RefillAmount: 1, + Burst: 1, + AuditLogger: &mockAuditLogger{ + emitAuditEventRespFn: func() error { return nil }, + }, + }) + require.NoError(t, err) + for i := 0; i < 20; i++ { + require.NoError(t, s.EmitAuditEvent(context.Background(), &apievents.AccessRequestCreate{})) + } + }) + + t.Run("with limiter", func(t *testing.T) { + burst := 20 + s, err := events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{ + RefillTime: 20 * time.Millisecond, + RefillAmount: 1, + Burst: burst, + AuditLogger: &mockAuditLogger{ + searchEventsRespFn: func() ([]apievents.AuditEvent, string, error) { return nil, "", nil }, + }, + }) + require.NoError(t, err) + + someDate := clockwork.NewFakeClock().Now().UTC() + // searchEvents and searchSessionEvents are helper fn to avoid coping those methods with huge + // number of attributes multiple times in that test case. + searchEvents := func() ([]apievents.AuditEvent, string, error) { + return s.SearchEvents(someDate, someDate, "default", nil /* eventTypes */, 100 /* limit */, types.EventOrderAscending, "" /* startKey */) + } + searchSessionEvents := func() ([]apievents.AuditEvent, string, error) { + return s.SearchSessionEvents(someDate, someDate, 100 /* limit */, types.EventOrderAscending, "" /* startKey */, nil /* cond */, "" /* sessionID */) + } + + for i := 0; i < burst; i++ { + var err error + // rate limit is shared between both search endpoints. + if i%2 == 0 { + _, _, err = searchEvents() + } else { + _, _, err = searchSessionEvents() + } + require.NoError(t, err) + } + // Now all tokens from rate limit should be used + _, _, err = searchEvents() + require.True(t, trace.IsLimitExceeded(err)) + // Also on SearchSessionEvents + _, _, err = searchSessionEvents() + require.True(t, trace.IsLimitExceeded(err)) + + // After 20ms 1 token should be added according to rate. + require.Eventually(t, func() bool { + _, _, err := searchEvents() + return err == nil + }, 40*time.Millisecond, 5*time.Millisecond) + }) +} + +func TestSearchEventsLimiterConfig(t *testing.T) { + tests := []struct { + name string + cfg events.SearchEventsLimiterConfig + wantFn func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) + }{ + { + name: "valid config", + cfg: events.SearchEventsLimiterConfig{ + AuditLogger: &mockAuditLogger{}, + RefillAmount: 1, + Burst: 1, + }, + wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) { + require.NoError(t, err) + require.Equal(t, time.Second, cfg.RefillTime) + }, + }, + { + name: "empty rate in config", + cfg: events.SearchEventsLimiterConfig{ + AuditLogger: &mockAuditLogger{}, + Burst: 1, + }, + wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) { + require.ErrorContains(t, err, "RefillAmount cannot be less or equal to 0") + }, + }, + + { + name: "empty burst in config", + cfg: events.SearchEventsLimiterConfig{ + AuditLogger: &mockAuditLogger{}, + RefillAmount: 1, + }, + wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) { + require.ErrorContains(t, err, "Burst cannot be less or equal to 0") + }, + }, + { + name: "empty logger", + cfg: events.SearchEventsLimiterConfig{ + RefillAmount: 1, + Burst: 1, + }, + wantFn: func(t *testing.T, err error, cfg events.SearchEventsLimiterConfig) { + require.ErrorContains(t, err, "empty auditLogger") + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.cfg.CheckAndSetDefaults() + tt.wantFn(t, err, tt.cfg) + }) + } +} + +type mockAuditLogger struct { + searchEventsRespFn func() ([]apievents.AuditEvent, string, error) + emitAuditEventRespFn func() error + events.AuditLogger +} + +func (m *mockAuditLogger) SearchEvents(fromUTC, toUTC time.Time, namespace string, eventTypes []string, limit int, order types.EventOrder, startKey string) ([]apievents.AuditEvent, string, error) { + return m.searchEventsRespFn() +} + +func (m *mockAuditLogger) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string) ([]apievents.AuditEvent, string, error) { + return m.searchEventsRespFn() +} + +func (m *mockAuditLogger) EmitAuditEvent(context.Context, apievents.AuditEvent) error { + return m.emitAuditEventRespFn() +} diff --git a/lib/service/service.go b/lib/service/service.go index a06ace5889d53..e458033c2e024 100644 --- a/lib/service/service.go +++ b/lib/service/service.go @@ -1414,10 +1414,23 @@ func initAuthExternalAuditLog(ctx context.Context, auditConfig types.ClusterAudi if err != nil { return nil, trace.Wrap(err) } - logger, err := athena.New(ctx, cfg) + var logger events.AuditLogger + logger, err = athena.New(ctx, cfg) if err != nil { return nil, trace.Wrap(err) } + if cfg.LimiterBurst > 0 { + // Wrap athena logger with rate limiter on search events. + logger, err = events.NewSearchEventLimiter(events.SearchEventsLimiterConfig{ + RefillTime: cfg.LimiterRefillTime, + RefillAmount: cfg.LimiterRefillAmount, + Burst: cfg.LimiterBurst, + AuditLogger: logger, + }) + if err != nil { + return nil, trace.Wrap(err) + } + } loggers = append(loggers, logger) case teleport.SchemeFile: if uri.Path == "" { diff --git a/lib/service/service_test.go b/lib/service/service_test.go index db81431dd7682..1668de2e7849d 100644 --- a/lib/service/service_test.go +++ b/lib/service/service_test.go @@ -52,6 +52,8 @@ import ( "github.com/gravitational/teleport/lib/backend/memory" "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/events/athena" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/modules" "github.com/gravitational/teleport/lib/reversetunnel" @@ -390,6 +392,47 @@ func TestServiceInitExternalLog(t *testing.T) { } } +func TestAthenaAuditLogSetup(t *testing.T) { + sampleValidConfig := "athena://db.table?topicArn=arn:aws:sns:eu-central-1:accnr:topicName&queryResultsS3=s3://testbucket/query-result/&workgroup=workgroup&locationS3=s3://testbucket/events-location&queueURL=https://sqs.eu-central-1.amazonaws.com/accnr/sqsname&largeEventsS3=s3://testbucket/largeevents" + tests := []struct { + name string + uri string + wantFn func(*testing.T, events.AuditLogger, error) + }{ + { + name: "valid athena config", + uri: sampleValidConfig, + wantFn: func(t *testing.T, alog events.AuditLogger, err error) { + require.NoError(t, err) + v, ok := alog.(*athena.Log) + require.True(t, ok, "invalid logger type, got %T", v) + }, + }, + { + name: "config with rate limit - should use events.SearchEventsLimiter", + uri: sampleValidConfig + "&limiterRefillAmount=3&limiterBurst=2", + wantFn: func(t *testing.T, alog events.AuditLogger, err error) { + require.NoError(t, err) + _, ok := alog.(*events.SearchEventsLimiter) + require.True(t, ok, "invalid logger type, got %T", alog) + }, + }, + } + backend, err := memory.New(memory.Config{}) + require.NoError(t, err) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + auditConfig, err := types.NewClusterAuditConfig(types.ClusterAuditConfigSpecV2{ + AuditEventsURI: []string{tt.uri}, + AuditSessionsURI: "s3://testbucket/sessions-rec", + }) + require.NoError(t, err) + log, err := initAuthExternalAuditLog(context.Background(), auditConfig, backend) + tt.wantFn(t, log, err) + }) + } +} + func TestGetAdditionalPrincipals(t *testing.T) { p := &TeleportProcess{ Config: &servicecfg.Config{