Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,88 @@ func TestRatelimitResponse(t *testing.T) {
require.Equal(t, int64(7), slowLimit.Remaining, "slow-limit: should not decrement when request is denied")
require.False(t, slowLimit.Exceeded, "slow-limit should not be exceeded")
})

t.Run("identity rate limits with same config but different names are isolated", func(t *testing.T) {
// Create an identity with two rate limits that have identical duration and limit but different names
// This tests that the rate limit name is included in the identifier to prevent shared counters
identity := h.CreateIdentity(seed.CreateIdentityRequest{
WorkspaceID: workspace.ID,
ExternalID: "user_with_multiple_limits",
Ratelimits: []seed.CreateRatelimitRequest{
{
Name: "api_requests",
WorkspaceID: workspace.ID,
AutoApply: true,
Duration: time.Minute.Milliseconds(),
Limit: 5,
},
{
Name: "data_access",
WorkspaceID: workspace.ID,
AutoApply: true,
Duration: time.Minute.Milliseconds(), // Same duration as api_requests
Limit: 5, // Same limit as api_requests
},
},
})

// Create a key for this identity
key := h.CreateKey(seed.CreateKeyRequest{
WorkspaceID: workspace.ID,
KeySpaceID: api.KeyAuthID.String,
IdentityID: ptr.P(identity.ID),
})

req := handler.Request{
Key: key.Key,
}

// Helper function to find rate limit by name
findRatelimit := func(ratelimits []openapi.VerifyKeyRatelimitData, name string) *openapi.VerifyKeyRatelimitData {
for _, rl := range ratelimits {
if rl.Name == name {
return &rl
}
}
return nil
}

// Make 5 requests - should use up api_requests limit
for i := 0; i < 5; i++ {
res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req)
require.Equal(t, 200, res.Status)
require.Equal(t, openapi.VALID, res.Body.Data.Code, "Request %d should be valid", i+1)
require.True(t, res.Body.Data.Valid)

ratelimits := *res.Body.Data.Ratelimits
require.Len(t, ratelimits, 2, "Should have two identity rate limits")

apiLimit := findRatelimit(ratelimits, "api_requests")
dataLimit := findRatelimit(ratelimits, "data_access")
require.NotNil(t, apiLimit, "api_requests rate limit should be present")
require.NotNil(t, dataLimit, "data_access rate limit should be present")

// Both limits should decrement independently
require.Equal(t, int64(5-i-1), apiLimit.Remaining, "api_requests: expected remaining=%d after request %d", 5-i-1, i+1)
require.Equal(t, int64(5-i-1), dataLimit.Remaining, "data_access: expected remaining=%d after request %d", 5-i-1, i+1)
require.False(t, apiLimit.Exceeded, "api_requests should not be exceeded yet")
require.False(t, dataLimit.Exceeded, "data_access should not be exceeded yet")
}

// 6th request should be rate limited
res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req)
require.Equal(t, 200, res.Status)
require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code, "6th request should be rate limited")
require.False(t, res.Body.Data.Valid)

ratelimits := *res.Body.Data.Ratelimits
apiLimit := findRatelimit(ratelimits, "api_requests")
dataLimit := findRatelimit(ratelimits, "data_access")

// Both limits should be exceeded (since they have the same config and both were checked)
require.True(t, apiLimit.Exceeded, "api_requests should be exceeded")
require.True(t, dataLimit.Exceeded, "data_access should be exceeded")
require.Equal(t, int64(0), apiLimit.Remaining, "api_requests: should have 0 remaining")
require.Equal(t, int64(0), dataLimit.Remaining, "data_access: should have 0 remaining")
})
}
3 changes: 2 additions & 1 deletion go/apps/api/routes/v2_ratelimit_limit/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {

// Apply rate limit
limitReq := ratelimit.RatelimitRequest{
Identifier: namespace.ID + ":" + req.Identifier,
Name: namespace.ID,
Identifier: req.Identifier,
Duration: time.Duration(duration) * time.Millisecond,
Limit: limit,
Cost: ptr.SafeDeref(req.Cost, 1),
Expand Down
1 change: 1 addition & 0 deletions go/internal/services/keys/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ func (k *KeyVerifier) withRateLimits(ctx context.Context, specifiedLimits []open
for name, config := range ratelimitsToCheck {
names = append(names, name)
ratelimitRequests = append(ratelimitRequests, ratelimit.RatelimitRequest{
Name: config.Name,
Identifier: config.Identifier, // Use the pre-determined identifier
Limit: config.Limit,
Duration: config.Duration,
Expand Down
10 changes: 9 additions & 1 deletion go/internal/services/ratelimit/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type bucket struct {
// mu protects all bucket operations
mu sync.RWMutex

// name is the name of the bucket
name string

// identifier is the rate limit subject (user ID, API key, etc)
identifier string

Expand All @@ -53,6 +56,7 @@ type bucket struct {

func (b *bucket) key() bucketKey {
return bucketKey{
name: b.name,
identifier: b.identifier,
limit: b.limit,
duration: b.duration,
Expand All @@ -76,6 +80,9 @@ func (b *bucket) key() bucketKey {
// }
// bucketID := key.toString()
type bucketKey struct {
// name is an arbitrary name for the bucket
name string

// identifier is the rate limit subject (user ID, API key, etc)
identifier string

Expand All @@ -87,7 +94,7 @@ type bucketKey struct {
}

func (b bucketKey) toString() string {
return fmt.Sprintf("%s-%d-%d", b.identifier, b.limit, b.duration.Milliseconds())
return fmt.Sprintf("%s-%s-%d-%d", b.name, b.identifier, b.limit, b.duration.Milliseconds())
}

// getOrCreateBucket retrieves a rate limiting bucket for the given key.
Expand All @@ -112,6 +119,7 @@ func (s *service) getOrCreateBucket(key bucketKey) (*bucket, bool) {
metrics.RatelimitBucketsCreated.Inc()
b = &bucket{
mu: sync.RWMutex{},
name: key.name,
identifier: key.identifier,
limit: key.limit,
duration: key.duration,
Expand Down
3 changes: 3 additions & 0 deletions go/internal/services/ratelimit/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@ type Service interface {
//
// Thread Safety: This type is immutable and safe for concurrent use.
type RatelimitRequest struct {
// Name is an arbitrary string that identifies the rate limit topic.
Name string

// Identifier uniquely identifies the rate limit subject.
// This could be:
// - A user ID
Expand Down
1 change: 1 addition & 0 deletions go/internal/services/ratelimit/replay.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func (s *service) syncWithOrigin(ctx context.Context, req RatelimitRequest) erro
}

key := bucketKey{
name: req.Name,
identifier: req.Identifier,
limit: req.Limit,
duration: req.Duration,
Expand Down
6 changes: 4 additions & 2 deletions go/internal/services/ratelimit/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ func (s *service) RatelimitMany(ctx context.Context, reqs []RatelimitRequest) ([
}

err := assert.All(
assert.NotEmpty(reqs[i].Name, "ratelimit name must not be empty"),
assert.NotEmpty(reqs[i].Identifier, "ratelimit identifier must not be empty"),
assert.Greater(reqs[i].Limit, 0, "ratelimit limit must be greater than zero"),
assert.GreaterOrEqual(reqs[i].Cost, 0, "ratelimit cost must not be negative"),
Expand All @@ -183,7 +184,7 @@ func (s *service) RatelimitMany(ctx context.Context, reqs []RatelimitRequest) ([
// Build and sort keys first (before getting buckets)
reqsWithKeys := make([]reqWithKey, len(reqs))
for i, req := range reqs {
key := bucketKey{req.Identifier, req.Limit, req.Duration}
key := bucketKey{req.Name, req.Identifier, req.Limit, req.Duration}
reqsWithKeys[i] = reqWithKey{
req: req,
key: key,
Expand Down Expand Up @@ -274,6 +275,7 @@ func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (Ratelimi

err := assert.All(
assert.NotEmpty(req.Identifier, "ratelimit identifier must not be empty"),
assert.NotEmpty(req.Name, "ratelimit name must not be empty"),
assert.Greater(req.Limit, 0, "ratelimit limit must be greater than zero"),
assert.GreaterOrEqual(req.Cost, 0, "ratelimit cost must not be negative"),
assert.GreaterOrEqual(req.Duration.Milliseconds(), 1000, "ratelimit duration must be at least 1s"),
Expand All @@ -283,7 +285,7 @@ func (s *service) Ratelimit(ctx context.Context, req RatelimitRequest) (Ratelimi
return RatelimitResponse{}, err
}

key := bucketKey{req.Identifier, req.Limit, req.Duration}
key := bucketKey{req.Name, req.Identifier, req.Limit, req.Duration}
span.SetAttributes(attribute.String("key", key.toString()))
b, _ := s.getOrCreateBucket(key)
b.mu.Lock()
Expand Down