diff --git a/v2/pkg/engine/datasource/httpclient/nethttpclient.go b/v2/pkg/engine/datasource/httpclient/nethttpclient.go index f0fb36694f..02025437c2 100644 --- a/v2/pkg/engine/datasource/httpclient/nethttpclient.go +++ b/v2/pkg/engine/datasource/httpclient/nethttpclient.go @@ -80,6 +80,13 @@ func InjectResponseContext(ctx context.Context) (context.Context, *ResponseConte return context.WithValue(ctx, responseContextKey{}, value), value } +// GetResponseContext retrieves the ResponseContext previously injected into ctx +// via InjectResponseContext. Returns nil if no ResponseContext is present. +func GetResponseContext(ctx context.Context) *ResponseContext { + value, _ := ctx.Value(responseContextKey{}).(*ResponseContext) + return value +} + func setRequest(ctx context.Context, request *http.Request) { if value, ok := ctx.Value(responseContextKey{}).(*ResponseContext); ok { value.Request = request diff --git a/v2/pkg/engine/resolve/context.go b/v2/pkg/engine/resolve/context.go index 846324fb49..dfa481b281 100644 --- a/v2/pkg/engine/resolve/context.go +++ b/v2/pkg/engine/resolve/context.go @@ -44,6 +44,43 @@ type Context struct { subgraphErrors map[string]error SubgraphHeadersBuilder SubgraphHeadersBuilder + + // GetDeduplicationData is called after the leader of an inbound singleflight request + // finishes resolving. It extracts data from the leader's context (e.g. accumulated + // response headers) that should be shared with all follower requests. + // The returned value is stored on the InflightRequest and passed to each follower's + // SetDeduplicationData callback before the follower writes its response. + // Use SetDeduplicationCallbacks to set both callbacks with type safety. + GetDeduplicationData func(ctx context.Context) any + // SetDeduplicationData is called for each follower of an inbound singleflight request, + // before the response body is written to the client. The data argument is the value + // returned by the leader's GetDeduplicationData call. + // Typical use: copy response header propagation state from the leader into the + // follower's context so that the response writer can set the correct HTTP headers. + // Use SetDeduplicationCallbacks to set both callbacks with type safety. + SetDeduplicationData func(ctx context.Context, data any) +} + +// SetDeduplicationCallbacks is a generic helper that configures both GetDeduplicationData +// and SetDeduplicationData on a Context with compile-time type safety. +// The resolve package stores the data as "any" internally, but callers get typed callbacks: +// +// resolve.SetDeduplicationCallbacks(ctx, +// func(ctx context.Context) *MyHeaders { return extractHeaders(ctx) }, +// func(ctx context.Context, h *MyHeaders) { applyHeaders(ctx, h) }, +// ) +// +// The get and set callbacks must use the same concrete type T. If the value returned by +// get cannot be asserted to T when passed to set, the set callback will be skipped. +func SetDeduplicationCallbacks[T any](c *Context, get func(ctx context.Context) T, set func(ctx context.Context, data T)) { + c.GetDeduplicationData = func(ctx context.Context) any { + return get(ctx) + } + c.SetDeduplicationData = func(ctx context.Context, data any) { + if typed, ok := data.(T); ok { + set(ctx, typed) + } + } } // SubgraphHeadersBuilder allows the user of the engine to "define" the headers for a subgraph request @@ -276,6 +313,8 @@ func (c *Context) Free() { c.subgraphErrors = nil c.authorizer = nil c.LoaderHooks = nil + c.GetDeduplicationData = nil + c.SetDeduplicationData = nil } type traceStartKey struct{} diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight.go b/v2/pkg/engine/resolve/inbound_request_singleflight.go index a20629939f..a67d8deccc 100644 --- a/v2/pkg/engine/resolve/inbound_request_singleflight.go +++ b/v2/pkg/engine/resolve/inbound_request_singleflight.go @@ -38,8 +38,13 @@ func NewRequestSingleFlight(shardCount int) *InboundRequestSingleFlight { type InflightRequest struct { Done chan struct{} Data []byte - Err error - ID uint64 + // SharedData carries opaque state from the leader to followers (e.g. accumulated + // response headers). Set by the leader via Context.GetDeduplicationData, read by + // followers via Context.SetDeduplicationData. Typed as "any" because the resolve + // package is data-agnostic — the caller decides the concrete type. + SharedData any + Err error + ID uint64 HasFollowers bool Mu sync.Mutex @@ -95,8 +100,7 @@ func (r *InboundRequestSingleFlight) GetOrCreate(ctx *Context, response *GraphQL } return request, nil case <-ctx.ctx.Done(): - request.Err = ctx.ctx.Err() - return nil, request.Err + return nil, ctx.ctx.Err() } } diff --git a/v2/pkg/engine/resolve/inbound_request_singleflight_test.go b/v2/pkg/engine/resolve/inbound_request_singleflight_test.go new file mode 100644 index 0000000000..a82d372d4f --- /dev/null +++ b/v2/pkg/engine/resolve/inbound_request_singleflight_test.go @@ -0,0 +1,112 @@ +package resolve + +import ( + "context" + "sync" + "testing" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" +) + +// TestInboundSingleFlight_ConcurrentFollowerTimeout exercises the scenario where +// multiple followers time out concurrently. Before the fix, each follower wrote +// its context error to the shared request.Err field without synchronization, +// causing a data race. After the fix, followers return ctx.Err() directly +// without mutating shared state. Run with -race to verify. +func TestInboundSingleFlight_ConcurrentFollowerTimeout(t *testing.T) { + sf := NewRequestSingleFlight(1) + response := &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + } + + // Leader creates the inflight request + leaderCtx := NewContext(context.Background()) + leaderCtx.Request.ID = 1 + inflight, err := sf.GetOrCreate(leaderCtx, response) + if err != nil { + t.Fatalf("leader GetOrCreate: %v", err) + } + if inflight == nil { + t.Fatal("expected inflight request from leader") + } + + const numFollowers = 10 + var wg sync.WaitGroup + wg.Add(numFollowers) + + for i := 0; i < numFollowers; i++ { + go func() { + defer wg.Done() + ctx, cancel := context.WithCancel(context.Background()) + followerCtx := NewContext(ctx) + followerCtx.Request.ID = 1 + + // Cancel immediately so the follower's context is done + cancel() + + _, followerErr := sf.GetOrCreate(followerCtx, response) + if followerErr == nil { + t.Error("expected error from timed-out follower") + } + }() + } + + wg.Wait() + + // Clean up: finish the leader request + sf.FinishOk(inflight, []byte("ok")) +} + +func TestInboundSingleFlight_FollowerReceivesLeaderError(t *testing.T) { + sf := NewRequestSingleFlight(1) + response := &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + } + + leaderCtx := NewContext(context.Background()) + leaderCtx.Request.ID = 2 + inflight, err := sf.GetOrCreate(leaderCtx, response) + if err != nil { + t.Fatalf("leader GetOrCreate: %v", err) + } + + // The follower calls GetOrCreate which blocks on inflight.Done. + // We wait for HasFollowers to be set before calling FinishErr. + followerReady := make(chan struct{}) + var wg sync.WaitGroup + wg.Add(1) + + go func() { + defer wg.Done() + followerCtx := NewContext(context.Background()) + followerCtx.Request.ID = 2 + + // Signal that we're about to enter GetOrCreate. HasFollowers will be + // set inside GetOrCreate before the select blocks, so closing + // followerReady here is slightly early, but we poll HasFollowers below. + close(followerReady) + + _, followerErr := sf.GetOrCreate(followerCtx, response) + if followerErr == nil { + t.Error("expected error from follower after leader FinishErr") + } + }() + + <-followerReady + // Spin until the follower has actually registered (set HasFollowers) + for { + inflight.Mu.Lock() + ready := inflight.HasFollowers + inflight.Mu.Unlock() + if ready { + break + } + } + + sf.FinishErr(inflight, context.DeadlineExceeded) + wg.Wait() +} diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 07991e8100..3d5a1b40f0 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -1717,6 +1717,21 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem } res.out = item.response + // Populate the ResponseContext that was injected by executeSourceLoad. + // This is the same pointer that executeSourceLoad reads when it assigns + // res.statusCode and res.httpResponseContext, so the follower's result + // fields will be set correctly even though no HTTP call was made. + if rc := httpclient.GetResponseContext(ctx); rc != nil { + rc.StatusCode = item.statusCode + if item.responseHeaders != nil { + // Minimal synthetic http.Response carrying only status and headers. + // Clone headers so each concurrent follower gets an independent copy. + rc.Response = &http.Response{ + StatusCode: item.statusCode, + Header: item.responseHeaders.Clone(), + } + } + } return nil } @@ -1733,6 +1748,14 @@ func (l *Loader) loadByContext(ctx context.Context, source DataSource, fetchItem } item.response = res.out + // Capture the leader's HTTP response metadata so followers can reuse it. + // The ResponseContext was populated by the HTTP client during loadByContextDirect. + if rc := httpclient.GetResponseContext(ctx); rc != nil { + item.statusCode = rc.StatusCode + if rc.Response != nil && rc.Response.Header != nil { + item.responseHeaders = rc.Response.Header.Clone() + } + } return nil } diff --git a/v2/pkg/engine/resolve/resolve.go b/v2/pkg/engine/resolve/resolve.go index 3a93b9d0f3..7ea1bc9860 100644 --- a/v2/pkg/engine/resolve/resolve.go +++ b/v2/pkg/engine/resolve/resolve.go @@ -362,6 +362,11 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe if inflight != nil && inflight.Data != nil { // follower resp.ResolveDeduplicated = true + // Apply the leader's shared state (e.g. response headers) to this follower's context + // before writing the response, so the response writer can propagate headers correctly. + if ctx.SetDeduplicationData != nil && inflight.SharedData != nil { + ctx.SetDeduplicationData(ctx.ctx, inflight.SharedData) + } _, err = writer.Write(inflight.Data) return resp, err } @@ -412,6 +417,14 @@ func (r *Resolver) ArenaResolveGraphQLResponse(ctx *Context, response *GraphQLRe // as such, it can take some time // which is why we split the arenas and released the first one _, err = writer.Write(buf.Bytes()) + // Extract data from the leader's context to share with singleflight followers. + // This runs after the leader has fully resolved and written its response, so all + // subgraph response headers have been accumulated on the leader's context. + // SharedData MUST be set BEFORE FinishOk, which closes the Done channel and + // unblocks followers. Otherwise followers could read SharedData before it is set. + if inflight != nil && ctx.GetDeduplicationData != nil { + inflight.SharedData = ctx.GetDeduplicationData(ctx.ctx) + } r.inboundRequestSingleFlight.FinishOk(inflight, buf.Bytes()) // all data is written to the client // we're safe to release our buffer diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index fda6f60806..287ab72a63 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -4746,6 +4746,149 @@ func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication(t *testing.T) } } +func TestResolver_ArenaResolveGraphQLResponse_RequestDeduplication_SharedData(t *testing.T) { + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := newResolver(rCtx) + + ds := newBlockingDataSource([]byte(`{"value":"slow"}`)) + defer ds.Release() + + response := &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + Fetches: Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: ds, + }, + }), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("value"), + Value: &String{ + Path: []string{"value"}, + Nullable: false, + }, + }, + }, + }, + } + + type sharedHeaders struct { + CacheControl string + } + + // Tracks which contexts received shared data + var followerData sync.Map + + ctxTemplateBase := NewContext(context.Background()) + ctxTemplateBase.Request.ID = 42 + ctxTemplateBase.VariablesHash = 1337 + + const requestCount = 3 + + type result struct { + info *GraphQLResolveInfo + output string + err error + } + + results := make([]result, requestCount) + + var wg sync.WaitGroup + wg.Add(requestCount) + + leaderWriter := newBlockingWriter() + + go func() { + defer wg.Done() + ctx := *ctxTemplateBase + SetDeduplicationCallbacks(&ctx, + func(ctx context.Context) *sharedHeaders { + return &sharedHeaders{CacheControl: "max-age=120"} + }, + func(ctx context.Context, data *sharedHeaders) { + followerData.Store("leader", data) + }, + ) + info, err := r.ArenaResolveGraphQLResponse(&ctx, response, leaderWriter) + results[0] = result{info: info, output: leaderWriter.String(), err: err} + }() + + select { + case <-ds.Ready(): + case <-time.After(time.Second): + t.Fatalf("timeout waiting for leader data source load") + } + + startFollowers := make(chan struct{}) + followersEntered := make(chan struct{}, requestCount-1) + + for i := 1; i < requestCount; i++ { + go func(i int) { + defer wg.Done() + ctx := *ctxTemplateBase + SetDeduplicationCallbacks(&ctx, + func(ctx context.Context) *sharedHeaders { + return &sharedHeaders{CacheControl: "max-age=120"} + }, + func(ctx context.Context, data *sharedHeaders) { + followerData.Store(i, data) + }, + ) + <-startFollowers + followersEntered <- struct{}{} + buf := &bytes.Buffer{} + info, err := r.ArenaResolveGraphQLResponse(&ctx, response, buf) + results[i] = result{info: info, output: buf.String(), err: err} + }(i) + } + + close(startFollowers) + + for i := 1; i < requestCount; i++ { + select { + case <-followersEntered: + case <-time.After(time.Second): + t.Fatalf("timeout waiting for follower %d to start", i) + } + } + + ds.Release() + + select { + case <-leaderWriter.Ready(): + case <-time.After(time.Second): + t.Fatalf("timeout waiting for leader to start writing response") + } + + leaderWriter.Release() + wg.Wait() + + for _, res := range results { + require.NoError(t, res.err) + require.NotNil(t, res.info) + } + + assert.False(t, results[0].info.ResolveDeduplicated) + + // Leader should not have SetDeduplicationData called + _, leaderReceived := followerData.Load("leader") + assert.False(t, leaderReceived, "leader should not receive shared data via SetDeduplicationData") + + // Each follower should have received the shared data + for i := 1; i < requestCount; i++ { + assert.True(t, results[i].info.ResolveDeduplicated) + data, ok := followerData.Load(i) + require.True(t, ok, "follower %d should have received shared data", i) + shared, ok := data.(*sharedHeaders) + require.True(t, ok, "shared data should be *sharedHeaders") + assert.Equal(t, "max-age=120", shared.CacheControl) + } +} + func TestResolver_ApolloCompatibilityMode_FetchError(t *testing.T) { options := apolloCompatibilityOptions{ valueCompletion: true, diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight.go b/v2/pkg/engine/resolve/subgraph_request_singleflight.go index 1f6fcfaf4d..804c189c2e 100644 --- a/v2/pkg/engine/resolve/subgraph_request_singleflight.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight.go @@ -2,6 +2,7 @@ package resolve import ( "encoding/binary" + "net/http" "sync" "github.com/cespare/xxhash/v2" @@ -30,6 +31,10 @@ type SingleFlightItem struct { loaded chan struct{} // response is the shared result, it must not be modified response []byte + // responseHeaders contains the cloned subgraph response headers from the leader's HTTP call + responseHeaders http.Header + // statusCode is the HTTP status code from the leader's subgraph response + statusCode int // err is non nil if the leader produced an error while doing the work err error // sizeHint keeps track of the last 50 responses per fetchKey to give an estimate on the size diff --git a/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go b/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go index 312236359a..aaf07f5af7 100644 --- a/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go +++ b/v2/pkg/engine/resolve/subgraph_request_singleflight_test.go @@ -3,6 +3,7 @@ package resolve import ( "bytes" "fmt" + "net/http" "testing" ) @@ -207,3 +208,59 @@ func TestSubgraphRequestSingleFlight_SizeHintRollingWindow(t *testing.T) { t.Fatalf("expected rolling average size hint %d, got %d", expected, next.sizeHint) } } + +func TestSubgraphRequestSingleFlight_LeaderFollowerResponseHeaders(t *testing.T) { + flight := NewSingleFlight(2) + fetchInfo := &FetchInfo{ + DataSourceID: "accounts", + RootFields: []GraphCoordinate{ + {TypeName: "Query", FieldName: "viewer"}, + }, + } + fetchItem := newFetchItem(fetchInfo) + + item, shared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if shared { + t.Fatalf("expected leader to be first caller") + } + + follower, followerShared := flight.GetOrCreateItem(fetchItem, []byte("query { viewer { id } }"), 42) + if !followerShared { + t.Fatalf("expected second caller to be follower") + } + if follower != item { + t.Fatalf("expected follower to receive same item instance") + } + + // Leader sets response, headers, and status code + item.response = []byte(`{"data":{"viewer":{"id":"1"}}}`) + item.statusCode = 200 + item.responseHeaders = http.Header{ + "Cache-Control": []string{"max-age=120"}, + "X-Custom": []string{"value"}, + } + flight.Finish(item) + + select { + case <-item.loaded: + default: + t.Fatalf("expected leader to close loaded channel") + } + + // Verify follower can access all shared data + if string(follower.response) != `{"data":{"viewer":{"id":"1"}}}` { + t.Fatalf("expected follower to get leader's response body") + } + if follower.statusCode != 200 { + t.Fatalf("expected follower to get leader's status code, got %d", follower.statusCode) + } + if follower.responseHeaders == nil { + t.Fatalf("expected follower to get leader's response headers") + } + if follower.responseHeaders.Get("Cache-Control") != "max-age=120" { + t.Fatalf("expected Cache-Control header, got %q", follower.responseHeaders.Get("Cache-Control")) + } + if follower.responseHeaders.Get("X-Custom") != "value" { + t.Fatalf("expected X-Custom header, got %q", follower.responseHeaders.Get("X-Custom")) + } +}