diff --git a/v2/pkg/engine/resolve/loader.go b/v2/pkg/engine/resolve/loader.go index 48eafb525a..07991e8100 100644 --- a/v2/pkg/engine/resolve/loader.go +++ b/v2/pkg/engine/resolve/loader.go @@ -36,6 +36,25 @@ const ( IntrospectionTypeEnumValuesDataSourceID = "introspection__type__enumValues" ) +// parseStringOnArena copies the string bytes onto the arena before parsing. +// This is critical because arena-allocated Values store references to the input's +// backing bytes, and Go's GC cannot trace pointers stored in arena memory (which +// is backed by []byte buffers). Without this, the GC may collect the input string +// while Values still reference it, causing segfaults. +func parseStringOnArena(a arena.Arena, s string) (*astjson.Value, error) { + b := arena.AllocateSlice[byte](a, len(s), len(s)) + copy(b, s) + return astjson.ParseBytesWithArena(a, b) +} + +// stringValueOnArena copies the string bytes onto the arena before creating +// a StringValue. Same GC safety reasoning as parseStringOnArena. +func stringValueOnArena(a arena.Arena, s string) *astjson.Value { + b := arena.AllocateSlice[byte](a, len(s), len(s)) + copy(b, s) + return astjson.StringValueBytes(a, b) +} + type LoaderHooks interface { // OnLoad is called before the fetch is executed OnLoad(ctx context.Context, ds DataSourceInfo) context.Context @@ -735,7 +754,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V } // Wrap mode (default) - errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) + errorObject, err := parseStringOnArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason)) if err != nil { return err } @@ -807,16 +826,16 @@ func (l *Loader) optionallyEnsureExtensionErrorCode(values []*astjson.Value) { switch extensions.Type() { case astjson.TypeObject: if !extensions.Exists("code") { - extensions.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + extensions.Set(l.jsonArena, "code", stringValueOnArena(l.jsonArena, l.defaultErrorExtensionCode)) } case astjson.TypeNull: extensionsObj := astjson.ObjectValue(l.jsonArena) - extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + extensionsObj.Set(l.jsonArena, "code", stringValueOnArena(l.jsonArena, l.defaultErrorExtensionCode)) value.Set(l.jsonArena, "extensions", extensionsObj) } } else { extensionsObj := astjson.ObjectValue(l.jsonArena) - extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode)) + extensionsObj.Set(l.jsonArena, "code", stringValueOnArena(l.jsonArena, l.defaultErrorExtensionCode)) value.Set(l.jsonArena, "extensions", extensionsObj) } } @@ -834,15 +853,15 @@ func (l *Loader) optionallyAttachServiceNameToErrorExtension(values []*astjson.V extensions := value.Get("extensions") switch extensions.Type() { case astjson.TypeObject: - extensions.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + extensions.Set(l.jsonArena, "serviceName", stringValueOnArena(l.jsonArena, serviceName)) case astjson.TypeNull: extensionsObj := astjson.ObjectValue(l.jsonArena) - extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + extensionsObj.Set(l.jsonArena, "serviceName", stringValueOnArena(l.jsonArena, serviceName)) value.Set(l.jsonArena, "extensions", extensionsObj) } } else { extensionsObj := astjson.ObjectValue(l.jsonArena) - extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName)) + extensionsObj.Set(l.jsonArena, "serviceName", stringValueOnArena(l.jsonArena, serviceName)) value.Set(l.jsonArena, "extensions", extensionsObj) } } @@ -948,7 +967,7 @@ func rewriteErrorPaths(a arena.Arena, fetchItem *FetchItem, values []*astjson.Va } arr := astjson.ArrayValue(a) for j := range pathPrefix { - astjson.AppendToArray(arr, astjson.StringValue(a, pathPrefix[j])) + astjson.AppendToArray(arr, stringValueOnArena(a, pathPrefix[j])) } for j := i + 1; j < len(pathItems); j++ { // If the item after _entities is an index (number), we should ignore it. @@ -981,13 +1000,13 @@ func (l *Loader) setSubgraphStatusCode(values []*astjson.Value, statusCode int) if extensions.Type() != astjson.TypeObject { continue } - v, err := astjson.ParseWithArena(l.jsonArena, strconv.Itoa(statusCode)) + v, err := parseStringOnArena(l.jsonArena, strconv.Itoa(statusCode)) if err != nil { continue } extensions.Set(l.jsonArena, "statusCode", v) } else { - v, err := astjson.ParseWithArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`) + v, err := parseStringOnArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`) if err != nil { continue } @@ -1028,7 +1047,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { } } }`, res.ds.Name, http.StatusText(res.statusCode), res.statusCode) - apolloRouterStatusError, err := astjson.ParseWithArena(l.jsonArena, apolloRouterStatusErrorJSON) + apolloRouterStatusError, err := parseStringOnArena(l.jsonArena, apolloRouterStatusErrorJSON) if err != nil { return err } @@ -1044,7 +1063,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error { func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error { path := l.renderAtPathErrorPart(fetchItem.ResponsePath) msg := fmt.Sprintf(`{"message":"Failed to obtain field dependencies from Subgraph '%s'%s."}`, res.ds.Name, path) - errorObject, err := astjson.ParseWithArena(l.jsonArena, msg) + errorObject, err := parseStringOnArena(l.jsonArena, msg) if err != nil { return err } @@ -1059,7 +1078,7 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, reason string) error { l.ctx.appendSubgraphErrors(res.ds, res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) + errorObject, err := parseStringOnArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason)) if err != nil { return err } @@ -1080,7 +1099,7 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s l.ctx.appendSubgraphErrors(res.ds, res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode)) - errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason)) + errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason)) if err != nil { return err } @@ -1121,13 +1140,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re if res.ds.Name == "" { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) + errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) + errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1137,13 +1156,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re } else { for _, reason := range res.authorizationRejectedReasons { if reason == "" { - errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) + errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode)) if err != nil { continue } astjson.AppendToArray(l.resolvable.errors, errorObject) } else { - errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) + errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode)) if err != nil { continue } @@ -1163,31 +1182,31 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result ) if res.ds.Name == "" { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) + errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) + errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } else { if res.rateLimitRejectedReason == "" { - errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) + errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart)) if err != nil { return err } } else { - errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) + errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason)) if err != nil { return err } } } if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled { - extension, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) + extension, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code)) if err != nil { return err } diff --git a/v2/pkg/engine/resolve/loader_arena_gc_test.go b/v2/pkg/engine/resolve/loader_arena_gc_test.go new file mode 100644 index 0000000000..d1bde362ce --- /dev/null +++ b/v2/pkg/engine/resolve/loader_arena_gc_test.go @@ -0,0 +1,318 @@ +package resolve + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "net/http" + "runtime" + "testing" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" +) + +// _errorReturningDataSource implements DataSource and returns a configurable error from Load. +type _errorReturningDataSource struct { + err error +} + +func (d *_errorReturningDataSource) Load(ctx context.Context, headers http.Header, input []byte) ([]byte, error) { + return nil, d.err +} + +func (d *_errorReturningDataSource) LoadWithFiles(ctx context.Context, headers http.Header, input []byte, files []*httpclient.FileUpload) ([]byte, error) { + return nil, d.err +} + +// gcTestResponse builds a minimal GraphQLResponse with a single fetch using the given DataSource. +// Callers can override FetchConfiguration and Info fields on the returned SingleFetch. +func gcTestResponse(ds DataSource) (*GraphQLResponse, *SingleFetch) { + fetch := &SingleFetch{ + FetchConfiguration: FetchConfiguration{ + DataSource: ds, + PostProcessing: PostProcessingConfiguration{ + SelectResponseDataPath: []string{"data"}, + SelectResponseErrorsPath: []string{"errors"}, + }, + }, + Info: &FetchInfo{ + DataSourceID: "test-ds", + DataSourceName: "testService", + RootFields: []GraphCoordinate{ + { + TypeName: "Query", + FieldName: "field", + }, + }, + }, + } + return &GraphQLResponse{ + Fetches: SingleWithPath(fetch, "query"), + Data: &Object{ + Nullable: true, + Fields: []*Field{ + { + Name: []byte("field"), + Value: &String{Path: []string{"field"}, Nullable: true}, + }, + }, + }, + Info: &GraphQLResponseInfo{OperationType: ast.OperationTypeQuery}, + }, fetch +} + +// Benchmark_ArenaGCSafety exercises all error codepaths that produce arena-allocated +// JSON values via parseStringOnArena / stringValueOnArena. Each sub-benchmark resolves +// a GraphQLResponse through ArenaResolveGraphQLResponse with runtime.GC() calls between +// iterations to maximize GC pressure on any dangling pointers. +// +// If the GC safety fix (copying string bytes onto the arena before parsing) were reverted, +// these benchmarks would SIGSEGV. +// +// Codepaths NOT directly covered (require HTTP status codes injected by loadByContext): +// - renderErrorsStatusFallback (status code fallback) +// - addApolloRouterCompatibilityError (Apollo compat error) +// - setSubgraphStatusCode (subgraph status code propagation) +// +// These use the same parseStringOnArena helper and are covered transitively. +func Benchmark_ArenaGCSafety(b *testing.B) { + type testCase struct { + name string + resolverOpts func() ResolverOptions + setupCtx func() *Context + setupResp func() *GraphQLResponse + } + + baseResolverOpts := func() ResolverOptions { + return ResolverOptions{ + MaxConcurrency: 1024, + PropagateSubgraphErrors: true, + PropagateSubgraphStatusCodes: true, + } + } + + cases := []testCase{ + { + // Codepath 1: DataSource.Load() returns error → renderErrorsFailedToFetch + name: "fetchError", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(&_errorReturningDataSource{err: errors.New("connection refused")}) + return resp + }, + }, + { + // Codepath 4: DataSource.Load() returns empty response → renderErrorsFailedToFetch(emptyGraphQLResponse) + name: "emptyResponse", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource("")) + return resp + }, + }, + { + // Codepath 5: DataSource.Load() returns invalid JSON → renderErrorsFailedToFetch(invalidGraphQLResponse) + name: "invalidJSON", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource("{invalid")) + return resp + }, + }, + { + // Codepath 6: Response has no data/errors key → renderErrorsFailedToFetch(invalidGraphQLResponseShape) + name: "invalidShape", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"something":"else"}`)) + return resp + }, + }, + { + // Codepath 3: Subgraph returns errors with wrap mode (default) → mergeErrors wrap path → parseStringOnArena + name: "subgraphErrorsWrapMode", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"errors":[{"message":"downstream error"}],"data":null}`)) + return resp + }, + }, + { + // Codepath 2: Subgraph returns errors with passthrough mode → mergeErrors passthrough path + name: "subgraphErrorsPassthroughMode", + resolverOpts: func() ResolverOptions { + opts := baseResolverOpts() + opts.SubgraphErrorPropagationMode = SubgraphErrorPropagationModePassThrough + return opts + }, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"errors":[{"message":"downstream error"}],"data":null}`)) + return resp + }, + }, + { + // Codepath 12: defaultErrorExtensionCode set + subgraph errors → stringValueOnArena + name: "subgraphErrorsWithExtensionCode", + resolverOpts: func() ResolverOptions { + opts := baseResolverOpts() + opts.SubgraphErrorPropagationMode = SubgraphErrorPropagationModePassThrough + opts.DefaultErrorExtensionCode = "DOWNSTREAM_SERVICE_ERROR" + return opts + }, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"errors":[{"message":"downstream error"}],"data":null}`)) + return resp + }, + }, + { + // Codepath 13: attachServiceNameToErrorExtension set → stringValueOnArena + name: "subgraphErrorsWithServiceName", + resolverOpts: func() ResolverOptions { + opts := baseResolverOpts() + opts.SubgraphErrorPropagationMode = SubgraphErrorPropagationModePassThrough + opts.AttachServiceNameToErrorExtensions = true + return opts + }, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"errors":[{"message":"downstream error"}],"data":null}`)) + return resp + }, + }, + { + // Codepath 12+13 combined: both extension code and service name + name: "subgraphErrorsWithExtensionCodeAndServiceName", + resolverOpts: func() ResolverOptions { + opts := baseResolverOpts() + opts.SubgraphErrorPropagationMode = SubgraphErrorPropagationModePassThrough + opts.DefaultErrorExtensionCode = "DOWNSTREAM_SERVICE_ERROR" + opts.AttachServiceNameToErrorExtensions = true + return opts + }, + setupCtx: func() *Context { + return NewContext(context.Background()) + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"errors":[{"message":"downstream error"}],"data":null}`)) + return resp + }, + }, + { + // Codepath 9: Authorization rejected → renderAuthorizationRejectedErrors → parseStringOnArena + name: "authorizationRejected", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + ctx := NewContext(context.Background()) + ctx.SetAuthorizer(createTestAuthorizer( + func(ctx *Context, dataSourceID string, input json.RawMessage, coordinate GraphCoordinate) (*AuthorizationDeny, error) { + return &AuthorizationDeny{Reason: "not allowed"}, nil + }, + func(ctx *Context, dataSourceID string, object json.RawMessage, coordinate GraphCoordinate) (*AuthorizationDeny, error) { + return nil, nil + }, + )) + return ctx + }, + setupResp: func() *GraphQLResponse { + resp, fetch := gcTestResponse(FakeDataSource(`{"data":{"field":"value"}}`)) + fetch.Info.RootFields[0].HasAuthorizationRule = true + return resp + }, + }, + { + // Codepath 10: Rate limit rejected → renderRateLimitRejectedErrors → parseStringOnArena + name: "rateLimitRejected", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + ctx := NewContext(context.Background()) + ctx.SetRateLimiter(&testRateLimiter{ + allowFn: func(ctx *Context, info *FetchInfo, input json.RawMessage) (*RateLimitDeny, error) { + return &RateLimitDeny{Reason: "rate limit exceeded"}, nil + }, + }) + ctx.RateLimitOptions = RateLimitOptions{Enable: true} + return ctx + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"data":{"field":"value"}}`)) + return resp + }, + }, + { + // Codepath 14: Rate limit with extension code → extra parseStringOnArena for extension + name: "rateLimitWithExtensionCode", + resolverOpts: baseResolverOpts, + setupCtx: func() *Context { + ctx := NewContext(context.Background()) + ctx.SetRateLimiter(&testRateLimiter{ + allowFn: func(ctx *Context, info *FetchInfo, input json.RawMessage) (*RateLimitDeny, error) { + return &RateLimitDeny{Reason: "rate limit exceeded"}, nil + }, + }) + ctx.RateLimitOptions = RateLimitOptions{ + Enable: true, + ErrorExtensionCode: RateLimitErrorExtensionCode{Enabled: true, Code: "RATE_LIMIT_EXCEEDED"}, + } + return ctx + }, + setupResp: func() *GraphQLResponse { + resp, _ := gcTestResponse(FakeDataSource(`{"data":{"field":"value"}}`)) + return resp + }, + }, + } + + for _, tc := range cases { + b.Run(tc.name, func(b *testing.B) { + rCtx, cancel := context.WithCancel(context.Background()) + defer cancel() + resolver := New(rCtx, tc.resolverOpts()) + buf := &bytes.Buffer{} + + b.ResetTimer() + for i := 0; i < b.N; i++ { + response := tc.setupResp() + resolveCtx := tc.setupCtx() + + // Force GC between iterations to maximize pressure on any + // dangling pointers in arena-allocated Values. If the GC safety + // fix were reverted, this would cause SIGSEGV. + runtime.GC() + + buf.Reset() + _, err := resolver.ArenaResolveGraphQLResponse(resolveCtx, response, buf) + if err != nil { + b.Fatal(err) + } + if buf.Len() == 0 { + b.Fatal("empty output") + } + } + }) + } +}