Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,25 @@ const (
IntrospectionTypeEnumValuesDataSourceID = "introspection__type__enumValues"
)

// parseStringOnArena copies the string bytes onto the arena before parsing.
// This is critical because arena-allocated Values store references to the input's
// backing bytes, and Go's GC cannot trace pointers stored in arena memory (which
// is backed by []byte buffers). Without this, the GC may collect the input string
// while Values still reference it, causing segfaults.
func parseStringOnArena(a arena.Arena, s string) (*astjson.Value, error) {
b := arena.AllocateSlice[byte](a, len(s), len(s))
copy(b, s)
return astjson.ParseBytesWithArena(a, b)
}

// stringValueOnArena copies the string bytes onto the arena before creating
// a StringValue. Same GC safety reasoning as parseStringOnArena.
func stringValueOnArena(a arena.Arena, s string) *astjson.Value {
b := arena.AllocateSlice[byte](a, len(s), len(s))
copy(b, s)
return astjson.StringValueBytes(a, b)
}

type LoaderHooks interface {
// OnLoad is called before the fetch is executed
OnLoad(ctx context.Context, ds DataSourceInfo) context.Context
Expand Down Expand Up @@ -735,7 +754,7 @@ func (l *Loader) mergeErrors(res *result, fetchItem *FetchItem, value *astjson.V
}

// Wrap mode (default)
errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason))
errorObject, err := parseStringOnArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, failedToFetchNoReason))
if err != nil {
return err
}
Expand Down Expand Up @@ -807,16 +826,16 @@ func (l *Loader) optionallyEnsureExtensionErrorCode(values []*astjson.Value) {
switch extensions.Type() {
case astjson.TypeObject:
if !extensions.Exists("code") {
extensions.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode))
extensions.Set(l.jsonArena, "code", stringValueOnArena(l.jsonArena, l.defaultErrorExtensionCode))
}
case astjson.TypeNull:
extensionsObj := astjson.ObjectValue(l.jsonArena)
extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode))
extensionsObj.Set(l.jsonArena, "code", stringValueOnArena(l.jsonArena, l.defaultErrorExtensionCode))
value.Set(l.jsonArena, "extensions", extensionsObj)
}
} else {
extensionsObj := astjson.ObjectValue(l.jsonArena)
extensionsObj.Set(l.jsonArena, "code", astjson.StringValue(l.jsonArena, l.defaultErrorExtensionCode))
extensionsObj.Set(l.jsonArena, "code", stringValueOnArena(l.jsonArena, l.defaultErrorExtensionCode))
value.Set(l.jsonArena, "extensions", extensionsObj)
}
}
Expand All @@ -834,15 +853,15 @@ func (l *Loader) optionallyAttachServiceNameToErrorExtension(values []*astjson.V
extensions := value.Get("extensions")
switch extensions.Type() {
case astjson.TypeObject:
extensions.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName))
extensions.Set(l.jsonArena, "serviceName", stringValueOnArena(l.jsonArena, serviceName))
case astjson.TypeNull:
extensionsObj := astjson.ObjectValue(l.jsonArena)
extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName))
extensionsObj.Set(l.jsonArena, "serviceName", stringValueOnArena(l.jsonArena, serviceName))
value.Set(l.jsonArena, "extensions", extensionsObj)
}
} else {
extensionsObj := astjson.ObjectValue(l.jsonArena)
extensionsObj.Set(l.jsonArena, "serviceName", astjson.StringValue(l.jsonArena, serviceName))
extensionsObj.Set(l.jsonArena, "serviceName", stringValueOnArena(l.jsonArena, serviceName))
value.Set(l.jsonArena, "extensions", extensionsObj)
}
}
Expand Down Expand Up @@ -948,7 +967,7 @@ func rewriteErrorPaths(a arena.Arena, fetchItem *FetchItem, values []*astjson.Va
}
arr := astjson.ArrayValue(a)
for j := range pathPrefix {
astjson.AppendToArray(arr, astjson.StringValue(a, pathPrefix[j]))
astjson.AppendToArray(arr, stringValueOnArena(a, pathPrefix[j]))
}
for j := i + 1; j < len(pathItems); j++ {
// If the item after _entities is an index (number), we should ignore it.
Expand Down Expand Up @@ -981,13 +1000,13 @@ func (l *Loader) setSubgraphStatusCode(values []*astjson.Value, statusCode int)
if extensions.Type() != astjson.TypeObject {
continue
}
v, err := astjson.ParseWithArena(l.jsonArena, strconv.Itoa(statusCode))
v, err := parseStringOnArena(l.jsonArena, strconv.Itoa(statusCode))
if err != nil {
continue
}
extensions.Set(l.jsonArena, "statusCode", v)
} else {
v, err := astjson.ParseWithArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`)
v, err := parseStringOnArena(l.jsonArena, `{"statusCode":`+strconv.Itoa(statusCode)+`}`)
if err != nil {
continue
}
Expand Down Expand Up @@ -1028,7 +1047,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error {
}
}
}`, res.ds.Name, http.StatusText(res.statusCode), res.statusCode)
apolloRouterStatusError, err := astjson.ParseWithArena(l.jsonArena, apolloRouterStatusErrorJSON)
apolloRouterStatusError, err := parseStringOnArena(l.jsonArena, apolloRouterStatusErrorJSON)
if err != nil {
return err
}
Expand All @@ -1044,7 +1063,7 @@ func (l *Loader) addApolloRouterCompatibilityError(res *result) error {
func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error {
path := l.renderAtPathErrorPart(fetchItem.ResponsePath)
msg := fmt.Sprintf(`{"message":"Failed to obtain field dependencies from Subgraph '%s'%s."}`, res.ds.Name, path)
errorObject, err := astjson.ParseWithArena(l.jsonArena, msg)
errorObject, err := parseStringOnArena(l.jsonArena, msg)
if err != nil {
return err
}
Expand All @@ -1059,7 +1078,7 @@ func (l *Loader) renderErrorsFailedDeps(fetchItem *FetchItem, res *result) error

func (l *Loader) renderErrorsFailedToFetch(fetchItem *FetchItem, res *result, reason string) error {
l.ctx.appendSubgraphErrors(res.ds, res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode))
errorObject, err := astjson.ParseWithArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason))
errorObject, err := parseStringOnArena(l.jsonArena, l.renderSubgraphBaseError(res.ds, fetchItem.ResponsePath, reason))
if err != nil {
return err
}
Expand All @@ -1080,7 +1099,7 @@ func (l *Loader) renderErrorsStatusFallback(fetchItem *FetchItem, res *result, s

l.ctx.appendSubgraphErrors(res.ds, res.err, NewSubgraphError(res.ds, fetchItem.ResponsePath, reason, res.statusCode))

errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason))
errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"%s"}`, reason))
if err != nil {
return err
}
Expand Down Expand Up @@ -1121,13 +1140,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re
if res.ds.Name == "" {
for _, reason := range res.authorizationRejectedReasons {
if reason == "" {
errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode))
errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s.",%s}`, pathPart, extensionErrorCode))
if err != nil {
continue
}
astjson.AppendToArray(l.resolvable.errors, errorObject)
} else {
errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode))
errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized Subgraph request%s, Reason: %s.",%s}`, pathPart, reason, extensionErrorCode))
if err != nil {
continue
}
Expand All @@ -1137,13 +1156,13 @@ func (l *Loader) renderAuthorizationRejectedErrors(fetchItem *FetchItem, res *re
} else {
for _, reason := range res.authorizationRejectedReasons {
if reason == "" {
errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode))
errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s.",%s}`, res.ds.Name, pathPart, extensionErrorCode))
if err != nil {
continue
}
astjson.AppendToArray(l.resolvable.errors, errorObject)
} else {
errorObject, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode))
errorObject, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Unauthorized request to Subgraph '%s'%s, Reason: %s.",%s}`, res.ds.Name, pathPart, reason, extensionErrorCode))
if err != nil {
continue
}
Expand All @@ -1163,31 +1182,31 @@ func (l *Loader) renderRateLimitRejectedErrors(fetchItem *FetchItem, res *result
)
if res.ds.Name == "" {
if res.rateLimitRejectedReason == "" {
errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart))
errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s."}`, pathPart))
if err != nil {
return err
}
} else {
errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason))
errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph request%s, Reason: %s."}`, pathPart, res.rateLimitRejectedReason))
if err != nil {
return err
}
}
} else {
if res.rateLimitRejectedReason == "" {
errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart))
errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s."}`, res.ds.Name, pathPart))
if err != nil {
return err
}
} else {
errorObject, err = astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason))
errorObject, err = parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"message":"Rate limit exceeded for Subgraph '%s'%s, Reason: %s."}`, res.ds.Name, pathPart, res.rateLimitRejectedReason))
if err != nil {
return err
}
}
}
if l.ctx.RateLimitOptions.ErrorExtensionCode.Enabled {
extension, err := astjson.ParseWithArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code))
extension, err := parseStringOnArena(l.jsonArena, fmt.Sprintf(`{"code":"%s"}`, l.ctx.RateLimitOptions.ErrorExtensionCode.Code))
if err != nil {
return err
}
Expand Down
Loading