diff --git a/proxyd/frontend_rate_limiter.go b/proxyd/frontend_rate_limiter.go index d0590f0561da1..bb2114a5222e9 100644 --- a/proxyd/frontend_rate_limiter.go +++ b/proxyd/frontend_rate_limiter.go @@ -26,7 +26,7 @@ type FrontendRateLimiter interface { type limitedKeys struct { truncTS int64 keys map[string]int - mtx sync.Mutex + mtx sync.RWMutex } func newLimitedKeys(t int64) *limitedKeys { @@ -39,6 +39,7 @@ func newLimitedKeys(t int64) *limitedKeys { func (l *limitedKeys) Take(key string, max int) bool { l.mtx.Lock() defer l.mtx.Unlock() + val, ok := l.keys[key] if !ok { l.keys[key] = 0 @@ -60,17 +61,25 @@ type MemoryFrontendRateLimiter struct { dur time.Duration max int mtx sync.Mutex + metrics MemoryFrontendRateLimitMetrics +} + +type MemoryFrontendRateLimitMetrics interface { + IncTakeError() } -func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter { +func NewMemoryFrontendRateLimit(dur time.Duration, max int, metrics MemoryFrontendRateLimitMetrics) FrontendRateLimiter { return &MemoryFrontendRateLimiter{ - dur: dur, - max: max, + dur: dur, + max: max, + metrics: metrics, } } func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { m.mtx.Lock() + defer m.mtx.Unlock() + // Create truncated timestamp truncTS := truncateNow(m.dur) @@ -82,42 +91,52 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, // Pull out the limiter so we can unlock before incrementing the limit. limiter := m.currGeneration + taken := limiter.Take(key, m.max) + if !taken { + return false, nil + } - m.mtx.Unlock() - - return limiter.Take(key, m.max), nil + return taken, nil } // RedisFrontendRateLimiter is a rate limiter that stores data in Redis. // It uses the basic rate limiter pattern described on the Redis best // practices website: https://redis.com/redis-best-practices/basic-rate-limiting/. type RedisFrontendRateLimiter struct { - r *redis.Client - dur time.Duration - max int - prefix string + r *redis.Client + dur time.Duration + max int + prefix string + metrics FrontendRateLimitMetrics +} + +type FrontendRateLimitMetrics interface { + IncTakeError() } -func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter { +func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string, metrics FrontendRateLimitMetrics) FrontendRateLimiter { return &RedisFrontendRateLimiter{ - r: r, - dur: dur, - max: max, - prefix: prefix, + r: r, + dur: dur, + max: max, + prefix: prefix, + metrics: metrics, } } func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) { var incr *redis.IntCmd + truncTS := truncateNow(r.dur) fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS) + _, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error { incr = pipe.Incr(ctx, fullKey) pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond) return nil }) if err != nil { - frontendRateLimitTakeErrors.Inc() + r.metrics.IncTakeError() return false, err }