diff --git a/execution/federationtesting/skipped_fetch_test.go b/execution/federationtesting/skipped_fetch_test.go new file mode 100644 index 0000000000..ed3ea8c467 --- /dev/null +++ b/execution/federationtesting/skipped_fetch_test.go @@ -0,0 +1,70 @@ +package federationtesting + +import ( + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/jensneuse/abstractlogger" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/wundergraph/graphql-go-tools/execution/engine" + "github.com/wundergraph/graphql-go-tools/execution/graphql" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +func TestSkippedFetchOnNullParent(t *testing.T) { + // Users subgraph: returns null for the "user" field. + usersServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = io.ReadAll(r.Body) + _, _ = w.Write([]byte(`{"data":{"user":null}}`)) + })) + t.Cleanup(usersServer.Close) + + // Reviews subgraph: tracks all requests. Should never be called at query time + // because the user is null and the entity fetch must be skipped. + var reviewsCalls atomic.Int64 + reviewsServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + reviewsCalls.Add(1) + w.Header().Set("Content-Type", "application/json") + _, _ = io.ReadAll(r.Body) + _, _ = w.Write([]byte(`{"data":{"_entities":[]}}`)) + })) + t.Cleanup(reviewsServer.Close) + + const usersSDL = `type Query { user(id: ID!): User } type User @key(fields: "id") { id: ID! name: String! }` + const reviewsSDL = `type User @key(fields: "id") { id: ID! @external reviews: [Review] } type Review { body: String! }` + + ctx := t.Context() + factory := engine.NewFederationEngineConfigFactory(ctx, []engine.SubgraphConfiguration{ + {Name: "users", URL: usersServer.URL, SDL: usersSDL}, + {Name: "reviews", URL: reviewsServer.URL, SDL: reviewsSDL}, + }) + + engineConfig, err := factory.BuildEngineConfiguration() + require.NoError(t, err) + + eng, err := engine.NewExecutionEngine(ctx, abstractlogger.NoopLogger, engineConfig, resolve.ResolverOptions{ + MaxConcurrency: 1024, + }) + require.NoError(t, err) + + gqlRequest := &graphql.Request{ + Query: `{ user(id: "1") { id name reviews { body } } }`, + } + + resultWriter := graphql.NewEngineResultWriter() + err = eng.Execute(ctx, gqlRequest, &resultWriter) + require.NoError(t, err) + + // The user is null, so the response should reflect that without panic. + assert.Equal(t, `{"data":{"user":null}}`, resultWriter.String()) + + // The reviews subgraph must NOT have been called — the entity fetch was skipped + // because the parent user is null. + assert.Equal(t, int64(0), reviewsCalls.Load(), "reviews subgraph should not be called when parent user is null") +} diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 3d5a1b40f0..014762df5e 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -56,9 +56,12 @@ func stringValueOnArena(a arena.Arena, s string) *astjson.Value { } type LoaderHooks interface { - // OnLoad is called before the fetch is executed + // OnLoad is called before a fetch is executed. + // The returned context is passed to OnFinished after the fetch completes. + // OnLoad is not called when the fetch is skipped (e.g. null parent data, auth rejection). OnLoad(ctx context.Context, ds DataSourceInfo) context.Context - // OnFinished is called after the fetch has been executed and the response has been processed and merged + // OnFinished is called after a fetch has been executed and the response has been processed and merged. + // It is only called when OnLoad was called, i.e. when the fetch was not skipped. OnFinished(ctx context.Context, ds DataSourceInfo, info *ResponseInfo) } @@ -139,8 +142,9 @@ type result struct { rateLimitRejected bool rateLimitRejectedReason string - // loaderHookContext used to share data between the OnLoad and OnFinished hooks - // It should be valid even when OnLoad isn't called + // loaderHookContext is set by OnLoad during fetch execution. + // It is nil when the fetch was skipped (e.g. null parent data, auth rejection), + // in which case OnFinished must not be called. loaderHookContext context.Context httpResponseContext *httpclient.ResponseContext @@ -270,20 +274,14 @@ func (l *Loader) resolveParallel(nodes []*FetchTreeNode) error { if results[i].nestedMergeItems != nil { for j := range results[i].nestedMergeItems { err = l.mergeResult(nodes[i].Item, results[i].nestedMergeItems[j], itemsItems[i][j:j+1]) - if l.ctx.LoaderHooks != nil && results[i].nestedMergeItems[j].loaderHookContext != nil { - l.ctx.LoaderHooks.OnFinished(results[i].nestedMergeItems[j].loaderHookContext, - results[i].nestedMergeItems[j].ds, - newResponseInfo(results[i].nestedMergeItems[j], l.ctx.subgraphErrors)) - } + l.callOnFinished(results[i].nestedMergeItems[j]) if err != nil { return errors.WithStack(err) } } } else { err = l.mergeResult(nodes[i].Item, results[i], itemsItems[i]) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(results[i].loaderHookContext, results[i].ds, newResponseInfo(results[i], l.ctx.subgraphErrors)) - } + l.callOnFinished(results[i]) if err != nil { return errors.WithStack(err) } @@ -316,9 +314,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return err } err = l.mergeResult(item, res, items) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) - } + l.callOnFinished(res) return err case *BatchEntityFetch: res := &result{} @@ -328,9 +324,7 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return errors.WithStack(err) } err = l.mergeResult(item, res, items) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) - } + l.callOnFinished(res) return err case *EntityFetch: res := &result{} @@ -339,15 +333,19 @@ func (l *Loader) resolveSingle(item *FetchItem) error { return errors.WithStack(err) } err = l.mergeResult(item, res, items) - if l.ctx.LoaderHooks != nil { - l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) - } + l.callOnFinished(res) return err default: return nil } } +func (l *Loader) callOnFinished(res *result) { + if l.ctx.LoaderHooks != nil && res.loaderHookContext != nil { + l.ctx.LoaderHooks.OnFinished(res.loaderHookContext, res.ds, newResponseInfo(res, l.ctx.subgraphErrors)) + } +} + func (l *Loader) selectItemsForPath(path []FetchItemPathElement) []*astjson.Value { // Use arena allocation for the initial items slice items := arena.AllocateSlice[*astjson.Value](l.jsonArena, 1, 1) diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index ea63ebac65..286fe7330f 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -21,13 +21,8 @@ type TestLoaderHooks struct { mu sync.Mutex } -func NewTestLoaderHooks() LoaderHooks { - return &TestLoaderHooks{ - preFetchCalls: atomic.Int64{}, - postFetchCalls: atomic.Int64{}, - errors: make([]error, 0), - mu: sync.Mutex{}, - } +func NewTestLoaderHooks() *TestLoaderHooks { + return &TestLoaderHooks{} } func (f *TestLoaderHooks) OnLoad(ctx context.Context, ds DataSourceInfo) context.Context { @@ -548,6 +543,99 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { }, *NewContext(context.Background()), `{"errors":[{"message":"errorMessage","extensions":{"code":"DOWNSTREAM_SERVICE_ERROR"}},{"message":"errorMessage2","extensions":{"code":"DOWNSTREAM_SERVICE_ERROR"}}],"data":{"name":null}}` })) + // Test that skipped fetches (null parent) don't call OnFinished with nil loaderHookContext. + // Covers both the serial (resolveSingle) and parallel (resolveParallel) code paths. + for _, tc := range []struct { + name string + wrapSecondFetch func(node *FetchTreeNode) *FetchTreeNode + }{ + { + name: "skipped fetch does not call OnFinished with nil loaderHookContext", + wrapSecondFetch: func(node *FetchTreeNode) *FetchTreeNode { return node }, + }, + { + name: "parallel skipped fetch does not call OnFinished with nil loaderHookContext", + wrapSecondFetch: func(node *FetchTreeNode) *FetchTreeNode { return Parallel(node) }, + }, + } { + t.Run(tc.name, testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { + userService := NewMockDataSource(ctrl) + userService.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"data":{"user":null}}`), nil + }) + + detailsService := NewMockDataSource(ctrl) + detailsService.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + + resolveCtx := NewContext(context.Background()) + resolveCtx.LoaderHooks = NewTestLoaderHooks() + + return &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + Fetches: Sequence( + Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: userService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + Info: &FetchInfo{ + DataSourceID: "Users", + DataSourceName: "Users", + }, + }), + tc.wrapSecondFetch(SingleWithPath(&SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: detailsService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + Info: &FetchInfo{ + DataSourceID: "Details", + DataSourceName: "Details", + }, + }, "query.user", ObjectPath("user"))), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Nullable: true, + Path: []string{"user"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: true, + }, + }, + }, + }, + }, + }, + }, + }, resolveCtx, `{"data":{"user":null}}`, + func(t *testing.T) { + loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) + // Only the first fetch should trigger OnLoad/OnFinished. + // The second fetch is skipped (null parent), so OnFinished must NOT be called + // (its loaderHookContext would be nil, which previously caused a panic in the router). + assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) + assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) + } + })) + } + t.Run("Fallback to default extension code value when extensions is an empty object", testFnSubgraphErrorsWithExtensionDefaultCode(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { mockDataSource := NewMockDataSource(ctrl) mockDataSource.EXPECT(). @@ -583,4 +671,286 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { }, *NewContext(context.Background()), `{"errors":[{"message":"errorMessage","extensions":{"code":"DOWNSTREAM_SERVICE_ERROR"}},{"message":"errorMessage2","extensions":{"code":"DOWNSTREAM_SERVICE_ERROR"}}],"data":{"name":null}}` })) + t.Run("skipped entity fetch does not call OnFinished with nil loaderHookContext", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { + userService := NewMockDataSource(ctrl) + userService.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"data":{"user":{"name":"Bill","info":null}}}`), nil + }) + + infoService := NewMockDataSource(ctrl) + infoService.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + + resolveCtx := NewContext(context.Background()) + resolveCtx.LoaderHooks = NewTestLoaderHooks() + + return &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + Fetches: Sequence( + Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: userService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + Info: &FetchInfo{ + DataSourceID: "Users", + DataSourceName: "Users", + }, + }), + SingleWithPath(&EntityFetch{ + FetchDependencies: FetchDependencies{ + FetchID: 1, + DependsOnFetchIDs: []int{0}, + }, + Input: EntityInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Item: InputTemplate{ + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + }, + OnTypeNames: [][]byte{[]byte("Info")}, + }, + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + OnTypeNames: [][]byte{[]byte("Info")}, + }, + }, + }), + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + SkipErrItem: true, + }, + DataSource: infoService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities", "0"}, + }, + Info: &FetchInfo{ + DataSourceID: "Info", + DataSourceName: "Info", + }, + }, "user.info", ObjectPath("user"), ObjectPath("info")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("info"), + Value: &Object{ + Nullable: true, + Path: []string{"info"}, + Fields: []*Field{ + { + Name: []byte("age"), + Value: &Integer{ + Path: []string{"age"}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, resolveCtx, `{"data":{"user":{"name":"Bill","info":null}}}`, + func(t *testing.T) { + loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) + assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) + assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) + } + })) + + t.Run("skipped batch entity fetch does not call OnFinished with nil loaderHookContext", testFnWithPostEvaluation(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { + userService := NewMockDataSource(ctrl) + userService.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"data":{"user":{"name":"Bill","infoList":[{"id":1,"__typename":"Unknown"}]}}}`), nil + }) + + infoService := NewMockDataSource(ctrl) + infoService.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + Times(0) + + resolveCtx := NewContext(context.Background()) + resolveCtx.LoaderHooks = NewTestLoaderHooks() + + return &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + Fetches: Sequence( + Single(&SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: userService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + }, + }, + Info: &FetchInfo{ + DataSourceID: "Users", + DataSourceName: "Users", + }, + }), + SingleWithPath(&BatchEntityFetch{ + FetchDependencies: FetchDependencies{ + FetchID: 1, + DependsOnFetchIDs: []int{0}, + }, + Input: BatchInput{ + Header: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`{"method":"POST","url":"http://localhost:4002","body":{"query":"query($representations: [_Any!]!){_entities(representations: $representations) { ... on Info { age }}}","variables":{"representations":[`), + SegmentType: StaticSegmentType, + }, + }, + }, + Items: []InputTemplate{ + { + Segments: []TemplateSegment{ + { + SegmentType: VariableSegmentType, + VariableKind: ResolvableObjectVariableKind, + Renderer: NewGraphQLVariableResolveRenderer(&Object{ + Fields: []*Field{ + { + Name: []byte("id"), + Value: &Integer{ + Path: []string{"id"}, + }, + OnTypeNames: [][]byte{[]byte("Info")}, + }, + { + Name: []byte("__typename"), + Value: &String{ + Path: []string{"__typename"}, + }, + OnTypeNames: [][]byte{[]byte("Info")}, + }, + }, + }), + }, + }, + }, + }, + Separator: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`,`), + SegmentType: StaticSegmentType, + }, + }, + }, + Footer: InputTemplate{ + Segments: []TemplateSegment{ + { + Data: []byte(`]}}}`), + SegmentType: StaticSegmentType, + }, + }, + }, + SkipNullItems: true, + SkipEmptyObjectItems: true, + SkipErrItems: true, + }, + DataSource: infoService, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data", "_entities"}, + }, + Info: &FetchInfo{ + DataSourceID: "Info", + DataSourceName: "Info", + }, + }, "user.infoList", ObjectPath("user"), ArrayPath("infoList")), + ), + Data: &Object{ + Fields: []*Field{ + { + Name: []byte("user"), + Value: &Object{ + Path: []string{"user"}, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + }, + }, + { + Name: []byte("infoList"), + Value: &Array{ + Path: []string{"infoList"}, + Item: &Object{ + Fields: []*Field{ + { + Name: []byte("age"), + Value: &Integer{ + Path: []string{"age"}, + }, + OnTypeNames: [][]byte{[]byte("Info")}, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, resolveCtx, `{"data":{"user":{"name":"Bill","infoList":[{}]}}}`, + func(t *testing.T) { + loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) + assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) + assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) + } + })) + }