Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer)
return err
}

// In case of a federated response, we need to ensure that the response is valid.
// The number of entities per type must match the number of lookup keys in the variables.
err = builder.validateFederatedResponse(response)
Comment thread
StarpTech marked this conversation as resolved.
if err != nil {
return err
}

responses[index] = response
return nil
})
Expand Down
36 changes: 36 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/json_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,42 @@ func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder {
}
}

// validateFederatedResponse validates that the federated response is valid
// by checking that the number of entities per type is correct.
// For non-federated responses, this function is a no-op.
func (j *jsonBuilder) validateFederatedResponse(response *astjson.Value) error {
if j.indexMap == nil {
return nil
}

// Get the entities array from the response
// If we have an index map, we expect it to be a federated response
entities, err := response.Get(entityPath).Array()
if err != nil {
return err
}

// Count the number of entities per type
entitiyCountPerType := make(map[string]int)
for _, entity := range entities {
entityType := entity.Get("__typename").GetStringBytes()
entitiyCountPerType[string(entityType)]++
}

// Check that the number of entities per type is correct and exists in the index map.
for typeName, count := range entitiyCountPerType {
em, found := j.indexMap[typeName]
if !found {
return fmt.Errorf("entity type %s received in the subgraph response, but was not expected", typeName)
}

if len(em) != count {
return fmt.Errorf("entity type %s received %d entities in the subgraph response, but %d are expected", typeName, count, len(em))
}
}
return nil
}
Comment thread
Noroth marked this conversation as resolved.

// mergeValues combines two JSON values while preserving proper federation entity ordering.
// This is a critical function for GraphQL federation where multiple subgraphs may
// return entities that need to be merged in the correct order.
Expand Down
40 changes: 38 additions & 2 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,35 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo {
return responseInfo
}

// batchStats represents an index map for batched items.
// It is used to ensure that the correct json values will be merged with the correct items from the batch.
//
// Example:
// [[0],[1],[0],[1]] We originally have 4 items, but we have 2 unique indexes (0 and 1).
// This means we are deduplicating 2 items by merging them from their response entity indexes.
// 0 -> 0, 1 -> 1, 2 -> 0, 3 -> 1
type batchStats [][]int

// getUniqueIndexes returns the number of unique indexes in the batchStats.
// This is used to ensure that we can provide a valid error message in case of differing array lengths.
func (b *batchStats) getUniqueIndexes() int {
uniqueIndexes := make(map[int]struct{})
for _, bi := range *b {
for _, index := range bi {
if index < 0 {
continue
}
uniqueIndexes[index] = struct{}{}
}
}

return len(uniqueIndexes)
}
Comment thread
Noroth marked this conversation as resolved.

type result struct {
postProcessing PostProcessingConfiguration
out *bytes.Buffer
batchStats [][]int
batchStats batchStats
fetchSkipped bool
nestedMergeItems []*result

Expand Down Expand Up @@ -601,7 +626,13 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson
if batch == nil {
return l.renderErrorsFailedToFetch(fetchItem, res, invalidGraphQLResponseShape)
}

if res.batchStats != nil {
uniqueIndexes := res.batchStats.getUniqueIndexes()
if uniqueIndexes != len(batch) {
return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, uniqueIndexes, len(batch)))
}

for i, stats := range res.batchStats {
for _, item := range stats {
if item == -1 {
Expand All @@ -618,6 +649,10 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson
}
}
} else {
if batchCount, itemCount := len(batch), len(items); batchCount != itemCount {
return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, itemCount, batchCount))
}

for i, item := range items {
_, _, err = astjson.MergeValuesWithPath(item, batch[i], res.postProcessing.MergePath...)
if err != nil {
Expand Down Expand Up @@ -953,6 +988,7 @@ const (
emptyGraphQLResponse = "empty response"
invalidGraphQLResponse = "invalid JSON"
invalidGraphQLResponseShape = "no data or errors in response"
invalidBatchItemCount = "returned items from batch do not match the number of items in the request. Expected %d, got %d"
Comment thread
StarpTech marked this conversation as resolved.
Outdated
)

func (l *Loader) renderAtPathErrorPart(path string) string {
Expand Down Expand Up @@ -1380,7 +1416,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem,
if err != nil {
return errors.WithStack(err)
}
res.batchStats = make([][]int, len(items))
res.batchStats = make(batchStats, len(items))
itemHashes := make([]uint64, 0, len(items))
batchItemIndex := 0
addSeparator := false
Expand Down
Loading