diff --git a/go/apps/api/routes/v2_ratelimit_limit/handler.go b/go/apps/api/routes/v2_ratelimit_limit/handler.go index 7fa1bd5b9e..3c51686b83 100644 --- a/go/apps/api/routes/v2_ratelimit_limit/handler.go +++ b/go/apps/api/routes/v2_ratelimit_limit/handler.go @@ -74,11 +74,14 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { namespace, hit, err := h.RatelimitNamespaceCache.SWR(ctx, cache.ScopedKey{WorkspaceID: auth.AuthorizedWorkspaceID, Key: namespaceKey}, func(ctx context.Context) (db.FindRatelimitNamespace, error) { - response, err := db.Query.FindRatelimitNamespace(ctx, h.DB.RO(), db.FindRatelimitNamespaceParams{ - WorkspaceID: auth.AuthorizedWorkspaceID, - Namespace: namespaceKey, - }) result := db.FindRatelimitNamespace{} // nolint:exhaustruct + + response, err := db.WithRetry(func() (db.FindRatelimitNamespaceRow, error) { + return db.Query.FindRatelimitNamespace(ctx, h.DB.RO(), db.FindRatelimitNamespaceParams{ + WorkspaceID: auth.AuthorizedWorkspaceID, + Namespace: namespaceKey, + }) + }) if err != nil { return result, err } @@ -186,6 +189,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { Cost: cost, Time: time.Time{}, } + if h.TestMode { header := s.Request().Header.Get("X-Test-Time") if header != "" { @@ -216,6 +220,7 @@ func (h *Handler) Handle(ctx context.Context, s *zen.Session) error { Passed: result.Success, }) } + res := Response{ Meta: openapi.Meta{ RequestId: s.RequestID(), diff --git a/go/internal/services/caches/op.go b/go/internal/services/caches/op.go index 14427ca374..cfe5524fee 100644 --- a/go/internal/services/caches/op.go +++ b/go/internal/services/caches/op.go @@ -1,10 +1,8 @@ package caches import ( - "database/sql" - "errors" - "github.com/unkeyed/unkey/go/pkg/cache" + "github.com/unkeyed/unkey/go/pkg/db" ) // DefaultFindFirstOp returns the appropriate cache operation based on the sql error @@ -14,7 +12,7 @@ func DefaultFindFirstOp(err error) cache.Op { return cache.WriteValue } - if errors.Is(err, sql.ErrNoRows) { + if db.IsNotFound(err) { // the response is empty, we need to store that the row does not exist return cache.WriteNull } diff --git a/go/internal/services/keys/get.go b/go/internal/services/keys/get.go index c15c55b848..e39995f69d 100644 --- a/go/internal/services/keys/get.go +++ b/go/internal/services/keys/get.go @@ -68,8 +68,12 @@ func (s *service) Get(ctx context.Context, sess *zen.Session, rawKey string) (*K h := hash.Sha256(rawKey) key, hit, err := s.keyCache.SWR(ctx, h, func(ctx context.Context) (db.FindKeyForVerificationRow, error) { - return db.Query.FindKeyForVerification(ctx, s.db.RO(), h) + // Use database retry with exponential backoff, skipping non-transient errors + return db.WithRetry(func() (db.FindKeyForVerificationRow, error) { + return db.Query.FindKeyForVerification(ctx, s.db.RO(), h) + }) }, caches.DefaultFindFirstOp) + if err != nil { if db.IsNotFound(err) { // nolint:exhaustruct diff --git a/go/internal/services/usagelimiter/limit.go b/go/internal/services/usagelimiter/limit.go index 862f7b9766..7a1391e7fa 100644 --- a/go/internal/services/usagelimiter/limit.go +++ b/go/internal/services/usagelimiter/limit.go @@ -13,7 +13,9 @@ func (s *service) Limit(ctx context.Context, req UsageRequest) (UsageResponse, e ctx, span := tracing.Start(ctx, "usagelimiter.Limit") defer span.End() - limit, err := db.Query.FindKeyCredits(ctx, s.db.RW(), req.KeyId) + limit, err := db.WithRetry(func() (sql.NullInt32, error) { + return db.Query.FindKeyCredits(ctx, s.db.RO(), req.KeyId) + }) if err != nil { if db.IsNotFound(err) { return UsageResponse{Valid: false, Remaining: 0}, nil diff --git a/go/internal/services/usagelimiter/redis.go b/go/internal/services/usagelimiter/redis.go index 9ba3268032..34aba39539 100644 --- a/go/internal/services/usagelimiter/redis.go +++ b/go/internal/services/usagelimiter/redis.go @@ -232,7 +232,10 @@ func (s *counterService) initializeFromDatabase(ctx context.Context, req UsageRe ctx, span := tracing.Start(ctx, "usagelimiter.counter.initializeFromDatabase") defer span.End() - limit, err := db.Query.FindKeyCredits(ctx, s.db.RO(), req.KeyId) + limit, err := db.WithRetry(func() (sql.NullInt32, error) { + return db.Query.FindKeyCredits(ctx, s.db.RO(), req.KeyId) + }) + if err != nil { if db.IsNotFound(err) { return UsageResponse{Valid: false, Remaining: 0}, nil diff --git a/go/pkg/db/retry.go b/go/pkg/db/retry.go new file mode 100644 index 0000000000..e9bc9f95cf --- /dev/null +++ b/go/pkg/db/retry.go @@ -0,0 +1,67 @@ +package db + +import ( + "time" + + "github.com/unkeyed/unkey/go/pkg/retry" +) + +const ( + // DefaultBackoff is the base duration for exponential backoff in database retries + DefaultBackoff = 50 * time.Millisecond +) + +// WithRetry executes a database operation with optimized retry configuration. +// It retries transient errors with exponential backoff but skips non-retryable errors +// like "not found" or "duplicate key" to avoid unnecessary delays. +// +// Configuration: +// - 3 attempts maximum +// - Exponential backoff: 50ms, 100ms, 200ms +// - Skips retries for "not found" and "duplicate key" errors +// +// Usage: +// +// result, err := db.WithRetry(func() (SomeType, error) { +// return db.Query.SomeOperation(ctx, db.RO(), params) +// }) +func WithRetry[T any](fn func() (T, error)) (T, error) { + retrier := retry.New( + retry.Attempts(3), + retry.Backoff(func(n int) time.Duration { + // Predefined backoff delays: 50ms, 100ms, 200ms + delays := []time.Duration{ + DefaultBackoff, // 50ms for attempt 1 + DefaultBackoff * 2, // 100ms for attempt 2 + DefaultBackoff * 4, // 200ms for attempt 3 + } + if n <= 0 || n > len(delays) { + return DefaultBackoff // fallback to base delay + } + return delays[n-1] + }), + retry.ShouldRetry(func(err error) bool { + // Don't retry if resource is not found - this is a valid response + if IsNotFound(err) { + return false + } + + // Don't retry duplicate key errors - these won't succeed on retry + if IsDuplicateKeyError(err) { + return false + } + + // Retry all other errors (network issues, timeouts, deadlocks, etc.) + return true + }), + ) + + var result T + err := retrier.Do(func() error { + var retryErr error + result, retryErr = fn() + return retryErr + }) + + return result, err +} diff --git a/go/pkg/db/retry_test.go b/go/pkg/db/retry_test.go new file mode 100644 index 0000000000..c4d8463f43 --- /dev/null +++ b/go/pkg/db/retry_test.go @@ -0,0 +1,246 @@ +package db + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + "github.com/go-sql-driver/mysql" + "github.com/stretchr/testify/require" + "github.com/unkeyed/unkey/go/pkg/hash" + "github.com/unkeyed/unkey/go/pkg/otel/logging" + "github.com/unkeyed/unkey/go/pkg/testutil/containers" + "github.com/unkeyed/unkey/go/pkg/uid" +) + +func TestWithRetry_Success(t *testing.T) { + callCount := 0 + + result, err := WithRetry(func() (string, error) { + callCount++ + return "success", nil + }) + + require.NoError(t, err) + require.Equal(t, "success", result) + require.Equal(t, 1, callCount, "should succeed on first try") +} + +func TestWithRetry_RetriesTransientErrors(t *testing.T) { + callCount := 0 + transientErr := errors.New("connection timeout") + + result, err := WithRetry(func() (string, error) { + callCount++ + if callCount < 3 { + return "", transientErr + } + return "success", nil + }) + + require.NoError(t, err) + require.Equal(t, "success", result) + require.Equal(t, 3, callCount, "should retry twice then succeed") +} + +func TestWithRetry_SkipsRetryOnNotFound(t *testing.T) { + callCount := 0 + + result, err := WithRetry(func() (string, error) { + callCount++ + return "", sql.ErrNoRows + }) + + require.Error(t, err) + require.True(t, IsNotFound(err)) + require.Equal(t, "", result) + require.Equal(t, 1, callCount, "should not retry on not found error") +} + +func TestWithRetry_SkipsRetryOnDuplicateKey(t *testing.T) { + callCount := 0 + duplicateKeyErr := &mysql.MySQLError{Number: 1062, Message: "Duplicate entry"} + + result, err := WithRetry(func() (string, error) { + callCount++ + return "", duplicateKeyErr + }) + + require.Error(t, err) + require.True(t, IsDuplicateKeyError(err)) + require.Equal(t, "", result) + require.Equal(t, 1, callCount, "should not retry on duplicate key error") +} + +func TestWithRetry_ExhaustsRetries(t *testing.T) { + callCount := 0 + transientErr := errors.New("persistent connection failure") + + result, err := WithRetry(func() (string, error) { + callCount++ + return "", transientErr + }) + + require.Error(t, err) + require.Equal(t, transientErr, err) + require.Equal(t, "", result) + require.Equal(t, 3, callCount, "should try 3 times then give up") +} + +func TestWithRetry_GenericTypes(t *testing.T) { + t.Run("int type", func(t *testing.T) { + result, err := WithRetry(func() (int, error) { + return 42, nil + }) + + require.NoError(t, err) + require.Equal(t, 42, result) + }) + + t.Run("struct type", func(t *testing.T) { + type TestStruct struct { + ID int + Name string + } + + expected := TestStruct{ID: 1, Name: "test"} + result, err := WithRetry(func() (TestStruct, error) { + return expected, nil + }) + + require.NoError(t, err) + require.Equal(t, expected, result) + }) +} + +// TestWithRetry_Integration tests retry functionality with a real database connection +// This test requires Docker to be running for the MySQL container +func TestWithRetry_Integration(t *testing.T) { + if testing.Short() { + t.Skip("Skipping integration test in short mode") + } + ctx := context.Background() + + // Set up test database using containers + mysqlCfg := containers.MySQL(t) + mysqlCfg.DBName = "unkey" + + // Create database instance + dbInstance, err := New(Config{ + PrimaryDSN: mysqlCfg.FormatDSN(), + Logger: logging.NewNoop(), + }) + require.NoError(t, err) + defer dbInstance.Close() + + // Create test data using sqlc statements + workspaceID := uid.New(uid.WorkspacePrefix) + keyringID := uid.New(uid.KeyAuthPrefix) + + // Insert workspace using sqlc + err = Query.InsertWorkspace(ctx, dbInstance.RW(), InsertWorkspaceParams{ + ID: workspaceID, + OrgID: workspaceID, + Name: "Test Workspace", + CreatedAt: time.Now().UnixMilli(), + }) + require.NoError(t, err) + + // Insert keyring using sqlc + err = Query.InsertKeyring(ctx, dbInstance.RW(), InsertKeyringParams{ + ID: keyringID, + WorkspaceID: workspaceID, + CreatedAtM: time.Now().UnixMilli(), + }) + require.NoError(t, err) + + t.Run("retry with real database - success after transient failure", func(t *testing.T) { + callCount := 0 + _, err := WithRetry(func() (string, error) { + callCount++ + + // Simulate transient failure on first attempt + if callCount == 1 { + return "", errors.New("dial tcp: connection refused") + } + + // Succeed on second attempt - insert using sqlc + keyID := uid.New(uid.KeyPrefix) + err := Query.InsertKey(ctx, dbInstance.RW(), InsertKeyParams{ + ID: keyID, + KeyringID: keyringID, + Hash: hash.Sha256(keyID), + Start: "retry_start", + WorkspaceID: workspaceID, + ForWorkspaceID: sql.NullString{}, + Name: sql.NullString{String: "retry_key", Valid: true}, + IdentityID: sql.NullString{}, + Meta: sql.NullString{}, + Expires: sql.NullTime{}, + CreatedAtM: time.Now().UnixMilli(), + Enabled: true, + RemainingRequests: sql.NullInt32{}, + RefillDay: sql.NullInt16{}, + RefillAmount: sql.NullInt32{}, + }) + + return keyID, err + }) + + require.NoError(t, err) + require.Equal(t, 2, callCount, "should retry once then succeed") + }) + + t.Run("retry with real database - no retry on duplicate key", func(t *testing.T) { + // Insert initial key using sqlc + keyID := uid.New(uid.KeyPrefix) + + keyParams := InsertKeyParams{ + ID: keyID, + KeyringID: keyringID, + Hash: hash.Sha256(keyID), + Start: "dup_start", + WorkspaceID: workspaceID, + ForWorkspaceID: sql.NullString{}, + Name: sql.NullString{String: "dup_key", Valid: true}, + IdentityID: sql.NullString{}, + Meta: sql.NullString{}, + Expires: sql.NullTime{}, + CreatedAtM: time.Now().UnixMilli(), + Enabled: true, + RemainingRequests: sql.NullInt32{}, + RefillDay: sql.NullInt16{}, + RefillAmount: sql.NullInt32{}, + } + err := Query.InsertKey(ctx, dbInstance.RW(), keyParams) + require.NoError(t, err) + + callCount := 0 + _, err = WithRetry(func() (string, error) { + callCount++ + + // Try to insert duplicate key - should not be retried + err := Query.InsertKey(ctx, dbInstance.RW(), keyParams) + return "success", err + }) + + require.Error(t, err) + require.True(t, IsDuplicateKeyError(err)) + require.Equal(t, 1, callCount, "should not retry on duplicate key error") + }) + + t.Run("retry with real database - no retry on not found", func(t *testing.T) { + callCount := 0 + _, err := WithRetry(func() (FindKeyForVerificationRow, error) { + callCount++ + // Try to find non-existent key using sqlc - should not be retried + return Query.FindKeyForVerification(ctx, dbInstance.RO(), uid.New(uid.KeyPrefix)) + }) + + require.Error(t, err) + require.True(t, IsNotFound(err)) + require.Equal(t, 1, callCount, "should not retry on not found error") + }) +} diff --git a/go/pkg/retry/retry.go b/go/pkg/retry/retry.go index 437dd6188a..107d7c39f5 100644 --- a/go/pkg/retry/retry.go +++ b/go/pkg/retry/retry.go @@ -10,10 +10,15 @@ import ( type retry struct { // attempts is the maximum number of times to try the operation attempts int + // backoff is a function that returns the duration to wait before the next retry // based on the current attempt number (zero-based) backoff func(n int) time.Duration + // shouldRetry is a function that determines if an error is retryable + // If nil, all errors are considered retryable + shouldRetry func(error) bool + // used for testing // overwrite time.Sleep to speed up tests sleep func(d time.Duration) @@ -55,11 +60,12 @@ type retry struct { // }), // // ) -func New(applies ...apply) *retry { +func New(applies ...Apply) *retry { r := &retry{ - attempts: 3, - backoff: func(n int) time.Duration { return time.Duration(n) * 100 * time.Millisecond }, - sleep: time.Sleep, + attempts: 3, + backoff: func(n int) time.Duration { return time.Duration(n) * 100 * time.Millisecond }, + shouldRetry: nil, // nil means all errors are retryable + sleep: time.Sleep, } for _, a := range applies { r = a(r) @@ -67,12 +73,12 @@ func New(applies ...apply) *retry { return r } -// apply modifies r and returns it -type apply func(r *retry) *retry +// Apply modifies r and returns it +type Apply func(r *retry) *retry // Attempts sets the maximum number of retry attempts. // The operation will be attempted up to this many times before giving up. -func Attempts(attempts int) apply { +func Attempts(attempts int) Apply { return func(r *retry) *retry { r.attempts = attempts return r @@ -82,7 +88,7 @@ func Attempts(attempts int) apply { // Backoff sets the backoff strategy function. // The function receives the current attempt number (starting with 1) and // should return the duration to wait before the next attempt. -func Backoff(backoff func(n int) time.Duration) apply { +func Backoff(backoff func(n int) time.Duration) Apply { return func(r *retry) *retry { r.backoff = backoff @@ -90,11 +96,37 @@ func Backoff(backoff func(n int) time.Duration) apply { } } +// ShouldRetry sets a function to determine if an error should trigger a retry. +// If not set or set to nil, all errors will trigger retries. +// This is useful for skipping retries on non-transient errors like "not found". +// +// Example: +// +// r := retry.New( +// retry.Attempts(3), +// retry.ShouldRetry(func(err error) bool { +// // Don't retry if the error is a "not found" error +// if errors.Is(err, ErrNotFound) { +// return false +// } +// // Retry all other errors +// return true +// }), +// ) +func ShouldRetry(shouldRetry func(error) bool) Apply { + return func(r *retry) *retry { + r.shouldRetry = shouldRetry + return r + } +} + // Do executes the given function with configured retry behavior. // The function is retried until it succeeds or the maximum number of attempts is reached. // Between attempts, the backoff function determines the wait duration. +// If shouldRetry is configured, it will be called to determine if a retry should occur. // -// Returns nil if the operation succeeds, or the last error encountered if all retries fail. +// Returns nil if the operation succeeds, or the last error encountered if all retries fail +// or if the error is non-retryable according to shouldRetry. // Returns an error if attempts is configured to less than 1. func (r *retry) Do(fn func() error) error { if r.attempts < 1 { @@ -107,7 +139,18 @@ func (r *retry) Do(fn func() error) error { if err == nil { return nil } - r.sleep(r.backoff(i)) + + // Check if we should retry this error + if r.shouldRetry != nil && !r.shouldRetry(err) { + // Error is not retryable, return immediately + return err + } + + // Don't sleep after the last attempt + if i < r.attempts { + r.sleep(r.backoff(i)) + } } + return err } diff --git a/go/pkg/retry/retry_test.go b/go/pkg/retry/retry_test.go index f3aa0b4de1..5f6234026c 100644 --- a/go/pkg/retry/retry_test.go +++ b/go/pkg/retry/retry_test.go @@ -43,7 +43,7 @@ func TestRetry(t *testing.T) { }, expectedCalls: 3, expectedError: true, - expectedSleep: 600 * time.Millisecond, + expectedSleep: 300 * time.Millisecond, // 100ms + 200ms (no sleep after final attempt) }, { name: "invalid attempts", @@ -73,7 +73,7 @@ func TestRetry(t *testing.T) { ), fn: failNTimes(3), expectedCalls: 3, - expectedSleep: 14 * time.Second, + expectedSleep: 5 * time.Second, // 1s + 4s (no sleep after final attempt) expectedError: true, }, { @@ -85,7 +85,7 @@ func TestRetry(t *testing.T) { fn: failNTimes(3), expectedCalls: 3, expectedError: true, - expectedSleep: 3 * time.Second, + expectedSleep: 2 * time.Second, // 1s + 1s (no sleep after final attempt) }, } @@ -132,3 +132,93 @@ func failNTimes(n int) func() error { return nil } } + +func TestShouldRetry(t *testing.T) { + nonRetryableError := errors.New("non-retryable") + retryableError := errors.New("retryable") + + tests := []struct { + name string + shouldRetry func(error) bool + errorSequence []error + expectedCalls int + expectedError error + expectedSleep time.Duration + }{ + { + name: "should retry all errors by default", + shouldRetry: nil, // default behavior + errorSequence: []error{retryableError, retryableError, nil}, + expectedCalls: 3, + expectedError: nil, + expectedSleep: 300 * time.Millisecond, // 100ms + 200ms + }, + { + name: "should not retry non-retryable errors", + shouldRetry: func(err error) bool { + return err != nonRetryableError + }, + errorSequence: []error{nonRetryableError}, + expectedCalls: 1, + expectedError: nonRetryableError, + expectedSleep: 0, // no retry, no sleep + }, + { + name: "should retry retryable errors but not non-retryable ones", + shouldRetry: func(err error) bool { + return err != nonRetryableError + }, + errorSequence: []error{retryableError, nonRetryableError}, + expectedCalls: 2, + expectedError: nonRetryableError, + expectedSleep: 100 * time.Millisecond, // only one retry before hitting non-retryable + }, + { + name: "should eventually succeed after retrying retryable errors", + shouldRetry: func(err error) bool { + return err != nonRetryableError + }, + errorSequence: []error{retryableError, retryableError, nil}, + expectedCalls: 3, + expectedError: nil, + expectedSleep: 300 * time.Millisecond, // 100ms + 200ms + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var retrier *retry + if tt.shouldRetry != nil { + retrier = New(ShouldRetry(tt.shouldRetry)) + } else { + retrier = New() + } + + totalSleep := time.Duration(0) + retrier.sleep = func(d time.Duration) { + totalSleep += d + } + + calls := 0 + err := retrier.Do(func() error { + if calls < len(tt.errorSequence) { + err := tt.errorSequence[calls] + calls++ + return err + } + calls++ + return nil + }) + + require.Equal(t, tt.expectedCalls, calls, "unexpected number of calls") + require.Equal(t, tt.expectedSleep, totalSleep, "unexpected sleep duration") + + if tt.expectedError != nil { + require.Error(t, err) + require.Equal(t, tt.expectedError, err) + } else { + require.NoError(t, err) + } + }) + } +} diff --git a/go/pkg/zen/middleware_errors.go b/go/pkg/zen/middleware_errors.go index 7dd1ee1261..92c2b11a1d 100644 --- a/go/pkg/zen/middleware_errors.go +++ b/go/pkg/zen/middleware_errors.go @@ -201,6 +201,7 @@ func WithErrorHandling(logger logging.Logger) Middleware { "requestId", s.RequestID(), "publicMessage", fault.UserFacingMessage(err), ) + return s.JSON(http.StatusInternalServerError, openapi.InternalServerErrorResponse{ Meta: openapi.Meta{ RequestId: s.RequestID(),