diff --git a/router-tests/error_handling_test.go b/router-tests/error_handling_test.go index feca3791a8..394b712492 100644 --- a/router-tests/error_handling_test.go +++ b/router-tests/error_handling_test.go @@ -1380,7 +1380,7 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -1412,7 +1412,7 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -1444,7 +1444,7 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -1476,7 +1476,7 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ diff --git a/router-tests/go.mod b/router-tests/go.mod index bec8f63fb7..275c75610d 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -70,7 +70,7 @@ require ( github.com/docker/docker-credential-helpers v0.9.3 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/expr-lang/expr v1.17.3 // indirect + github.com/expr-lang/expr v1.17.6 // indirect github.com/fatih/color v1.18.0 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-chi/chi/v5 v5.2.2 // indirect diff --git a/router-tests/go.sum b/router-tests/go.sum index 87ab146b7f..29bbe99be7 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -85,8 +85,8 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/expr-lang/expr v1.17.3 h1:myeTTuDFz7k6eFe/JPlep/UsiIjVhG61FMHFu63U7j0= -github.com/expr-lang/expr v1.17.3/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec= +github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= diff --git a/router-tests/panic_test.go b/router-tests/panic_test.go index 153de8f02d..bb12614f44 100644 --- a/router-tests/panic_test.go +++ b/router-tests/panic_test.go @@ -48,7 +48,7 @@ func TestEnginePanic(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, }, func(t *testing.T, xEnv *testenv.Environment) { res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -80,7 +80,7 @@ func TestEnginePanic(t *testing.T) { EnableSingleFlight: true, ParseKitPoolSize: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, }, func(t *testing.T, xEnv *testenv.Environment) { res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ diff --git a/router-tests/retry_test.go b/router-tests/retry_test.go new file mode 100644 index 0000000000..cf916e15bc --- /dev/null +++ b/router-tests/retry_test.go @@ -0,0 +1,401 @@ +package integration + +import ( + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "net/http" + "strconv" + "sync/atomic" + "testing" + "time" +) + +func CreateRetryCounterFunc(counter *atomic.Int32, duration *atomic.Int64) func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + return func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + counter.Add(1) + if duration != nil { + duration.Store(int64(sleepDuration)) + } + } +} + +func TestRetry(t *testing.T) { + t.Parallel() + + t.Run("verify mutations are not retried", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + serviceCallsCounter := atomic.Int32{} + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, nil) + + maxRetryCount := 3 + expression := "true" + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + serviceCallsCounter.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `mutation updateEmployeeTag { updateEmployeeTag(id: 10, tag: "dd") { id } }`, + }) + + require.Equal(t, 0, int(onRetryCounter.Load())) + require.Equal(t, 1, int(serviceCallsCounter.Load())) + + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"updateEmployeeTag":null}}`, res.Body) + }) + + }) + + t.Run("verify no retries when expression and default check is not met", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + serviceCallsCounter := atomic.Int32{} + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, nil) + + maxRetryCount := 3 + expression := "false" + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + serviceCallsCounter.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"employees":null}}`, res.Body) + + require.Equal(t, 0, int(onRetryCounter.Load())) + require.Equal(t, 1, int(serviceCallsCounter.Load())) + }) + }) + + t.Run("verify retries when every retry results in a failure", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + serviceCallsCounter := atomic.Int32{} + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, nil) + + maxRetryCount := 3 + expression := "true" + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + serviceCallsCounter.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"employees":null}}`, res.Body) + + require.Equal(t, maxRetryCount, int(onRetryCounter.Load())) + require.Equal(t, maxRetryCount+1, int(serviceCallsCounter.Load())) + }) + }) + + t.Run("verify retries when only first n retries results in a failure", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + serviceCallsCounter := atomic.Int32{} + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, nil) + + maxRetryCount := 5 + maxAttemptsBeforeServiceSucceeds := 2 + expression := "true" + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + // When the Nth retry is executed only we want to run the request successfully + if onRetryCounter.Load() == int32(maxAttemptsBeforeServiceSucceeds) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"data":{"employees":[{"id":1},{"id":2}]}}`)) + } else { + w.WriteHeader(http.StatusBadGateway) + } + serviceCallsCounter.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + require.NoError(t, err) + require.Equal(t, `{"data":{"employees":[{"id":1},{"id":2}]}}`, res.Body) + + require.Equal(t, maxAttemptsBeforeServiceSucceeds, int(onRetryCounter.Load())) + require.Equal(t, maxAttemptsBeforeServiceSucceeds+1, int(serviceCallsCounter.Load())) + }) + }) + + t.Run("verify retry interval for 429 when a nonzero Retry-After is set", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + serviceCallsCounter := atomic.Int32{} + sleepDuration := atomic.Int64{} + + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, &sleepDuration) + + maxRetryCount := 3 + expression := "statusCode == 429" + headerRetryIntervalInSeconds := 1 + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 2000*time.Second, 100*time.Millisecond, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Retry-After", strconv.Itoa(headerRetryIntervalInSeconds)) + w.WriteHeader(http.StatusTooManyRequests) + serviceCallsCounter.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":429}}],"data":{"employees":null}}`, res.Body) + + // The service will get one extra call, in addition to the first request + require.Equal(t, maxRetryCount, int(onRetryCounter.Load())) + require.Equal(t, maxRetryCount+1, int(serviceCallsCounter.Load())) + + secondsDuration := time.Duration(headerRetryIntervalInSeconds) * time.Second + require.Equal(t, int64(secondsDuration), sleepDuration.Load()) + }) + }) + +} + +func TestFlakyRetry(t *testing.T) { + t.Parallel() + + t.Run("verify max duration is not exceeded on intervals", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, nil) + + maxRetryCount := 3 + retryInterval := 300 * time.Millisecond + maxDuration := 100 * time.Millisecond + expression := "true" + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, maxDuration, retryInterval, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + startTime := time.Now() + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + doneTime := time.Now() + + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"employees":null}}`, res.Body) + + // We subtract one from the retry count as we only care about the interval counts + requestDuration := doneTime.Sub(startTime) + + shouldBeLessThanDuration := (time.Duration(maxRetryCount-1) * retryInterval) - (20 * time.Millisecond) + require.Less(t, requestDuration, shouldBeLessThanDuration) + + // We reduce by 100 for any jitter + expectedMinDuration := (time.Duration(maxRetryCount-1) * maxDuration) - (100 * time.Millisecond) + require.GreaterOrEqual(t, requestDuration, expectedMinDuration) + }) + }) + + t.Run("Verify retry interval for 429", func(t *testing.T) { + t.Parallel() + + t.Run("when no Retry-After is set", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + serviceCallsCounter := atomic.Int32{} + sleepDuration := atomic.Int64{} + + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, &sleepDuration) + + retryInterval := 300 * time.Millisecond + maxRetryCount := 3 + expression := "statusCode == 429" + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 1000*time.Millisecond, retryInterval, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + serviceCallsCounter.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":429}}],"data":{"employees":null}}`, res.Body) + + // The service will get one extra call, in addition to the first request + require.Equal(t, maxRetryCount, int(onRetryCounter.Load())) + require.Equal(t, maxRetryCount+1, int(serviceCallsCounter.Load())) + + require.NotEqual(t, sleepDuration.Load(), int64(retryInterval)) + }) + }) + + t.Run("when zero Retry-After is set", func(t *testing.T) { + t.Parallel() + + onRetryCounter := atomic.Int32{} + serviceCallsCounter := atomic.Int32{} + sleepDuration := atomic.Int64{} + + retryCounterFunc := CreateRetryCounterFunc(&onRetryCounter, &sleepDuration) + + maxRetryCount := 3 + expression := "statusCode == 429" + emptyRetryInterval := 0 + retryInterval := 300 * time.Millisecond + + options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 1000*time.Millisecond, retryInterval, expression, retryCounterFunc) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + AccessLogFields: []config.CustomAttribute{}, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusTooManyRequests) + w.Header().Set("Retry-After", strconv.Itoa(emptyRetryInterval)) + serviceCallsCounter.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + + require.NoError(t, err) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":429}}],"data":{"employees":null}}`, res.Body) + + // The service will get one extra call, in addition to the first request + require.Equal(t, maxRetryCount, int(onRetryCounter.Load())) + require.Equal(t, maxRetryCount+1, int(serviceCallsCounter.Load())) + + require.NotEqual(t, sleepDuration.Load(), int64(retryInterval)) + }) + }) + }) +} diff --git a/router-tests/structured_logging_test.go b/router-tests/structured_logging_test.go index 82bfd1d9ff..7464fa26c9 100644 --- a/router-tests/structured_logging_test.go +++ b/router-tests/structured_logging_test.go @@ -709,7 +709,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -828,7 +828,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -960,7 +960,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -1096,7 +1096,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -2211,7 +2211,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, diff --git a/router-tests/telemetry/telemetry_test.go b/router-tests/telemetry/telemetry_test.go index 58dc2be217..2b4987c0b7 100644 --- a/router-tests/telemetry/telemetry_test.go +++ b/router-tests/telemetry/telemetry_test.go @@ -8706,7 +8706,7 @@ func TestFlakyTelemetry(t *testing.T) { }, }, RouterOptions: []core.Option{ - core.WithSubgraphRetryOptions(false, 0, 0, 0), + core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), }, Subgraphs: testenv.SubgraphsConfig{ Products: testenv.SubgraphConfig{ diff --git a/router/core/context.go b/router/core/context.go index fee849f778..26f5bcdc7c 100644 --- a/router/core/context.go +++ b/router/core/context.go @@ -634,15 +634,6 @@ func (o *operationContext) QueryPlanStats() (QueryPlanStats, error) { return qps, nil } -// isMutationRequest returns true if the current request is a mutation request -func isMutationRequest(ctx context.Context) bool { - op := getRequestContext(ctx) - if op == nil { - return false - } - return op.Operation().Type() == "mutation" -} - type SubgraphResolver struct { subgraphsByURL map[string]*Subgraph subgraphsByID map[string]*Subgraph diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 380b8f0573..875b0969d5 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -40,7 +40,6 @@ import ( rmiddleware "github.com/wundergraph/cosmo/router/internal/middleware" "github.com/wundergraph/cosmo/router/internal/recoveryhandler" "github.com/wundergraph/cosmo/router/internal/requestlogger" - "github.com/wundergraph/cosmo/router/internal/retrytransport" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/cors" "github.com/wundergraph/cosmo/router/pkg/execution_config" @@ -1156,6 +1155,12 @@ func (s *graphServer) buildGraphMux( baseConnMetricStore = s.connectionMetrics } + // Build retry options and handle any expression compilation errors + processedRetryOptions, err := ProcessRetryOptions(s.retryOptions) + if err != nil { + return nil, fmt.Errorf("failed to process retry options: %w", err) + } + ecb := &ExecutorConfigurationBuilder{ introspection: s.introspection, baseURL: s.baseURL, @@ -1171,20 +1176,12 @@ func (s *graphServer) buildGraphMux( FrameTimeout: s.engineExecutionConfiguration.WebSocketClientFrameTimeout, }, transportOptions: &TransportOptions{ - SubgraphTransportOptions: s.subgraphTransportOptions, - PreHandlers: s.preOriginHandlers, - PostHandlers: s.postOriginHandlers, - MetricStore: gm.metricStore, - ConnectionMetricStore: baseConnMetricStore, - RetryOptions: retrytransport.RetryOptions{ - Enabled: s.retryOptions.Enabled, - MaxRetryCount: s.retryOptions.MaxRetryCount, - MaxDuration: s.retryOptions.MaxDuration, - Interval: s.retryOptions.Interval, - ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool { - return retrytransport.IsRetryableError(err, resp) && !isMutationRequest(req.Context()) - }, - }, + SubgraphTransportOptions: s.subgraphTransportOptions, + PreHandlers: s.preOriginHandlers, + PostHandlers: s.postOriginHandlers, + MetricStore: gm.metricStore, + ConnectionMetricStore: baseConnMetricStore, + RetryOptions: *processedRetryOptions, TracerProvider: s.tracerProvider, TracePropagators: s.compositePropagator, LocalhostFallbackInsideDocker: s.localhostFallbackInsideDocker, diff --git a/router/core/retry_builder.go b/router/core/retry_builder.go new file mode 100644 index 0000000000..30fffcb53f --- /dev/null +++ b/router/core/retry_builder.go @@ -0,0 +1,123 @@ +package core + +import ( + "fmt" + "net/http" + "strings" + + "github.com/wundergraph/cosmo/router/internal/expr" + "github.com/wundergraph/cosmo/router/internal/retrytransport" + "go.uber.org/zap" +) + +const ( + defaultRetryExpression = "IsRetryableStatusCode() || IsConnectionError() || IsTimeout()" + + backoffJitter = "backoff_jitter" +) + +var noopRetryFunc = func(err error, req *http.Request, resp *http.Response) bool { + return false +} + +func ProcessRetryOptions(retryOpts retrytransport.RetryOptions) (*retrytransport.RetryOptions, error) { + // Default to backOffJitter if no algorithm is specified + // This will occur either in tests or if the user explicitly makes it an empty string + if retryOpts.Algorithm == "" { + retryOpts.Algorithm = backoffJitter + } + + // We skip validating the algorithm if retries are disabled + if retryOpts.Enabled && retryOpts.Algorithm != backoffJitter { + return nil, fmt.Errorf("unsupported retry algorithm: %s", retryOpts.Algorithm) + } + + shouldRetryFunc, err := buildRetryFunction(retryOpts) + if err != nil { + return nil, fmt.Errorf("failed to build retry function: %w", err) + } + + // Create copy to not mutate the original reference + retryOptions := retrytransport.RetryOptions{ + Enabled: retryOpts.Enabled, + Algorithm: retryOpts.Algorithm, + MaxRetryCount: retryOpts.MaxRetryCount, + MaxDuration: retryOpts.MaxDuration, + Interval: retryOpts.Interval, + Expression: retryOpts.Expression, + + OnRetry: retryOpts.OnRetry, + + ShouldRetry: shouldRetryFunc, + } + + return &retryOptions, nil +} + +// BuildRetryFunction creates a ShouldRetry function based on the provided expression +func buildRetryFunction(retryOpts retrytransport.RetryOptions) (retrytransport.ShouldRetryFunc, error) { + // We do not need to build a retry function if retries are disabled + // This means that any bad expressions are ignored if retries are disabled + if !retryOpts.Enabled { + return noopRetryFunc, nil + } + + // Use default expression if empty string is passed + expression := retryOpts.Expression + if expression == "" { + expression = defaultRetryExpression + } + + // Create the retry expression manager + manager, err := expr.NewRetryExpressionManager(expression) + if err != nil { + return nil, fmt.Errorf("failed to create expression manager: %w", err) + } + + // Return expression-based retry function + return func(err error, req *http.Request, resp *http.Response) bool { + reqContext := getRequestContext(req.Context()) + + if reqContext == nil { + return false + } + + // Never retry mutations, regardless of expression result + if strings.ToLower(reqContext.Operation().Type()) == "mutation" { + return false + } + + if isDefaultRetryableError(err) { + return true + } + + // Create retry context + ctx := expr.LoadRetryContext(err, resp) + + // Evaluate the expression + shouldRetry, evalErr := manager.ShouldRetry(ctx) + if evalErr != nil { + reqContext.logger.Error("Failed to evaluate retry expression", + zap.Error(evalErr), + zap.String("expression", expression), + ) + + // Disable retries on evaluation error + return false + } + + return shouldRetry + }, nil +} + +// isDefaultRetryableError checks for errors that should always be retryable +// regardless of the configured retry expression +func isDefaultRetryableError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + // EOF errors are always retryable as they indicate connection issues + return strings.Contains(errStr, "unexpected eof") +} diff --git a/router/core/retry_builder_test.go b/router/core/retry_builder_test.go new file mode 100644 index 0000000000..d4fa1cf3f6 --- /dev/null +++ b/router/core/retry_builder_test.go @@ -0,0 +1,481 @@ +package core + +import ( + "errors" + "fmt" + "github.com/wundergraph/cosmo/router/internal/retrytransport" + "io" + "net/http" + "reflect" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/zap" +) + +// Helper functions for creating proper request contexts + +func createOperationContext(opType string) *operationContext { + return &operationContext{ + name: "TestOperation", + opType: opType, + hash: 12345, + content: "test content", + } +} + +func createRequestWithContext(opType string) (*http.Request, *requestContext) { + req, _ := http.NewRequest("POST", "http://example.com/graphql", nil) + logger := zap.NewNop() + + // Create operation context + operationCtx := createOperationContext(opType) + + // Create request context using the buildRequestContext function + reqCtx := buildRequestContext(requestContextOptions{ + operationContext: operationCtx, + requestLogger: logger, + metricsEnabled: false, + traceEnabled: false, + mapper: &attributeMapper{}, + w: nil, + r: req, + }) + + // Attach the request context to the Go context + ctx := withRequestContext(req.Context(), reqCtx) + req = req.WithContext(ctx) + + return req, reqCtx +} + +func TestBuildRetryFunction(t *testing.T) { + t.Run("build function when retry is disabled", func(t *testing.T) { + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: false, + Expression: "invalid expression ++++++", + }) + assert.NoError(t, err) + assert.Equal(t, + reflect.ValueOf(noopRetryFunc).Pointer(), + reflect.ValueOf(fn).Pointer(), + ) + }) + + t.Run("default expression behavior", func(t *testing.T) { + // Use the default expression that would be in the config + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: defaultRetryExpression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Test default behavior - should retry on 500 + resp := &http.Response{StatusCode: 500} + assert.True(t, fn(nil, req, resp)) + + // Should not retry on 200 + resp.StatusCode = 200 + assert.False(t, fn(nil, req, resp)) + + // Test with errors - only expression-defined errors are handled here + assert.True(t, fn(syscall.ETIMEDOUT, req, nil)) + assert.True(t, fn(errors.New("connection refused"), req, nil)) + assert.True(t, fn(errors.New("unexpected EOF"), req, nil)) // EOF is now handled at transport layer, not expression + assert.False(t, fn(errors.New("some other error"), req, nil)) + }) + + t.Run("expression-based retry", func(t *testing.T) { + expression := "statusCode == 500 || statusCode == 503" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Should retry on 500 + resp := &http.Response{StatusCode: 500} + assert.True(t, fn(nil, req, resp)) + + // Should retry on 503 + resp.StatusCode = 503 + assert.True(t, fn(nil, req, resp)) + + // Should not retry on 502 + resp.StatusCode = 502 + assert.False(t, fn(nil, req, resp)) + }) + + t.Run("expression with error conditions", func(t *testing.T) { + expression := "IsTimeout() || statusCode == 503" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Should retry on timeout error + err = syscall.ETIMEDOUT + assert.True(t, fn(err, req, nil)) + + // Should retry on 503 + resp := &http.Response{StatusCode: 503} + assert.True(t, fn(nil, req, resp)) + + // Should not retry on other errors + err = errors.New("some other error") + assert.False(t, fn(err, req, nil)) + }) + + t.Run("invalid expression returns error", func(t *testing.T) { + expression := "invalid syntax +++" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.Error(t, err) + assert.Nil(t, fn) + assert.Contains(t, err.Error(), "failed to compile retry expression") + }) + + t.Run("empty expression uses default", func(t *testing.T) { + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: "", + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Test with retryable status code + resp := &http.Response{StatusCode: 502} + assert.True(t, fn(nil, req, resp)) + + // Test with connection error + err = errors.New("connection refused") + assert.True(t, fn(err, req, nil)) + + // Test with timeout error + err = syscall.ETIMEDOUT + assert.True(t, fn(err, req, nil)) + + // Test with non-retryable error + err = errors.New("some other error") + assert.False(t, fn(err, req, nil)) + }) + + t.Run("expression that always returns false but the error is an eof error", func(t *testing.T) { + expression := "false" // Don't retry + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + + assert.True(t, fn(io.ErrUnexpectedEOF, req, nil)) + }) + + t.Run("expression that always returns true", func(t *testing.T) { + expression := "true" // Always retry + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + resp := &http.Response{StatusCode: 500} + + // Should retry when expression is true + assert.True(t, fn(nil, req, resp)) + + // Even for status codes that wouldn't normally retry + resp.StatusCode = 200 + assert.True(t, fn(nil, req, resp)) + }) + + t.Run("complex expression", func(t *testing.T) { + expression := "(statusCode >= 500 && statusCode < 600) || IsConnectionError()" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Test 5xx errors + resp := &http.Response{StatusCode: 503} + assert.True(t, fn(nil, req, resp)) + + // Test connection error + err = errors.New("connection refused") + assert.True(t, fn(err, req, nil)) + + // Test non-matching conditions + resp.StatusCode = 404 + err = errors.New("some other error") + assert.False(t, fn(err, req, resp)) + }) + + t.Run("mutation never retries with proper context", func(t *testing.T) { + // Use expression that would normally retry on 500 errors + expression := "statusCode >= 500 || IsTimeout() || IsConnectionError()" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with mutation context + req, _ := createRequestWithContext(OperationTypeMutation) + + // Test with 500 status - should NOT retry because it's a mutation + resp := &http.Response{StatusCode: 500} + assert.False(t, fn(nil, req, resp)) + + // Test with timeout error - should NOT retry because it's a mutation + assert.False(t, fn(syscall.ETIMEDOUT, req, nil)) + + // Test with connection error - should NOT retry because it's a mutation + assert.False(t, fn(errors.New("connection refused"), req, nil)) + + // Test with expression that always returns true - should still NOT retry + alwaysRetryFn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: "true", + }) + assert.NoError(t, err) + assert.False(t, alwaysRetryFn(nil, req, resp)) + }) + + t.Run("query retries with proper context", func(t *testing.T) { + expression := "statusCode >= 500 || IsTimeout()" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with query context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Test with 500 status - should retry because it's a query + resp := &http.Response{StatusCode: 500} + assert.True(t, fn(nil, req, resp)) + + // Test with timeout error - should retry because it's a query + assert.True(t, fn(syscall.ETIMEDOUT, req, nil)) + + // Test with 200 status - should not retry even for query + resp.StatusCode = 200 + assert.False(t, fn(nil, req, resp)) + }) + + t.Run("subscription retries with proper context", func(t *testing.T) { + expression := "statusCode >= 500" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with subscription context + req, _ := createRequestWithContext(OperationTypeSubscription) + + // Test with 500 status - should retry because it's a subscription (not mutation) + resp := &http.Response{StatusCode: 500} + assert.True(t, fn(nil, req, resp)) + + // Test with 200 status - should not retry + resp.StatusCode = 200 + assert.False(t, fn(nil, req, resp)) + }) + + t.Run("error logging with proper context", func(t *testing.T) { + // Test that error logging works with proper request context + expression := "statusCode >= 500" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create request with proper context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Test that it works normally with proper context + resp := &http.Response{StatusCode: 500} + assert.True(t, fn(nil, req, resp)) + + resp.StatusCode = 200 + assert.False(t, fn(nil, req, resp)) + }) + + t.Run("request context with query operation", func(t *testing.T) { + expression := "statusCode >= 500" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create request with proper query context + req, _ := createRequestWithContext(OperationTypeQuery) + + // Should work with proper request context - expression should be evaluated normally + resp := &http.Response{StatusCode: 500} + assert.True(t, fn(nil, req, resp)) + + resp.StatusCode = 200 + assert.False(t, fn(nil, req, resp)) + }) + + t.Run("complex expression with mutation context", func(t *testing.T) { + // Complex expression that would normally retry in many cases + expression := "(statusCode >= 500 && statusCode < 600) || IsConnectionError() || IsTimeout() || statusCode == 429" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Create a request with mutation context + req, _ := createRequestWithContext(OperationTypeMutation) + + // Test various conditions that would normally trigger retry + resp := &http.Response{StatusCode: 500} + assert.False(t, fn(nil, req, resp)) + + resp.StatusCode = 503 + assert.False(t, fn(nil, req, resp)) + + resp.StatusCode = 429 + assert.False(t, fn(nil, req, resp)) + + assert.False(t, fn(syscall.ETIMEDOUT, req, nil)) + assert.False(t, fn(errors.New("connection refused"), req, nil)) + }) + + t.Run("new operation with comprehensive retry conditions", func(t *testing.T) { + // Create a new comprehensive operation to test all retry scenarios + expression := "statusCode >= 500 || statusCode == 429 || IsTimeout() || IsConnectionError()" + fn, err := buildRetryFunction(retrytransport.RetryOptions{ + Enabled: true, + Expression: expression, + }) + assert.NoError(t, err) + assert.NotNil(t, fn) + + // Test query operation - should retry on all conditions + req, _ := createRequestWithContext(OperationTypeQuery) + + // Test 5xx errors + resp := &http.Response{StatusCode: 500} + assert.True(t, fn(nil, req, resp)) + resp.StatusCode = 503 + assert.True(t, fn(nil, req, resp)) + + // Test rate limiting + resp.StatusCode = 429 + assert.True(t, fn(nil, req, resp)) + + // Test timeouts + assert.True(t, fn(syscall.ETIMEDOUT, req, nil)) + + // Test connection errors + assert.True(t, fn(errors.New("connection refused"), req, nil)) + + // Test success - should not retry + resp.StatusCode = 200 + assert.False(t, fn(nil, req, resp)) + + // Test client errors - should not retry + resp.StatusCode = 404 + assert.False(t, fn(nil, req, resp)) + + // Now test the same conditions with a mutation - should never retry + mutationReq, _ := createRequestWithContext(OperationTypeMutation) + + resp.StatusCode = 500 + assert.False(t, fn(nil, mutationReq, resp)) + resp.StatusCode = 503 + assert.False(t, fn(nil, mutationReq, resp)) + resp.StatusCode = 429 + assert.False(t, fn(nil, mutationReq, resp)) + assert.False(t, fn(syscall.ETIMEDOUT, mutationReq, nil)) + assert.False(t, fn(errors.New("connection refused"), mutationReq, nil)) + }) +} + +func TestProcessRetryOptions(t *testing.T) { + t.Run("process invalid algorithm", func(t *testing.T) { + algorithm := "abcdee" + _, err := ProcessRetryOptions(retrytransport.RetryOptions{ + Enabled: true, + Algorithm: algorithm, + }) + + expectedError := fmt.Sprintf("unsupported retry algorithm: %s", algorithm) + assert.ErrorContains(t, err, expectedError) + }) + + t.Run("process invalid algorithm when retries are disabled", func(t *testing.T) { + algorithm := "abcdee" + _, err := ProcessRetryOptions(retrytransport.RetryOptions{ + Enabled: false, + Algorithm: algorithm, + }) + assert.NoError(t, err) + }) + + t.Run("process invalid expression", func(t *testing.T) { + _, err := ProcessRetryOptions(retrytransport.RetryOptions{ + Enabled: true, + Algorithm: "backoff_jitter", + Expression: "invalid syntax +++", + }) + + assert.ErrorContains(t, err, "failed to build retry function") + }) + + t.Run("process valid options", func(t *testing.T) { + options := retrytransport.RetryOptions{ + Enabled: true, + Algorithm: "backoff_jitter", + Expression: "statusCode == 500 || IsTimeout() || IsConnectionError()", + } + response, err := ProcessRetryOptions(options) + assert.NoError(t, err) + assert.NotSame(t, &options, response) + }) +} diff --git a/router/core/router.go b/router/core/router.go index e5388aeeb7..f5bed12df7 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -1752,13 +1752,25 @@ func WithSubgraphCircuitBreakerOptions(opts *SubgraphCircuitBreakerOptions) Opti } } -func WithSubgraphRetryOptions(enabled bool, maxRetryCount int, retryMaxDuration, retryInterval time.Duration) Option { +func WithSubgraphRetryOptions( + enabled bool, + algorithm string, + maxRetryCount int, + retryMaxDuration, retryInterval time.Duration, + expression string, + onRetryFunc retrytransport.OnRetryFunc, +) Option { return func(r *Router) { r.retryOptions = retrytransport.RetryOptions{ Enabled: enabled, + Algorithm: algorithm, MaxRetryCount: maxRetryCount, MaxDuration: retryMaxDuration, Interval: retryInterval, + Expression: expression, + + // Test case overrides + OnRetry: onRetryFunc, } } } diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 9fab3f93bc..09605f6d9e 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -195,9 +195,12 @@ func optionsFromResources(logger *zap.Logger, config *config.Config) []Option { WithSubgraphCircuitBreakerOptions(NewSubgraphCircuitBreakerOptions(config.TrafficShaping)), WithSubgraphRetryOptions( config.TrafficShaping.All.BackoffJitterRetry.Enabled, + config.TrafficShaping.All.BackoffJitterRetry.Algorithm, config.TrafficShaping.All.BackoffJitterRetry.MaxAttempts, config.TrafficShaping.All.BackoffJitterRetry.MaxDuration, config.TrafficShaping.All.BackoffJitterRetry.Interval, + config.TrafficShaping.All.BackoffJitterRetry.Expression, + nil, ), WithCors(&cors.Config{ Enabled: config.CORS.Enabled, diff --git a/router/go.mod b/router/go.mod index 4954a40fe1..c8302087bc 100644 --- a/router/go.mod +++ b/router/go.mod @@ -64,7 +64,7 @@ require ( github.com/caarlos0/env/v11 v11.3.1 github.com/cep21/circuit/v4 v4.0.0 github.com/dgraph-io/ristretto/v2 v2.1.0 - github.com/expr-lang/expr v1.17.3 + github.com/expr-lang/expr v1.17.6 github.com/goccy/go-json v0.10.3 github.com/google/go-containerregistry v0.20.3 github.com/google/uuid v1.6.0 diff --git a/router/go.sum b/router/go.sum index a72296f7de..a6ad50dc3e 100644 --- a/router/go.sum +++ b/router/go.sum @@ -72,8 +72,8 @@ github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4 github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/expr-lang/expr v1.17.3 h1:myeTTuDFz7k6eFe/JPlep/UsiIjVhG61FMHFu63U7j0= -github.com/expr-lang/expr v1.17.3/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= +github.com/expr-lang/expr v1.17.6 h1:1h6i8ONk9cexhDmowO/A64VPxHScu7qfSl2k8OlINec= +github.com/expr-lang/expr v1.17.6/go.mod h1:8/vRC7+7HBzESEqt5kKpYXxrxkr31SaO8r40VO/1IT4= github.com/fatih/color v1.13.0/go.mod h1:kLAiJbzzSOZDVNGyDpeOxJ47H46qBXwg5ILebYFFOfk= github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= diff --git a/router/internal/expr/retry_context.go b/router/internal/expr/retry_context.go new file mode 100644 index 0000000000..7229180af5 --- /dev/null +++ b/router/internal/expr/retry_context.go @@ -0,0 +1,146 @@ +package expr + +import ( + "errors" + "net" + "net/http" + "os" + "strings" + "syscall" +) + +// RetryContext is the context for retry expressions +type RetryContext struct { + StatusCode int `expr:"statusCode"` + Error string `expr:"error"` + // originalError stores the original error for proper type checking + // This field is not exposed to expressions + originalError error +} + +// IsHttpReadTimeout returns true if the error is an HTTP-specific timeout +// waiting for response headers from the server. +func (ctx RetryContext) IsHttpReadTimeout() bool { + // Only check for HTTP-specific timeout awaiting response headers + if ctx.Error != "" { + errLower := strings.ToLower(ctx.Error) + return strings.Contains(errLower, "timeout awaiting response headers") + } + + return false +} + +// IsTimeout returns true if the error is any type of timeout error, +// including HTTP read timeouts, network timeouts, deadline exceeded errors, +// or direct syscall timeout errors. +func (ctx RetryContext) IsTimeout() bool { + // Check for HTTP-specific read timeouts + if ctx.IsHttpReadTimeout() { + return true + } + + // Check for net package timeout errors using the standard Go method + if ctx.originalError != nil { + var netErr net.Error + if errors.As(ctx.originalError, &netErr) && netErr.Timeout() { + return true + } + // Check for deadline exceeded errors + if errors.Is(ctx.originalError, os.ErrDeadlineExceeded) { + return true + } + // Also check for direct syscall timeout errors not wrapped in net.Error + if errors.Is(ctx.originalError, syscall.ETIMEDOUT) { + return true + } + } + + return false +} + +// IsConnectionError returns true if the error is a connection-related error, +// including connection refused, connection reset, DNS resolution failures, +// or TLS handshake errors. +func (ctx RetryContext) IsConnectionError() bool { + // Use existing helpers for specific connection errors + if ctx.IsConnectionRefused() || ctx.IsConnectionReset() { + return true + } + + // Fall back to string matching for other connection errors not covered by specific helpers + if ctx.Error != "" { + errLower := strings.ToLower(ctx.Error) + return strings.Contains(errLower, "no such host") || + strings.Contains(errLower, "handshake failure") || + strings.Contains(errLower, "handshake timeout") + } + + return false +} + +// IsRetryableStatusCode returns true if the HTTP status code is generally +// considered retryable, including 500, 502, 503, and 504. +func (ctx RetryContext) IsRetryableStatusCode() bool { + switch ctx.StatusCode { + case http.StatusInternalServerError, + http.StatusBadGateway, + http.StatusServiceUnavailable, + http.StatusGatewayTimeout: + return true + default: + return false + } +} + +// IsConnectionRefused returns true if the error is specifically a connection +// refused error (ECONNREFUSED), either through direct syscall error checking +// or string matching. +func (ctx RetryContext) IsConnectionRefused() bool { + if ctx.originalError != nil && errors.Is(ctx.originalError, syscall.ECONNREFUSED) { + return true + } + + // Fall back to string matching + if ctx.Error != "" { + errLower := strings.ToLower(ctx.Error) + return strings.Contains(errLower, "connection refused") + } + + return false +} + +// IsConnectionReset returns true if the error is specifically a connection +// reset error (ECONNRESET), either through direct syscall error checking +// or string matching. +func (ctx RetryContext) IsConnectionReset() bool { + if ctx.originalError != nil && errors.Is(ctx.originalError, syscall.ECONNRESET) { + return true + } + + // Fall back to string matching + if ctx.Error != "" { + errLower := strings.ToLower(ctx.Error) + return strings.Contains(errLower, "connection reset") + } + + return false +} + +// LoadRetryContext creates a RetryContext from the given error and HTTP response. +// It extracts the error message and status code to make them available for +// retry condition evaluation in expressions. +func LoadRetryContext(err error, resp *http.Response) RetryContext { + ctx := RetryContext{ + originalError: err, + } + + if err != nil { + ctx.Error = err.Error() + } + + if resp != nil { + ctx.StatusCode = resp.StatusCode + } + + return ctx +} diff --git a/router/internal/expr/retry_expression.go b/router/internal/expr/retry_expression.go new file mode 100644 index 0000000000..4414d29181 --- /dev/null +++ b/router/internal/expr/retry_expression.go @@ -0,0 +1,56 @@ +package expr + +import ( + "fmt" + "reflect" + + "github.com/expr-lang/expr" + "github.com/expr-lang/expr/vm" +) + +// RetryExpressionManager handles compilation and evaluation of retry expressions +type RetryExpressionManager struct { + program *vm.Program +} + +// NewRetryExpressionManager creates a new RetryExpressionManager with the given expression +func NewRetryExpressionManager(expression string) (*RetryExpressionManager, error) { + if expression == "" { + return nil, nil + } + + // Compile the expression with retry context + options := []expr.Option{ + expr.Env(RetryContext{}), + expr.AsKind(reflect.Bool), + } + + program, err := expr.Compile(expression, options...) + if err != nil { + return nil, fmt.Errorf("failed to compile retry expression: %w", handleExpressionError(err)) + } + + return &RetryExpressionManager{ + program: program, + }, nil +} + +// ShouldRetry evaluates the retry expression with the given context +func (m *RetryExpressionManager) ShouldRetry(ctx RetryContext) (bool, error) { + if m == nil || m.program == nil { + // Use default behavior if no expression is configured + return false, nil + } + + result, err := expr.Run(m.program, ctx) + if err != nil { + return false, fmt.Errorf("failed to evaluate retry expression: %w", handleExpressionError(err)) + } + + shouldRetry, ok := result.(bool) + if !ok { + return false, fmt.Errorf("retry expression must return a boolean, got %T", result) + } + + return shouldRetry, nil +} diff --git a/router/internal/expr/retry_expression_test.go b/router/internal/expr/retry_expression_test.go new file mode 100644 index 0000000000..04a8c8c268 --- /dev/null +++ b/router/internal/expr/retry_expression_test.go @@ -0,0 +1,549 @@ +package expr + +import ( + "errors" + "fmt" + "net" + "net/http" + "os" + "syscall" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRetryExpressionManager(t *testing.T) { + tests := []struct { + name string + expression string + ctx RetryContext + expected bool + expectErr bool + }{ + { + name: "status code exact match", + expression: "statusCode == 500", + ctx: RetryContext{StatusCode: 500}, + expected: true, + }, + { + name: "status code no match", + expression: "statusCode == 500", + ctx: RetryContext{StatusCode: 200}, + expected: false, + }, + { + name: "OR condition - first true", + expression: "statusCode == 500 || statusCode == 502", + ctx: RetryContext{StatusCode: 500}, + expected: true, + }, + { + name: "OR condition - second true", + expression: "statusCode == 500 || statusCode == 502", + ctx: RetryContext{StatusCode: 502}, + expected: true, + }, + { + name: "OR condition - both false", + expression: "statusCode == 500 || statusCode == 502", + ctx: RetryContext{StatusCode: 200}, + expected: false, + }, + { + name: "IsHttpReadTimeout helper function", + expression: "IsHttpReadTimeout()", + ctx: RetryContext{Error: "timeout awaiting response headers"}, + expected: true, + }, + { + name: "IsHttpReadTimeout with different error", + expression: "IsHttpReadTimeout()", + ctx: RetryContext{Error: "connection refused"}, + expected: false, + }, + { + name: "IsTimeout helper function", + expression: "IsTimeout()", + ctx: LoadRetryContext(&mockTimeoutError{msg: "net timeout", timeout: true}, nil), + expected: true, + }, + { + name: "IsTimeout helper function wrapped", + expression: "IsTimeout()", + ctx: LoadRetryContext( + fmt.Errorf("wrapped error: %w", &mockTimeoutError{msg: "net timeout", timeout: true}), nil), + expected: true, + }, + { + name: "complex expression with helpers", + expression: "statusCode == 500 || IsTimeout()", + ctx: LoadRetryContext(&mockTimeoutError{msg: "net timeout", timeout: true}, &http.Response{StatusCode: 200}), + expected: true, + }, + { + name: "isConnectionError helper", + expression: "IsConnectionError()", + ctx: RetryContext{Error: "connection refused"}, + expected: true, + }, + + { + name: "isRetryableStatusCode helper", + expression: "IsRetryableStatusCode()", + ctx: RetryContext{StatusCode: 429}, + expected: false, + }, + { + name: "range check", + expression: "statusCode >= 500 && statusCode < 600", + ctx: RetryContext{StatusCode: 503}, + expected: true, + }, + { + name: "error string contains", + expression: `error contains "timeout"`, + ctx: RetryContext{Error: "request timeout occurred"}, + expected: true, + }, + { + name: "error string exact match", + expression: `error == "connection refused"`, + ctx: RetryContext{Error: "connection refused"}, + expected: true, + }, + { + name: "invalid expression", + expression: "invalid syntax +++", + expectErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + manager, err := NewRetryExpressionManager(tt.expression) + if tt.expectErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + require.NotNil(t, manager) + + result, err := manager.ShouldRetry(tt.ctx) + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestRetryExpressionManager_EmptyExpression(t *testing.T) { + manager, err := NewRetryExpressionManager("") + assert.NoError(t, err) + assert.Nil(t, manager) +} + +func TestLoadRetryContext(t *testing.T) { + t.Run("with error and response", func(t *testing.T) { + err := errors.New("connection timeout") + resp := &http.Response{StatusCode: 500} + + ctx := LoadRetryContext(err, resp) + + assert.Equal(t, "connection timeout", ctx.Error) + assert.Equal(t, 500, ctx.StatusCode) + }) + + t.Run("with only error", func(t *testing.T) { + err := errors.New("network error") + + ctx := LoadRetryContext(err, nil) + + assert.Equal(t, "network error", ctx.Error) + assert.Equal(t, 0, ctx.StatusCode) + }) + + t.Run("with only response", func(t *testing.T) { + resp := &http.Response{StatusCode: 503} + + ctx := LoadRetryContext(nil, resp) + + assert.Equal(t, "", ctx.Error) + assert.Equal(t, 503, ctx.StatusCode) + }) + + t.Run("with neither error nor response", func(t *testing.T) { + ctx := LoadRetryContext(nil, nil) + + assert.Equal(t, "", ctx.Error) + assert.Equal(t, 0, ctx.StatusCode) + }) +} + +func TestRetryContext_SyscallErrorDetection(t *testing.T) { + t.Run("IsConnectionRefused", func(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "direct ECONNREFUSED", + err: syscall.ECONNREFUSED, + expected: true, + }, + { + name: "wrapped ECONNREFUSED", + err: fmt.Errorf("connection failed: %w", syscall.ECONNREFUSED), + expected: true, + }, + { + name: "ECONNREFUSED in net.OpError", + err: &net.OpError{ + Err: &os.SyscallError{ + Err: syscall.ECONNREFUSED, + }, + }, + expected: true, + }, + { + name: "string fallback - connection refused", + err: errors.New("connection refused"), + expected: true, + }, + { + name: "string fallback - mixed case", + err: errors.New("Connection Refused by server"), + expected: true, + }, + { + name: "different error", + err: syscall.ECONNRESET, + expected: false, + }, + { + name: "nil error", + err: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := LoadRetryContext(tt.err, nil) + result := ctx.IsConnectionRefused() + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("IsConnectionReset", func(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "direct ECONNRESET", + err: syscall.ECONNRESET, + expected: true, + }, + { + name: "wrapped ECONNRESET", + err: fmt.Errorf("network error: %w", syscall.ECONNRESET), + expected: true, + }, + { + name: "ECONNRESET in net.OpError", + err: &net.OpError{ + Err: &os.SyscallError{ + Err: syscall.ECONNRESET, + }, + }, + expected: true, + }, + { + name: "string fallback - connection reset", + err: errors.New("connection reset by peer"), + expected: true, + }, + { + name: "string fallback - mixed case", + err: errors.New("Connection Reset By Peer"), + expected: true, + }, + { + name: "different error", + err: syscall.ECONNREFUSED, + expected: false, + }, + { + name: "nil error", + err: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := LoadRetryContext(tt.err, nil) + result := ctx.IsConnectionReset() + assert.Equal(t, tt.expected, result) + }) + } + }) +} + +func TestRetryContext_ImprovedErrorDetection(t *testing.T) { + t.Run("IsConnectionError with syscall errors", func(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "ECONNREFUSED detected", + err: syscall.ECONNREFUSED, + expected: true, + }, + { + name: "ECONNRESET detected", + err: syscall.ECONNRESET, + expected: true, + }, + { + name: "wrapped ECONNREFUSED detected", + err: fmt.Errorf("dial error: %w", syscall.ECONNREFUSED), + expected: true, + }, + { + name: "string fallback still works", + err: errors.New("no such host"), + expected: true, + }, + { + name: "ETIMEDOUT not detected by IsConnectionError", + err: syscall.ETIMEDOUT, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := LoadRetryContext(tt.err, nil) + result := ctx.IsConnectionError() + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("IsTimeout with syscall errors", func(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "ETIMEDOUT detected", + err: syscall.ETIMEDOUT, + expected: true, + }, + { + name: "wrapped ETIMEDOUT detected", + err: fmt.Errorf("read error: %w", syscall.ETIMEDOUT), + expected: true, + }, + { + name: "ETIMEDOUT in net.OpError detected", + err: &net.OpError{ + Err: &os.SyscallError{ + Err: syscall.ETIMEDOUT, + }, + }, + expected: true, + }, + { + name: "i/o timeout string not detected (no string matching)", + err: errors.New("i/o timeout"), + expected: false, + }, + { + name: "operation timed out string not detected (no string matching)", + err: errors.New("operation timed out"), + expected: false, + }, + { + name: "connection errors not detected by IsTimeout", + err: syscall.ECONNREFUSED, + expected: false, + }, + { + name: "deadline exceeded error should be detected as timeout", + err: os.ErrDeadlineExceeded, + expected: true, + }, + { + name: "HTTP read timeout detected by IsTimeout", + err: errors.New("timeout awaiting response headers"), + expected: true, + }, + { + name: "non-timeout error not detected", + err: errors.New("some other error"), + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := LoadRetryContext(tt.err, nil) + result := ctx.IsTimeout() + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("IsHttpReadTimeout with specific HTTP timeout", func(t *testing.T) { + tests := []struct { + name string + err error + expected bool + }{ + { + name: "HTTP timeout awaiting response headers", + err: errors.New("timeout awaiting response headers"), + expected: true, + }, + { + name: "HTTP timeout awaiting response headers mixed case", + err: errors.New("Timeout Awaiting Response Headers"), + expected: true, + }, + { + name: "ETIMEDOUT not detected by IsHttpReadTimeout", + err: syscall.ETIMEDOUT, + expected: false, + }, + { + name: "nil error", + err: nil, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := LoadRetryContext(tt.err, nil) + result := ctx.IsHttpReadTimeout() + assert.Equal(t, tt.expected, result) + }) + } + }) + + t.Run("IsTimeout with net timeout errors", func(t *testing.T) { + // Mock net timeout error + mockNetTimeoutErr := &mockTimeoutError{msg: "net timeout error", timeout: true} + mockNetNonTimeoutErr := &mockTimeoutError{msg: "net regular error", timeout: false} + + tests := []struct { + name string + err error + expected bool + }{ + { + name: "net timeout error detected", + err: mockNetTimeoutErr, + expected: true, + }, + { + name: "net non-timeout error not detected", + err: mockNetNonTimeoutErr, + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := LoadRetryContext(tt.err, nil) + result := ctx.IsTimeout() + assert.Equal(t, tt.expected, result) + }) + } + }) +} + +// mockTimeoutError implements net.Error interface for testing +type mockTimeoutError struct { + msg string + timeout bool +} + +func (e *mockTimeoutError) Error() string { + return e.msg +} + +func (e *mockTimeoutError) Timeout() bool { + return e.timeout +} + +func (e *mockTimeoutError) Temporary() bool { + return false // Not a temporary error for this test +} + +func TestRetryExpressionManager_WithSyscallErrors(t *testing.T) { + t.Run("expression with specific syscall error helpers", func(t *testing.T) { + expression := "IsConnectionRefused() || IsConnectionReset() || IsTimeout()" + manager, err := NewRetryExpressionManager(expression) + require.NoError(t, err) + require.NotNil(t, manager) + + // Test ECONNREFUSED + ctx := LoadRetryContext(syscall.ECONNREFUSED, nil) + result, err := manager.ShouldRetry(ctx) + assert.NoError(t, err) + assert.True(t, result) + + // Test ECONNRESET + ctx = LoadRetryContext(syscall.ECONNRESET, nil) + result, err = manager.ShouldRetry(ctx) + assert.NoError(t, err) + assert.True(t, result) + + // Test ETIMEDOUT + ctx = LoadRetryContext(syscall.ETIMEDOUT, nil) + result, err = manager.ShouldRetry(ctx) + assert.NoError(t, err) + assert.True(t, result) + + // Test unrelated error + ctx = LoadRetryContext(errors.New("some other error"), nil) + result, err = manager.ShouldRetry(ctx) + assert.NoError(t, err) + assert.False(t, result) + }) + + t.Run("expression combining status and syscall errors", func(t *testing.T) { + expression := "statusCode == 500 || IsConnectionRefused()" + manager, err := NewRetryExpressionManager(expression) + require.NoError(t, err) + require.NotNil(t, manager) + + // Test with status code + ctx := LoadRetryContext(nil, &http.Response{StatusCode: 500}) + result, err := manager.ShouldRetry(ctx) + assert.NoError(t, err) + assert.True(t, result) + + // Test with syscall error + ctx = LoadRetryContext(syscall.ECONNREFUSED, nil) + result, err = manager.ShouldRetry(ctx) + assert.NoError(t, err) + assert.True(t, result) + + // Test with neither condition + ctx = LoadRetryContext(nil, &http.Response{StatusCode: 200}) + result, err = manager.ShouldRetry(ctx) + assert.NoError(t, err) + assert.False(t, result) + }) +} diff --git a/router/internal/retrytransport/retry_transport.go b/router/internal/retrytransport/retry_transport.go index aa9d99038a..da0b262bea 100644 --- a/router/internal/retrytransport/retry_transport.go +++ b/router/internal/retrytransport/retry_transport.go @@ -2,45 +2,29 @@ package retrytransport import ( "errors" - "github.com/cloudflare/backoff" - "go.uber.org/zap" "io" "net/http" - "strings" - "syscall" + "strconv" "time" -) -var defaultRetryableErrors = []error{ - syscall.ECONNREFUSED, // "connection refused" - syscall.ECONNRESET, // "connection reset by peer" - syscall.ETIMEDOUT, // "operation timed out" - errors.New("i/o timeout"), - errors.New("no such host"), - errors.New("handshake failure"), - errors.New("handshake timeout"), - errors.New("timeout awaiting response headers"), - errors.New("unexpected EOF"), - errors.New("unexpected EOF reading trailer"), -} - -var defaultRetryableStatusCodes = []int{ - http.StatusInternalServerError, - http.StatusBadGateway, - http.StatusServiceUnavailable, - http.StatusGatewayTimeout, - http.StatusTooManyRequests, -} + "github.com/cloudflare/backoff" + "go.uber.org/zap" +) type ShouldRetryFunc func(err error, req *http.Request, resp *http.Response) bool +type OnRetryFunc func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) type RetryOptions struct { Enabled bool + Algorithm string MaxRetryCount int Interval time.Duration MaxDuration time.Duration - OnRetry func(count int, req *http.Request, resp *http.Response, err error) + Expression string ShouldRetry ShouldRetryFunc + + // Test specific only + OnRetry OnRetryFunc } type requestLoggerGetter func(req *http.Request) *zap.Logger @@ -51,6 +35,61 @@ type RetryHTTPTransport struct { getRequestLogger requestLoggerGetter } +// parseRetryAfterHeader parses the Retry-After header value according to RFC 7231. +// It supports both delay-seconds and HTTP-date formats. +// Returns the duration to wait before retrying, or 0 if parsing fails. +func parseRetryAfterHeader(logger *zap.Logger, retryAfter string) time.Duration { + if retryAfter == "" { + return 0 + } + + var errJoin error + + seconds, err := strconv.Atoi(retryAfter) + if err != nil { + errJoin = errors.Join(errJoin, err) + } else { + if seconds >= 0 { + return time.Duration(seconds) * time.Second + } + } + + t, err := http.ParseTime(retryAfter) + if err != nil { + errJoin = errors.Join(errJoin, err) + } else { + if duration := time.Until(t); duration > 0 { + return duration + } + } + + // Collect and print the error in case of a malformed header + if errJoin != nil { + logger.Error("Failed to parse Retry-After header", zap.String("retry-after", retryAfter), zap.Error(errJoin)) + } + + return 0 +} + +// shouldUseRetryAfter determines if we should use Retry-After header for 429 responses +func shouldUseRetryAfter(logger *zap.Logger, resp *http.Response, maxDuration time.Duration) (time.Duration, bool) { + if resp == nil || resp.StatusCode != http.StatusTooManyRequests { + return 0, false + } + + retryAfter := resp.Header.Get("Retry-After") + if retryAfter == "" { + return 0, false + } + + duration := parseRetryAfterHeader(logger, retryAfter) + if duration > maxDuration { + duration = maxDuration + } + + return duration, duration > 0 +} + func NewRetryHTTPTransport( roundTripper http.RoundTripper, retryOptions RetryOptions, @@ -77,23 +116,34 @@ func (rt *RetryHTTPTransport) RoundTrip(req *http.Request) (*http.Response, erro // Retry logic retries := 0 - for rt.RetryOptions.ShouldRetry(err, req, resp) && retries < rt.RetryOptions.MaxRetryCount { - if rt.RetryOptions.OnRetry != nil { - rt.RetryOptions.OnRetry(retries, req, resp, err) - } - + for (rt.RetryOptions.ShouldRetry(err, req, resp)) && retries < rt.RetryOptions.MaxRetryCount { retries++ - // Wait for the specified backoff period - sleepDuration := b.Duration() + // Check if we should use Retry-After header for 429 responses + var sleepDuration time.Duration + if retryAfterDuration, useRetryAfter := shouldUseRetryAfter(requestLogger, resp, rt.RetryOptions.MaxDuration); useRetryAfter { + sleepDuration = retryAfterDuration + requestLogger.Debug("Using Retry-After header for 429 response", + zap.Int("retry", retries), + zap.String("url", req.URL.String()), + zap.Duration("retry-after", sleepDuration), + ) + } else { + // Use normal backoff for non-429 or 429 without valid Retry-After + sleepDuration = b.Duration() + requestLogger.Debug("Retrying request", + zap.Int("retry", retries), + zap.String("url", req.URL.String()), + zap.Duration("sleep", sleepDuration), + ) + } - requestLogger.Debug("Retrying request", - zap.Int("retry", retries), - zap.String("url", req.URL.String()), - zap.Duration("sleep", sleepDuration), - ) + // Test Specific + if rt.RetryOptions.OnRetry != nil { + rt.RetryOptions.OnRetry(retries, req, resp, sleepDuration, err) + } - // Wait for the specified backoff period + // Wait for the specified duration time.Sleep(sleepDuration) // drain the previous response before retrying @@ -137,29 +187,3 @@ func isResponseOK(resp *http.Response) bool { // spec-compliant and returns a different status code than 200. return resp.StatusCode >= 200 && resp.StatusCode < 300 } - -func IsRetryableError(err error, resp *http.Response) bool { - - if err != nil { - // Network - s := err.Error() - for _, retryableError := range defaultRetryableErrors { - if strings.HasSuffix( - strings.ToLower(s), - strings.ToLower(retryableError.Error())) { - return true - } - } - } - - if resp != nil { - // HTTP - for _, retryableStatusCode := range defaultRetryableStatusCodes { - if resp.StatusCode == retryableStatusCode { - return true - } - } - } - - return false -} diff --git a/router/internal/retrytransport/retry_transport_test.go b/router/internal/retrytransport/retry_transport_test.go index 117d81e46c..e2d49d3f28 100644 --- a/router/internal/retrytransport/retry_transport_test.go +++ b/router/internal/retrytransport/retry_transport_test.go @@ -4,17 +4,45 @@ import ( "bytes" "errors" "fmt" - "go.uber.org/zap/zapcore" "io" "net/http" "net/http/httptest" + "strings" "testing" "time" + "go.uber.org/zap/zapcore" + "github.com/stretchr/testify/assert" "go.uber.org/zap" ) +const defaultMaxDuration = 100 * time.Second + +// simpleShouldRetry provides simple retry logic for testing the transport implementation +func simpleShouldRetry(err error, req *http.Request, resp *http.Response) bool { + // Simple logic for testing - retry on 5xx status codes or any error + if err != nil { + return true + } + if resp != nil && resp.StatusCode >= 500 { + return true + } + return false +} + +// shouldRetryWith429 includes 429 responses in addition to the simple retry logic +func shouldRetryWith429(err error, req *http.Request, resp *http.Response) bool { + // Include 429 responses in retryable conditions + if err != nil { + return true + } + if resp != nil && (resp.StatusCode >= 500 || resp.StatusCode == http.StatusTooManyRequests) { + return true + } + return false +} + type MockTransport struct { handler func(req *http.Request) (*http.Response, error) } @@ -25,34 +53,34 @@ func (dt *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { func TestRetryOnHTTP5xx(t *testing.T) { retries := 0 - index := -1 + attemptCount := 0 + maxRetries := 3 tr := RetryHTTPTransport{ RoundTripper: &MockTransport{ handler: func(req *http.Request) (*http.Response, error) { - if index < len(defaultRetryableStatusCodes)-1 { - index++ - return &http.Response{ - StatusCode: defaultRetryableStatusCodes[index], - }, nil - } else { + attemptCount++ + if attemptCount <= maxRetries { + // Return 500 to trigger retry return &http.Response{ - StatusCode: http.StatusOK, + StatusCode: http.StatusInternalServerError, }, nil } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, }, getRequestLogger: func(req *http.Request) *zap.Logger { return zap.NewNop() }, RetryOptions: RetryOptions{ - MaxRetryCount: len(defaultRetryableStatusCodes), + MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond, - ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool { - return IsRetryableError(err, resp) - }, - OnRetry: func(count int, req *http.Request, resp *http.Response, err error) { + ShouldRetry: simpleShouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { retries++ }, }, @@ -62,42 +90,41 @@ func TestRetryOnHTTP5xx(t *testing.T) { resp, err := tr.RoundTrip(req) assert.Nil(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - assert.Equal(t, len(defaultRetryableStatusCodes), retries) - + // Should have retried exactly maxRetries times + assert.Equal(t, maxRetries, retries) + // Should have made maxRetries + 1 total attempts + assert.Equal(t, maxRetries+1, attemptCount) } -func TestRetryOnNetErrors(t *testing.T) { +func TestRetryOnErrors(t *testing.T) { retries := 0 - index := -1 + attemptCount := 0 + maxRetries := 3 tr := RetryHTTPTransport{ RoundTripper: &MockTransport{ handler: func(req *http.Request) (*http.Response, error) { - - if index < len(defaultRetryableErrors)-1 { - index++ - return nil, defaultRetryableErrors[index] - } else { - return &http.Response{ - StatusCode: http.StatusOK, - }, nil + attemptCount++ + if attemptCount <= maxRetries { + // Return any error to trigger retry + return nil, errors.New("some network error") } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, }, getRequestLogger: func(req *http.Request) *zap.Logger { return zap.NewNop() }, RetryOptions: RetryOptions{ - MaxRetryCount: len(defaultRetryableErrors), + MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond, - ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool { - return IsRetryableError(err, resp) - }, - OnRetry: func(count int, req *http.Request, resp *http.Response, err error) { + ShouldRetry: simpleShouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { retries++ }, }, @@ -107,35 +134,42 @@ func TestRetryOnNetErrors(t *testing.T) { resp, err := tr.RoundTrip(req) assert.Nil(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - assert.Equal(t, len(defaultRetryableErrors), retries) - + // Should have retried exactly maxRetries times + assert.Equal(t, maxRetries, retries) + // Should have made maxRetries + 1 total attempts + assert.Equal(t, maxRetries+1, attemptCount) } -func TestDoNotRetryWhenErrorIsNotRetryableAndResponseIsNil(t *testing.T) { - finalError := errors.New("some error") +func TestDoNotRetryWhenShouldRetryReturnsFalse(t *testing.T) { + nonRetryableError := errors.New("non-retryable error") - expectedRetries := 2 retries := 0 - index := -1 - maxRetryCount := 7 + attemptCount := 0 + maxRetryCount := 5 + + // Custom ShouldRetry that returns false for our specific non-retryable error + shouldRetry := func(err error, req *http.Request, resp *http.Response) bool { + if err != nil && err.Error() == "non-retryable error" { + return false + } + return simpleShouldRetry(err, req, resp) + } tr := RetryHTTPTransport{ RoundTripper: &MockTransport{ handler: func(req *http.Request) (*http.Response, error) { - index++ - switch index { - case 0: - // The first retry we return a retryable error - return &http.Response{StatusCode: defaultRetryableStatusCodes[0]}, nil + attemptCount++ + switch attemptCount { case 1: - // The second retry we return a retryable status code - return nil, defaultRetryableErrors[index] + // First attempt: return retryable error + return nil, errors.New("retryable error") + case 2: + // Second attempt: return retryable status code + return &http.Response{StatusCode: http.StatusInternalServerError}, nil default: - // The third retry we return a nil response as well as a non-retryable error - return nil, finalError + // Third attempt: return non-retryable error (should stop retrying) + return nil, nonRetryableError } }, }, @@ -146,10 +180,8 @@ func TestDoNotRetryWhenErrorIsNotRetryableAndResponseIsNil(t *testing.T) { MaxRetryCount: maxRetryCount, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond, - ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool { - return IsRetryableError(err, resp) - }, - OnRetry: func(count int, req *http.Request, resp *http.Response, err error) { + ShouldRetry: shouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { retries++ }, }, @@ -158,10 +190,14 @@ func TestDoNotRetryWhenErrorIsNotRetryableAndResponseIsNil(t *testing.T) { req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) - assert.Error(t, finalError, err) + assert.Error(t, err) + assert.Equal(t, nonRetryableError, err) assert.Nil(t, resp) - assert.Equal(t, expectedRetries, retries) + // Should have retried exactly 2 times before encountering non-retryable error + assert.Equal(t, 2, retries) + assert.Equal(t, 3, attemptCount) + // Should not have exhausted max retry count assert.NotEqual(t, maxRetryCount, retries) } @@ -193,6 +229,86 @@ func (b *TrackableBody) Close() error { return nil } +func TestShortCircuitOnSuccess(t *testing.T) { + attemptCount := 0 + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + // Always return success + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("success")), + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: 5, + Interval: 1 * time.Millisecond, + MaxDuration: 10 * time.Millisecond, + ShouldRetry: simpleShouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + t.Error("OnRetry should not be called when first request succeeds") + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + // Should only make one attempt since first attempt succeeds + assert.Equal(t, 1, attemptCount) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.Equal(t, "success", string(body)) + resp.Body.Close() +} + +func TestMaxRetryCountRespected(t *testing.T) { + maxRetries := 2 + retries := 0 + attemptCount := 0 + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + // Always return retryable error to test max retry limit + return nil, errors.New("always fail") + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 1 * time.Millisecond, + MaxDuration: 10 * time.Millisecond, + ShouldRetry: simpleShouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.Error(t, err) + assert.Nil(t, resp) + // Should have retried exactly maxRetries times + assert.Equal(t, maxRetries, retries) + // Should have made maxRetries + 1 total attempts + assert.Equal(t, maxRetries+1, attemptCount) +} + func TestResponseBodyDraining(t *testing.T) { actualRetries := 0 index := -1 @@ -213,7 +329,7 @@ func TestResponseBodyDraining(t *testing.T) { index++ if index < retryCount { return &http.Response{ - StatusCode: defaultRetryableStatusCodes[0], + StatusCode: http.StatusInternalServerError, Body: bodies[index], }, nil } else { @@ -231,10 +347,8 @@ func TestResponseBodyDraining(t *testing.T) { MaxRetryCount: retryCount, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond, - ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool { - return IsRetryableError(err, resp) - }, - OnRetry: func(count int, req *http.Request, resp *http.Response, err error) { + ShouldRetry: simpleShouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { actualRetries++ }, }, @@ -290,7 +404,7 @@ func TestRequestLoggerIsUsed(t *testing.T) { index++ if index < retryCount { return &http.Response{ - StatusCode: defaultRetryableStatusCodes[0], + StatusCode: http.StatusInternalServerError, Body: bodies[index], }, nil } else { @@ -308,10 +422,8 @@ func TestRequestLoggerIsUsed(t *testing.T) { MaxRetryCount: retryCount, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond, - ShouldRetry: func(err error, req *http.Request, resp *http.Response) bool { - return IsRetryableError(err, resp) - }, - OnRetry: func(count int, req *http.Request, resp *http.Response, err error) { + ShouldRetry: simpleShouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { actualRetries++ }, }, @@ -337,3 +449,693 @@ func createTestLogger(t *testing.T) (*bytes.Buffer, *zap.Logger) { unusedLogger := zap.New(core) return &buf, unusedLogger } + +func TestOnRetryCallbackInvoked(t *testing.T) { + maxRetries := 3 + retries := 0 + var retryCallbacks []struct { + count int + err error + resp *http.Response + } + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + if retries < maxRetries { + // Return retryable error + return nil, errors.New("retryable error") + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 1 * time.Millisecond, + MaxDuration: 10 * time.Millisecond, + ShouldRetry: simpleShouldRetry, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + retryCallbacks = append(retryCallbacks, struct { + count int + err error + resp *http.Response + }{count: count, err: err, resp: resp}) + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + + // Verify OnRetry was called the right number of times + assert.Equal(t, maxRetries, retries) + assert.Len(t, retryCallbacks, maxRetries) + + // Verify callback parameters are correct + for i, callback := range retryCallbacks { + assert.Equal(t, i+1, callback.count) + assert.Error(t, callback.err) + assert.Equal(t, "retryable error", callback.err.Error()) + assert.Nil(t, callback.resp) + } +} + +func TestRetryOn429WithDelaySeconds(t *testing.T) { + retries := 0 + attemptCount := 0 + maxRetries := 2 + retryAfterSeconds := 1 // Use 1 second to keep test fast + + // Track what retry duration was requested to verify Retry-After is parsed correctly + var retryAfterUsed []bool + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with Retry-After header in seconds + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", fmt.Sprintf("%d", retryAfterSeconds)) + + // Verify the header is parsed correctly + duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsed = append(retryAfterUsed, useRetryAfter) + assert.True(t, useRetryAfter, "Should use Retry-After header for 429") + assert.Equal(t, time.Duration(retryAfterSeconds)*time.Second, duration) + + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 100 * time.Millisecond, // This should be ignored for 429 + MaxDuration: 10 * time.Second, + ShouldRetry: shouldRetryWith429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, maxRetries, retries) + assert.Equal(t, maxRetries+1, attemptCount) + // Verify that Retry-After was detected and used + assert.Len(t, retryAfterUsed, maxRetries) + for i, used := range retryAfterUsed { + assert.True(t, used, "Retry %d should have used Retry-After header", i) + } +} + +func TestRetryOn429WithDelaySecondsLargerThanMaxDuration(t *testing.T) { + retries := 0 + attemptCount := 0 + maxRetries := 2 + retryAfterSeconds := 1 // Use 1 second to keep test fast + maxDuration := 500 * time.Millisecond + + // Track what retry duration was requested to verify Retry-After is parsed correctly + var retryAfterUsed []bool + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with Retry-After header in seconds + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", fmt.Sprintf("%d", retryAfterSeconds)) + + // Verify the header is parsed correctly + duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, maxDuration) + retryAfterUsed = append(retryAfterUsed, useRetryAfter) + assert.True(t, useRetryAfter, "Should use Retry-After header for 429") + assert.Equal(t, maxDuration, duration) + + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 100 * time.Millisecond, // This should be ignored for 429 + MaxDuration: 10 * time.Second, + ShouldRetry: shouldRetryWith429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, maxRetries, retries) + assert.Equal(t, maxRetries+1, attemptCount) + // Verify that Retry-After was detected and used + assert.Len(t, retryAfterUsed, maxRetries) + for i, used := range retryAfterUsed { + assert.True(t, used, "Retry %d should have used Retry-After header", i) + } +} + +func TestRetryOn429WithoutRetryAfter(t *testing.T) { + retries := 0 + attemptCount := 0 + maxRetries := 2 + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 without Retry-After header + return &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + }, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 1 * time.Millisecond, + MaxDuration: 10 * time.Millisecond, + ShouldRetry: shouldRetryWith429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + // Should have retried exactly maxRetries times + assert.Equal(t, maxRetries, retries) + assert.Equal(t, maxRetries+1, attemptCount) +} + +func TestRetryOn429WithHTTPDate(t *testing.T) { + retries := 0 + attemptCount := 0 + maxRetries := 2 + + // Track what retry duration was requested to verify Retry-After is parsed correctly + var retryAfterUsed []bool + var expectedDuration time.Duration + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with Retry-After header as HTTP-date (1 second in future to keep test fast) + expectedDuration = 1 * time.Second + futureTime := time.Now().UTC().Add(expectedDuration) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", futureTime.Format(http.TimeFormat)) + + // Verify the header is parsed correctly + duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsed = append(retryAfterUsed, useRetryAfter) + assert.True(t, useRetryAfter, "Should use Retry-After header for 429") + // Allow reasonable tolerance for execution delay between time creation and parsing + assert.True(t, duration > 0 && duration <= expectedDuration, + "Duration should be positive and <= %v, got %v", expectedDuration, duration) + + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 100 * time.Millisecond, // This should be ignored for 429 + MaxDuration: 10 * time.Second, + ShouldRetry: shouldRetryWith429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, maxRetries, retries) + assert.Equal(t, maxRetries+1, attemptCount) + // Verify that Retry-After was detected and used + assert.Len(t, retryAfterUsed, maxRetries) + for i, used := range retryAfterUsed { + assert.True(t, used, "Retry %d should have used Retry-After header", i) + } +} + +func TestRetryOn429WithInvalidRetryAfterHeader(t *testing.T) { + retries := 0 + attemptCount := 0 + maxRetries := 2 + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with invalid Retry-After header + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "invalid-value") + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 1 * time.Millisecond, + MaxDuration: 10 * time.Millisecond, + ShouldRetry: shouldRetryWith429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + // Should have retried exactly maxRetries times + assert.Equal(t, maxRetries, retries) + assert.Equal(t, maxRetries+1, attemptCount) + // Should fall back to normal backoff when Retry-After is invalid +} + +func TestRetryOn429WithNegativeDelaySeconds(t *testing.T) { + retries := 0 + attemptCount := 0 + maxRetries := 2 + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with negative Retry-After value (should fall back to normal backoff) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "-1") + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 1 * time.Millisecond, + MaxDuration: 10 * time.Millisecond, + ShouldRetry: shouldRetryWith429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + // Should have retried exactly maxRetries times + assert.Equal(t, maxRetries, retries) + assert.Equal(t, maxRetries+1, attemptCount) +} + +func TestRetryMixed429AndOtherErrors(t *testing.T) { + retries := 0 + attemptCount := 0 + maxRetries := 4 + + // Track which responses used Retry-After vs normal backoff + var retryAfterUsedPerAttempt []bool + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + switch attemptCount { + case 1: + // First: 429 with Retry-After + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "1") + + // Verify this should use Retry-After + _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) + + return resp, nil + case 2: + // Second: Network error (should use normal backoff) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, false) + return nil, errors.New("network error") + case 3: + // Third: 500 error (should use normal backoff) + resp := &http.Response{ + StatusCode: http.StatusInternalServerError, + } + _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) + return resp, nil + case 4: + // Fourth: 429 without Retry-After (should use normal backoff) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) + return resp, nil + default: + // Finally: Success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + } + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: maxRetries, + Interval: 10 * time.Millisecond, // Should be used for non-429-with-Retry-After cases + MaxDuration: 10 * time.Second, + ShouldRetry: shouldRetryWith429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, maxRetries, retries) + assert.Equal(t, maxRetries+1, attemptCount) + assert.Len(t, retryAfterUsedPerAttempt, maxRetries) + + // First attempt should use Retry-After (429 with header) + assert.True(t, retryAfterUsedPerAttempt[0], "First retry should use Retry-After") + + // Other attempts should not use Retry-After + assert.False(t, retryAfterUsedPerAttempt[1], "Second retry should not use Retry-After (network error)") + assert.False(t, retryAfterUsedPerAttempt[2], "Third retry should not use Retry-After (500 error)") + assert.False(t, retryAfterUsedPerAttempt[3], "Fourth retry should not use Retry-After (429 without header)") +} + +func TestNoRetryOn429WhenShouldRetryReturnsFalse(t *testing.T) { + retries := 0 + attemptCount := 0 + + // ShouldRetry function that excludes 429 responses + shouldNotRetry429 := func(err error, req *http.Request, resp *http.Response) bool { + // Only retry on errors, not on 429 responses + // Do not retry on any HTTP status codes (including 429) + return err != nil + } + + tr := RetryHTTPTransport{ + RoundTripper: &MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + // Always return 429 with Retry-After header + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "1") + return resp, nil + }, + }, + getRequestLogger: func(req *http.Request) *zap.Logger { + return zap.NewNop() + }, + RetryOptions: RetryOptions{ + MaxRetryCount: 3, + Interval: 1 * time.Millisecond, + MaxDuration: 10 * time.Millisecond, + ShouldRetry: shouldNotRetry429, + OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + }, + }, + } + + req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) + resp, err := tr.RoundTrip(req) + + assert.NoError(t, err) + assert.Equal(t, http.StatusTooManyRequests, resp.StatusCode) + // Should not have retried at all since ShouldRetry returns false for 429 + assert.Equal(t, 0, retries) + assert.Equal(t, 1, attemptCount) +} + +// Test unit functions directly +func TestParseRetryAfterHeader(t *testing.T) { + tests := []struct { + name string + header string + expected time.Duration + }{ + { + name: "valid delay seconds", + header: "120", + expected: 120 * time.Second, + }, + { + name: "zero delay seconds", + header: "0", + expected: 0, + }, + { + name: "negative delay seconds should return 0", + header: "-1", + expected: 0, + }, + { + name: "invalid string should return 0", + header: "invalid", + expected: 0, + }, + { + name: "empty string should return 0", + header: "", + expected: 0, + }, + { + name: "HTTP date in future", + header: time.Now().UTC().Add(3 * time.Second).Format(http.TimeFormat), + expected: 3 * time.Second, // approximately + }, + { + name: "HTTP date in past should return 0", + header: time.Now().UTC().Add(-3 * time.Second).Format(http.TimeFormat), + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := parseRetryAfterHeader(zap.NewNop(), tt.header) + if tt.name == "HTTP date in future" { + // For HTTP date tests, allow reasonable tolerance for timing variations + assert.True(t, result >= tt.expected-1*time.Second && result <= tt.expected+1*time.Second, + "Expected ~%v, got %v", tt.expected, result) + } else { + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestShouldUseRetryAfter(t *testing.T) { + tests := []struct { + name string + resp *http.Response + expectedDur time.Duration + expectedUse bool + maxDuration time.Duration + }{ + { + name: "nil response", + resp: nil, + expectedDur: 0, + expectedUse: false, + }, + { + name: "non-429 response", + resp: &http.Response{ + StatusCode: http.StatusInternalServerError, + Header: make(http.Header), + }, + expectedDur: 0, + expectedUse: false, + }, + { + name: "429 without Retry-After header", + resp: &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + }, + expectedDur: 0, + expectedUse: false, + }, + { + name: "429 with empty Retry-After header", + resp: func() *http.Response { + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "") + return resp + }(), + expectedDur: 0, + expectedUse: false, + }, + { + name: "429 with Retry-After seconds larger than allowed", + resp: func() *http.Response { + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "30") + return resp + }(), + expectedDur: 20 * time.Second, + maxDuration: 20 * time.Second, + expectedUse: true, + }, + { + name: "429 with valid Retry-After seconds", + resp: func() *http.Response { + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "30") + return resp + }(), + expectedDur: 30 * time.Second, + expectedUse: true, + }, + { + name: "429 with invalid Retry-After", + resp: func() *http.Response { + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "invalid") + return resp + }(), + expectedDur: 0, + expectedUse: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + maxDuration := defaultMaxDuration + if tt.maxDuration > 0 { + maxDuration = tt.maxDuration + } + dur, use := shouldUseRetryAfter(zap.NewNop(), tt.resp, maxDuration) + assert.Equal(t, tt.expectedDur, dur) + assert.Equal(t, tt.expectedUse, use) + }) + } +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 1493cefdfe..beeac69969 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -230,10 +230,11 @@ type GraphqlMetrics struct { type BackoffJitterRetry struct { Enabled bool `yaml:"enabled" envDefault:"true" env:"RETRY_ENABLED"` - Algorithm string `yaml:"algorithm" envDefault:"backoff_jitter"` - MaxAttempts int `yaml:"max_attempts" envDefault:"5"` - MaxDuration time.Duration `yaml:"max_duration" envDefault:"10s"` - Interval time.Duration `yaml:"interval" envDefault:"3s"` + Algorithm string `yaml:"algorithm" envDefault:"backoff_jitter" env:"RETRY_ALGORITHM"` + MaxAttempts int `yaml:"max_attempts" envDefault:"5" env:"RETRY_MAX_ATTEMPTS"` + MaxDuration time.Duration `yaml:"max_duration" envDefault:"10s" env:"RETRY_MAX_DURATION"` + Interval time.Duration `yaml:"interval" envDefault:"3s" env:"RETRY_INTERVAL"` + Expression string `yaml:"expression,omitempty" env:"RETRY_EXPRESSION" envDefault:"IsRetryableStatusCode() || IsConnectionError() || IsTimeout()"` } type SubgraphCacheControlRule struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index e10c21578a..9b58d9ac88 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -3239,6 +3239,11 @@ "format": "go-duration", "default": "10s", "description": "The maximum allowable duration between retries (random). The period is specified as a string with a number and a unit, e.g. 10ms, 1s, 1m, 1h. The supported units are 'ms', 's', 'm', 'h'." + }, + "expression": { + "type": "string", + "description": "The expression used to determine if a request should be retried. The expression can reference status codes, error messages, and helper functions like IsRetryableStatusCode(), IsConnectionError(), IsHttpReadTimeout(), IsTimeout() (includes HTTP read timeouts). See https://expr-lang.org/ for expression syntax. Note: Mutations are never retried regardless of this expression. EOF errors are always retried at the transport layer regardless of this expression.", + "default": "IsRetryableStatusCode() || IsConnectionError() || IsTimeout()" } } } diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index b67ccd9f8d..5b8ba74b26 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -146,7 +146,8 @@ "Algorithm": "backoff_jitter", "MaxAttempts": 5, "MaxDuration": 10000000000, - "Interval": 3000000000 + "Interval": 3000000000, + "Expression": "IsRetryableStatusCode() || IsConnectionError() || IsTimeout()" }, "CircuitBreaker": { "Enabled": false, diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d73deddec9..1f70e0987a 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -311,7 +311,8 @@ "Algorithm": "backoff_jitter", "MaxAttempts": 5, "MaxDuration": 10000000000, - "Interval": 3000000000 + "Interval": 3000000000, + "Expression": "IsRetryableStatusCode() || IsConnectionError() || IsTimeout()" }, "CircuitBreaker": { "Enabled": false, @@ -349,7 +350,8 @@ "Algorithm": "", "MaxAttempts": 0, "MaxDuration": 0, - "Interval": 0 + "Interval": 0, + "Expression": "" }, "CircuitBreaker": { "Enabled": false,