diff --git a/go/pkg/db/retry.go b/go/pkg/db/retry.go index fa768a53a1..1cee53e0f5 100644 --- a/go/pkg/db/retry.go +++ b/go/pkg/db/retry.go @@ -56,12 +56,5 @@ func WithRetry[T any](fn func() (T, error)) (T, error) { }), ) - var result T - err := retrier.Do(func() error { - var retryErr error - result, retryErr = fn() - return retryErr - }) - - return result, err + return retry.DoWithResult(retrier, fn) } diff --git a/go/pkg/retry/retry.go b/go/pkg/retry/retry.go index 107d7c39f5..0336132022 100644 --- a/go/pkg/retry/retry.go +++ b/go/pkg/retry/retry.go @@ -12,7 +12,7 @@ type retry struct { attempts int // backoff is a function that returns the duration to wait before the next retry - // based on the current attempt number (zero-based) + // based on the current attempt number (starting at 1) backoff func(n int) time.Duration // shouldRetry is a function that determines if an error is retryable @@ -24,14 +24,14 @@ type retry struct { sleep func(d time.Duration) } -// Build creates a new retry instance with default configuration. +// New creates a new retry instance with default configuration. // Default configuration: // - 3 retry attempts // - Linear backoff starting at 100ms, increasing by 100ms per attempt // // Example: // -// r := retry.Build() +// r := retry.New() // err := r.Do(func() error { // // Simulate an operation that might fail // resp, err := http.Get("https://api.example.com") @@ -52,7 +52,7 @@ type retry struct { // // The retry behavior can be customized using Attempts() and Backoff(): // -// r := retry.Build( +// r := retry.New( // // retry.Attempts(5), // retry.Backoff(func(n int) time.Duration { @@ -90,7 +90,6 @@ func Attempts(attempts int) Apply { // should return the duration to wait before the next attempt. func Backoff(backoff func(n int) time.Duration) Apply { return func(r *retry) *retry { - r.backoff = backoff return r } @@ -154,3 +153,26 @@ func (r *retry) Do(fn func() error) error { return err } + +// DoWithResult executes the given function with configured retry behavior and returns a result. +// Works like Do() but for functions that return a value along with an error. +// On failure, returns the result from the last attempt along with the final error. +// +// Example: +// +// r := retry.New(retry.Attempts(3)) +// user, err := retry.DoWithResult(r, func() (*User, error) { +// return fetchUserFromAPI(userID) +// }) +// if err != nil { +// log.Printf("failed to fetch user after 3 attempts: %v", err) +// } +func DoWithResult[T any](r *retry, fn func() (T, error)) (T, error) { + var result T + err := r.Do(func() error { + var retryErr error + result, retryErr = fn() + return retryErr + }) + return result, err +} diff --git a/go/pkg/retry/retry_test.go b/go/pkg/retry/retry_test.go index 259ff4b29d..65d8e99d8a 100644 --- a/go/pkg/retry/retry_test.go +++ b/go/pkg/retry/retry_test.go @@ -115,7 +115,6 @@ func TestRetry(t *testing.T) { } else { require.NoError(t, err) } - }) } } @@ -222,3 +221,74 @@ func TestShouldRetry(t *testing.T) { }) } } + +func TestDoWithResult(t *testing.T) { + tests := []struct { + name string + errorSequence []error + resultSequence []string + expectedCalls int + expectedResult string + expectedError error + expectedSleep time.Duration + }{ + { + name: "should return result on first success", + errorSequence: []error{nil}, + resultSequence: []string{"success"}, + expectedCalls: 1, + expectedResult: "success", + expectedError: nil, + expectedSleep: 0, + }, + { + name: "should return result after retries", + errorSequence: []error{errors.New("temp"), errors.New("temp"), nil}, + resultSequence: []string{"", "", "success"}, + expectedCalls: 3, + expectedResult: "success", + expectedError: nil, + expectedSleep: 300 * time.Millisecond, + }, + { + name: "should return last result on complete failure", + errorSequence: []error{errors.New("fail1"), errors.New("fail2"), errors.New("fail3")}, + resultSequence: []string{"result1", "result2", "result3"}, + expectedCalls: 3, + expectedResult: "result3", // last attempt's result + expectedError: errors.New("fail3"), + expectedSleep: 300 * time.Millisecond, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + retrier := New() + totalSleep := time.Duration(0) + retrier.sleep = func(d time.Duration) { + totalSleep += d + } + + calls := 0 + result, err := DoWithResult(retrier, func() (string, error) { + idx := calls + calls++ + if idx < len(tt.errorSequence) { + return tt.resultSequence[idx], tt.errorSequence[idx] + } + return "", nil + }) + + require.Equal(t, tt.expectedCalls, calls, "unexpected number of calls") + require.Equal(t, tt.expectedSleep, totalSleep, "unexpected sleep duration") + require.Equal(t, tt.expectedResult, result, "unexpected result") + + if tt.expectedError != nil { + require.Error(t, err) + require.Equal(t, tt.expectedError.Error(), err.Error()) + } else { + require.NoError(t, err) + } + }) + } +}