diff --git a/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go b/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go index b928adc2cd..2b7796fada 100644 --- a/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go +++ b/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go @@ -147,4 +147,106 @@ func TestRatelimitResponse(t *testing.T) { require.Equal(t, int64(7), rl.Remaining, "Should have 7 remaining (10-3)") require.False(t, rl.AutoApply, "Custom rate limit should not be auto-applied") }) + + t.Run("multiple rate limits with accurate remaining counters", func(t *testing.T) { + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeyAuthID: api.KeyAuthID.String, + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "fast-limit", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: time.Minute.Milliseconds(), + Limit: 3, + }, + { + Name: "slow-limit", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: time.Hour.Milliseconds(), + Limit: 10, + }, + }, + }) + + req := handler.Request{ + Key: key.Key, + } + + // Helper function to find rate limit by name + findRatelimit := func(ratelimits []openapi.VerifyKeyRatelimitData, name string) *openapi.VerifyKeyRatelimitData { + for _, rl := range ratelimits { + if rl.Name == name { + return &rl + } + } + return nil + } + + // Request 1: Both limits should decrement + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + ratelimits := *res.Body.Data.Ratelimits + require.Len(t, ratelimits, 2, "Should have two rate limits") + + fastLimit := findRatelimit(ratelimits, "fast-limit") + slowLimit := findRatelimit(ratelimits, "slow-limit") + require.NotNil(t, fastLimit, "fast-limit should be present") + require.NotNil(t, slowLimit, "slow-limit should be present") + + require.Equal(t, int64(2), fastLimit.Remaining, "fast-limit: expected remaining=2 after 1st request") + require.Equal(t, int64(9), slowLimit.Remaining, "slow-limit: expected remaining=9 after 1st request") + require.False(t, fastLimit.Exceeded) + require.False(t, slowLimit.Exceeded) + + // Request 2: Both limits should decrement again + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + ratelimits = *res.Body.Data.Ratelimits + fastLimit = findRatelimit(ratelimits, "fast-limit") + slowLimit = findRatelimit(ratelimits, "slow-limit") + + require.Equal(t, int64(1), fastLimit.Remaining, "fast-limit: expected remaining=1 after 2nd request") + require.Equal(t, int64(8), slowLimit.Remaining, "slow-limit: expected remaining=8 after 2nd request") + require.False(t, fastLimit.Exceeded) + require.False(t, slowLimit.Exceeded) + + // Request 3: Both limits should decrement, fast-limit reaches 0 + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.Equal(t, openapi.VALID, res.Body.Data.Code) + require.True(t, res.Body.Data.Valid) + + ratelimits = *res.Body.Data.Ratelimits + fastLimit = findRatelimit(ratelimits, "fast-limit") + slowLimit = findRatelimit(ratelimits, "slow-limit") + + require.Equal(t, int64(0), fastLimit.Remaining, "fast-limit: expected remaining=0 after 3rd request") + require.Equal(t, int64(7), slowLimit.Remaining, "slow-limit: expected remaining=7 after 3rd request") + require.False(t, fastLimit.Exceeded) + require.False(t, slowLimit.Exceeded) + + // Request 4: fast-limit should be exceeded, slow-limit continues + res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code) + require.False(t, res.Body.Data.Valid, "Key should be rate limited") + + ratelimits = *res.Body.Data.Ratelimits + fastLimit = findRatelimit(ratelimits, "fast-limit") + slowLimit = findRatelimit(ratelimits, "slow-limit") + + require.Equal(t, int64(0), fastLimit.Remaining, "fast-limit: expected remaining=0 when exceeded") + require.True(t, fastLimit.Exceeded, "fast-limit should be exceeded") + // slow-limit should NOT increment since the request was denied + require.Equal(t, int64(7), slowLimit.Remaining, "slow-limit: should not decrement when request is denied") + require.False(t, slowLimit.Exceeded, "slow-limit should not be exceeded") + }) } diff --git a/go/apps/api/routes/v2_ratelimit_limit/handler.go b/go/apps/api/routes/v2_ratelimit_limit/handler.go index 1ae3cd3295..79e77cfa9f 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/handler.go +++ b/go/apps/api/routes/v2_ratelimit_limit/handler.go @@ -272,14 +272,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { } } - results, err := h.Ratelimit.Ratelimit(ctx, []ratelimit.RatelimitRequest{limitReq}) + result, err := h.Ratelimit.Ratelimit(ctx, limitReq) if err != nil { return fault.Wrap(err, fault.Internal("rate limit failed"), fault.Public("We're unable to process the rate limit request."), ) } - result := results[0] if s.ShouldLogRequestToClickHouse() { h.ClickHouse.BufferRatelimit(schema.RatelimitRequestV1{ diff --git a/go/internal/services/keys/validation.go b/go/internal/services/keys/validation.go index 344821a241..838aa806ea 100644 --- a/go/internal/services/keys/validation.go +++ b/go/internal/services/keys/validation.go @@ -204,7 +204,20 @@ func (k *KeyVerifier) withRateLimits(ctx context.Context, specifiedLimits []open }) } - resp, err := k.rateLimiter.Ratelimit(ctx, ratelimitRequests) + // Use different rate limiting paths based on number of limits + var resp []ratelimit.RatelimitResponse + var err error + + if len(ratelimitRequests) == 1 { + // Single rate limit - use fast path + singleResp, singleErr := k.rateLimiter.Ratelimit(ctx, ratelimitRequests[0]) + resp = []ratelimit.RatelimitResponse{singleResp} + err = singleErr + } else { + // Multiple rate limits - use atomic all-or-nothing path + resp, err = k.rateLimiter.RatelimitMany(ctx, ratelimitRequests) + } + if err != nil { k.logger.Error("Failed to ratelimit", "key_id", k.Key.ID, @@ -214,13 +227,14 @@ func (k *KeyVerifier) withRateLimits(ctx context.Context, specifiedLimits []open // We will just allow the request to proceed, but log the error return nil } - for i, response := range resp { + + for i := range resp { // Write response back to config to be passed to the client config := ratelimitsToCheck[names[i]] - config.Response = &response + config.Response = &resp[i] ratelimitsToCheck[names[i]] = config - if !response.Success { + if !resp[i].Success { k.setInvalid(StatusRateLimited, fmt.Sprintf("key exceeded rate limit %s", names[i])) } } diff --git a/go/internal/services/ratelimit/bucket.go b/go/internal/services/ratelimit/bucket.go index e9f3dd285e..5217c4ac4b 100644 --- a/go/internal/services/ratelimit/bucket.go +++ b/go/internal/services/ratelimit/bucket.go @@ -31,6 +31,9 @@ type bucket struct { // mu protects all bucket operations mu sync.RWMutex + // identifier is the rate limit subject (user ID, API key, etc) + identifier string + // limit is the maximum number of requests allowed per duration limit int64 @@ -48,6 +51,14 @@ type bucket struct { strictUntil time.Time } +func (b *bucket) key() bucketKey { + return bucketKey{ + identifier: b.identifier, + limit: b.limit, + duration: b.duration, + } +} + // bucketKey uniquely identifies a rate limit bucket by combining the // identifier, limit, and duration. This ensures separate tracking when // the same identifier has different rate limit configurations. @@ -101,6 +112,7 @@ func (s *service) getOrCreateBucket(key bucketKey) (*bucket, bool) { metrics.RatelimitBucketsCreated.Inc() b = &bucket{ mu: sync.RWMutex{}, + identifier: key.identifier, limit: key.limit, duration: key.duration, windows: make(map[int64]*window), diff --git a/go/internal/services/ratelimit/interface.go b/go/internal/services/ratelimit/interface.go index 0f15a20599..a322042e1b 100644 --- a/go/internal/services/ratelimit/interface.go +++ b/go/internal/services/ratelimit/interface.go @@ -33,21 +33,15 @@ type Service interface { // Performance: O(1) time complexity for local decisions // // Example Usage: - // responses, err := svc.Ratelimit(ctx, []RatelimitRequest{ - // { - // Identifier: "user-123", - // Limit: 100, - // Duration: time.Minute, - // Cost: 1, - // }, - // { - // Identifier: "user-456", - // Limit: 50, - // Duration: time.Minute, - // Cost: 2, - // }, + // response, err := svc.Ratelimit(ctx, RatelimitRequest{ + // Identifier: "user-123", + // Limit: 100, + // Duration: time.Minute, + // Cost: 1, // }) - Ratelimit(context.Context, []RatelimitRequest) ([]RatelimitResponse, error) + Ratelimit(context.Context, RatelimitRequest) (RatelimitResponse, error) + + RatelimitMany(context.Context, []RatelimitRequest) ([]RatelimitResponse, error) } // RatelimitRequest represents a request to check or consume rate limit tokens. diff --git a/go/internal/services/ratelimit/service.go b/go/internal/services/ratelimit/service.go index 2d27c90e4e..1fa0c7a764 100644 --- a/go/internal/services/ratelimit/service.go +++ b/go/internal/services/ratelimit/service.go @@ -2,6 +2,7 @@ package ratelimit import ( "context" + "sort" "sync" "github.com/unkeyed/unkey/go/pkg/assert" @@ -139,21 +140,27 @@ func (s *service) calculateRateLimit(req RatelimitRequest, currentWindow, previo return exceeded, effectiveCount, remaining } -// Ratelimit checks multiple rate limits atomically. +// Type to track request with its key and index +type reqWithKey struct { + req RatelimitRequest + key bucketKey + index int +} + +// RatelimitMany checks multiple rate limits atomically. // // All rate limit checks must pass for the request to be allowed. If any limit fails, // none of the counters are incremented. This all-or-nothing behavior prevents counter // leaks when a key has multiple rate limits (e.g., per-minute and per-month). // -// The method tries to make decisions using local cached data when possible. If local -// data is insufficient (first request or after strictUntil period), it fetches current -// counts from Redis. When all checks pass, counters are incremented locally and the -// changes are asynchronously propagated to Redis via the replay buffer. +// The method acquires locks on all unique buckets (sorted to prevent deadlock) and +// holds them while checking limits and incrementing counters. This ensures no race +// conditions occur between check and increment. // // Returns validation errors for invalid request parameters (empty identifier, zero limit, // negative cost, or duration less than 1 second). -func (s *service) Ratelimit(ctx context.Context, reqs []RatelimitRequest) ([]RatelimitResponse, error) { - _, span := tracing.Start(ctx, "Ratelimit") +func (s *service) RatelimitMany(ctx context.Context, reqs []RatelimitRequest) ([]RatelimitResponse, error) { + _, span := tracing.Start(ctx, "RatelimitMany") defer span.End() for i := range reqs { @@ -173,123 +180,166 @@ func (s *service) Ratelimit(ctx context.Context, reqs []RatelimitRequest) ([]Rat } } - responses := make([]RatelimitResponse, len(reqs)) - + // Build and sort keys first (before getting buckets) + reqsWithKeys := make([]reqWithKey, len(reqs)) for i, req := range reqs { - res, err := s.handleBucket(ctx, req) - if err != nil { - return nil, err + key := bucketKey{req.Identifier, req.Limit, req.Duration} + reqsWithKeys[i] = reqWithKey{ + req: req, + key: key, + index: i, } - responses[i] = res } + // Sort by key to ensure consistent ordering (prevents deadlock) + sort.Slice(reqsWithKeys, func(i, j int) bool { + return reqsWithKeys[i].key.toString() < reqsWithKeys[j].key.toString() + }) + + // Get unique buckets in sorted order and deduplicate + uniqueBuckets := make([]*bucket, 0, len(reqs)) + bucketMap := make(map[bucketKey]*bucket) + + for _, rwk := range reqsWithKeys { + if _, exists := bucketMap[rwk.key]; !exists { + b, _ := s.getOrCreateBucket(rwk.key) + bucketMap[rwk.key] = b + uniqueBuckets = append(uniqueBuckets, b) + } + } + + // Acquire locks on unique buckets only (already sorted) + for _, b := range uniqueBuckets { + b.mu.Lock() + defer b.mu.Unlock() + } + + // Check all limits while holding locks + responses := make([]RatelimitResponse, len(reqs)) allPassed := true - for i, res := range responses { + for _, rwk := range reqsWithKeys { + bucket := bucketMap[rwk.key] + + // Check limit with lock already held + res, err := s.checkBucketWithLockHeld(ctx, rwk.req, bucket) + if err != nil { + return nil, err + } + responses[rwk.index] = res + if !res.Success { allPassed = false - for j := range i { - s.rollback(reqs[j]) - } - span.SetAttributes(attribute.Bool("denied", true)) - break + // Don't break - check all limits to return complete status } } + span.SetAttributes(attribute.Bool("passed", allPassed)) + + // If all passed, increment all counters (still holding locks!) if allPassed { - span.SetAttributes(attribute.Bool("passed", true)) + for _, rwk := range reqsWithKeys { + bucket := bucketMap[rwk.key] + currentWindow, _ := bucket.getCurrentWindow(rwk.req.Time) + currentWindow.counter += rwk.req.Cost - for _, req := range reqs { - s.replayBuffer.Buffer(req) + // Buffer for async replay to Redis + s.replayBuffer.Buffer(rwk.req) } } else { - // When batch fails, add back the cost since we're not consuming tokens - for i := range responses { - responses[i].Remaining += reqs[i].Cost + + // At least one failed - adjust remaining values + for i, res := range responses { + if res.Success { + responses[i].Remaining += reqs[i].Cost + } } } - // Clamp all Remaining values to 0 before returning + // Clamp all remaining values for i := range responses { responses[i].Remaining = max(0, responses[i].Remaining) } return responses, nil - } -// rollback decrements the local counter for a request that was optimistically incremented -// during checking but later failed due to another rate limit in the batch failing. -// -// This is only called when processing a batch of rate limits where earlier checks passed -// but a later check failed, requiring us to undo the optimistic counter increments. -func (s *service) rollback(req RatelimitRequest) { +func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (RatelimitResponse, error) { + _, span := tracing.Start(ctx, "Ratelimit") + defer span.End() - key := bucketKey{req.Identifier, req.Limit, req.Duration} + if req.Time.IsZero() { + req.Time = s.clock.Now() + } - b, bucketExisted := s.getOrCreateBucket(key) - if !bucketExisted { - return + err := assert.All( + assert.NotEmpty(req.Identifier, "ratelimit identifier must not be empty"), + assert.Greater(req.Limit, 0, "ratelimit limit must be greater than zero"), + assert.GreaterOrEqual(req.Cost, 0, "ratelimit cost must not be negative"), + assert.GreaterOrEqual(req.Duration.Milliseconds(), 1000, "ratelimit duration must be at least 1s"), + assert.False(req.Time.IsZero(), "request time must not be zero"), + ) + if err != nil { + return RatelimitResponse{}, err } + + key := bucketKey{req.Identifier, req.Limit, req.Duration} + span.SetAttributes(attribute.String("key", key.toString())) + b, _ := s.getOrCreateBucket(key) b.mu.Lock() defer b.mu.Unlock() - currentWindow, existed := b.getCurrentWindow(req.Time) - if existed { - currentWindow.counter = max(0, currentWindow.counter-req.Cost) - } -} + // Use the shared method + res, err := s.checkBucketWithLockHeld(ctx, req, b) + if err != nil { + return RatelimitResponse{}, err + } + span.SetAttributes(attribute.Bool("passed", res.Success)) -// handleBucket evaluates a single rate limit request against its bucket. -// -// The method attempts to make decisions using only local cached data when possible. -// If both the current and previous windows exist locally, it can determine whether -// the request would exceed the limit without contacting Redis. -// -// If local data is insufficient (first request or during strictUntil period after -// a denial), it fetches the current counts from Redis to ensure accuracy. -// -// The strictUntil mechanism forces a Redis lookup for a full window duration after -// any rate limit is exceeded. This prevents over-admission during the decay period -// of the sliding window when relying only on stale local data. -func (s *service) handleBucket(ctx context.Context, req RatelimitRequest) (RatelimitResponse, error) { + // If successful, increment counter and buffer + if res.Success { + currentWindow, _ := b.getCurrentWindow(req.Time) + currentWindow.counter += req.Cost + s.replayBuffer.Buffer(req) - key := bucketKey{req.Identifier, req.Limit, req.Duration} + } - b, _ := s.getOrCreateBucket(key) + return res, nil +} - b.mu.Lock() - defer b.mu.Unlock() +// checkBucketWithLockHeld evaluates a rate limit request with the bucket lock already held. +// The caller MUST hold bucket.mu.Lock() before calling this. +func (s *service) checkBucketWithLockHeld(ctx context.Context, req RatelimitRequest, b *bucket) (RatelimitResponse, error) { currentWindow, currentWindowExisted := b.getCurrentWindow(req.Time) previousWindow, previousWindowExisted := b.getPreviousWindow(req.Time) decisionSource := "local" + // First, try to make a decision based only on local data if currentWindowExisted && previousWindowExisted { - // Check if we can reject based on local data alone - // exceeded, effectiveCount, remaining := s.calculateRateLimit(req, currentWindow, previousWindow) + if exceeded { b.strictUntil = req.Time.Add(req.Duration) - metrics.RatelimitDecision.WithLabelValues(decisionSource, "denied").Inc() - + } else { + metrics.RatelimitDecision.WithLabelValues(decisionSource, "allowed").Inc() } + return RatelimitResponse{ Success: !exceeded, - Remaining: remaining, + Remaining: max(0, remaining), Reset: currentWindow.start.Add(currentWindow.duration), Limit: req.Limit, Current: effectiveCount, }, nil - } - // If we couldn't make a local rejection decision, proceed with Redis checks if needed + // If we couldn't make a local decision, proceed with Redis checks if needed goToOrigin := req.Time.UnixMilli() < b.strictUntil.UnixMilli() if goToOrigin || !currentWindowExisted { decisionSource = "origin" - currentKey := counterKey(key, currentWindow.sequence) + currentKey := counterKey(b.key(), currentWindow.sequence) res, err := s.counter.Get(ctx, currentKey) if err != nil { s.logger.Error("unable to get counter value", @@ -303,7 +353,7 @@ func (s *service) handleBucket(ctx context.Context, req RatelimitRequest) (Ratel if goToOrigin || !previousWindowExisted { decisionSource = "origin" - previousKey := counterKey(key, previousWindow.sequence) + previousKey := counterKey(b.key(), previousWindow.sequence) res, err := s.counter.Get(ctx, previousKey) if err != nil { s.logger.Error("unable to get counter value", @@ -319,13 +369,12 @@ func (s *service) handleBucket(ctx context.Context, req RatelimitRequest) (Ratel exceeded, effectiveCount, remaining := s.calculateRateLimit(req, currentWindow, previousWindow) if exceeded { - // Set strictUntil to prevent further requests b.strictUntil = req.Time.Add(req.Duration) - metrics.RatelimitDecision.WithLabelValues(decisionSource, "denied").Inc() + return RatelimitResponse{ Success: false, - Remaining: remaining, + Remaining: max(0, remaining), Reset: currentWindow.start.Add(currentWindow.duration), Limit: req.Limit, Current: effectiveCount, @@ -333,12 +382,12 @@ func (s *service) handleBucket(ctx context.Context, req RatelimitRequest) (Ratel } metrics.RatelimitDecision.WithLabelValues(decisionSource, "allowed").Inc() + return RatelimitResponse{ Success: true, Remaining: remaining, Reset: currentWindow.start.Add(currentWindow.duration), Limit: req.Limit, - Current: currentWindow.counter, + Current: effectiveCount, }, nil - }