From a48f46499951d80055a6f2f1ce441b69da2e5621 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Tue, 24 Feb 2026 17:21:32 +0100 Subject: [PATCH 1/4] fix: keep correct order of error extension fields when removing not allowed fields --- v2/pkg/engine/resolve/errors.go | 51 +++++++++++-- v2/pkg/engine/resolve/loader.go | 11 +-- v2/pkg/engine/resolve/loader_hooks_test.go | 83 +++++++++++++++++++++- v2/pkg/engine/resolve/resolve_test.go | 26 +++++++ 4 files changed, 159 insertions(+), 12 deletions(-) diff --git a/v2/pkg/engine/resolve/errors.go b/v2/pkg/engine/resolve/errors.go index 93b43f8c01..0ac969f1d0 100644 --- a/v2/pkg/engine/resolve/errors.go +++ b/v2/pkg/engine/resolve/errors.go @@ -2,8 +2,11 @@ package resolve import ( "bytes" + "encoding/json" "fmt" "slices" + + "github.com/wundergraph/astjson" ) type GraphQLError struct { @@ -11,7 +14,7 @@ type GraphQLError struct { Locations []Location `json:"locations,omitempty"` // Path is a list of path segments that lead to the error, can be number or string Path []any `json:"path"` - Extensions map[string]any `json:"extensions,omitempty"` + Extensions *astjson.Value `json:"extensions,omitempty"` } type Location struct { @@ -19,6 +22,44 @@ type Location struct { Column uint32 `json:"column"` } +func (e *GraphQLError) UnmarshalJSON(data []byte) error { + type Alias GraphQLError + + aux := &struct { + *Alias + + Extensions json.RawMessage `json:"extensions,omitempty"` + }{ + Alias: (*Alias)(e), + } + + if err := json.Unmarshal(data, aux); err != nil { + return err + } + + if len(aux.Extensions) > 0 { + e.Extensions = astjson.MustParseBytes(aux.Extensions) + } + + return nil +} + +func (e GraphQLError) MarshalJSON() ([]byte, error) { + aux := &struct { + *GraphQLError + + Extensions json.RawMessage `json:"extensions,omitempty"` + }{ + GraphQLError: &e, + } + + if e.Extensions != nil { + aux.Extensions = e.Extensions.MarshalTo(nil) + } + + return json.Marshal(aux) +} + type SubgraphError struct { DataSourceInfo DataSourceInfo Path string @@ -45,11 +86,9 @@ func (e *SubgraphError) Codes() []string { codes := make([]string, 0, len(e.DownstreamErrors)) for _, downstreamError := range e.DownstreamErrors { - if downstreamError.Extensions != nil { - if ok := downstreamError.Extensions["code"]; ok != nil { - if code, ok := downstreamError.Extensions["code"].(string); ok && !slices.Contains(codes, code) { - codes = append(codes, code) - } + if code := downstreamError.Extensions.Get("code"); code != nil { + if !slices.Contains(codes, code.String()) { + codes = append(codes, code.String()) } } } diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index e01f0ca32a..27580db559 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -682,6 +682,7 @@ func (l *Loader) appendSubgraphError(res *result, fetchItem *FetchItem, value *a } func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.Value) error { + fmt.Println("mergeErrors", string(value.String())) values := value.GetArray() l.optionallyOmitErrorLocations(values) if l.rewriteSubgraphErrorPaths { @@ -777,12 +778,14 @@ func (l *Loader) optionallyAllowCustomExtensionProperties(values []*astjson.Valu value.Del("extensions") continue } + newExt := astjson.ObjectValue(l.jsonArena) - for key := range l.allowedErrorExtensionFields { - if v := extensions.Get(key); v != nil { - newExt.Set(l.jsonArena, key, v) + extensions.GetObject().Visit(func(key []byte, v *astjson.Value) { + if _, ok := l.allowedErrorExtensionFields[unsafebytes.BytesToString(key)]; ok { + newExt.Set(l.jsonArena, string(key), v) } - } + }) + if newExt.GetObject().Len() == 0 { value.Del("extensions") continue diff --git a/v2/pkg/engine/resolve/loader_hooks_test.go b/v2/pkg/engine/resolve/loader_hooks_test.go index 36cc19449b..11462d8c3b 100644 --- a/v2/pkg/engine/resolve/loader_hooks_test.go +++ b/v2/pkg/engine/resolve/loader_hooks_test.go @@ -290,9 +290,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { assert.Equal(t, 0, subgraphError.ResponseCode) assert.Len(t, subgraphError.DownstreamErrors, 2) assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message) - assert.Empty(t, subgraphError.DownstreamErrors[0].Extensions["code"]) + assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions.Get("code")) assert.Equal(t, "errorMessage2", subgraphError.DownstreamErrors[1].Message) - assert.Empty(t, subgraphError.DownstreamErrors[1].Extensions["code"]) + assert.Nil(t, subgraphError.DownstreamErrors[1].Extensions.Get("code")) assert.NotNil(t, resolveCtx.SubgraphErrors()) } @@ -1058,4 +1058,83 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) { } })) + t.Run("fetch with subgraph error propagates only allowed extension fields to downstream errors in hooks", + testFnWithPostEvaluationAndOptions(ResolverOptions{ + MaxConcurrency: 1024, + PropagateSubgraphErrors: true, + PropagateSubgraphStatusCodes: true, + AllowedErrorExtensionFields: []string{"code", "serviceName"}, + SubgraphErrorPropagationMode: SubgraphErrorPropagationModePassThrough, + }, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) { + mockDataSource := NewMockDataSource(ctrl) + mockDataSource.EXPECT(). + Load(gomock.Any(), gomock.Any(), gomock.Any()). + DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","serviceName":"products","internalTrace":"abc123","sensitiveField":"secret"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT","serviceName":"users","internalTrace":"def456"}}]}`), nil + }) + resolveCtx := NewContext(context.Background()) + resolveCtx.LoaderHooks = NewTestLoaderHooks() + return &GraphQLResponse{ + Info: &GraphQLResponseInfo{ + OperationType: ast.OperationTypeQuery, + }, + Fetches: SingleWithPath(&SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: mockDataSource, + PostProcessing: PostProcessingConfiguration{ + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + Info: &FetchInfo{ + DataSourceID: "Products", + DataSourceName: "Products", + }, + }, "query"), + Data: &Object{ + Nullable: false, + Fields: []*Field{ + { + Name: []byte("name"), + Value: &String{ + Path: []string{"name"}, + Nullable: true, + }, + }, + }, + }, + }, resolveCtx, `{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","serviceName":"products"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT","serviceName":"users"}}],"data":{"name":null}}`, + func(t *testing.T) { + loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks) + + assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load()) + assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load()) + + var subgraphError *SubgraphError + assert.Len(t, loaderHooks.errors, 1) + assert.ErrorAs(t, loaderHooks.errors[0], &subgraphError) + assert.Equal(t, "Products", subgraphError.DataSourceInfo.Name) + assert.Equal(t, "query", subgraphError.Path) + assert.Len(t, subgraphError.DownstreamErrors, 2) + + // First error: allowed fields "code" and "serviceName" are present, + // non-allowed fields "internalTrace" and "sensitiveField" are absent. + assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message) + assert.NotNil(t, subgraphError.DownstreamErrors[0].Extensions) + assert.Equal(t, `"GRAPHQL_VALIDATION_FAILED"`, subgraphError.DownstreamErrors[0].Extensions.Get("code").String()) + assert.Equal(t, `"products"`, subgraphError.DownstreamErrors[0].Extensions.Get("serviceName").String()) + assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions.Get("internalTrace")) + assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions.Get("sensitiveField")) + + // Second error: allowed fields "code" and "serviceName" are present, + // non-allowed field "internalTrace" is absent. + assert.Equal(t, "errorMessage2", subgraphError.DownstreamErrors[1].Message) + assert.NotNil(t, subgraphError.DownstreamErrors[1].Extensions) + assert.Equal(t, `"BAD_USER_INPUT"`, subgraphError.DownstreamErrors[1].Extensions.Get("code").String()) + assert.Equal(t, `"users"`, subgraphError.DownstreamErrors[1].Extensions.Get("serviceName").String()) + assert.Nil(t, subgraphError.DownstreamErrors[1].Extensions.Get("internalTrace")) + + assert.NotNil(t, resolveCtx.SubgraphErrors()) + } + })) + } diff --git a/v2/pkg/engine/resolve/resolve_test.go b/v2/pkg/engine/resolve/resolve_test.go index 56619c9035..82a8e1e635 100644 --- a/v2/pkg/engine/resolve/resolve_test.go +++ b/v2/pkg/engine/resolve/resolve_test.go @@ -1624,6 +1624,32 @@ func testFnSubgraphErrorsPassthroughAndOmitCustomFields(fn func(t *testing.T, ct } } +// testFnWithPostEvaluationAndOptions is like testFnWithPostEvaluation but allows +// configuring arbitrary ResolverOptions, enabling tests that need specific settings +// such as AllowedErrorExtensionFields. +func testFnWithPostEvaluationAndOptions(opts ResolverOptions, fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T))) func(t *testing.T) { + return func(t *testing.T) { + t.Helper() + + ctrl := gomock.NewController(t) + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + r := New(rCtx, opts) + node, ctx, expectedOutput, postEvaluation := fn(t, ctrl) + + if t.Skipped() { + return + } + + buf := &bytes.Buffer{} + _, err := r.ResolveGraphQLResponse(ctx, node, nil, buf) + assert.NoError(t, err) + assert.Equal(t, expectedOutput, buf.String()) + ctrl.Finish() + postEvaluation(t) + } +} + func TestResolver_ResolveGraphQLResponse(t *testing.T) { t.Run("empty graphql response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) { From 4d7345e52c03c2577777bf9a04c754bc053f81bb Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 25 Feb 2026 11:29:16 +0100 Subject: [PATCH 2/4] chore: improve code --- v2/pkg/engine/resolve/errors.go | 17 ++++++++++++++--- v2/pkg/engine/resolve/loader.go | 1 - 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/resolve/errors.go b/v2/pkg/engine/resolve/errors.go index 0ac969f1d0..1a0610a1c0 100644 --- a/v2/pkg/engine/resolve/errors.go +++ b/v2/pkg/engine/resolve/errors.go @@ -22,6 +22,9 @@ type Location struct { Column uint32 `json:"column"` } +// UnmarshalJSON unmarshals the GraphQLError from JSON. +// It unmarshals the Extensions field as a json.RawMessage and then parses it into an astjson.Value. +// This is necessary because we want to be able to keep the orginal order of the extensions fields. func (e *GraphQLError) UnmarshalJSON(data []byte) error { type Alias GraphQLError @@ -38,19 +41,27 @@ func (e *GraphQLError) UnmarshalJSON(data []byte) error { } if len(aux.Extensions) > 0 { - e.Extensions = astjson.MustParseBytes(aux.Extensions) + extensions, err := astjson.ParseBytes(aux.Extensions) + if err != nil { + return err + } + + e.Extensions = extensions } return nil } +// MarshalJSON marshals the GraphQLError to JSON. +// This is necessary because we need to marshal the Extensions field from an astjson.Value to a json.RawMessage. func (e GraphQLError) MarshalJSON() ([]byte, error) { + type Alias GraphQLError aux := &struct { - *GraphQLError + *Alias Extensions json.RawMessage `json:"extensions,omitempty"` }{ - GraphQLError: &e, + Alias: (*Alias)(&e), } if e.Extensions != nil { diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 27580db559..8c6fbed84f 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -682,7 +682,6 @@ func (l *Loader) appendSubgraphError(res *result, fetchItem *FetchItem, value *a } func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.Value) error { - fmt.Println("mergeErrors", string(value.String())) values := value.GetArray() l.optionallyOmitErrorLocations(values) if l.rewriteSubgraphErrorPaths { From 0abe0dc15b36d63dd9687b6298014c60b089c6b0 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 25 Feb 2026 12:29:52 +0100 Subject: [PATCH 3/4] chore: use GetStringBytes --- v2/pkg/engine/resolve/errors.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/resolve/errors.go b/v2/pkg/engine/resolve/errors.go index 1a0610a1c0..a4f74889bc 100644 --- a/v2/pkg/engine/resolve/errors.go +++ b/v2/pkg/engine/resolve/errors.go @@ -98,8 +98,9 @@ func (e *SubgraphError) Codes() []string { for _, downstreamError := range e.DownstreamErrors { if code := downstreamError.Extensions.Get("code"); code != nil { - if !slices.Contains(codes, code.String()) { - codes = append(codes, code.String()) + codeStr := string(code.GetStringBytes()) + if !slices.Contains(codes, codeStr) { + codes = append(codes, codeStr) } } } From c2e9bef7f019f2c060bb4a3c3a77e1811f1c5026 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 25 Feb 2026 12:56:52 +0100 Subject: [PATCH 4/4] chore: do not assign an astjson null type --- v2/pkg/engine/resolve/errors.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/resolve/errors.go b/v2/pkg/engine/resolve/errors.go index a4f74889bc..81b9d748da 100644 --- a/v2/pkg/engine/resolve/errors.go +++ b/v2/pkg/engine/resolve/errors.go @@ -46,7 +46,9 @@ func (e *GraphQLError) UnmarshalJSON(data []byte) error { return err } - e.Extensions = extensions + if extensions.Type() != astjson.TypeNull { + e.Extensions = extensions + } } return nil