Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions go/apps/api/routes/v2_keys_verify_key/ratelimit_response_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,106 @@ func TestRatelimitResponse(t *testing.T) {
require.Equal(t, int64(7), rl.Remaining, "Should have 7 remaining (10-3)")
require.False(t, rl.AutoApply, "Custom rate limit should not be auto-applied")
})

t.Run("multiple rate limits with accurate remaining counters", func(t *testing.T) {
key := h.CreateKey(seed.CreateKeyRequest{
WorkspaceID: workspace.ID,
KeyAuthID: api.KeyAuthID.String,
Ratelimits: []seed.CreateRatelimitRequest{
{
Name: "fast-limit",
WorkspaceID: workspace.ID,
AutoApply: true,
Duration: time.Minute.Milliseconds(),
Limit: 3,
},
{
Name: "slow-limit",
WorkspaceID: workspace.ID,
AutoApply: true,
Duration: time.Hour.Milliseconds(),
Limit: 10,
},
},
})

req := handler.Request{
Key: key.Key,
}

// Helper function to find rate limit by name
findRatelimit := func(ratelimits []openapi.VerifyKeyRatelimitData, name string) *openapi.VerifyKeyRatelimitData {
for _, rl := range ratelimits {
if rl.Name == name {
return &rl
}
}
return nil
}

// Request 1: Both limits should decrement
res := testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req)
require.Equal(t, 200, res.Status)
require.Equal(t, openapi.VALID, res.Body.Data.Code)
require.True(t, res.Body.Data.Valid)

ratelimits := *res.Body.Data.Ratelimits
require.Len(t, ratelimits, 2, "Should have two rate limits")

fastLimit := findRatelimit(ratelimits, "fast-limit")
slowLimit := findRatelimit(ratelimits, "slow-limit")
require.NotNil(t, fastLimit, "fast-limit should be present")
require.NotNil(t, slowLimit, "slow-limit should be present")

require.Equal(t, int64(2), fastLimit.Remaining, "fast-limit: expected remaining=2 after 1st request")
require.Equal(t, int64(9), slowLimit.Remaining, "slow-limit: expected remaining=9 after 1st request")
require.False(t, fastLimit.Exceeded)
require.False(t, slowLimit.Exceeded)

// Request 2: Both limits should decrement again
res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req)
require.Equal(t, 200, res.Status)
require.Equal(t, openapi.VALID, res.Body.Data.Code)
require.True(t, res.Body.Data.Valid)

ratelimits = *res.Body.Data.Ratelimits
fastLimit = findRatelimit(ratelimits, "fast-limit")
slowLimit = findRatelimit(ratelimits, "slow-limit")

require.Equal(t, int64(1), fastLimit.Remaining, "fast-limit: expected remaining=1 after 2nd request")
require.Equal(t, int64(8), slowLimit.Remaining, "slow-limit: expected remaining=8 after 2nd request")
require.False(t, fastLimit.Exceeded)
require.False(t, slowLimit.Exceeded)

// Request 3: Both limits should decrement, fast-limit reaches 0
res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req)
require.Equal(t, 200, res.Status)
require.Equal(t, openapi.VALID, res.Body.Data.Code)
require.True(t, res.Body.Data.Valid)

ratelimits = *res.Body.Data.Ratelimits
fastLimit = findRatelimit(ratelimits, "fast-limit")
slowLimit = findRatelimit(ratelimits, "slow-limit")

require.Equal(t, int64(0), fastLimit.Remaining, "fast-limit: expected remaining=0 after 3rd request")
require.Equal(t, int64(7), slowLimit.Remaining, "slow-limit: expected remaining=7 after 3rd request")
require.False(t, fastLimit.Exceeded)
require.False(t, slowLimit.Exceeded)

// Request 4: fast-limit should be exceeded, slow-limit continues
res = testutil.CallRoute[handler.Request, handler.Response](h, route, headers, req)
require.Equal(t, 200, res.Status)
require.Equal(t, openapi.RATELIMITED, res.Body.Data.Code)
require.False(t, res.Body.Data.Valid, "Key should be rate limited")

ratelimits = *res.Body.Data.Ratelimits
fastLimit = findRatelimit(ratelimits, "fast-limit")
slowLimit = findRatelimit(ratelimits, "slow-limit")

require.Equal(t, int64(0), fastLimit.Remaining, "fast-limit: expected remaining=0 when exceeded")
require.True(t, fastLimit.Exceeded, "fast-limit should be exceeded")
// slow-limit should NOT increment since the request was denied
require.Equal(t, int64(7), slowLimit.Remaining, "slow-limit: should not decrement when request is denied")
require.False(t, slowLimit.Exceeded, "slow-limit should not be exceeded")
})
}
3 changes: 1 addition & 2 deletions go/apps/api/routes/v2_ratelimit_limit/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,14 +272,13 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
}
}

results, err := h.Ratelimit.Ratelimit(ctx, []ratelimit.RatelimitRequest{limitReq})
result, err := h.Ratelimit.Ratelimit(ctx, limitReq)
if err != nil {
return fault.Wrap(err,
fault.Internal("rate limit failed"),
fault.Public("We're unable to process the rate limit request."),
)
}
result := results[0]

if s.ShouldLogRequestToClickHouse() {
h.ClickHouse.BufferRatelimit(schema.RatelimitRequestV1{
Expand Down
22 changes: 18 additions & 4 deletions go/internal/services/keys/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,20 @@ func (k *KeyVerifier) withRateLimits(ctx context.Context, specifiedLimits []open
})
}

resp, err := k.rateLimiter.Ratelimit(ctx, ratelimitRequests)
// Use different rate limiting paths based on number of limits
var resp []ratelimit.RatelimitResponse
var err error

if len(ratelimitRequests) == 1 {
// Single rate limit - use fast path
singleResp, singleErr := k.rateLimiter.Ratelimit(ctx, ratelimitRequests[0])
resp = []ratelimit.RatelimitResponse{singleResp}
err = singleErr
} else {
// Multiple rate limits - use atomic all-or-nothing path
resp, err = k.rateLimiter.RatelimitMany(ctx, ratelimitRequests)
}

if err != nil {
k.logger.Error("Failed to ratelimit",
"key_id", k.Key.ID,
Expand All @@ -214,13 +227,14 @@ func (k *KeyVerifier) withRateLimits(ctx context.Context, specifiedLimits []open
// We will just allow the request to proceed, but log the error
return nil
}
for i, response := range resp {

for i := range resp {
// Write response back to config to be passed to the client
config := ratelimitsToCheck[names[i]]
config.Response = &response
config.Response = &resp[i]
ratelimitsToCheck[names[i]] = config

if !response.Success {
if !resp[i].Success {
k.setInvalid(StatusRateLimited, fmt.Sprintf("key exceeded rate limit %s", names[i]))
}
}
Expand Down
12 changes: 12 additions & 0 deletions go/internal/services/ratelimit/bucket.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ type bucket struct {
// mu protects all bucket operations
mu sync.RWMutex

// identifier is the rate limit subject (user ID, API key, etc)
identifier string

// limit is the maximum number of requests allowed per duration
limit int64

Expand All @@ -48,6 +51,14 @@ type bucket struct {
strictUntil time.Time
}

func (b *bucket) key() bucketKey {
return bucketKey{
identifier: b.identifier,
limit: b.limit,
duration: b.duration,
}
}

// bucketKey uniquely identifies a rate limit bucket by combining the
// identifier, limit, and duration. This ensures separate tracking when
// the same identifier has different rate limit configurations.
Expand Down Expand Up @@ -101,6 +112,7 @@ func (s *service) getOrCreateBucket(key bucketKey) (*bucket, bool) {
metrics.RatelimitBucketsCreated.Inc()
b = &bucket{
mu: sync.RWMutex{},
identifier: key.identifier,
limit: key.limit,
duration: key.duration,
windows: make(map[int64]*window),
Expand Down
22 changes: 8 additions & 14 deletions go/internal/services/ratelimit/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,21 +33,15 @@ type Service interface {
// Performance: O(1) time complexity for local decisions
//
// Example Usage:
// responses, err := svc.Ratelimit(ctx, []RatelimitRequest{
// {
// Identifier: "user-123",
// Limit: 100,
// Duration: time.Minute,
// Cost: 1,
// },
// {
// Identifier: "user-456",
// Limit: 50,
// Duration: time.Minute,
// Cost: 2,
// },
// response, err := svc.Ratelimit(ctx, RatelimitRequest{
// Identifier: "user-123",
// Limit: 100,
// Duration: time.Minute,
// Cost: 1,
// })
Ratelimit(context.Context, []RatelimitRequest) ([]RatelimitResponse, error)
Ratelimit(context.Context, RatelimitRequest) (RatelimitResponse, error)

RatelimitMany(context.Context, []RatelimitRequest) ([]RatelimitResponse, error)
}

// RatelimitRequest represents a request to check or consume rate limit tokens.
Expand Down
Loading