Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 36 additions & 17 deletions proxyd/frontend_rate_limiter.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type FrontendRateLimiter interface {
type limitedKeys struct {
truncTS int64
keys map[string]int
mtx sync.Mutex
mtx sync.RWMutex
}

func newLimitedKeys(t int64) *limitedKeys {
Expand All @@ -39,6 +39,7 @@ func newLimitedKeys(t int64) *limitedKeys {
func (l *limitedKeys) Take(key string, max int) bool {
l.mtx.Lock()
defer l.mtx.Unlock()

val, ok := l.keys[key]
if !ok {
l.keys[key] = 0
Expand All @@ -60,17 +61,25 @@ type MemoryFrontendRateLimiter struct {
dur time.Duration
max int
mtx sync.Mutex
metrics MemoryFrontendRateLimitMetrics
}

type MemoryFrontendRateLimitMetrics interface {
IncTakeError()
}

func NewMemoryFrontendRateLimit(dur time.Duration, max int) FrontendRateLimiter {
func NewMemoryFrontendRateLimit(dur time.Duration, max int, metrics MemoryFrontendRateLimitMetrics) FrontendRateLimiter {
return &MemoryFrontendRateLimiter{
dur: dur,
max: max,
dur: dur,
max: max,
metrics: metrics,
}
}

func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
m.mtx.Lock()
defer m.mtx.Unlock()

// Create truncated timestamp
truncTS := truncateNow(m.dur)

Expand All @@ -82,42 +91,52 @@ func (m *MemoryFrontendRateLimiter) Take(ctx context.Context, key string) (bool,

// Pull out the limiter so we can unlock before incrementing the limit.
limiter := m.currGeneration
taken := limiter.Take(key, m.max)
if !taken {
return false, nil
}

m.mtx.Unlock()

return limiter.Take(key, m.max), nil
return taken, nil
}

// RedisFrontendRateLimiter is a rate limiter that stores data in Redis.
// It uses the basic rate limiter pattern described on the Redis best
// practices website: https://redis.com/redis-best-practices/basic-rate-limiting/.
type RedisFrontendRateLimiter struct {
r *redis.Client
dur time.Duration
max int
prefix string
r *redis.Client
dur time.Duration
max int
prefix string
metrics FrontendRateLimitMetrics
}

type FrontendRateLimitMetrics interface {
IncTakeError()
}

func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string) FrontendRateLimiter {
func NewRedisFrontendRateLimiter(r *redis.Client, dur time.Duration, max int, prefix string, metrics FrontendRateLimitMetrics) FrontendRateLimiter {
return &RedisFrontendRateLimiter{
r: r,
dur: dur,
max: max,
prefix: prefix,
r: r,
dur: dur,
max: max,
prefix: prefix,
metrics: metrics,
}
}

func (r *RedisFrontendRateLimiter) Take(ctx context.Context, key string) (bool, error) {
var incr *redis.IntCmd

truncTS := truncateNow(r.dur)
fullKey := fmt.Sprintf("rate_limit:%s:%s:%d", r.prefix, key, truncTS)

_, err := r.r.Pipelined(ctx, func(pipe redis.Pipeliner) error {
incr = pipe.Incr(ctx, fullKey)
pipe.PExpire(ctx, fullKey, r.dur-time.Millisecond)
return nil
})
if err != nil {
frontendRateLimitTakeErrors.Inc()
r.metrics.IncTakeError()
return false, err
}

Expand Down