Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions v2/pkg/engine/resolve/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,77 @@ package resolve

import (
"bytes"
"encoding/json"
"fmt"
"slices"

"github.com/wundergraph/astjson"
)

type GraphQLError struct {
Message string `json:"message"`
Locations []Location `json:"locations,omitempty"`
// Path is a list of path segments that lead to the error, can be number or string
Path []any `json:"path"`
Extensions map[string]any `json:"extensions,omitempty"`
Extensions *astjson.Value `json:"extensions,omitempty"`
}

type Location struct {
Line uint32 `json:"line"`
Column uint32 `json:"column"`
}

// UnmarshalJSON unmarshals the GraphQLError from JSON.
// It unmarshals the Extensions field as a json.RawMessage and then parses it into an astjson.Value.
// This is necessary because we want to be able to keep the orginal order of the extensions fields.
func (e *GraphQLError) UnmarshalJSON(data []byte) error {
type Alias GraphQLError

aux := &struct {
*Alias

Extensions json.RawMessage `json:"extensions,omitempty"`
}{
Alias: (*Alias)(e),
}

if err := json.Unmarshal(data, aux); err != nil {
return err
}

if len(aux.Extensions) > 0 {
extensions, err := astjson.ParseBytes(aux.Extensions)
if err != nil {
return err
}

if extensions.Type() != astjson.TypeNull {
e.Extensions = extensions
}
}
Comment thread
Noroth marked this conversation as resolved.

return nil
}

// MarshalJSON marshals the GraphQLError to JSON.
// This is necessary because we need to marshal the Extensions field from an astjson.Value to a json.RawMessage.
func (e GraphQLError) MarshalJSON() ([]byte, error) {
type Alias GraphQLError
aux := &struct {
*Alias

Extensions json.RawMessage `json:"extensions,omitempty"`
}{
Alias: (*Alias)(&e),
}

if e.Extensions != nil {
aux.Extensions = e.Extensions.MarshalTo(nil)
}

return json.Marshal(aux)
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

type SubgraphError struct {
DataSourceInfo DataSourceInfo
Path string
Expand All @@ -45,11 +99,10 @@ func (e *SubgraphError) Codes() []string {
codes := make([]string, 0, len(e.DownstreamErrors))

for _, downstreamError := range e.DownstreamErrors {
if downstreamError.Extensions != nil {
if ok := downstreamError.Extensions["code"]; ok != nil {
if code, ok := downstreamError.Extensions["code"].(string); ok && !slices.Contains(codes, code) {
codes = append(codes, code)
}
if code := downstreamError.Extensions.Get("code"); code != nil {
codeStr := string(code.GetStringBytes())
if !slices.Contains(codes, codeStr) {
codes = append(codes, codeStr)
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -777,12 +777,14 @@ func (l *Loader) optionallyAllowCustomExtensionProperties(values []*astjson.Valu
value.Del("extensions")
continue
}

newExt := astjson.ObjectValue(l.jsonArena)
for key := range l.allowedErrorExtensionFields {
if v := extensions.Get(key); v != nil {
newExt.Set(l.jsonArena, key, v)
extensions.GetObject().Visit(func(key []byte, v *astjson.Value) {
if _, ok := l.allowedErrorExtensionFields[unsafebytes.BytesToString(key)]; ok {
newExt.Set(l.jsonArena, string(key), v)
}
}
})

if newExt.GetObject().Len() == 0 {
value.Del("extensions")
continue
Expand Down
83 changes: 81 additions & 2 deletions v2/pkg/engine/resolve/loader_hooks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) {
assert.Equal(t, 0, subgraphError.ResponseCode)
assert.Len(t, subgraphError.DownstreamErrors, 2)
assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message)
assert.Empty(t, subgraphError.DownstreamErrors[0].Extensions["code"])
assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions.Get("code"))
assert.Equal(t, "errorMessage2", subgraphError.DownstreamErrors[1].Message)
assert.Empty(t, subgraphError.DownstreamErrors[1].Extensions["code"])
assert.Nil(t, subgraphError.DownstreamErrors[1].Extensions.Get("code"))

assert.NotNil(t, resolveCtx.SubgraphErrors())
}
Expand Down Expand Up @@ -1058,4 +1058,83 @@ func TestLoaderHooks_FetchPipeline(t *testing.T) {
}
}))

t.Run("fetch with subgraph error propagates only allowed extension fields to downstream errors in hooks",
testFnWithPostEvaluationAndOptions(ResolverOptions{
MaxConcurrency: 1024,
PropagateSubgraphErrors: true,
PropagateSubgraphStatusCodes: true,
AllowedErrorExtensionFields: []string{"code", "serviceName"},
SubgraphErrorPropagationMode: SubgraphErrorPropagationModePassThrough,
}, func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T)) {
mockDataSource := NewMockDataSource(ctrl)
mockDataSource.EXPECT().
Load(gomock.Any(), gomock.Any(), gomock.Any()).
DoAndReturn(func(ctx context.Context, headers http.Header, input []byte) ([]byte, error) {
return []byte(`{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","serviceName":"products","internalTrace":"abc123","sensitiveField":"secret"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT","serviceName":"users","internalTrace":"def456"}}]}`), nil
})
resolveCtx := NewContext(context.Background())
resolveCtx.LoaderHooks = NewTestLoaderHooks()
return &GraphQLResponse{
Info: &GraphQLResponseInfo{
OperationType: ast.OperationTypeQuery,
},
Fetches: SingleWithPath(&SingleFetch{
FetchConfiguration: FetchConfiguration{
DataSource: mockDataSource,
PostProcessing: PostProcessingConfiguration{
SelectResponseErrorsPath: []string{"errors"},
},
},
Info: &FetchInfo{
DataSourceID: "Products",
DataSourceName: "Products",
},
}, "query"),
Data: &Object{
Nullable: false,
Fields: []*Field{
{
Name: []byte("name"),
Value: &String{
Path: []string{"name"},
Nullable: true,
},
},
},
},
}, resolveCtx, `{"errors":[{"message":"errorMessage","extensions":{"code":"GRAPHQL_VALIDATION_FAILED","serviceName":"products"}},{"message":"errorMessage2","extensions":{"code":"BAD_USER_INPUT","serviceName":"users"}}],"data":{"name":null}}`,
func(t *testing.T) {
loaderHooks := resolveCtx.LoaderHooks.(*TestLoaderHooks)

assert.Equal(t, int64(1), loaderHooks.preFetchCalls.Load())
assert.Equal(t, int64(1), loaderHooks.postFetchCalls.Load())

var subgraphError *SubgraphError
assert.Len(t, loaderHooks.errors, 1)
assert.ErrorAs(t, loaderHooks.errors[0], &subgraphError)
assert.Equal(t, "Products", subgraphError.DataSourceInfo.Name)
assert.Equal(t, "query", subgraphError.Path)
assert.Len(t, subgraphError.DownstreamErrors, 2)

// First error: allowed fields "code" and "serviceName" are present,
// non-allowed fields "internalTrace" and "sensitiveField" are absent.
assert.Equal(t, "errorMessage", subgraphError.DownstreamErrors[0].Message)
assert.NotNil(t, subgraphError.DownstreamErrors[0].Extensions)
assert.Equal(t, `"GRAPHQL_VALIDATION_FAILED"`, subgraphError.DownstreamErrors[0].Extensions.Get("code").String())
assert.Equal(t, `"products"`, subgraphError.DownstreamErrors[0].Extensions.Get("serviceName").String())
assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions.Get("internalTrace"))
assert.Nil(t, subgraphError.DownstreamErrors[0].Extensions.Get("sensitiveField"))

// Second error: allowed fields "code" and "serviceName" are present,
// non-allowed field "internalTrace" is absent.
assert.Equal(t, "errorMessage2", subgraphError.DownstreamErrors[1].Message)
assert.NotNil(t, subgraphError.DownstreamErrors[1].Extensions)
assert.Equal(t, `"BAD_USER_INPUT"`, subgraphError.DownstreamErrors[1].Extensions.Get("code").String())
assert.Equal(t, `"users"`, subgraphError.DownstreamErrors[1].Extensions.Get("serviceName").String())
assert.Nil(t, subgraphError.DownstreamErrors[1].Extensions.Get("internalTrace"))

assert.NotNil(t, resolveCtx.SubgraphErrors())
}
}))

}
26 changes: 26 additions & 0 deletions v2/pkg/engine/resolve/resolve_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,32 @@ func testFnSubgraphErrorsPassthroughAndOmitCustomFields(fn func(t *testing.T, ct
}
}

// testFnWithPostEvaluationAndOptions is like testFnWithPostEvaluation but allows
// configuring arbitrary ResolverOptions, enabling tests that need specific settings
// such as AllowedErrorExtensionFields.
func testFnWithPostEvaluationAndOptions(opts ResolverOptions, fn func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx *Context, expectedOutput string, postEvaluation func(t *testing.T))) func(t *testing.T) {
return func(t *testing.T) {
t.Helper()

ctrl := gomock.NewController(t)
rCtx, cancel := context.WithCancel(context.Background())
defer cancel()
r := New(rCtx, opts)
node, ctx, expectedOutput, postEvaluation := fn(t, ctrl)

if t.Skipped() {
return
}

buf := &bytes.Buffer{}
_, err := r.ResolveGraphQLResponse(ctx, node, nil, buf)
assert.NoError(t, err)
assert.Equal(t, expectedOutput, buf.String())
ctrl.Finish()
postEvaluation(t)
}
}

func TestResolver_ResolveGraphQLResponse(t *testing.T) {

t.Run("empty graphql response", testFn(func(t *testing.T, ctrl *gomock.Controller) (node *GraphQLResponse, ctx Context, expectedOutput string) {
Expand Down