Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions go/pkg/db/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,5 @@ func WithRetry[T any](fn func() (T, error)) (T, error) {
}),
)

var result T
err := retrier.Do(func() error {
var retryErr error
result, retryErr = fn()
return retryErr
})

return result, err
return retry.DoWithResult(retrier, fn)
}
32 changes: 27 additions & 5 deletions go/pkg/retry/retry.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ type retry struct {
attempts int

// backoff is a function that returns the duration to wait before the next retry
// based on the current attempt number (zero-based)
// based on the current attempt number (starting at 1)
backoff func(n int) time.Duration

// shouldRetry is a function that determines if an error is retryable
Expand All @@ -24,14 +24,14 @@ type retry struct {
sleep func(d time.Duration)
}

// Build creates a new retry instance with default configuration.
// New creates a new retry instance with default configuration.
// Default configuration:
// - 3 retry attempts
// - Linear backoff starting at 100ms, increasing by 100ms per attempt
//
// Example:
//
// r := retry.Build()
// r := retry.New()
// err := r.Do(func() error {
// // Simulate an operation that might fail
// resp, err := http.Get("https://api.example.com")
Expand All @@ -52,7 +52,7 @@ type retry struct {
//
// The retry behavior can be customized using Attempts() and Backoff():
//
// r := retry.Build(
// r := retry.New(
//
// retry.Attempts(5),
// retry.Backoff(func(n int) time.Duration {
Expand Down Expand Up @@ -90,7 +90,6 @@ func Attempts(attempts int) Apply {
// should return the duration to wait before the next attempt.
func Backoff(backoff func(n int) time.Duration) Apply {
return func(r *retry) *retry {

r.backoff = backoff
return r
}
Expand Down Expand Up @@ -154,3 +153,26 @@ func (r *retry) Do(fn func() error) error {

return err
}

// DoWithResult executes the given function with configured retry behavior and returns a result.
// Works like Do() but for functions that return a value along with an error.
// On failure, returns the result from the last attempt along with the final error.
//
// Example:
//
// r := retry.New(retry.Attempts(3))
// user, err := retry.DoWithResult(r, func() (*User, error) {
// return fetchUserFromAPI(userID)
// })
// if err != nil {
// log.Printf("failed to fetch user after 3 attempts: %v", err)
// }
func DoWithResult[T any](r *retry, fn func() (T, error)) (T, error) {
var result T
err := r.Do(func() error {
var retryErr error
result, retryErr = fn()
return retryErr
})
return result, err
}
72 changes: 71 additions & 1 deletion go/pkg/retry/retry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ func TestRetry(t *testing.T) {
} else {
require.NoError(t, err)
}

})
}
}
Expand Down Expand Up @@ -222,3 +221,74 @@ func TestShouldRetry(t *testing.T) {
})
}
}

func TestDoWithResult(t *testing.T) {
tests := []struct {
name string
errorSequence []error
resultSequence []string
expectedCalls int
expectedResult string
expectedError error
expectedSleep time.Duration
}{
{
name: "should return result on first success",
errorSequence: []error{nil},
resultSequence: []string{"success"},
expectedCalls: 1,
expectedResult: "success",
expectedError: nil,
expectedSleep: 0,
},
{
name: "should return result after retries",
errorSequence: []error{errors.New("temp"), errors.New("temp"), nil},
resultSequence: []string{"", "", "success"},
expectedCalls: 3,
expectedResult: "success",
expectedError: nil,
expectedSleep: 300 * time.Millisecond,
},
{
name: "should return last result on complete failure",
errorSequence: []error{errors.New("fail1"), errors.New("fail2"), errors.New("fail3")},
resultSequence: []string{"result1", "result2", "result3"},
expectedCalls: 3,
expectedResult: "result3", // last attempt's result
expectedError: errors.New("fail3"),
expectedSleep: 300 * time.Millisecond,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
retrier := New()
totalSleep := time.Duration(0)
retrier.sleep = func(d time.Duration) {
totalSleep += d
}

calls := 0
result, err := DoWithResult(retrier, func() (string, error) {
idx := calls
calls++
if idx < len(tt.errorSequence) {
return tt.resultSequence[idx], tt.errorSequence[idx]
}
return "", nil
})

require.Equal(t, tt.expectedCalls, calls, "unexpected number of calls")
require.Equal(t, tt.expectedSleep, totalSleep, "unexpected sleep duration")
require.Equal(t, tt.expectedResult, result, "unexpected result")

if tt.expectedError != nil {
require.Error(t, err)
require.Equal(t, tt.expectedError.Error(), err.Error())
} else {
require.NoError(t, err)
}
})
}
}