Skip to content
9 changes: 5 additions & 4 deletions go/apps/api/routes/v2_ratelimit_delete_override/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@ import (
"github.com/unkeyed/unkey/go/pkg/zen"
)

type Request = openapi.V2RatelimitDeleteOverrideRequestBody
type Response = openapi.V2RatelimitDeleteOverrideResponseBody
type (
Request = openapi.V2RatelimitDeleteOverrideRequestBody
Response = openapi.V2RatelimitDeleteOverrideResponseBody
)

type Handler struct {
Logger logging.Logger
Expand Down Expand Up @@ -59,7 +61,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
cache.ScopedKey{WorkspaceID: auth.AuthorizedWorkspaceID, Key: req.Namespace},
func(ctx context.Context) (db.FindRatelimitNamespace, error) {
result := db.FindRatelimitNamespace{} // nolint:exhaustruct
response, err := db.WithRetry(func() (db.FindRatelimitNamespaceRow, error) {
response, err := db.WithRetryContext(ctx, func() (db.FindRatelimitNamespaceRow, error) {
return db.Query.FindRatelimitNamespace(ctx, h.DB.RO(), db.FindRatelimitNamespaceParams{
WorkspaceID: auth.AuthorizedWorkspaceID,
Namespace: req.Namespace,
Expand Down Expand Up @@ -99,7 +101,6 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {
},
caches.DefaultFindFirstOp,
)

if err != nil {
if db.IsNotFound(err) {
return fault.New("namespace not found",
Expand Down
10 changes: 6 additions & 4 deletions go/apps/api/routes/v2_ratelimit_limit/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ import (
"github.com/unkeyed/unkey/go/pkg/zen"
)

type Request = openapi.V2RatelimitLimitRequestBody
type Response = openapi.V2RatelimitLimitResponseBody
type (
Request = openapi.V2RatelimitLimitRequestBody
Response = openapi.V2RatelimitLimitResponseBody
)

// Handler implements zen.Route interface for the v2 ratelimit limit endpoint
type Handler struct {
Expand Down Expand Up @@ -73,9 +75,9 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error {

cacheKey := cache.ScopedKey{WorkspaceID: auth.AuthorizedWorkspaceID, Key: req.Namespace}

var loader = func(ctx context.Context) (db.FindRatelimitNamespace, error) {
loader := func(ctx context.Context) (db.FindRatelimitNamespace, error) {
result := db.FindRatelimitNamespace{} // nolint:exhaustruct
response, err := db.WithRetry(func() (db.FindRatelimitNamespaceRow, error) {
response, err := db.WithRetryContext(ctx, func() (db.FindRatelimitNamespaceRow, error) {
return db.Query.FindRatelimitNamespace(ctx, h.DB.RO(), db.FindRatelimitNamespaceParams{
WorkspaceID: auth.AuthorizedWorkspaceID,
Namespace: req.Namespace,
Expand Down
3 changes: 1 addition & 2 deletions go/internal/services/keys/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (s *service) Get(ctx context.Context, sess *zen.Session, rawKey string) (*K
h := hash.Sha256(rawKey)
key, hit, err := s.keyCache.SWR(ctx, h, func(ctx context.Context) (db.CachedKeyData, error) {
// Use database retry with exponential backoff, skipping non-transient errors
row, err := db.WithRetry(func() (db.FindKeyForVerificationRow, error) {
row, err := db.WithRetryContext(ctx, func() (db.FindKeyForVerificationRow, error) {
return db.Query.FindKeyForVerification(ctx, s.db.RO(), h)
})
if err != nil {
Expand All @@ -94,7 +94,6 @@ func (s *service) Get(ctx context.Context, sess *zen.Session, rawKey string) (*K
ParsedIPWhitelist: parsedIPWhitelist,
}, nil
}, caches.DefaultFindFirstOp)

if err != nil {
if db.IsNotFound(err) {
// nolint:exhaustruct
Expand Down
2 changes: 1 addition & 1 deletion go/internal/services/usagelimiter/limit.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ func (s *service) Limit(ctx context.Context, req UsageRequest) (UsageResponse, e
ctx, span := tracing.Start(ctx, "usagelimiter.Limit")
defer span.End()

limit, err := db.WithRetry(func() (sql.NullInt32, error) {
limit, err := db.WithRetryContext(ctx, func() (sql.NullInt32, error) {
return db.Query.FindKeyCredits(ctx, s.db.RO(), req.KeyId)
})
if err != nil {
Expand Down
7 changes: 1 addition & 6 deletions go/internal/services/usagelimiter/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@ func (s *counterService) Limit(ctx context.Context, req UsageRequest) (UsageResp

func (s *counterService) Invalidate(ctx context.Context, keyID string) error {
return s.counter.Delete(ctx, s.redisKey(keyID))

}

func (s *counterService) redisKey(keyID string) string {
Expand Down Expand Up @@ -239,10 +238,9 @@ func (s *counterService) initializeFromDatabase(ctx context.Context, req UsageRe
ctx, span := tracing.Start(ctx, "usagelimiter.counter.initializeFromDatabase")
defer span.End()

limit, err := db.WithRetry(func() (sql.NullInt32, error) {
limit, err := db.WithRetryContext(ctx, func() (sql.NullInt32, error) {
return db.Query.FindKeyCredits(ctx, s.db.RO(), req.KeyId)
})

if err != nil {
if db.IsNotFound(err) {
return UsageResponse{Valid: false, Remaining: 0}, nil
Expand Down Expand Up @@ -302,7 +300,6 @@ func (s *counterService) initializeFromDatabase(ctx context.Context, req UsageRe
func (s *counterService) replayRequests() {
for change := range s.replayBuffer.Consume() {
err := s.syncWithDB(context.Background(), change)

if err != nil {
s.logger.Error("failed to replay credit change", "error", err)
}
Expand All @@ -324,7 +321,6 @@ func (s *counterService) syncWithDB(ctx context.Context, change CreditChange) er
Credits: sql.NullInt32{Int32: change.Cost, Valid: true},
})
})

if err != nil {
metrics.UsagelimiterReplayOperations.WithLabelValues("error").Inc()
return err
Expand Down Expand Up @@ -393,5 +389,4 @@ func (s *counterService) Close() error {
s.logger.Debug("usage limiter replay buffer drained successfully")
return nil
}

}
75 changes: 43 additions & 32 deletions go/pkg/db/retry.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package db

import (
"context"
"time"

"github.com/unkeyed/unkey/go/pkg/retry"
Expand All @@ -9,52 +10,62 @@ import (
const (
// DefaultBackoff is the base duration for exponential backoff in database retries
DefaultBackoff = 50 * time.Millisecond
// DefaultAttempts is the maximum number of retry attempts for database operations
DefaultAttempts = 3
)

// WithRetry executes a database operation with optimized retry configuration.
// WithRetryContext executes a database operation with optimized retry configuration while respecting context cancellation and deadlines.
// It retries transient errors with exponential backoff but skips non-retryable errors
// like "not found" or "duplicate key" to avoid unnecessary delays.
//
// Context behavior:
// - Returns immediately if context is already cancelled or deadline exceeded
// - Detects context cancellation during backoff sleep without waiting for full duration
// - Returns context.Canceled or context.DeadlineExceeded on context errors
//
// Configuration:
// - 3 attempts maximum
// - Exponential backoff: 50ms, 100ms, 200ms
// - Skips retries for "not found" and "duplicate key" errors
//
// Usage:
//
// result, err := db.WithRetry(func() (SomeType, error) {
// result, err := db.WithRetryContext(ctx, func() (SomeType, error) {
// return db.Query.SomeOperation(ctx, db.RO(), params)
// })
func WithRetry[T any](fn func() (T, error)) (T, error) {
retrier := retry.New(
retry.Attempts(3),
retry.Backoff(func(n int) time.Duration {
// Predefined backoff delays: 50ms, 100ms, 200ms
delays := []time.Duration{
DefaultBackoff, // 50ms for attempt 1
DefaultBackoff * 2, // 100ms for attempt 2
DefaultBackoff * 4, // 200ms for attempt 3
}
if n <= 0 || n > len(delays) {
return DefaultBackoff // fallback to base delay
}
return delays[n-1]
}),
retry.ShouldRetry(func(err error) bool {
// Don't retry if resource is not found - this is a valid response
if IsNotFound(err) {
return false
}

// Don't retry duplicate key errors - these won't succeed on retry
if IsDuplicateKeyError(err) {
return false
}

// Retry all other errors (network issues, timeouts, deadlocks, etc.)
return true
}),
func WithRetryContext[T any](ctx context.Context, fn func() (T, error)) (T, error) {
return retry.DoWithResultContext(
retry.New(
retry.Attempts(DefaultAttempts),
retry.Backoff(backoffStrategy),
retry.ShouldRetry(shouldRetryError),
),
ctx,
fn,
)
}

// backoffStrategy defines exponential backoff delays: 50ms, 100ms, 200ms
func backoffStrategy(n int) time.Duration {
delays := []time.Duration{
DefaultBackoff, // 50ms for attempt 1
DefaultBackoff * 2, // 100ms for attempt 2
DefaultBackoff * 4, // 200ms for attempt 3
}
if n <= 0 || n > len(delays) {
return DefaultBackoff
}
return delays[n-1]
}

return retry.DoWithResult(retrier, fn)
// shouldRetryError determines if a database error should trigger a retry.
// Returns false for "not found" and "duplicate key" errors as these won't succeed on retry.
func shouldRetryError(err error) bool {
if IsNotFound(err) {
return false
}
if IsDuplicateKeyError(err) {
return false
}
return true
}
Loading