Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions lib/events/athena/athena.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package athena

import (
"context"
"math"
"net/url"
"regexp"
"strconv"
Expand Down Expand Up @@ -81,9 +80,14 @@ type Config struct {
// GetQueryResultsInterval is used to define how long query will wait before
// checking again for results status if previous status was not ready (optional).
GetQueryResultsInterval time.Duration
// LimiterRate defines rate at which search_event rate limiter is filled (optional).
LimiterRate float64
// LimiterBurst defines rate limit bucket capacity (optional).

// LimiterRefillTime determines the duration of time between the addition of tokens to the bucket (optional).
LimiterRefillTime time.Duration
// LimiterRefillAmount is the number of tokens that are added to the bucket during interval
// specified by LimiterRefillTime (optional).
LimiterRefillAmount int
// Burst defines number of available tokens. It's initially full and refilled
// based on LimiterRefillAmount and LimiterRefillTime (optional).
LimiterBurst int

// Batcher settings.
Expand Down Expand Up @@ -198,19 +202,23 @@ func (cfg *Config) CheckAndSetDefaults(ctx context.Context) error {
return trace.BadParameter("BatchMaxInterval too short, must be greater than 5s")
}

if cfg.LimiterRate < 0 {
return trace.BadParameter("LimiterRate cannot be negative")
if cfg.LimiterRefillAmount < 0 {
return trace.BadParameter("LimiterRefillAmount cannot be nagative")
}
if cfg.LimiterBurst < 0 {
return trace.BadParameter("LimiterBurst cannot be negative")
}

if cfg.LimiterRate > 0 && cfg.LimiterBurst == 0 {
return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRate is used")
if cfg.LimiterRefillAmount > 0 && cfg.LimiterBurst == 0 {
return trace.BadParameter("LimiterBurst must be greater than 0 if LimiterRefillAmount is used")
}

if cfg.LimiterBurst > 0 && cfg.LimiterRefillAmount == 0 {
return trace.BadParameter("LimiterRefillAmount must be greater than 0 if LimiterBurst is used")
}

if cfg.LimiterBurst > 0 && math.Abs(cfg.LimiterRate) < 1e-9 {
return trace.BadParameter("LimiterRate must be greater than 0 if LimiterBurst is used")
if cfg.LimiterRefillAmount > 0 && cfg.LimiterRefillTime == 0 {
cfg.LimiterRefillTime = time.Second
}

if cfg.Clock == nil {
Expand Down Expand Up @@ -283,13 +291,21 @@ func (cfg *Config) SetFromURL(url *url.URL) error {
}
cfg.GetQueryResultsInterval = dur
}
rateInString := url.Query().Get("limiterRate")
if rateInString != "" {
rate, err := strconv.ParseFloat(rateInString, 32)
refillAmountInString := url.Query().Get("limiterRefillAmount")
if refillAmountInString != "" {
refillAmount, err := strconv.Atoi(refillAmountInString)
if err != nil {
return trace.BadParameter("invalid limiterRefillAmount value (it must be int): %v", err)
}
cfg.LimiterRefillAmount = refillAmount
}
refillTimeInString := url.Query().Get("limiterRefillTime")
if refillTimeInString != "" {
dur, err := time.ParseDuration(refillTimeInString)
if err != nil {
return trace.BadParameter("invalid limiterRate value (it must be float32): %v", err)
return trace.BadParameter("invalid limiterRefillTime value: %v", err)
}
cfg.LimiterRate = rate
cfg.LimiterRefillTime = dur
}
burstInString := url.Query().Get("limiterBurst")
if burstInString != "" {
Expand Down
50 changes: 39 additions & 11 deletions lib/events/athena/athena_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,13 @@ func TestConfig_SetFromURL(t *testing.T) {
},
{
name: "params to querier - part 2",
url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRate=0.642&limiterBurst=3",
url: "athena://db.tbl/?getQueryResultsInterval=200ms&limiterRefillAmount=2&&limiterRefillTime=2s&limiterBurst=3",
want: Config{
TableName: "tbl",
Database: "db",
GetQueryResultsInterval: 200 * time.Millisecond,
LimiterRate: 0.642,
LimiterRefillAmount: 2,
LimiterRefillTime: 2 * time.Second,
LimiterBurst: 3,
},
},
Expand All @@ -100,9 +101,9 @@ func TestConfig_SetFromURL(t *testing.T) {
wantErr: "invalid athena address, supported format is 'athena://database.table'",
},
{
name: "invalid limiterRate format",
url: "athena://db.tbl/?limiterRate=abc",
wantErr: "invalid limiterRate value (it must be float32)",
name: "invalid limiterRefillAmount format",
url: "athena://db.tbl/?limiterRefillAmount=abc",
wantErr: "invalid limiterRefillAmount value (it must be int)",
},
}
for _, tt := range tests {
Expand Down Expand Up @@ -163,6 +164,33 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) {
Backend: mockBackend{},
},
},
{
name: "valid config with limiter, check defaults refillTime",
input: func() Config {
cfg := validConfig
cfg.LimiterBurst = 10
cfg.LimiterRefillAmount = 5
return cfg
},
want: Config{
Database: "db",
TableName: "tbl",
TopicARN: "arn:topic",
LargeEventsS3: "s3://large-payloads-bucket",
largeEventsBucket: "large-payloads-bucket",
LocationS3: "s3://events-bucket",
locationS3Bucket: "events-bucket",
QueueURL: "https://queue-url",
GetQueryResultsInterval: 100 * time.Millisecond,
BatchMaxItems: 20000,
BatchMaxInterval: 1 * time.Minute,
AWSConfig: &aws.Config{},
Backend: mockBackend{},
LimiterRefillTime: 1 * time.Second,
LimiterBurst: 10,
LimiterRefillAmount: 5,
},
},
{
name: "missing table name",
input: func() Config {
Expand Down Expand Up @@ -227,24 +255,24 @@ func TestConfig_CheckAndSetDefaults(t *testing.T) {
wantErr: "QueueURL must be valid url and start with https",
},
{
name: "invalid LimiterBurst and LimiterRate combination",
name: "invalid LimiterBurst and LimiterRefillAmount combination",
input: func() Config {
cfg := validConfig
cfg.LimiterBurst = 0
cfg.LimiterRate = 2.5
cfg.LimiterRefillAmount = 2
return cfg
},
wantErr: "LimiterBurst must be greater than 0 if LimiterRate is used",
wantErr: "LimiterBurst must be greater than 0 if LimiterRefillAmount is used",
},
{
name: "invalid LimiterRate and LimiterBurst combination",
name: "invalid LimiterRefillAmount and LimiterBurst combination",
input: func() Config {
cfg := validConfig
cfg.LimiterBurst = 3
cfg.LimiterRate = 0
cfg.LimiterRefillAmount = 0
return cfg
},
wantErr: "LimiterRate must be greater than 0 if LimiterBurst is used",
wantErr: "LimiterRefillAmount must be greater than 0 if LimiterBurst is used",
},
}
for _, tt := range tests {
Expand Down
54 changes: 40 additions & 14 deletions lib/events/athena/querier.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,16 @@ func (q *querier) SearchEvents(fromUTC, toUTC time.Time, namespace string,
eventTypes []string, limit int, order types.EventOrder, startKey string,
) ([]apievents.AuditEvent, string, error) {
filter := searchEventsFilter{eventTypes: eventTypes}
return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, "")
events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{
fromUTC: fromUTC,
toUTC: toUTC,
limit: limit,
order: order,
startKey: startKey,
filter: filter,
sessionID: "",
})
return events, keyset, trace.Wrap(err)
}

func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
Expand All @@ -133,12 +142,29 @@ func (q *querier) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
}
filter.condition = condFn
}
return q.searchEvents(context.TODO(), fromUTC, toUTC, limit, order, startKey, filter, sessionID)
events, keyset, err := q.searchEvents(context.TODO(), searchEventsRequest{
fromUTC: fromUTC,
toUTC: toUTC,
limit: limit,
order: order,
startKey: startKey,
filter: filter,
sessionID: sessionID,
})
return events, keyset, trace.Wrap(err)
}

func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, limit int,
order types.EventOrder, startKey string, filter searchEventsFilter, sessionID string,
) ([]apievents.AuditEvent, string, error) {
type searchEventsRequest struct {
fromUTC, toUTC time.Time
limit int
order types.EventOrder
startKey string
filter searchEventsFilter
sessionID string
}

func (q *querier) searchEvents(ctx context.Context, req searchEventsRequest) ([]apievents.AuditEvent, string, error) {
limit := req.limit
if limit <= 0 {
limit = defaults.EventsIterationLimit
}
Expand All @@ -147,28 +173,28 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li
}

var startKeyset *keyset
if startKey != "" {
if req.startKey != "" {
var err error
startKeyset, err = fromKey(startKey)
startKeyset, err = fromKey(req.startKey)
if err != nil {
return nil, "", trace.Wrap(err)
}
}

query, params := prepareQuery(searchParams{
fromUTC: fromUTC,
toUTC: toUTC,
order: order,
fromUTC: req.fromUTC,
toUTC: req.toUTC,
order: req.order,
limit: limit,
startKeyset: startKeyset,
filter: filter,
sessionID: sessionID,
filter: req.filter,
sessionID: req.sessionID,
tablename: q.tablename,
})

q.logger.WithField("query", query).
WithField("params", params).
WithField("startKey", startKey).
WithField("startKey", req.startKey).
Debug("Executing events query on Athena")

queryId, err := q.startQueryExecution(ctx, query, params)
Expand All @@ -180,7 +206,7 @@ func (q *querier) searchEvents(ctx context.Context, fromUTC, toUTC time.Time, li
return nil, "", trace.Wrap(err)
}

output, nextKey, err := q.fetchResults(ctx, queryId, limit, filter.condition)
output, nextKey, err := q.fetchResults(ctx, queryId, limit, req.filter.condition)
return output, nextKey, trace.Wrap(err)
}

Expand Down
95 changes: 95 additions & 0 deletions lib/events/search_limiter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2023 Gravitational, Inc
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package events

import (
"time"

"github.com/gravitational/trace"
"golang.org/x/time/rate"

"github.com/gravitational/teleport/api/types"
apievents "github.com/gravitational/teleport/api/types/events"
)

// SearchEventsLimiter allows to wrap any AuditLogger with rate limit on
// search events endpoints.
// Note it share limiter for both SearchEvents and SearchSessionEvents.
type SearchEventsLimiter struct {
limiter *rate.Limiter
AuditLogger
}

// SearchEventsLimiterConfig is configuration for SearchEventsLimiter.
type SearchEventsLimiterConfig struct {
// RefillTime determines the duration of time between the addition of tokens to the bucket.
RefillTime time.Duration
// RefillAmount is the number of tokens that are added to the bucket during interval
// specified by RefillTime.
RefillAmount int
// Burst defines number of available tokens. It's initially full and refilled
// based on RefillAmount and RefillTime.
Burst int
// AuditLogger is auditLogger that will be wrapped with limiter on search endpoints.
AuditLogger AuditLogger
}

func (cfg *SearchEventsLimiterConfig) CheckAndSetDefaults() error {
if cfg.AuditLogger == nil {
return trace.BadParameter("empty auditLogger")
}
if cfg.Burst <= 0 {
return trace.BadParameter("Burst cannot be less or equal to 0")
}
if cfg.RefillAmount <= 0 {
return trace.BadParameter("RefillAmount cannot be less or equal to 0")
}
if cfg.RefillTime == 0 {
// Default to seconds so it can be just used as rate.
cfg.RefillTime = time.Second
}
return nil
}

// NewSearchEventLimiter returns instance of new SearchEventsLimiter.
func NewSearchEventLimiter(cfg SearchEventsLimiterConfig) (*SearchEventsLimiter, error) {
if err := cfg.CheckAndSetDefaults(); err != nil {
return nil, trace.Wrap(err)
}
return &SearchEventsLimiter{
limiter: rate.NewLimiter(rate.Every(cfg.RefillTime/time.Duration(cfg.RefillAmount)), cfg.Burst),
AuditLogger: cfg.AuditLogger,
}, nil
}

func (s *SearchEventsLimiter) SearchEvents(fromUTC, toUTC time.Time, namespace string,
eventTypes []string, limit int, order types.EventOrder, startKey string,
) ([]apievents.AuditEvent, string, error) {
if !s.limiter.Allow() {
return nil, "", trace.LimitExceeded("rate limit exceeded for searching events")
}
out, keyset, err := s.AuditLogger.SearchEvents(fromUTC, toUTC, namespace, eventTypes, limit, order, startKey)
return out, keyset, trace.Wrap(err)
}

func (s *SearchEventsLimiter) SearchSessionEvents(fromUTC, toUTC time.Time, limit int,
order types.EventOrder, startKey string, cond *types.WhereExpr, sessionID string,
) ([]apievents.AuditEvent, string, error) {
if !s.limiter.Allow() {
return nil, "", trace.LimitExceeded("rate limit exceeded for searching events")
}
out, keyset, err := s.AuditLogger.SearchSessionEvents(fromUTC, toUTC, limit, order, startKey, cond, sessionID)
return out, keyset, trace.Wrap(err)
}
Loading