diff --git a/router-tests/observability/structured_logging_test.go b/router-tests/observability/structured_logging_test.go index cb586086b8..ce8d7cfeca 100644 --- a/router-tests/observability/structured_logging_test.go +++ b/router-tests/observability/structured_logging_test.go @@ -745,7 +745,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{All: config.GlobalSubgraphRequestRule{BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}}})), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -864,7 +864,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{All: config.GlobalSubgraphRequestRule{BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}}})), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -996,7 +996,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{All: config.GlobalSubgraphRequestRule{BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}}})), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -1132,7 +1132,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{All: config.GlobalSubgraphRequestRule{BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}}})), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, @@ -2476,7 +2476,7 @@ func TestFlakyAccessLogs(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{All: config.GlobalSubgraphRequestRule{BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}}})), }, LogObservation: testenv.LogObservationConfig{ Enabled: true, diff --git a/router-tests/security/circuit_breaker_test.go b/router-tests/security/circuit_breaker_test.go index 93a30e91d8..babbd19d70 100644 --- a/router-tests/security/circuit_breaker_test.go +++ b/router-tests/security/circuit_breaker_test.go @@ -4,6 +4,7 @@ import ( "github.com/wundergraph/cosmo/router-tests/testutils" "context" + "github.com/gorilla/websocket" "net/http" "sort" "sync/atomic" @@ -625,6 +626,148 @@ func TestCircuitBreaker(t *testing.T) { }) }) + t.Run("verify circuit breaker tripping on upgrade requests", func(t *testing.T) { + t.Parallel() + + const failedTries int64 = 3 + + const timestampMessage = `{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}` + const completeMessage = `{"type":"complete","id":"1"}` + const defaultErrorMessage = `{"id":"1","type":"error","payload":[{"message":"Internal server error"}]}` + + breaker := getCircuitBreakerWithDefaults() + breaker.RequestThreshold = 3 + breaker.ErrorThresholdPercentage = 100 + + breaker.NumBuckets = 1 + breaker.RollingDuration = 5000 * time.Millisecond + trafficConfig := getTrafficConfigWithTimeout(breaker, 1*time.Second) + + employeesCalls := atomic.Int64{} + + testenv.Run(t, &testenv.Config{ + ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { + engineExecutionConfiguration.WebSocketClientReadTimeout = time.Millisecond * 2000 + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.DebugLevel, + }, + RouterOptions: []core.Option{ + core.WithSubgraphCircuitBreakerOptions(core.NewSubgraphCircuitBreakerOptions(trafficConfig)), + core.WithSubgraphTransportOptions(core.NewSubgraphTransportOptions(trafficConfig)), + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + employeesCalls.Add(1) + if employeesCalls.Load() <= failedTries { + simulateConnectionFailureOnClose(w) + return + } + + upgrader := websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { + return true + }, + Subprotocols: []string{"graphql-transport-ws"}, + } + conn, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + // Read connection_init + _, _, err = testenv.WSReadMessage(t, conn) + require.NoError(t, err) + + err = testenv.WSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) + require.NoError(t, err) + + // Read subscribe message before sending data + _, _, err = testenv.WSReadMessage(t, conn) + require.NoError(t, err) + + err = testenv.WSWriteMessage(t, conn, websocket.TextMessage, []byte(timestampMessage)) + require.NoError(t, err) + + err = testenv.WSWriteMessage(t, conn, websocket.TextMessage, []byte(completeMessage)) + require.NoError(t, err) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + for i := range failedTries + 2 { + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err := testenv.WSWriteJSON(t, conn, &testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + _, message, err := testenv.WSReadMessage(t, conn) + require.NoError(t, err) + + require.JSONEq(t, defaultErrorMessage, string(message)) + require.NoError(t, conn.Close()) + + switch { + case i < breaker.RequestThreshold-1: + require.Zero(t, xEnv.Observer().FilterMessage("Circuit breaker status changed").Len()) + case i == breaker.RequestThreshold-1: + require.Equal(t, 1, xEnv.Observer().FilterMessage("Circuit breaker status changed").Len()) + case i > breaker.RequestThreshold-1: + expectedCount := i - (breaker.RequestThreshold - 1) + require.Equal(t, int(expectedCount), xEnv.Observer().FilterMessage("Circuit breaker open, request callback did not execute").Len()) + } + + } + + require.Equal(t, failedTries, employeesCalls.Load()) + + // Ensure all previous subscriptions are fully cleaned up before + // waiting for the circuit to reset, to prevent leftover subscription + // cleanup from interfering with the half-open circuit state. + xEnv.WaitForSubscriptionCount(0, time.Second*5) + + // Wait for current bucket to be cleaned up + time.Sleep(breaker.RollingDuration*3 + time.Millisecond*1000) + + // ==== + // Verify a success case with messages validated from here onwards + // ==== + + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err := testenv.WSWriteJSON(t, conn, &testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + _, message, err := testenv.WSReadMessage(t, conn) + require.NoError(t, err) + + require.JSONEq(t, timestampMessage, string(message)) + + err = testenv.WSWriteJSON(t, conn, &testenv.WebSocketMessage{ID: "1", Type: "complete"}) + require.NoError(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + require.NoError(t, err) + + _, actualCompleteMessage, err := testenv.WSReadMessage(t, conn) + require.NoError(t, err) + require.JSONEq(t, completeMessage, string(actualCompleteMessage)) + + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) + t.Run("circuit breaker metrics", func(t *testing.T) { t.Parallel() diff --git a/router-tests/security/error_handling_test.go b/router-tests/security/error_handling_test.go index 52434c6df7..1f3b24ea92 100644 --- a/router-tests/security/error_handling_test.go +++ b/router-tests/security/error_handling_test.go @@ -1563,7 +1563,11 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}, + }, + })), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -1595,7 +1599,11 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}, + }, + })), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -1627,7 +1635,11 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}, + }, + })), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -1659,7 +1671,11 @@ func TestErrorPropagation(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}, + }, + })), }, }, func(t *testing.T, xEnv *testenv.Environment) { resp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ diff --git a/router-tests/security/panic_test.go b/router-tests/security/panic_test.go index bb12614f44..eeca055597 100644 --- a/router-tests/security/panic_test.go +++ b/router-tests/security/panic_test.go @@ -48,7 +48,11 @@ func TestEnginePanic(t *testing.T) { EnableSingleFlight: true, MaxConcurrentResolvers: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}, + }, + })), }, }, func(t *testing.T, xEnv *testenv.Environment) { res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ @@ -80,7 +84,11 @@ func TestEnginePanic(t *testing.T) { EnableSingleFlight: true, ParseKitPoolSize: 1, }), - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}, + }, + })), }, }, func(t *testing.T, xEnv *testenv.Environment) { res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ diff --git a/router-tests/security/retry_test.go b/router-tests/security/retry_test.go index cf916e15bc..7939240891 100644 --- a/router-tests/security/retry_test.go +++ b/router-tests/security/retry_test.go @@ -1,15 +1,17 @@ package integration import ( - "github.com/stretchr/testify/require" - "github.com/wundergraph/cosmo/router-tests/testenv" - "github.com/wundergraph/cosmo/router/core" - "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/gorilla/websocket" "net/http" "strconv" "sync/atomic" "testing" "time" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" ) func CreateRetryCounterFunc(counter *atomic.Int32, duration *atomic.Int64) func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { @@ -34,7 +36,19 @@ func TestRetry(t *testing.T) { maxRetryCount := 3 expression := "true" - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: 10 * time.Second, + Interval: 200 * time.Millisecond, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -76,7 +90,19 @@ func TestRetry(t *testing.T) { maxRetryCount := 3 expression := "false" - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: 10 * time.Second, + Interval: 200 * time.Millisecond, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -117,7 +143,19 @@ func TestRetry(t *testing.T) { maxRetryCount := 3 expression := "true" - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: 10 * time.Second, + Interval: 200 * time.Millisecond, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -159,7 +197,19 @@ func TestRetry(t *testing.T) { maxAttemptsBeforeServiceSucceeds := 2 expression := "true" - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 10*time.Second, 200*time.Millisecond, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: 10 * time.Second, + Interval: 200 * time.Millisecond, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -208,7 +258,19 @@ func TestRetry(t *testing.T) { expression := "statusCode == 429" headerRetryIntervalInSeconds := 1 - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 2000*time.Second, 100*time.Millisecond, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: 2000 * time.Second, + Interval: 100 * time.Millisecond, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -260,7 +322,19 @@ func TestFlakyRetry(t *testing.T) { maxDuration := 100 * time.Millisecond expression := "true" - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, maxDuration, retryInterval, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: maxDuration, + Interval: retryInterval, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -315,7 +389,19 @@ func TestFlakyRetry(t *testing.T) { maxRetryCount := 3 expression := "statusCode == 429" - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 1000*time.Millisecond, retryInterval, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: 1000 * time.Millisecond, + Interval: retryInterval, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -363,7 +449,19 @@ func TestFlakyRetry(t *testing.T) { emptyRetryInterval := 0 retryInterval := 300 * time.Millisecond - options := core.WithSubgraphRetryOptions(true, "", maxRetryCount, 1000*time.Millisecond, retryInterval, expression, retryCounterFunc) + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: maxRetryCount, + MaxDuration: 1000 * time.Millisecond, + Interval: retryInterval, + Expression: expression, + }, + }, + }) + opts.OnRetryFunc = retryCounterFunc + options := core.WithSubgraphRetryOptions(opts) testenv.Run(t, &testenv.Config{ NoRetryClient: true, @@ -375,8 +473,8 @@ func TestFlakyRetry(t *testing.T) { Employees: testenv.SubgraphConfig{ Middleware: func(_ http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { - w.WriteHeader(http.StatusTooManyRequests) w.Header().Set("Retry-After", strconv.Itoa(emptyRetryInterval)) + w.WriteHeader(http.StatusTooManyRequests) serviceCallsCounter.Add(1) }) }, @@ -399,3 +497,588 @@ func TestFlakyRetry(t *testing.T) { }) }) } + +func TestRetryPerSubgraph(t *testing.T) { + t.Parallel() + + t.Run("verify invalid algorithm is detected for base", func(t *testing.T) { + t.Parallel() + + // Configure per-subgraph retry: employees gets 3 retries, test1 gets 1 retry + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + Algorithm: "invalid_algorithm", + MaxAttempts: 2, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + err := testenv.RunWithError(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + }, func(t *testing.T, _ *testenv.Environment) { + require.Fail(t, "expected initialization to fail due to invalid algorithm") + }) + + require.ErrorContains(t, err, "unsupported retry algorithm") + }) + + t.Run("invalid algorithm is ignored when retries are disabled (base)", func(t *testing.T) { + t.Parallel() + + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: false, + Algorithm: "invalid_algorithm", + MaxAttempts: 2, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + err := testenv.RunWithError(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + }, func(_ *testing.T, _ *testenv.Environment) {}) + + require.NoError(t, err) + }) + + t.Run("invalid algorithm is ignored when retries are disabled (per subgraph)", func(t *testing.T) { + t.Parallel() + + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "employees": { + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: false, + Algorithm: "invalid_algorithm", + MaxAttempts: 2, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + err := testenv.RunWithError(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + }, func(_ *testing.T, _ *testenv.Environment) {}) + + require.NoError(t, err) + }) + + t.Run("verify invalid algorithm is detected for per subgraphs", func(t *testing.T) { + t.Parallel() + + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "employees": { + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + Algorithm: "invalid_algorithm", + MaxAttempts: 2, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + err := testenv.RunWithError(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + }, func(t *testing.T, _ *testenv.Environment) { + require.Fail(t, "expected initialization to fail due to invalid algorithm") + }) + + require.ErrorContains(t, err, "unsupported retry algorithm") + }) + + t.Run("verify invalid expression is detected for base", func(t *testing.T) { + t.Parallel() + + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: 2, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "truethere", + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + err := testenv.RunWithError(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + }, func(t *testing.T, _ *testenv.Environment) { + require.Fail(t, "expected initialization to fail due to invalid algorithm") + }) + + require.ErrorContains(t, err, "failed to add base retry expression: failed to compile retry expression: line 1, column 0: unknown name truethere") + }) + + t.Run("verify invalid expression is detected per subgraphs", func(t *testing.T) { + t.Parallel() + + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "employees": { + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: 2, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "truethere", + }, + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + err := testenv.RunWithError(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + }, func(t *testing.T, _ *testenv.Environment) { + require.Fail(t, "expected initialization to fail due to invalid algorithm") + }) + + require.ErrorContains(t, err, "failed to add retry expression for subgraph employees: failed to compile retry expression: line 1, column 0: unknown name truethere") + }) + + t.Run("verify valid expression and algorithm", func(t *testing.T) { + t.Parallel() + + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "employees": { + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + Algorithm: "backoff_jitter", + MaxAttempts: 2, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + err := testenv.RunWithError(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + }, func(_ *testing.T, _ *testenv.Environment) { + }) + + require.NoError(t, err) + }) + + t.Run("verify retries are applied per subgraph", func(t *testing.T) { + t.Parallel() + + employeesCalls := atomic.Int32{} + test1Calls := atomic.Int32{} + + // Configure per-subgraph retry: employees gets 3 retries, test1 gets 1 retry + employeesMax := 3 + test1Max := 1 + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "employees": { + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: employeesMax, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + "test1": { + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: test1Max, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + employeesCalls.Add(1) + }) + }, + }, + Test1: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + test1Calls.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // 1) Call employees-only query; expect employees subgraph to be retried employeesMax times (attempts = retries + 1) + resEmp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + require.NoError(t, err) + require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"employees":null}}`, resEmp.Body) + require.Equal(t, employeesMax+1, int(employeesCalls.Load())) + require.Equal(t, 0, int(test1Calls.Load())) + + // 2) Call test1-only query; expect test subgraph to be retried test1Max times (attempts = retries + 1) + resTest1, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { floatField(arg: 1.5) }`, + }) + + require.NoError(t, err) + require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'test1', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"floatField":null}}`, resTest1.Body) + require.Equal(t, employeesMax+1, int(employeesCalls.Load())) + require.Equal(t, test1Max+1, int(test1Calls.Load())) + }) + }) + + t.Run("verify retries are applied per subgraph when mixed configurations", func(t *testing.T) { + t.Parallel() + + employeesCalls := atomic.Int32{} + test1Calls := atomic.Int32{} + + // Configure per-subgraph retry: employees gets 3 retries, test1 gets 1 retry + employeesMax := 3 + test1Max := 1 + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: employeesMax, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "test1": { + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: test1Max, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + employeesCalls.Add(1) + }) + }, + }, + Test1: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + test1Calls.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // 1) Call employees-only query; expect employees subgraph to be retried employeesMax times (attempts = retries + 1) + resEmp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + require.NoError(t, err) + require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"employees":null}}`, resEmp.Body) + require.Equal(t, employeesMax+1, int(employeesCalls.Load())) + require.Equal(t, 0, int(test1Calls.Load())) + + // 2) Call test1-only query; expect test subgraph to be retried test1Max times (attempts = retries + 1) + resTest1, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { floatField(arg: 1.5) }`, + }) + + require.NoError(t, err) + require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'test1', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"floatField":null}}`, resTest1.Body) + require.Equal(t, employeesMax+1, int(employeesCalls.Load())) + require.Equal(t, test1Max+1, int(test1Calls.Load())) + }) + }) + + t.Run("verify retries are applied when only all and no subgraph specific overrides are present", func(t *testing.T) { + t.Parallel() + + employeesCalls := atomic.Int32{} + test1Calls := atomic.Int32{} + + // Configure per-subgraph retry: employees gets 3 retries, test1 gets 1 retry + generalMax := 3 + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: generalMax, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + employeesCalls.Add(1) + }) + }, + }, + Test1: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + test1Calls.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + // 1) Call employees-only query; expect employees subgraph to be retried generalMax times (attempts = retries + 1) + resEmp, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query employees { employees { id } }`, + }) + require.NoError(t, err) + require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"employees":null}}`, resEmp.Body) + require.Equal(t, generalMax+1, int(employeesCalls.Load())) + require.Equal(t, 0, int(test1Calls.Load())) + + // 2) Call test1-only query; expect test subgraph to be retried test1Max times (attempts = retries + 1) + resTest1, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { floatField(arg: 1.5) }`, + }) + + require.NoError(t, err) + require.JSONEq(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'test1', Reason: empty response.","extensions":{"statusCode":502}}],"data":{"floatField":null}}`, resTest1.Body) + require.Equal(t, generalMax+1, int(employeesCalls.Load())) + require.Equal(t, generalMax+1, int(test1Calls.Load())) + }) + }) + + t.Run("verify retries are applied on feature flags when only all and no subgraph specific overrides are present", func(t *testing.T) { + t.Parallel() + + calls := atomic.Int32{} + + generalMax := 5 + opts := core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: generalMax, + MaxDuration: 2 * time.Second, + Interval: 10 * time.Millisecond, + Expression: "true", + }, + }, + }) + options := core.WithSubgraphRetryOptions(opts) + + testenv.Run(t, &testenv.Config{ + NoRetryClient: true, + RouterOptions: []core.Option{ + options, + }, + Subgraphs: testenv.SubgraphsConfig{ + ProductsFg: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusBadGateway) + calls.Add(1) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { employees { id products } }`, + Header: map[string][]string{ + "X-Feature-Flag": {"myff"}, + }, + }) + require.NoError(t, err) + require.Equal(t, "myff", res.Response.Header.Get("X-Feature-Flag")) + require.Contains(t, res.Body, `{"message":"Failed to fetch from Subgraph 'products_fg' at Path 'employees', Reason: empty response.","extensions":{"statusCode":502}}`) + require.Equal(t, generalMax+1, int(calls.Load())) + }) + }) + + t.Run("verify retry on upgrade requests", func(t *testing.T) { + t.Parallel() + + const failedTries int64 = 2 + + const timestampMessage = `{"type":"next","id":"1","payload":{"data":{"currentTime":{"unixTime":1,"timeStamp":"2021-09-01T12:00:00Z"}}}}` + const completeMessage = `{"type":"complete","id":"1"}` + + employeesCalls := atomic.Int64{} + + testenv.Run(t, &testenv.Config{ + ModifyEngineExecutionConfiguration: func(engineExecutionConfiguration *config.EngineExecutionConfiguration) { + engineExecutionConfiguration.WebSocketClientReadTimeout = time.Millisecond * 2000 + }, + + RouterOptions: []core.Option{ + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{ + Enabled: true, + MaxAttempts: 5, + MaxDuration: 10 * time.Second, + Interval: 200 * time.Millisecond, + Expression: "IsRetryableStatusCode() || IsConnectionError() || IsTimeout()", + }, + }, + })), + }, + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + Middleware: func(_ http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + employeesCalls.Add(1) + if employeesCalls.Load() <= failedTries { + w.WriteHeader(http.StatusBadGateway) + return + } + + upgrader := websocket.Upgrader{ + CheckOrigin: func(_ *http.Request) bool { + return true + }, + Subprotocols: []string{"graphql-transport-ws"}, + } + conn, err := upgrader.Upgrade(w, r, nil) + require.NoError(t, err) + defer func() { + _ = conn.Close() + }() + + // Read connection_init + _, _, err = testenv.WSReadMessage(t, conn) + require.NoError(t, err) + + err = testenv.WSWriteMessage(t, conn, websocket.TextMessage, []byte(`{"type":"connection_ack"}`)) + require.NoError(t, err) + + // Read subscribe message before sending data + _, _, err = testenv.WSReadMessage(t, conn) + require.NoError(t, err) + + err = testenv.WSWriteMessage(t, conn, websocket.TextMessage, []byte(timestampMessage)) + require.NoError(t, err) + + err = testenv.WSWriteMessage(t, conn, websocket.TextMessage, []byte(completeMessage)) + require.NoError(t, err) + }) + }, + }, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + conn := xEnv.InitGraphQLWebSocketConnection(nil, nil, nil) + err := testenv.WSWriteJSON(t, conn, &testenv.WebSocketMessage{ + ID: "1", + Type: "subscribe", + Payload: []byte(`{"query":"subscription { currentTime { unixTime timeStamp }}"}`), + }) + require.NoError(t, err) + + // Read a result and store its timestamp, next result should be 1 second later + _, messageBytes, err := testenv.WSReadMessage(t, conn) + require.NoError(t, err) + + require.Equal(t, failedTries+1, employeesCalls.Load()) + + // ==== + // Verify the messages are correct from here onwards + // ==== + require.JSONEq(t, timestampMessage, string(messageBytes)) + + // Sending a complete must stop the subscription + err = testenv.WSWriteJSON(t, conn, &testenv.WebSocketMessage{ID: "1", Type: "complete"}) + require.NoError(t, err) + + err = conn.SetReadDeadline(time.Now().Add(2 * time.Second)) + require.NoError(t, err) + + _, actualCompleteMessage, err := testenv.WSReadMessage(t, conn) + require.NoError(t, err) + require.JSONEq(t, completeMessage, string(actualCompleteMessage)) + + require.NoError(t, conn.Close()) + xEnv.WaitForSubscriptionCount(0, time.Second*5) + }) + }) +} diff --git a/router-tests/telemetry/telemetry_test.go b/router-tests/telemetry/telemetry_test.go index 0ac327e1cf..a7f877fe85 100644 --- a/router-tests/telemetry/telemetry_test.go +++ b/router-tests/telemetry/telemetry_test.go @@ -9512,7 +9512,11 @@ func TestFlakyTelemetry(t *testing.T) { }, }, RouterOptions: []core.Option{ - core.WithSubgraphRetryOptions(false, "", 0, 0, 0, "", nil), + core.WithSubgraphRetryOptions(core.NewSubgraphRetryOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + BackoffJitterRetry: config.BackoffJitterRetry{Enabled: false}, + }, + })), }, Subgraphs: testenv.SubgraphsConfig{ Products: testenv.SubgraphConfig{ diff --git a/router/core/engine_loader_hooks.go b/router/core/engine_loader_hooks.go index 1d284f06ed..f3f7f1eafb 100644 --- a/router/core/engine_loader_hooks.go +++ b/router/core/engine_loader_hooks.go @@ -99,8 +99,6 @@ func (f *engineLoaderHooks) OnLoad(ctx context.Context, ds resolve.DataSourceInf start := time.Now() - ctx = context.WithValue(ctx, rcontext.CurrentSubgraphContextKey{}, ds.Name) - duration := atomic.Int64{} ctx = context.WithValue(ctx, rcontext.FetchTimingKey, &duration) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index f478e53590..8b313c17aa 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -41,6 +41,7 @@ import ( rmiddleware "github.com/wundergraph/cosmo/router/internal/middleware" "github.com/wundergraph/cosmo/router/internal/recoveryhandler" "github.com/wundergraph/cosmo/router/internal/requestlogger" + "github.com/wundergraph/cosmo/router/internal/retrytransport" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/cors" "github.com/wundergraph/cosmo/router/pkg/execution_config" @@ -104,6 +105,7 @@ type ( traceDialer *TraceDialer connector *grpcconnector.Connector circuitBreakerManager *circuit.Manager + retryManager *retrytransport.Manager headerPropagation *HeaderPropagation } ) @@ -294,6 +296,21 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC s.circuitBreakerManager = manager } + if s.retryOptions.IsEnabled() { + retryExprManager := expr.NewRetryExpressionManager() + retryManager := retrytransport.NewManager(retryExprManager, BuildRetryFunction(retryExprManager), s.retryOptions.OnRetryFunc, s.logger) + + if err := retryManager.Initialize( + s.retryOptions.All, + s.retryOptions.SubgraphMap, + routerConfig, + ); err != nil { + return nil, err + } + + s.retryManager = retryManager + } + routingUrlGroupings, err := getRoutingUrlGroupingForCircuitBreakers(routerConfig, s.overrideRoutingURLConfiguration, s.overrides) if err != nil { return nil, err @@ -1290,12 +1307,6 @@ func (s *graphServer) buildGraphMux( baseConnMetricStore = s.connectionMetrics } - // Build retry options and handle any expression compilation errors - processedRetryOptions, err := ProcessRetryOptions(s.retryOptions) - if err != nil { - return nil, fmt.Errorf("failed to process retry options: %w", err) - } - ecb := &ExecutorConfigurationBuilder{ introspection: s.introspection, baseURL: s.baseURL, @@ -1316,13 +1327,13 @@ func (s *graphServer) buildGraphMux( PostHandlers: s.postOriginHandlers, MetricStore: gm.metricStore, ConnectionMetricStore: baseConnMetricStore, - RetryOptions: *processedRetryOptions, TracerProvider: s.tracerProvider, TracePropagators: s.compositePropagator, LocalhostFallbackInsideDocker: s.localhostFallbackInsideDocker, Logger: s.logger, EnableTraceClient: enableTraceClient, CircuitBreaker: s.circuitBreakerManager, + RetryManager: s.retryManager, }, subscriptionHooks: s.subscriptionHooks, } diff --git a/router/core/retry_builder.go b/router/core/retry_builder.go index 30fffcb53f..00ce7443ca 100644 --- a/router/core/retry_builder.go +++ b/router/core/retry_builder.go @@ -1,7 +1,6 @@ package core import ( - "fmt" "net/http" "strings" @@ -10,74 +9,10 @@ import ( "go.uber.org/zap" ) -const ( - defaultRetryExpression = "IsRetryableStatusCode() || IsConnectionError() || IsTimeout()" - - backoffJitter = "backoff_jitter" -) - -var noopRetryFunc = func(err error, req *http.Request, resp *http.Response) bool { - return false -} - -func ProcessRetryOptions(retryOpts retrytransport.RetryOptions) (*retrytransport.RetryOptions, error) { - // Default to backOffJitter if no algorithm is specified - // This will occur either in tests or if the user explicitly makes it an empty string - if retryOpts.Algorithm == "" { - retryOpts.Algorithm = backoffJitter - } - - // We skip validating the algorithm if retries are disabled - if retryOpts.Enabled && retryOpts.Algorithm != backoffJitter { - return nil, fmt.Errorf("unsupported retry algorithm: %s", retryOpts.Algorithm) - } - - shouldRetryFunc, err := buildRetryFunction(retryOpts) - if err != nil { - return nil, fmt.Errorf("failed to build retry function: %w", err) - } - - // Create copy to not mutate the original reference - retryOptions := retrytransport.RetryOptions{ - Enabled: retryOpts.Enabled, - Algorithm: retryOpts.Algorithm, - MaxRetryCount: retryOpts.MaxRetryCount, - MaxDuration: retryOpts.MaxDuration, - Interval: retryOpts.Interval, - Expression: retryOpts.Expression, - - OnRetry: retryOpts.OnRetry, - - ShouldRetry: shouldRetryFunc, - } - - return &retryOptions, nil -} - // BuildRetryFunction creates a ShouldRetry function based on the provided expression -func buildRetryFunction(retryOpts retrytransport.RetryOptions) (retrytransport.ShouldRetryFunc, error) { - // We do not need to build a retry function if retries are disabled - // This means that any bad expressions are ignored if retries are disabled - if !retryOpts.Enabled { - return noopRetryFunc, nil - } - - // Use default expression if empty string is passed - expression := retryOpts.Expression - if expression == "" { - expression = defaultRetryExpression - } - - // Create the retry expression manager - manager, err := expr.NewRetryExpressionManager(expression) - if err != nil { - return nil, fmt.Errorf("failed to create expression manager: %w", err) - } - - // Return expression-based retry function - return func(err error, req *http.Request, resp *http.Response) bool { +func BuildRetryFunction(manager *expr.RetryExpressionManager) retrytransport.ShouldRetryFunc { + return func(err error, req *http.Request, resp *http.Response, expression string) bool { reqContext := getRequestContext(req.Context()) - if reqContext == nil { return false } @@ -95,7 +30,7 @@ func buildRetryFunction(retryOpts retrytransport.RetryOptions) (retrytransport.S ctx := expr.LoadRetryContext(err, resp) // Evaluate the expression - shouldRetry, evalErr := manager.ShouldRetry(ctx) + shouldRetry, evalErr := manager.ShouldRetry(ctx, expression) if evalErr != nil { reqContext.logger.Error("Failed to evaluate retry expression", zap.Error(evalErr), @@ -107,7 +42,7 @@ func buildRetryFunction(retryOpts retrytransport.RetryOptions) (retrytransport.S } return shouldRetry - }, nil + } } // isDefaultRetryableError checks for errors that should always be retryable diff --git a/router/core/retry_builder_test.go b/router/core/retry_builder_test.go index d4fa1cf3f6..ecf5b4662f 100644 --- a/router/core/retry_builder_test.go +++ b/router/core/retry_builder_test.go @@ -2,15 +2,13 @@ package core import ( "errors" - "fmt" - "github.com/wundergraph/cosmo/router/internal/retrytransport" "io" "net/http" - "reflect" "syscall" "testing" "github.com/stretchr/testify/assert" + "github.com/wundergraph/cosmo/router/internal/expr" "go.uber.org/zap" ) @@ -52,24 +50,18 @@ func createRequestWithContext(opType string) (*http.Request, *requestContext) { func TestBuildRetryFunction(t *testing.T) { t.Run("build function when retry is disabled", func(t *testing.T) { - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: false, - Expression: "invalid expression ++++++", - }) - assert.NoError(t, err) - assert.Equal(t, - reflect.ValueOf(noopRetryFunc).Pointer(), - reflect.ValueOf(fn).Pointer(), - ) + manager := expr.NewRetryExpressionManager() + fn := BuildRetryFunction(manager) + assert.NotNil(t, fn) }) t.Run("default expression behavior", func(t *testing.T) { // Use the default expression that would be in the config - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: defaultRetryExpression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression("") assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with proper query context @@ -77,26 +69,26 @@ func TestBuildRetryFunction(t *testing.T) { // Test default behavior - should retry on 500 resp := &http.Response{StatusCode: 500} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, "")) // Should not retry on 200 resp.StatusCode = 200 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, "")) // Test with errors - only expression-defined errors are handled here - assert.True(t, fn(syscall.ETIMEDOUT, req, nil)) - assert.True(t, fn(errors.New("connection refused"), req, nil)) - assert.True(t, fn(errors.New("unexpected EOF"), req, nil)) // EOF is now handled at transport layer, not expression - assert.False(t, fn(errors.New("some other error"), req, nil)) + assert.True(t, fn(syscall.ETIMEDOUT, req, nil, "")) + assert.True(t, fn(errors.New("connection refused"), req, nil, "")) + assert.True(t, fn(errors.New("unexpected EOF"), req, nil, "")) // EOF is now handled at transport layer, not expression + assert.False(t, fn(errors.New("some other error"), req, nil, "")) }) t.Run("expression-based retry", func(t *testing.T) { expression := "statusCode == 500 || statusCode == 503" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with proper query context @@ -104,24 +96,24 @@ func TestBuildRetryFunction(t *testing.T) { // Should retry on 500 resp := &http.Response{StatusCode: 500} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Should retry on 503 resp.StatusCode = 503 - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Should not retry on 502 resp.StatusCode = 502 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) }) t.Run("expression with error conditions", func(t *testing.T) { expression := "IsTimeout() || statusCode == 503" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with proper query context @@ -129,34 +121,31 @@ func TestBuildRetryFunction(t *testing.T) { // Should retry on timeout error err = syscall.ETIMEDOUT - assert.True(t, fn(err, req, nil)) + assert.True(t, fn(err, req, nil, expression)) // Should retry on 503 resp := &http.Response{StatusCode: 503} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Should not retry on other errors err = errors.New("some other error") - assert.False(t, fn(err, req, nil)) + assert.False(t, fn(err, req, nil, expression)) }) t.Run("invalid expression returns error", func(t *testing.T) { expression := "invalid syntax +++" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.Error(t, err) - assert.Nil(t, fn) assert.Contains(t, err.Error(), "failed to compile retry expression") }) t.Run("empty expression uses default", func(t *testing.T) { - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: "", - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression("") assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with proper query context @@ -164,43 +153,43 @@ func TestBuildRetryFunction(t *testing.T) { // Test with retryable status code resp := &http.Response{StatusCode: 502} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, "")) // Test with connection error err = errors.New("connection refused") - assert.True(t, fn(err, req, nil)) + assert.True(t, fn(err, req, nil, "")) // Test with timeout error err = syscall.ETIMEDOUT - assert.True(t, fn(err, req, nil)) + assert.True(t, fn(err, req, nil, "")) // Test with non-retryable error err = errors.New("some other error") - assert.False(t, fn(err, req, nil)) + assert.False(t, fn(err, req, nil, "")) }) t.Run("expression that always returns false but the error is an eof error", func(t *testing.T) { expression := "false" // Don't retry - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with proper query context req, _ := createRequestWithContext(OperationTypeQuery) - assert.True(t, fn(io.ErrUnexpectedEOF, req, nil)) + assert.True(t, fn(io.ErrUnexpectedEOF, req, nil, expression)) }) t.Run("expression that always returns true", func(t *testing.T) { expression := "true" // Always retry - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with proper query context @@ -208,20 +197,20 @@ func TestBuildRetryFunction(t *testing.T) { resp := &http.Response{StatusCode: 500} // Should retry when expression is true - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Even for status codes that wouldn't normally retry resp.StatusCode = 200 - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) }) t.Run("complex expression", func(t *testing.T) { expression := "(statusCode >= 500 && statusCode < 600) || IsConnectionError()" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with proper query context @@ -229,26 +218,26 @@ func TestBuildRetryFunction(t *testing.T) { // Test 5xx errors resp := &http.Response{StatusCode: 503} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Test connection error err = errors.New("connection refused") - assert.True(t, fn(err, req, nil)) + assert.True(t, fn(err, req, nil, expression)) // Test non-matching conditions resp.StatusCode = 404 err = errors.New("some other error") - assert.False(t, fn(err, req, resp)) + assert.False(t, fn(err, req, resp, expression)) }) t.Run("mutation never retries with proper context", func(t *testing.T) { // Use expression that would normally retry on 500 errors expression := "statusCode >= 500 || IsTimeout() || IsConnectionError()" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with mutation context @@ -256,30 +245,28 @@ func TestBuildRetryFunction(t *testing.T) { // Test with 500 status - should NOT retry because it's a mutation resp := &http.Response{StatusCode: 500} - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) // Test with timeout error - should NOT retry because it's a mutation - assert.False(t, fn(syscall.ETIMEDOUT, req, nil)) + assert.False(t, fn(syscall.ETIMEDOUT, req, nil, expression)) // Test with connection error - should NOT retry because it's a mutation - assert.False(t, fn(errors.New("connection refused"), req, nil)) + assert.False(t, fn(errors.New("connection refused"), req, nil, expression)) // Test with expression that always returns true - should still NOT retry - alwaysRetryFn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: "true", - }) + alwaysRetryExpression := "true" + err = manager.AddExpression(alwaysRetryExpression) assert.NoError(t, err) - assert.False(t, alwaysRetryFn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, alwaysRetryExpression)) }) t.Run("query retries with proper context", func(t *testing.T) { expression := "statusCode >= 500 || IsTimeout()" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with query context @@ -287,23 +274,23 @@ func TestBuildRetryFunction(t *testing.T) { // Test with 500 status - should retry because it's a query resp := &http.Response{StatusCode: 500} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Test with timeout error - should retry because it's a query - assert.True(t, fn(syscall.ETIMEDOUT, req, nil)) + assert.True(t, fn(syscall.ETIMEDOUT, req, nil, expression)) // Test with 200 status - should not retry even for query resp.StatusCode = 200 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) }) t.Run("subscription retries with proper context", func(t *testing.T) { expression := "statusCode >= 500" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with subscription context @@ -311,21 +298,21 @@ func TestBuildRetryFunction(t *testing.T) { // Test with 500 status - should retry because it's a subscription (not mutation) resp := &http.Response{StatusCode: 500} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Test with 200 status - should not retry resp.StatusCode = 200 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) }) t.Run("error logging with proper context", func(t *testing.T) { // Test that error logging works with proper request context expression := "statusCode >= 500" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create request with proper context @@ -333,19 +320,19 @@ func TestBuildRetryFunction(t *testing.T) { // Test that it works normally with proper context resp := &http.Response{StatusCode: 500} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) resp.StatusCode = 200 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) }) t.Run("request context with query operation", func(t *testing.T) { expression := "statusCode >= 500" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create request with proper query context @@ -353,20 +340,20 @@ func TestBuildRetryFunction(t *testing.T) { // Should work with proper request context - expression should be evaluated normally resp := &http.Response{StatusCode: 500} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) resp.StatusCode = 200 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) }) t.Run("complex expression with mutation context", func(t *testing.T) { // Complex expression that would normally retry in many cases expression := "(statusCode >= 500 && statusCode < 600) || IsConnectionError() || IsTimeout() || statusCode == 429" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Create a request with mutation context @@ -374,26 +361,26 @@ func TestBuildRetryFunction(t *testing.T) { // Test various conditions that would normally trigger retry resp := &http.Response{StatusCode: 500} - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) resp.StatusCode = 503 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) resp.StatusCode = 429 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) - assert.False(t, fn(syscall.ETIMEDOUT, req, nil)) - assert.False(t, fn(errors.New("connection refused"), req, nil)) + assert.False(t, fn(syscall.ETIMEDOUT, req, nil, expression)) + assert.False(t, fn(errors.New("connection refused"), req, nil, expression)) }) t.Run("new operation with comprehensive retry conditions", func(t *testing.T) { // Create a new comprehensive operation to test all retry scenarios expression := "statusCode >= 500 || statusCode == 429 || IsTimeout() || IsConnectionError()" - fn, err := buildRetryFunction(retrytransport.RetryOptions{ - Enabled: true, - Expression: expression, - }) + manager := expr.NewRetryExpressionManager() + err := manager.AddExpression(expression) assert.NoError(t, err) + + fn := BuildRetryFunction(manager) assert.NotNil(t, fn) // Test query operation - should retry on all conditions @@ -401,81 +388,38 @@ func TestBuildRetryFunction(t *testing.T) { // Test 5xx errors resp := &http.Response{StatusCode: 500} - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) resp.StatusCode = 503 - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Test rate limiting resp.StatusCode = 429 - assert.True(t, fn(nil, req, resp)) + assert.True(t, fn(nil, req, resp, expression)) // Test timeouts - assert.True(t, fn(syscall.ETIMEDOUT, req, nil)) + assert.True(t, fn(syscall.ETIMEDOUT, req, nil, expression)) // Test connection errors - assert.True(t, fn(errors.New("connection refused"), req, nil)) + assert.True(t, fn(errors.New("connection refused"), req, nil, expression)) // Test success - should not retry resp.StatusCode = 200 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) // Test client errors - should not retry resp.StatusCode = 404 - assert.False(t, fn(nil, req, resp)) + assert.False(t, fn(nil, req, resp, expression)) // Now test the same conditions with a mutation - should never retry mutationReq, _ := createRequestWithContext(OperationTypeMutation) resp.StatusCode = 500 - assert.False(t, fn(nil, mutationReq, resp)) + assert.False(t, fn(nil, mutationReq, resp, expression)) resp.StatusCode = 503 - assert.False(t, fn(nil, mutationReq, resp)) + assert.False(t, fn(nil, mutationReq, resp, expression)) resp.StatusCode = 429 - assert.False(t, fn(nil, mutationReq, resp)) - assert.False(t, fn(syscall.ETIMEDOUT, mutationReq, nil)) - assert.False(t, fn(errors.New("connection refused"), mutationReq, nil)) - }) -} - -func TestProcessRetryOptions(t *testing.T) { - t.Run("process invalid algorithm", func(t *testing.T) { - algorithm := "abcdee" - _, err := ProcessRetryOptions(retrytransport.RetryOptions{ - Enabled: true, - Algorithm: algorithm, - }) - - expectedError := fmt.Sprintf("unsupported retry algorithm: %s", algorithm) - assert.ErrorContains(t, err, expectedError) - }) - - t.Run("process invalid algorithm when retries are disabled", func(t *testing.T) { - algorithm := "abcdee" - _, err := ProcessRetryOptions(retrytransport.RetryOptions{ - Enabled: false, - Algorithm: algorithm, - }) - assert.NoError(t, err) - }) - - t.Run("process invalid expression", func(t *testing.T) { - _, err := ProcessRetryOptions(retrytransport.RetryOptions{ - Enabled: true, - Algorithm: "backoff_jitter", - Expression: "invalid syntax +++", - }) - - assert.ErrorContains(t, err, "failed to build retry function") - }) - - t.Run("process valid options", func(t *testing.T) { - options := retrytransport.RetryOptions{ - Enabled: true, - Algorithm: "backoff_jitter", - Expression: "statusCode == 500 || IsTimeout() || IsConnectionError()", - } - response, err := ProcessRetryOptions(options) - assert.NoError(t, err) - assert.NotSame(t, &options, response) + assert.False(t, fn(nil, mutationReq, resp, expression)) + assert.False(t, fn(syscall.ETIMEDOUT, mutationReq, nil, expression)) + assert.False(t, fn(errors.New("connection refused"), mutationReq, nil, expression)) }) } diff --git a/router/core/router.go b/router/core/router.go index bb8f8be8eb..c82d3c8e45 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -191,6 +191,27 @@ func (r *SubgraphCircuitBreakerOptions) IsEnabled() bool { return r.CircuitBreaker.Enabled || len(r.SubgraphMap) > 0 } +type SubgraphRetryOptions struct { + All retrytransport.RetryOptions + SubgraphMap map[string]retrytransport.RetryOptions + OnRetryFunc retrytransport.OnRetryFunc +} + +func (r *SubgraphRetryOptions) IsEnabled() bool { + if r == nil { + return false + } + if r.All.Enabled { + return true + } + for _, cfg := range r.SubgraphMap { + if cfg.Enabled { + return true + } + } + return false +} + // NewRouter creates a new Router instance. Router.Start() must be called to start the server. // Alternatively, use Router.NewServer() to create a new server instance without starting it. func NewRouter(opts ...Option) (*Router, error) { @@ -2030,26 +2051,9 @@ func WithSubgraphCircuitBreakerOptions(opts *SubgraphCircuitBreakerOptions) Opti } } -func WithSubgraphRetryOptions( - enabled bool, - algorithm string, - maxRetryCount int, - retryMaxDuration, retryInterval time.Duration, - expression string, - onRetryFunc retrytransport.OnRetryFunc, -) Option { +func WithSubgraphRetryOptions(opts *SubgraphRetryOptions) Option { return func(r *Router) { - r.retryOptions = retrytransport.RetryOptions{ - Enabled: enabled, - Algorithm: algorithm, - MaxRetryCount: maxRetryCount, - MaxDuration: retryMaxDuration, - Interval: retryInterval, - Expression: expression, - - // Test case overrides - OnRetry: onRetryFunc, - } + r.retryOptions = opts } } @@ -2179,6 +2183,37 @@ func NewSubgraphCircuitBreakerOptions(cfg config.TrafficShapingRules) *SubgraphC return entry } +func NewSubgraphRetryOptions(cfg config.TrafficShapingRules) *SubgraphRetryOptions { + entry := &SubgraphRetryOptions{ + SubgraphMap: map[string]retrytransport.RetryOptions{}, + } + // If we have a global default + if cfg.All.BackoffJitterRetry.Enabled { + entry.All = newRetryConfig(cfg.All.BackoffJitterRetry) + } + // Subgraph specific retry configs + for k, v := range cfg.Subgraphs { + entry.SubgraphMap[k] = newRetryConfig(v.BackoffJitterRetry) + } + + return entry +} + +func newRetryConfig(config config.BackoffJitterRetry) retrytransport.RetryOptions { + algorithm := config.Algorithm + if algorithm == "" { + algorithm = retrytransport.BackoffJitter + } + return retrytransport.RetryOptions{ + Enabled: config.Enabled, + Algorithm: algorithm, + MaxRetryCount: config.MaxAttempts, + MaxDuration: config.MaxDuration, + Interval: config.Interval, + Expression: config.Expression, + } +} + func newCircuitBreakerConfig(cb config.CircuitBreaker) circuit.CircuitBreakerConfig { return circuit.CircuitBreakerConfig{ Enabled: cb.Enabled, diff --git a/router/core/router_config.go b/router/core/router_config.go index 9f4b0bf84c..e8e9a9da84 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -10,7 +10,6 @@ import ( "github.com/wundergraph/cosmo/router/internal/persistedoperation" "github.com/wundergraph/cosmo/router/internal/persistedoperation/pqlmanifest" rd "github.com/wundergraph/cosmo/router/internal/rediscloser" - "github.com/wundergraph/cosmo/router/internal/retrytransport" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/connectrpc" "github.com/wundergraph/cosmo/router/pkg/controlplane/configpoller" @@ -105,12 +104,12 @@ type Config struct { headerRules *config.HeaderRules subgraphTransportOptions *SubgraphTransportOptions subgraphCircuitBreakerOptions *SubgraphCircuitBreakerOptions + retryOptions *SubgraphRetryOptions graphqlMetricsConfig *GraphQLMetricsConfig routerTrafficConfig *config.RouterTrafficConfiguration batchingConfig *BatchingConfig fileUploadConfig *config.FileUpload accessController *AccessController - retryOptions retrytransport.RetryOptions redisClient rd.RDCloser mcpServer *mcpserver.GraphQLSchemaServer connectRPCServer *connectrpc.Server @@ -254,7 +253,7 @@ func (c *Config) Usage() map[string]any { usage["file_upload_max_files"] = c.fileUploadConfig.MaxFiles } usage["access_controller"] = c.accessController != nil - usage["retry_options"] = c.retryOptions.Enabled + usage["retry_options"] = c.retryOptions.IsEnabled() usage["development_mode"] = c.developmentMode usage["access_logs"] = c.accessLogsConfig != nil usage["localhost_fallback_inside_docker"] = c.localhostFallbackInsideDocker diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 2f9f6fcbfb..3931f69918 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -225,15 +225,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config, reloadPersi WithFileUploadConfig(&config.FileUpload), WithSubgraphTransportOptions(NewSubgraphTransportOptions(config.TrafficShaping)), WithSubgraphCircuitBreakerOptions(NewSubgraphCircuitBreakerOptions(config.TrafficShaping)), - WithSubgraphRetryOptions( - config.TrafficShaping.All.BackoffJitterRetry.Enabled, - config.TrafficShaping.All.BackoffJitterRetry.Algorithm, - config.TrafficShaping.All.BackoffJitterRetry.MaxAttempts, - config.TrafficShaping.All.BackoffJitterRetry.MaxDuration, - config.TrafficShaping.All.BackoffJitterRetry.Interval, - config.TrafficShaping.All.BackoffJitterRetry.Expression, - nil, - ), + WithSubgraphRetryOptions(NewSubgraphRetryOptions(config.TrafficShaping)), WithCors(&cors.Config{ Enabled: config.CORS.Enabled, AllowOrigins: config.CORS.AllowOrigins, diff --git a/router/core/transport.go b/router/core/transport.go index c162fc54a1..8e4031a927 100644 --- a/router/core/transport.go +++ b/router/core/transport.go @@ -46,10 +46,10 @@ type CustomTransport struct { func NewCustomTransport( baseRoundTripper http.RoundTripper, - retryOptions retrytransport.RetryOptions, metricStore metric.Store, connectionMetricStore metric.ConnectionMetricStore, breaker *circuit.Manager, + retryManager *retrytransport.Manager, enableTraceClient bool, ) *CustomTransport { ct := &CustomTransport{ @@ -63,31 +63,41 @@ func NewCustomTransport( // As a workaround we pass in a function that can be used to get the logger from within the round tripper getRequestContextLogger := func(req *http.Request) *zap.Logger { reqContext := getRequestContext(req.Context()) + if reqContext == nil { + return zap.NewNop() + } return reqContext.Logger() } + getActiveSubgraphName := func(req *http.Request) string { + reqContext := getRequestContext(req.Context()) + if reqContext == nil { + return "" + } + subgraph := reqContext.ActiveSubgraph(req) + if subgraph != nil { + return subgraph.Name + } + return "" + } + if enableTraceClient { - getValuesFromRequest := func(ctx context.Context, req *http.Request) (*expr.Context, string) { + getExprContext := func(ctx context.Context) *expr.Context { reqContext := getRequestContext(ctx) if reqContext == nil { - return &expr.Context{}, "" - } - - var activeSubgraphName string - if activeSubgraph := reqContext.ActiveSubgraph(req); activeSubgraph != nil { - activeSubgraphName = activeSubgraph.Name + return &expr.Context{} } - return &reqContext.expressionContext, activeSubgraphName + return &reqContext.expressionContext } - baseRoundTripper = traceclient.NewTraceInjectingRoundTripper(baseRoundTripper, connectionMetricStore, getValuesFromRequest) + baseRoundTripper = traceclient.NewTraceInjectingRoundTripper(baseRoundTripper, connectionMetricStore, getExprContext, getActiveSubgraphName) } if breaker.HasCircuits() { - baseRoundTripper = circuit.NewCircuitTripper(baseRoundTripper, breaker, getRequestContextLogger) + baseRoundTripper = circuit.NewCircuitTripper(baseRoundTripper, breaker, getRequestContextLogger, getActiveSubgraphName) } - if retryOptions.Enabled { - ct.roundTripper = retrytransport.NewRetryHTTPTransport(baseRoundTripper, retryOptions, getRequestContextLogger) + if retryManager.IsEnabled() { + ct.roundTripper = retrytransport.NewRetryHTTPTransport(baseRoundTripper, getRequestContextLogger, retryManager, getActiveSubgraphName) } else { ct.roundTripper = baseRoundTripper } @@ -186,11 +196,11 @@ func (ct *CustomTransport) RoundTrip(req *http.Request) (resp *http.Response, er type TransportFactory struct { preHandlers []TransportPreHandler postHandlers []TransportPostHandler - retryOptions retrytransport.RetryOptions localhostFallbackInsideDocker bool metricStore metric.Store connectionMetricStore metric.ConnectionMetricStore circuitBreaker *circuit.Manager + retryManager *retrytransport.Manager logger *zap.Logger tracerProvider *sdktrace.TracerProvider tracePropagators propagation.TextMapPropagator @@ -203,7 +213,6 @@ type TransportOptions struct { PreHandlers []TransportPreHandler PostHandlers []TransportPostHandler SubgraphTransportOptions *SubgraphTransportOptions - RetryOptions retrytransport.RetryOptions LocalhostFallbackInsideDocker bool MetricStore metric.Store ConnectionMetricStore metric.ConnectionMetricStore @@ -212,6 +221,7 @@ type TransportOptions struct { TracerProvider *sdktrace.TracerProvider TracePropagators propagation.TextMapPropagator EnableTraceClient bool + RetryManager *retrytransport.Manager } type SubscriptionClientOptions struct { @@ -225,7 +235,6 @@ func NewTransport(opts *TransportOptions) *TransportFactory { return &TransportFactory{ preHandlers: opts.PreHandlers, postHandlers: opts.PostHandlers, - retryOptions: opts.RetryOptions, localhostFallbackInsideDocker: opts.LocalhostFallbackInsideDocker, metricStore: opts.MetricStore, connectionMetricStore: opts.ConnectionMetricStore, @@ -233,6 +242,7 @@ func NewTransport(opts *TransportOptions) *TransportFactory { tracerProvider: opts.TracerProvider, tracePropagators: opts.TracePropagators, circuitBreaker: opts.CircuitBreaker, + retryManager: opts.RetryManager, enableTraceClient: opts.EnableTraceClient, } } @@ -274,10 +284,10 @@ func (t TransportFactory) RoundTripper(baseTransport http.RoundTripper) http.Rou ) tp := NewCustomTransport( traceTransport, - t.retryOptions, t.metricStore, t.connectionMetricStore, t.circuitBreaker, + t.retryManager, t.enableTraceClient, ) diff --git a/router/internal/circuit/breaker.go b/router/internal/circuit/breaker.go index bbabdb42f2..382f355032 100644 --- a/router/internal/circuit/breaker.go +++ b/router/internal/circuit/breaker.go @@ -4,34 +4,27 @@ import ( "context" "net/http" - rcontext "github.com/wundergraph/cosmo/router/internal/context" "go.uber.org/zap" ) type Breaker struct { - roundTripper http.RoundTripper - loggerFunc func(req *http.Request) *zap.Logger - circuitBreaker *Manager + roundTripper http.RoundTripper + loggerFunc func(req *http.Request) *zap.Logger + circuitBreaker *Manager + getActiveSubgraphName func(req *http.Request) string } -func NewCircuitTripper(roundTripper http.RoundTripper, breaker *Manager, logger func(req *http.Request) *zap.Logger) *Breaker { +func NewCircuitTripper(roundTripper http.RoundTripper, breaker *Manager, logger func(req *http.Request) *zap.Logger, getActiveSubgraphName func(req *http.Request) string) *Breaker { return &Breaker{ - circuitBreaker: breaker, - loggerFunc: logger, - roundTripper: roundTripper, + circuitBreaker: breaker, + loggerFunc: logger, + roundTripper: roundTripper, + getActiveSubgraphName: getActiveSubgraphName, } } func (rt *Breaker) RoundTrip(req *http.Request) (resp *http.Response, err error) { - ctx := req.Context() - - var subgraph string - subgraphCtxVal := ctx.Value(rcontext.CurrentSubgraphContextKey{}) - if subgraphCtxVal != nil { - if sg, ok := subgraphCtxVal.(string); ok { - subgraph = sg - } - } + subgraph := rt.getActiveSubgraphName(req) // If there is no circuit defined for this subgraph circuit := rt.circuitBreaker.GetCircuitBreaker(subgraph) diff --git a/router/internal/context/keys.go b/router/internal/context/keys.go index d1706f4ab4..3524280e02 100644 --- a/router/internal/context/keys.go +++ b/router/internal/context/keys.go @@ -4,8 +4,6 @@ // instead of being moved into core package context -type CurrentSubgraphContextKey struct{} - type ContextKey int const ( diff --git a/router/internal/expr/retry_expression.go b/router/internal/expr/retry_expression.go index 4414d29181..c925bb87dd 100644 --- a/router/internal/expr/retry_expression.go +++ b/router/internal/expr/retry_expression.go @@ -10,13 +10,27 @@ import ( // RetryExpressionManager handles compilation and evaluation of retry expressions type RetryExpressionManager struct { - program *vm.Program + expressionMap map[string]*vm.Program } -// NewRetryExpressionManager creates a new RetryExpressionManager with the given expression -func NewRetryExpressionManager(expression string) (*RetryExpressionManager, error) { +const defaultRetryExpression = "IsRetryableStatusCode() || IsConnectionError() || IsTimeout()" + +// NewRetryExpressionManager creates a new RetryExpressionManager +func NewRetryExpressionManager() *RetryExpressionManager { + return &RetryExpressionManager{ + expressionMap: make(map[string]*vm.Program), + } +} + +func (m *RetryExpressionManager) AddExpression(exprString string) error { + expression := exprString if expression == "" { - return nil, nil + expression = defaultRetryExpression + } + + // The expression has already been processed, skip recompilation + if _, ok := m.expressionMap[expression]; ok { + return nil } // Compile the expression with retry context @@ -27,22 +41,32 @@ func NewRetryExpressionManager(expression string) (*RetryExpressionManager, erro program, err := expr.Compile(expression, options...) if err != nil { - return nil, fmt.Errorf("failed to compile retry expression: %w", handleExpressionError(err)) + return fmt.Errorf("failed to compile retry expression: %w", handleExpressionError(err)) } - return &RetryExpressionManager{ - program: program, - }, nil + // Use the normalized expression string as the key for deduplication + m.expressionMap[expression] = program + return nil } // ShouldRetry evaluates the retry expression with the given context -func (m *RetryExpressionManager) ShouldRetry(ctx RetryContext) (bool, error) { - if m == nil || m.program == nil { - // Use default behavior if no expression is configured +func (m *RetryExpressionManager) ShouldRetry(ctx RetryContext, expressionString string) (bool, error) { + if m == nil { + return false, nil + } + + expression := expressionString + if expression == "" { + expression = defaultRetryExpression + } + + program, ok := m.expressionMap[expression] + if !ok { + // If the expression wasn't pre-compiled, do not retry by default return false, nil } - result, err := expr.Run(m.program, ctx) + result, err := expr.Run(program, ctx) if err != nil { return false, fmt.Errorf("failed to evaluate retry expression: %w", handleExpressionError(err)) } diff --git a/router/internal/expr/retry_expression_test.go b/router/internal/expr/retry_expression_test.go index 04a8c8c268..d57a9d37c5 100644 --- a/router/internal/expr/retry_expression_test.go +++ b/router/internal/expr/retry_expression_test.go @@ -122,7 +122,8 @@ func TestRetryExpressionManager(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - manager, err := NewRetryExpressionManager(tt.expression) + manager := NewRetryExpressionManager() + err := manager.AddExpression(tt.expression) if tt.expectErr { assert.Error(t, err) return @@ -130,7 +131,7 @@ func TestRetryExpressionManager(t *testing.T) { require.NoError(t, err) require.NotNil(t, manager) - result, err := manager.ShouldRetry(tt.ctx) + result, err := manager.ShouldRetry(tt.ctx, tt.expression) assert.NoError(t, err) assert.Equal(t, tt.expected, result) }) @@ -138,9 +139,10 @@ func TestRetryExpressionManager(t *testing.T) { } func TestRetryExpressionManager_EmptyExpression(t *testing.T) { - manager, err := NewRetryExpressionManager("") + manager := NewRetryExpressionManager() + err := manager.AddExpression("") assert.NoError(t, err) - assert.Nil(t, manager) + assert.NotNil(t, manager) } func TestLoadRetryContext(t *testing.T) { @@ -493,56 +495,56 @@ func (e *mockTimeoutError) Temporary() bool { func TestRetryExpressionManager_WithSyscallErrors(t *testing.T) { t.Run("expression with specific syscall error helpers", func(t *testing.T) { expression := "IsConnectionRefused() || IsConnectionReset() || IsTimeout()" - manager, err := NewRetryExpressionManager(expression) - require.NoError(t, err) + manager := NewRetryExpressionManager() + require.NoError(t, manager.AddExpression(expression)) require.NotNil(t, manager) // Test ECONNREFUSED ctx := LoadRetryContext(syscall.ECONNREFUSED, nil) - result, err := manager.ShouldRetry(ctx) + result, err := manager.ShouldRetry(ctx, expression) assert.NoError(t, err) assert.True(t, result) // Test ECONNRESET ctx = LoadRetryContext(syscall.ECONNRESET, nil) - result, err = manager.ShouldRetry(ctx) + result, err = manager.ShouldRetry(ctx, expression) assert.NoError(t, err) assert.True(t, result) // Test ETIMEDOUT ctx = LoadRetryContext(syscall.ETIMEDOUT, nil) - result, err = manager.ShouldRetry(ctx) + result, err = manager.ShouldRetry(ctx, expression) assert.NoError(t, err) assert.True(t, result) // Test unrelated error ctx = LoadRetryContext(errors.New("some other error"), nil) - result, err = manager.ShouldRetry(ctx) + result, err = manager.ShouldRetry(ctx, expression) assert.NoError(t, err) assert.False(t, result) }) t.Run("expression combining status and syscall errors", func(t *testing.T) { expression := "statusCode == 500 || IsConnectionRefused()" - manager, err := NewRetryExpressionManager(expression) - require.NoError(t, err) + manager := NewRetryExpressionManager() + require.NoError(t, manager.AddExpression(expression)) require.NotNil(t, manager) // Test with status code ctx := LoadRetryContext(nil, &http.Response{StatusCode: 500}) - result, err := manager.ShouldRetry(ctx) + result, err := manager.ShouldRetry(ctx, expression) assert.NoError(t, err) assert.True(t, result) // Test with syscall error ctx = LoadRetryContext(syscall.ECONNREFUSED, nil) - result, err = manager.ShouldRetry(ctx) + result, err = manager.ShouldRetry(ctx, expression) assert.NoError(t, err) assert.True(t, result) // Test with neither condition ctx = LoadRetryContext(nil, &http.Response{StatusCode: 200}) - result, err = manager.ShouldRetry(ctx) + result, err = manager.ShouldRetry(ctx, expression) assert.NoError(t, err) assert.False(t, result) }) diff --git a/router/internal/retrytransport/manager.go b/router/internal/retrytransport/manager.go new file mode 100644 index 0000000000..0271a39706 --- /dev/null +++ b/router/internal/retrytransport/manager.go @@ -0,0 +1,154 @@ +package retrytransport + +import ( + "fmt" + "net/http" + "time" + + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/internal/expr" + "go.uber.org/zap" +) + +type ( + ShouldRetryFunc func(err error, req *http.Request, resp *http.Response, exprString string) bool + OnRetryFunc func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) + requestLoggerGetter func(req *http.Request) *zap.Logger +) + +const ( + BackoffJitter = "backoff_jitter" +) + +type RetryOptions struct { + Enabled bool + Algorithm string + MaxRetryCount int + Interval time.Duration + MaxDuration time.Duration + Expression string +} + +type Manager struct { + retries map[string]*RetryOptions + exprManager *expr.RetryExpressionManager + retryFunc ShouldRetryFunc + onRetry OnRetryFunc + logger *zap.Logger +} + +func NewManager(exprManager *expr.RetryExpressionManager, retryFunc ShouldRetryFunc, onRetryFunc OnRetryFunc, logger *zap.Logger) *Manager { + if logger == nil { + logger = zap.NewNop() + } + return &Manager{ + retries: make(map[string]*RetryOptions), + exprManager: exprManager, + retryFunc: retryFunc, + onRetry: onRetryFunc, + logger: logger, + } +} + +func (m *Manager) Initialize(baseRetryOptions RetryOptions, subgraphRetryOptions map[string]RetryOptions, routerConfig *nodev1.RouterConfig) error { + // Get the list of all subgraph AND feature subgraphs + subgraphNameSet := make(map[string]bool, len(routerConfig.Subgraphs)) + for _, subgraph := range routerConfig.GetSubgraphs() { + subgraphNameSet[subgraph.Name] = true + } + if routerConfig.FeatureFlagConfigs != nil { + for _, ffConfig := range routerConfig.FeatureFlagConfigs.ConfigByFeatureFlagName { + for _, subgraph := range ffConfig.GetSubgraphs() { + subgraphNameSet[subgraph.Name] = true + } + } + } + + // Warn on retry configs pointing at subgraphs that don't exist in the + // router config — likely a typo that would otherwise silently disable + // the override. + for sgName := range subgraphRetryOptions { + if !subgraphNameSet[sgName] { + m.logger.Warn("Retry config references unknown subgraph; override will be ignored", + zap.String("subgraph_name", sgName), + ) + } + } + + defaultSgNames := make([]string, 0, len(subgraphNameSet)) + customSgNames := make([]string, 0, len(subgraphNameSet)) + + for subgraphName := range subgraphNameSet { + entry, ok := subgraphRetryOptions[subgraphName] + if !ok { + defaultSgNames = append(defaultSgNames, subgraphName) + } else if entry.Enabled { + // This will cover the case of if a subgraph is explicitly disabled + customSgNames = append(customSgNames, subgraphName) + } + } + + // First validate and add expressions for base retry options if needed + if len(defaultSgNames) > 0 && baseRetryOptions.Enabled { + if baseRetryOptions.Algorithm != BackoffJitter { + return fmt.Errorf("unsupported retry algorithm: %s", baseRetryOptions.Algorithm) + } + + // There is a chance that this is not evaluated if all defaultSgNames == 0, and only will + // then error out when its > 0 there + err := m.exprManager.AddExpression(baseRetryOptions.Expression) + if err != nil { + return fmt.Errorf("failed to add base retry expression: %w", err) + } + // Only assign default options if validation succeeds + for _, sgName := range defaultSgNames { + opts := baseRetryOptions + m.retries[sgName] = &opts + } + } + + // Process custom retry options + for _, sgName := range customSgNames { + entry, ok := subgraphRetryOptions[sgName] + if !ok { + return fmt.Errorf("failed to get subgraphRetryOptions: %s", sgName) + } + + if entry.Algorithm != BackoffJitter { + return fmt.Errorf("unsupported retry algorithm for subgraph %s: %s", sgName, entry.Algorithm) + } + + // Validate expression before assigning options + err := m.exprManager.AddExpression(entry.Expression) + if err != nil { + return fmt.Errorf("failed to add retry expression for subgraph %s: %w", sgName, err) + } + + // Create a new copy of the options + opts := entry + m.retries[sgName] = &opts + } + + return nil +} + +func (m *Manager) GetSubgraphOptions(name string) *RetryOptions { + if m == nil { + return nil + } + return m.retries[name] +} + +func (m *Manager) IsEnabled() bool { + if m == nil { + return false + } + return len(m.retries) > 0 +} + +func (m *Manager) Retry(err error, req *http.Request, resp *http.Response, exprString string) bool { + if m.retryFunc == nil { + return false + } + return m.retryFunc(err, req, resp, exprString) +} diff --git a/router/internal/retrytransport/retry_transport.go b/router/internal/retrytransport/retry_transport.go index 964b030dce..9d9e452497 100644 --- a/router/internal/retrytransport/retry_transport.go +++ b/router/internal/retrytransport/retry_transport.go @@ -11,28 +11,11 @@ import ( "go.uber.org/zap" ) -type ShouldRetryFunc func(err error, req *http.Request, resp *http.Response) bool -type OnRetryFunc func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) - -type RetryOptions struct { - Enabled bool - Algorithm string - MaxRetryCount int - Interval time.Duration - MaxDuration time.Duration - Expression string - ShouldRetry ShouldRetryFunc - - // Test specific only - OnRetry OnRetryFunc -} - -type requestLoggerGetter func(req *http.Request) *zap.Logger - type RetryHTTPTransport struct { - RoundTripper http.RoundTripper - RetryOptions RetryOptions - getRequestLogger requestLoggerGetter + roundTripper http.RoundTripper + getRequestLogger requestLoggerGetter + retryManager *Manager + getActiveSubgraph func(req *http.Request) string } // parseRetryAfterHeader parses the Retry-After header value according to RFC 7231. @@ -90,37 +73,42 @@ func shouldUseRetryAfter(logger *zap.Logger, resp *http.Response, maxDuration ti return duration, duration > 0 } -func NewRetryHTTPTransport( - roundTripper http.RoundTripper, - retryOptions RetryOptions, - getRequestLogger requestLoggerGetter, -) *RetryHTTPTransport { +func NewRetryHTTPTransport(roundTripper http.RoundTripper, getRequestLogger requestLoggerGetter, retryManager *Manager, getActiveSubgraph func(req *http.Request) string) *RetryHTTPTransport { return &RetryHTTPTransport{ - RoundTripper: roundTripper, - RetryOptions: retryOptions, - getRequestLogger: getRequestLogger, + roundTripper: roundTripper, + getRequestLogger: getRequestLogger, + getActiveSubgraph: getActiveSubgraph, + retryManager: retryManager, } } func (rt *RetryHTTPTransport) RoundTrip(req *http.Request) (*http.Response, error) { - resp, err := rt.RoundTripper.RoundTrip(req) + resp, err := rt.roundTripper.RoundTrip(req) // Short circuit if the request was successful. if err == nil && isResponseOK(resp) { return resp, nil } - b := backoff.New(rt.RetryOptions.MaxDuration, rt.RetryOptions.Interval) + activeSubgraph := rt.getActiveSubgraph(req) + retryOptions := rt.retryManager.GetSubgraphOptions(activeSubgraph) + + // If there is no option defined for this subgraph + if retryOptions == nil { + return resp, err + } + + b := backoff.New(retryOptions.MaxDuration, retryOptions.Interval) requestLogger := rt.getRequestLogger(req) // Retry logic retries := 0 - for (rt.RetryOptions.ShouldRetry(err, req, resp)) && retries < rt.RetryOptions.MaxRetryCount { + for (rt.retryManager.Retry(err, req, resp, retryOptions.Expression)) && retries < retryOptions.MaxRetryCount { retries++ // Check if we should use Retry-After header for 429 responses var sleepDuration time.Duration - if retryAfterDuration, useRetryAfter := shouldUseRetryAfter(requestLogger, resp, rt.RetryOptions.MaxDuration); useRetryAfter { + if retryAfterDuration, useRetryAfter := shouldUseRetryAfter(requestLogger, resp, retryOptions.MaxDuration); useRetryAfter { sleepDuration = retryAfterDuration requestLogger.Debug("Using Retry-After header for 429 response", zap.Int("retry", retries), @@ -137,9 +125,9 @@ func (rt *RetryHTTPTransport) RoundTrip(req *http.Request) (*http.Response, erro ) } - // Test Specific - if rt.RetryOptions.OnRetry != nil { - rt.RetryOptions.OnRetry(retries, req, resp, sleepDuration, err) + // A hook used for testing + if rt.retryManager.onRetry != nil { + rt.retryManager.onRetry(retries, req, resp, sleepDuration, err) } // Wait for the specified duration @@ -149,7 +137,7 @@ func (rt *RetryHTTPTransport) RoundTrip(req *http.Request) (*http.Response, erro rt.drainBody(resp, requestLogger) // Retry the request - resp, err = rt.RoundTripper.RoundTrip(req) + resp, err = rt.roundTripper.RoundTrip(req) // Short circuit if the request was successful if err == nil && isResponseOK(resp) { @@ -182,6 +170,10 @@ func (rt *RetryHTTPTransport) drainBody(resp *http.Response, logger *zap.Logger) } func isResponseOK(resp *http.Response) bool { + // 101 Switching Protocols is a successful response for WebSocket upgrades + if resp.StatusCode == http.StatusSwitchingProtocols { + return true + } // Ensure we don't wait for no reason when subgraphs don't behave // spec-compliant and returns a different status code than 200. return resp.StatusCode >= 200 && resp.StatusCode < 300 diff --git a/router/internal/retrytransport/retry_transport_test.go b/router/internal/retrytransport/retry_transport_test.go index aa920ed318..41328c37ce 100644 --- a/router/internal/retrytransport/retry_transport_test.go +++ b/router/internal/retrytransport/retry_transport_test.go @@ -7,18 +7,32 @@ import ( "io" "net/http" "net/http/httptest" + "strconv" "strings" "testing" "time" "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.uber.org/zap" + + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/internal/expr" ) const defaultMaxDuration = 100 * time.Second +var loggerFunc = func(_ *http.Request) *zap.Logger { return zap.NewNop() } + +var getActiveSubgraph = func(subgraphName string) func(_ *http.Request) string { + return func(_ *http.Request) string { + return subgraphName + } +} + // simpleShouldRetry provides simple retry logic for testing the transport implementation func simpleShouldRetry(err error, req *http.Request, resp *http.Response) bool { // Simple logic for testing - retry on 5xx status codes or any error @@ -51,40 +65,40 @@ func (dt *MockTransport) RoundTrip(req *http.Request) (*http.Response, error) { return dt.handler(req) } +func newTestManager(shouldRetry func(error, *http.Request, *http.Response) bool, onRetry OnRetryFunc, opts RetryOptions, subgraphName string) *Manager { + mgr := NewManager(expr.NewRetryExpressionManager(), func(err error, req *http.Request, resp *http.Response, _ string) bool { + return shouldRetry(err, req, resp) + }, onRetry, zap.NewNop()) + // attach options for the subgraph + mgr.retries[subgraphName] = &opts + return mgr +} + func TestRetryOnHTTP5xx(t *testing.T) { retries := 0 attemptCount := 0 maxRetries := 3 - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return 500 to trigger retry - return &http.Response{ - StatusCode: http.StatusInternalServerError, - }, nil - } - // Finally return success + subgraphName := "sg" + mgr := newTestManager(simpleShouldRetry, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(_ *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 500 to trigger retry return &http.Response{ - StatusCode: http.StatusOK, + StatusCode: http.StatusInternalServerError, }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: simpleShouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) @@ -102,33 +116,24 @@ func TestRetryOnErrors(t *testing.T) { attemptCount := 0 maxRetries := 3 - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return any error to trigger retry - return nil, errors.New("some network error") - } - // Finally return success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: simpleShouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + subgraphName := "sg" + mgr := newTestManager(simpleShouldRetry, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(_ *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return any error to trigger retry + return nil, errors.New("some network error") + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) @@ -156,36 +161,27 @@ func TestDoNotRetryWhenShouldRetryReturnsFalse(t *testing.T) { return simpleShouldRetry(err, req, resp) } - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - switch attemptCount { - case 1: - // First attempt: return retryable error - return nil, errors.New("retryable error") - case 2: - // Second attempt: return retryable status code - return &http.Response{StatusCode: http.StatusInternalServerError}, nil - default: - // Third attempt: return non-retryable error (should stop retrying) - return nil, nonRetryableError - } - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetryCount, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: shouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + subgraphName := "sg" + mgr := newTestManager(shouldRetry, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetryCount, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(_ *http.Request) (*http.Response, error) { + attemptCount++ + switch attemptCount { + case 1: + // First attempt: return retryable error + return nil, errors.New("retryable error") + case 2: + // Second attempt: return retryable status code + return &http.Response{StatusCode: http.StatusInternalServerError}, nil + default: + // Third attempt: return non-retryable error (should stop retrying) + return nil, nonRetryableError + } }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) @@ -232,30 +228,21 @@ func (b *TrackableBody) Close() error { func TestShortCircuitOnSuccess(t *testing.T) { attemptCount := 0 - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - // Always return success - return &http.Response{ - StatusCode: http.StatusOK, - Body: io.NopCloser(strings.NewReader("success")), - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() + subgraphName := "sg" + mgr := newTestManager(simpleShouldRetry, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + t.Error("onRetry should not be called when first request succeeds") + }, RetryOptions{MaxRetryCount: 5, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(_ *http.Request) (*http.Response, error) { + attemptCount++ + // Always return success + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader("success")), + }, nil }, - RetryOptions: RetryOptions{ - MaxRetryCount: 5, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: simpleShouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - t.Error("OnRetry should not be called when first request succeeds") - }, - }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -276,27 +263,18 @@ func TestMaxRetryCountRespected(t *testing.T) { retries := 0 attemptCount := 0 - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - // Always return retryable error to test max retry limit - return nil, errors.New("always fail") - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: simpleShouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + subgraphName := "sg" + mgr := newTestManager(simpleShouldRetry, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(_ *http.Request) (*http.Response, error) { + attemptCount++ + // Always return retryable error to test max retry limit + return nil, errors.New("always fail") }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -323,36 +301,27 @@ func TestResponseBodyDraining(t *testing.T) { } } - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - index++ - if index < retryCount { - return &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: bodies[index], - }, nil - } else { - return &http.Response{ - StatusCode: http.StatusOK, - Body: bodies[index], - }, nil - } - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: retryCount, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: simpleShouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - actualRetries++ - }, + subgraphName := "sg" + mgr := newTestManager(simpleShouldRetry, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + actualRetries++ + }, RetryOptions{MaxRetryCount: retryCount, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(_ *http.Request) (*http.Response, error) { + index++ + if index < retryCount { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: bodies[index], + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: bodies[index], + }, nil + }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) @@ -398,40 +367,31 @@ func TestRequestLoggerIsUsed(t *testing.T) { bodies[i] = trackableBody } - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - index++ - if index < retryCount { - return &http.Response{ - StatusCode: http.StatusInternalServerError, - Body: bodies[index], - }, nil - } else { - return &http.Response{ - StatusCode: http.StatusOK, - Body: bodies[index], - }, nil - } - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return requestLogger - }, - RetryOptions: RetryOptions{ - MaxRetryCount: retryCount, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: simpleShouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - actualRetries++ - }, + subgraphName := "sg" + mgr := newTestManager(simpleShouldRetry, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + actualRetries++ + }, RetryOptions{MaxRetryCount: retryCount, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(_ *http.Request) (*http.Response, error) { + index++ + if index < retryCount { + return &http.Response{ + StatusCode: http.StatusInternalServerError, + Body: bodies[index], + }, nil + } + return &http.Response{ + StatusCode: http.StatusOK, + Body: bodies[index], + }, nil }, - } + }, func(req *http.Request) *zap.Logger { return requestLogger }, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) - _, _ = tr.RoundTrip(req) + _, err := tr.RoundTrip(req) + assert.NoError(t, err) assert.Contains(t, requestLoggerBuf.String(), "Failed draining when discarding the body\t{\"error\": \"retry read error, index: 1\"}") assert.Contains(t, requestLoggerBuf.String(), "Failed draining when closing the body\t{\"error\": \"retry close error, index: 2\"}") @@ -459,37 +419,28 @@ func TestOnRetryCallbackInvoked(t *testing.T) { resp *http.Response } - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - if retries < maxRetries { - // Return retryable error - return nil, errors.New("retryable error") - } - // Finally return success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: simpleShouldRetry, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - retryCallbacks = append(retryCallbacks, struct { - count int - err error - resp *http.Response - }{count: count, err: err, resp: resp}) - }, + subgraphName := "sg" + mgr := newTestManager(simpleShouldRetry, func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { + retries++ + retryCallbacks = append(retryCallbacks, struct { + count int + err error + resp *http.Response + }{count: count, err: err, resp: resp}) + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + if retries < maxRetries { + // Return retryable error + return nil, errors.New("retryable error") + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -497,7 +448,7 @@ func TestOnRetryCallbackInvoked(t *testing.T) { assert.NoError(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) - // Verify OnRetry was called the right number of times + // Verify onRetry was called the right number of times assert.Equal(t, maxRetries, retries) assert.Len(t, retryCallbacks, maxRetries) @@ -519,45 +470,36 @@ func TestRetryOn429WithDelaySeconds(t *testing.T) { // Track what retry duration was requested to verify Retry-After is parsed correctly var retryAfterUsed []bool - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return 429 with Retry-After header in seconds - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - resp.Header.Set("Retry-After", fmt.Sprintf("%d", retryAfterSeconds)) - - // Verify the header is parsed correctly - duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) - retryAfterUsed = append(retryAfterUsed, useRetryAfter) - assert.True(t, useRetryAfter, "Should use Retry-After header for 429") - assert.Equal(t, time.Duration(retryAfterSeconds)*time.Second, duration) - - return resp, nil + subgraphName := `sg` + mgr := newTestManager(shouldRetryWith429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 100 * time.Millisecond, MaxDuration: 10 * time.Second}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with Retry-After header in seconds + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), } - // Finally return success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 100 * time.Millisecond, // This should be ignored for 429 - MaxDuration: 10 * time.Second, - ShouldRetry: shouldRetryWith429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + resp.Header.Set("Retry-After", strconv.Itoa(retryAfterSeconds)) + + // Verify the header is parsed correctly + duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsed = append(retryAfterUsed, useRetryAfter) + assert.True(t, useRetryAfter, "Should use Retry-After header for 429") + assert.Equal(t, time.Duration(retryAfterSeconds)*time.Second, duration) + + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -583,45 +525,36 @@ func TestRetryOn429WithDelaySecondsLargerThanMaxDuration(t *testing.T) { // Track what retry duration was requested to verify Retry-After is parsed correctly var retryAfterUsed []bool - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return 429 with Retry-After header in seconds - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - resp.Header.Set("Retry-After", fmt.Sprintf("%d", retryAfterSeconds)) - - // Verify the header is parsed correctly - duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, maxDuration) - retryAfterUsed = append(retryAfterUsed, useRetryAfter) - assert.True(t, useRetryAfter, "Should use Retry-After header for 429") - assert.Equal(t, maxDuration, duration) - - return resp, nil + subgraphName := "sg" + mgr := newTestManager(shouldRetryWith429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 100 * time.Millisecond, MaxDuration: 10 * time.Second}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with Retry-After header in seconds + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), } - // Finally return success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 100 * time.Millisecond, // This should be ignored for 429 - MaxDuration: 10 * time.Second, - ShouldRetry: shouldRetryWith429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + resp.Header.Set("Retry-After", strconv.Itoa(retryAfterSeconds)) + + // Verify the header is parsed correctly + duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, maxDuration) + retryAfterUsed = append(retryAfterUsed, useRetryAfter) + assert.True(t, useRetryAfter, "Should use Retry-After header for 429") + assert.Equal(t, maxDuration, duration) + + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -642,36 +575,27 @@ func TestRetryOn429WithoutRetryAfter(t *testing.T) { attemptCount := 0 maxRetries := 2 - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return 429 without Retry-After header - return &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - }, nil - } - // Finally return success + subgraphName := "sg" + mgr := newTestManager(shouldRetryWith429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 without Retry-After header return &http.Response{ - StatusCode: http.StatusOK, + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: shouldRetryWith429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -692,49 +616,40 @@ func TestRetryOn429WithHTTPDate(t *testing.T) { var retryAfterUsed []bool var expectedDuration time.Duration - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return 429 with Retry-After header as HTTP-date (1 second in future to keep test fast) - expectedDuration = 1 * time.Second - futureTime := time.Now().UTC().Add(expectedDuration) - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - resp.Header.Set("Retry-After", futureTime.Format(http.TimeFormat)) - - // Verify the header is parsed correctly - duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) - retryAfterUsed = append(retryAfterUsed, useRetryAfter) - assert.True(t, useRetryAfter, "Should use Retry-After header for 429") - // Allow reasonable tolerance for execution delay between time creation and parsing - assert.True(t, duration > 0 && duration <= expectedDuration, - "Duration should be positive and <= %v, got %v", expectedDuration, duration) - - return resp, nil + subgraphName := "sg" + mgr := newTestManager(shouldRetryWith429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 100 * time.Millisecond, MaxDuration: 10 * time.Second}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with Retry-After header as HTTP-date (1 second in future to keep test fast) + expectedDuration = 1 * time.Second + futureTime := time.Now().UTC().Add(expectedDuration) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), } - // Finally return success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 100 * time.Millisecond, // This should be ignored for 429 - MaxDuration: 10 * time.Second, - ShouldRetry: shouldRetryWith429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + resp.Header.Set("Retry-After", futureTime.Format(http.TimeFormat)) + + // Verify the header is parsed correctly + duration, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsed = append(retryAfterUsed, useRetryAfter) + assert.True(t, useRetryAfter, "Should use Retry-After header for 429") + // Allow reasonable tolerance for execution delay between time creation and parsing + assert.True(t, duration > 0 && duration <= expectedDuration, + "Duration should be positive and <= %v, got %v", expectedDuration, duration) + + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -755,38 +670,29 @@ func TestRetryOn429WithInvalidRetryAfterHeader(t *testing.T) { attemptCount := 0 maxRetries := 2 - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return 429 with invalid Retry-After header - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - resp.Header.Set("Retry-After", "invalid-value") - return resp, nil + subgraphName := "sg" + mgr := newTestManager(shouldRetryWith429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with invalid Retry-After header + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), } - // Finally return success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: shouldRetryWith429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + resp.Header.Set("Retry-After", "invalid-value") + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -804,38 +710,29 @@ func TestRetryOn429WithNegativeDelaySeconds(t *testing.T) { attemptCount := 0 maxRetries := 2 - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - if attemptCount <= maxRetries { - // Return 429 with negative Retry-After value (should fall back to normal backoff) - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - resp.Header.Set("Retry-After", "-1") - return resp, nil + subgraphName := "sg" + mgr := newTestManager(shouldRetryWith429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + if attemptCount <= maxRetries { + // Return 429 with negative Retry-After value (should fall back to normal backoff) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), } - // Finally return success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: shouldRetryWith429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + resp.Header.Set("Retry-After", "-1") + return resp, nil + } + // Finally return success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -855,66 +752,57 @@ func TestRetryMixed429AndOtherErrors(t *testing.T) { // Track which responses used Retry-After vs normal backoff var retryAfterUsedPerAttempt []bool - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - switch attemptCount { - case 1: - // First: 429 with Retry-After - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - resp.Header.Set("Retry-After", "1") - - // Verify this should use Retry-After - _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) - retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) - - return resp, nil - case 2: - // Second: Network error (should use normal backoff) - retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, false) - return nil, errors.New("network error") - case 3: - // Third: 500 error (should use normal backoff) - resp := &http.Response{ - StatusCode: http.StatusInternalServerError, - } - _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) - retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) - return resp, nil - case 4: - // Fourth: 429 without Retry-After (should use normal backoff) - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) - retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) - return resp, nil - default: - // Finally: Success - return &http.Response{ - StatusCode: http.StatusOK, - }, nil + subgraphName := "sg" + mgr := newTestManager(shouldRetryWith429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: maxRetries, Interval: 10 * time.Millisecond, MaxDuration: 10 * time.Second}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + switch attemptCount { + case 1: + // First: 429 with Retry-After + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), } - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: maxRetries, - Interval: 10 * time.Millisecond, // Should be used for non-429-with-Retry-After cases - MaxDuration: 10 * time.Second, - ShouldRetry: shouldRetryWith429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + resp.Header.Set("Retry-After", "1") + + // Verify this should use Retry-After + _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) + + return resp, nil + case 2: + // Second: Network error (should use normal backoff) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, false) + return nil, errors.New("network error") + case 3: + // Third: 500 error (should use normal backoff) + resp := &http.Response{ + StatusCode: http.StatusInternalServerError, + } + _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) + return resp, nil + case 4: + // Fourth: 429 without Retry-After (should use normal backoff) + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + _, useRetryAfter := shouldUseRetryAfter(zap.NewNop(), resp, defaultMaxDuration) + retryAfterUsedPerAttempt = append(retryAfterUsedPerAttempt, useRetryAfter) + return resp, nil + default: + // Finally: Success + return &http.Response{ + StatusCode: http.StatusOK, + }, nil + } }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -945,32 +833,23 @@ func TestNoRetryOn429WhenShouldRetryReturnsFalse(t *testing.T) { return err != nil } - tr := RetryHTTPTransport{ - RoundTripper: &MockTransport{ - handler: func(req *http.Request) (*http.Response, error) { - attemptCount++ - // Always return 429 with Retry-After header - resp := &http.Response{ - StatusCode: http.StatusTooManyRequests, - Header: make(http.Header), - } - resp.Header.Set("Retry-After", "1") - return resp, nil - }, - }, - getRequestLogger: func(req *http.Request) *zap.Logger { - return zap.NewNop() - }, - RetryOptions: RetryOptions{ - MaxRetryCount: 3, - Interval: 1 * time.Millisecond, - MaxDuration: 10 * time.Millisecond, - ShouldRetry: shouldNotRetry429, - OnRetry: func(count int, req *http.Request, resp *http.Response, sleepDuration time.Duration, err error) { - retries++ - }, + subgraphName := `sg` + mgr := newTestManager(shouldNotRetry429, func(_ int, _ *http.Request, _ *http.Response, _ time.Duration, _ error) { + retries++ + }, RetryOptions{MaxRetryCount: 3, Interval: 1 * time.Millisecond, MaxDuration: 10 * time.Millisecond}, subgraphName) + + tr := NewRetryHTTPTransport(&MockTransport{ + handler: func(req *http.Request) (*http.Response, error) { + attemptCount++ + // Always return 429 with Retry-After header + resp := &http.Response{ + StatusCode: http.StatusTooManyRequests, + Header: make(http.Header), + } + resp.Header.Set("Retry-After", "1") + return resp, nil }, - } + }, loggerFunc, mgr, getActiveSubgraph(subgraphName)) req := httptest.NewRequest("GET", "http://localhost:3000/graphql", nil) resp, err := tr.RoundTrip(req) @@ -1139,3 +1018,57 @@ func TestShouldUseRetryAfter(t *testing.T) { }) } } + +func TestManager_Initialize_WarnsOnUnknownSubgraph(t *testing.T) { + routerConfig := &nodev1.RouterConfig{ + Subgraphs: []*nodev1.Subgraph{ + {Id: "known-id", Name: "known", RoutingUrl: "http://localhost:8001/graphql"}, + }, + } + + t.Run("warns for each retry config whose subgraph is not in the router config", func(t *testing.T) { + core, logs := observer.New(zapcore.WarnLevel) + mgr := NewManager(expr.NewRetryExpressionManager(), nil, nil, zap.New(core)) + + err := mgr.Initialize( + RetryOptions{}, + map[string]RetryOptions{ + "known": {Enabled: true, Algorithm: BackoffJitter, Expression: "true"}, + "typo_name": {Enabled: true, Algorithm: BackoffJitter, Expression: "true"}, + "other_typo": {Enabled: false}, + }, + routerConfig, + ) + require.NoError(t, err) + + warnings := logs.FilterMessage("Retry config references unknown subgraph; override will be ignored").All() + require.Len(t, warnings, 2) + + got := make(map[string]bool, len(warnings)) + for _, entry := range warnings { + got[entry.ContextMap()["subgraph_name"].(string)] = true + } + require.True(t, got["typo_name"]) + require.True(t, got["other_typo"]) + + // Unknown-named overrides must not leak into the retries map + require.Nil(t, mgr.GetSubgraphOptions("typo_name")) + require.Nil(t, mgr.GetSubgraphOptions("other_typo")) + require.NotNil(t, mgr.GetSubgraphOptions("known")) + }) + + t.Run("does not warn when every retry config matches a known subgraph", func(t *testing.T) { + core, logs := observer.New(zapcore.WarnLevel) + mgr := NewManager(expr.NewRetryExpressionManager(), nil, nil, zap.New(core)) + + err := mgr.Initialize( + RetryOptions{}, + map[string]RetryOptions{ + "known": {Enabled: true, Algorithm: BackoffJitter, Expression: "true"}, + }, + routerConfig, + ) + require.NoError(t, err) + require.Zero(t, logs.FilterMessage("Retry config references unknown subgraph; override will be ignored").Len()) + }) +} diff --git a/router/internal/traceclient/traceclient.go b/router/internal/traceclient/traceclient.go index d787473330..53cd7236a8 100644 --- a/router/internal/traceclient/traceclient.go +++ b/router/internal/traceclient/traceclient.go @@ -6,7 +6,6 @@ import ( "net/http/httptrace" "time" - rcontext "github.com/wundergraph/cosmo/router/internal/context" "github.com/wundergraph/cosmo/router/internal/expr" "github.com/wundergraph/cosmo/router/pkg/metric" @@ -33,20 +32,23 @@ type ClientTrace struct { type ClientTraceContextKey struct{} type TraceInjectingRoundTripper struct { - base http.RoundTripper - connectionMetricStore metric.ConnectionMetricStore - reqContextValuesGetter func(ctx context.Context, req *http.Request) (*expr.Context, string) + base http.RoundTripper + connectionMetricStore metric.ConnectionMetricStore + getExprContext func(ctx context.Context) *expr.Context + getActiveSubgraphName func(req *http.Request) string } func NewTraceInjectingRoundTripper( base http.RoundTripper, connectionMetricStore metric.ConnectionMetricStore, - reqContextValuesGetter func(ctx context.Context, req *http.Request) (*expr.Context, string), + getExprContext func(ctx context.Context) *expr.Context, + getActiveSubgraphName func(req *http.Request) string, ) *TraceInjectingRoundTripper { return &TraceInjectingRoundTripper{ - base: base, - connectionMetricStore: connectionMetricStore, - reqContextValuesGetter: reqContextValuesGetter, + base: base, + connectionMetricStore: connectionMetricStore, + getExprContext: getExprContext, + getActiveSubgraphName: getActiveSubgraphName, } } @@ -104,23 +106,14 @@ func (t *TraceInjectingRoundTripper) getClientTrace(ctx context.Context) *httptr func (t *TraceInjectingRoundTripper) processConnectionMetrics(ctx context.Context, req *http.Request) { trace := GetClientTraceFromContext(ctx) - var subgraph string - subgraphCtxVal := ctx.Value(rcontext.CurrentSubgraphContextKey{}) - if subgraphCtxVal != nil { - subgraph = subgraphCtxVal.(string) - } - - // We have a fallback for active subgraph name in case engine loader hooks is not called - // TODO: Evaluate if we actually need a fallback and if we can use only one way to get the active subgraph name - exprContext, activeSubgraphName := t.reqContextValuesGetter(ctx, req) - if subgraph == "" { - subgraph = activeSubgraphName - } + exprContext := t.getExprContext(ctx) if trace.ConnectionGet != nil && trace.ConnectionAcquired != nil { duration := trace.ConnectionAcquired.Time.Sub(trace.ConnectionGet.Time) exprContext.Subgraph.Request.ClientTrace.ConnectionAcquireDuration = duration + subgraph := t.getActiveSubgraphName(req) + serverAttributes := rotel.GetServerAttributes(trace.ConnectionGet.HostPort) serverAttributes = append( serverAttributes,