diff --git a/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go b/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go index 371705b1d0..1a9119fe6c 100644 --- a/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go +++ b/go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go @@ -249,4 +249,88 @@ func TestRatelimitResponse(t *testing.T) { require.Equal(t, int64(7), slowLimit.Remaining, "slow-limit: should not decrement when request is denied") require.False(t, slowLimit.Exceeded, "slow-limit should not be exceeded") }) + + t.Run("identity rate limits with same config but different names are isolated", func(t *testing.T) { + // Create an identity with two rate limits that have identical duration and limit but different names + // This tests that the rate limit name is included in the identifier to prevent shared counters + identity := h.CreateIdentity(seed.CreateIdentityRequest{ + WorkspaceID: workspace.ID, + ExternalID: "user_with_multiple_limits", + Ratelimits: []seed.CreateRatelimitRequest{ + { + Name: "api_requests", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: time.Minute.Milliseconds(), + Limit: 5, + }, + { + Name: "data_access", + WorkspaceID: workspace.ID, + AutoApply: true, + Duration: time.Minute.Milliseconds(), // Same duration as api_requests + Limit: 5, // Same limit as api_requests + }, + }, + }) + + // Create a key for this identity + key := h.CreateKey(seed.CreateKeyRequest{ + WorkspaceID: workspace.ID, + KeySpaceID: api.KeyAuthID.String, + IdentityID: ptr.P(identity.ID), + }) + + req := handler.Request{ + Key: key.Key, + } + + // Helper function to find rate limit by name + findRatelimit := func(ratelimits []openapi.VerifyKeyRatelimitData, name string) *openapi.VerifyKeyRatelimitData { + for _, rl := range ratelimits { + if rl.Name == name { + return &rl + } + } + return nil + } + + // Make 5 requests - should use up api_requests limit + for i := 0; i < 5; i++ { + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.Equal(t, openapi.VALID, res.Body.Data.Code, "Request %d should be valid", i+1) + require.True(t, res.Body.Data.Valid) + + ratelimits := *res.Body.Data.Ratelimits + require.Len(t, ratelimits, 2, "Should have two identity rate limits") + + apiLimit := findRatelimit(ratelimits, "api_requests") + dataLimit := findRatelimit(ratelimits, "data_access") + require.NotNil(t, apiLimit, "api_requests rate limit should be present") + require.NotNil(t, dataLimit, "data_access rate limit should be present") + + // Both limits should decrement independently + require.Equal(t, int64(5-i-1), apiLimit.Remaining, "api_requests: expected remaining=%d after request %d", 5-i-1, i+1) + require.Equal(t, int64(5-i-1), dataLimit.Remaining, "data_access: expected remaining=%d after request %d", 5-i-1, i+1) + require.False(t, apiLimit.Exceeded, "api_requests should not be exceeded yet") + require.False(t, dataLimit.Exceeded, "data_access should not be exceeded yet") + } + + // 6th request should be rate limited + res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req) + require.Equal(t, 200, res.Status) + require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code, "6th request should be rate limited") + require.False(t, res.Body.Data.Valid) + + ratelimits := *res.Body.Data.Ratelimits + apiLimit := findRatelimit(ratelimits, "api_requests") + dataLimit := findRatelimit(ratelimits, "data_access") + + // Both limits should be exceeded (since they have the same config and both were checked) + require.True(t, apiLimit.Exceeded, "api_requests should be exceeded") + require.True(t, dataLimit.Exceeded, "data_access should be exceeded") + require.Equal(t, int64(0), apiLimit.Remaining, "api_requests: should have 0 remaining") + require.Equal(t, int64(0), dataLimit.Remaining, "data_access: should have 0 remaining") + }) } diff --git a/go/apps/api/routes/v2_ratelimit_limit/handler.go b/go/apps/api/routes/v2_ratelimit_limit/handler.go index 1742444daa..c877d148eb 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/handler.go +++ b/go/apps/api/routes/v2_ratelimit_limit/handler.go @@ -257,7 +257,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { // Apply rate limit limitReq := ratelimit.RatelimitRequest{ - Identifier: namespace.ID + ":" + req.Identifier, + Name: namespace.ID, + Identifier: req.Identifier, Duration: time.Duration(duration) * time.Millisecond, Limit: limit, Cost: ptr.SafeDeref(req.Cost, 1), diff --git a/go/internal/services/keys/validation.go b/go/internal/services/keys/validation.go index c20964759e..00bb337eeb 100644 --- a/go/internal/services/keys/validation.go +++ b/go/internal/services/keys/validation.go @@ -188,6 +188,7 @@ func (k *KeyVerifier) withRateLimits(ctx context.Context, specifiedLimits []open for name, config := range ratelimitsToCheck { names = append(names, name) ratelimitRequests = append(ratelimitRequests, ratelimit.RatelimitRequest{ + Name: config.Name, Identifier: config.Identifier, // Use the pre-determined identifier Limit: config.Limit, Duration: config.Duration, diff --git a/go/internal/services/ratelimit/bucket.go b/go/internal/services/ratelimit/bucket.go index 5217c4ac4b..5276b73703 100644 --- a/go/internal/services/ratelimit/bucket.go +++ b/go/internal/services/ratelimit/bucket.go @@ -31,6 +31,9 @@ type bucket struct { // mu protects all bucket operations mu sync.RWMutex + // name is the name of the bucket + name string + // identifier is the rate limit subject (user ID, API key, etc) identifier string @@ -53,6 +56,7 @@ type bucket struct { func (b *bucket) key() bucketKey { return bucketKey{ + name: b.name, identifier: b.identifier, limit: b.limit, duration: b.duration, @@ -76,6 +80,9 @@ func (b *bucket) key() bucketKey { // } // bucketID := key.toString() type bucketKey struct { + // name is an arbitrary name for the bucket + name string + // identifier is the rate limit subject (user ID, API key, etc) identifier string @@ -87,7 +94,7 @@ type bucketKey struct { } func (b bucketKey) toString() string { - return fmt.Sprintf("%s-%d-%d", b.identifier, b.limit, b.duration.Milliseconds()) + return fmt.Sprintf("%s-%s-%d-%d", b.name, b.identifier, b.limit, b.duration.Milliseconds()) } // getOrCreateBucket retrieves a rate limiting bucket for the given key. @@ -112,6 +119,7 @@ func (s *service) getOrCreateBucket(key bucketKey) (*bucket, bool) { metrics.RatelimitBucketsCreated.Inc() b = &bucket{ mu: sync.RWMutex{}, + name: key.name, identifier: key.identifier, limit: key.limit, duration: key.duration, diff --git a/go/internal/services/ratelimit/interface.go b/go/internal/services/ratelimit/interface.go index a322042e1b..e03c9ec3c3 100644 --- a/go/internal/services/ratelimit/interface.go +++ b/go/internal/services/ratelimit/interface.go @@ -54,6 +54,9 @@ type Service interface { // // Thread Safety: This type is immutable and safe for concurrent use. type RatelimitRequest struct { + // Name is an arbitrary string that identifies the rate limit topic. + Name string + // Identifier uniquely identifies the rate limit subject. // This could be: // - A user ID diff --git a/go/internal/services/ratelimit/replay.go b/go/internal/services/ratelimit/replay.go index a9b613fca0..1edd3c71d5 100644 --- a/go/internal/services/ratelimit/replay.go +++ b/go/internal/services/ratelimit/replay.go @@ -60,6 +60,7 @@ func (s *service) syncWithOrigin(ctx context.Context, req RatelimitRequest) erro } key := bucketKey{ + name: req.Name, identifier: req.Identifier, limit: req.Limit, duration: req.Duration, diff --git a/go/internal/services/ratelimit/service.go b/go/internal/services/ratelimit/service.go index 4831345acd..bf3de74a9a 100644 --- a/go/internal/services/ratelimit/service.go +++ b/go/internal/services/ratelimit/service.go @@ -169,6 +169,7 @@ func (s *service) RatelimitMany(ctx context.Context, reqs []RatelimitRequest) ([ } err := assert.All( + assert.NotEmpty(reqs[i].Name, "ratelimit name must not be empty"), assert.NotEmpty(reqs[i].Identifier, "ratelimit identifier must not be empty"), assert.Greater(reqs[i].Limit, 0, "ratelimit limit must be greater than zero"), assert.GreaterOrEqual(reqs[i].Cost, 0, "ratelimit cost must not be negative"), @@ -183,7 +184,7 @@ func (s *service) RatelimitMany(ctx context.Context, reqs []RatelimitRequest) ([ // Build and sort keys first (before getting buckets) reqsWithKeys := make([]reqWithKey, len(reqs)) for i, req := range reqs { - key := bucketKey{req.Identifier, req.Limit, req.Duration} + key := bucketKey{req.Name, req.Identifier, req.Limit, req.Duration} reqsWithKeys[i] = reqWithKey{ req: req, key: key, @@ -274,6 +275,7 @@ func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (Ratelimi err := assert.All( assert.NotEmpty(req.Identifier, "ratelimit identifier must not be empty"), + assert.NotEmpty(req.Name, "ratelimit name must not be empty"), assert.Greater(req.Limit, 0, "ratelimit limit must be greater than zero"), assert.GreaterOrEqual(req.Cost, 0, "ratelimit cost must not be negative"), assert.GreaterOrEqual(req.Duration.Milliseconds(), 1000, "ratelimit duration must be at least 1s"), @@ -283,7 +285,7 @@ func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (Ratelimi return RatelimitResponse{}, err } - key := bucketKey{req.Identifier, req.Limit, req.Duration} + key := bucketKey{req.Name, req.Identifier, req.Limit, req.Duration} span.SetAttributes(attribute.String("key", key.toString())) b, _ := s.getOrCreateBucket(key) b.mu.Lock()