Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer)
return err
}

// In case of a federated response, we need to ensure that the response is valid.
// The number of entities per type must match the number of lookup keys in the variables.
err = builder.validateFederatedResponse(response)
Comment thread
StarpTech marked this conversation as resolved.
if err != nil {
return err
}

responses[index] = response
return nil
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3493,12 +3493,21 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) {
conn, cleanup := setupTestGRPCServer(t)
t.Cleanup(cleanup)

type graphqlError struct {
Message string `json:"message"`
}
type graphqlResponse struct {
Data map[string]interface{} `json:"data"`
Errors []graphqlError `json:"errors,omitempty"`
}

testCases := []struct {
name string
query string
vars string
federationConfigs plan.FederationFieldConfigurations
validate func(t *testing.T, data map[string]interface{})
validateError func(t *testing.T, errData []graphqlError)
}{
{
name: "Query nullable fields type with all fields",
Expand Down Expand Up @@ -3552,6 +3561,32 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) {
require.Equal(t, "4", storage2["id"])
require.Equal(t, "Storage 4", storage2["name"])
},
validateError: func(t *testing.T, errorData []graphqlError) {
require.Empty(t, errorData)
},
},
{
name: "Query warehouse and expect an error",
query: `query($representations: [_Any!]!) { _entities(representations: $representations) { ...on Warehouse { id name } } }`,
vars: `{"variables":{"representations":[
{"__typename":"Warehouse","id":"1"},
{"__typename":"Warehouse","id":"2"},
{"__typename":"Warehouse","id":"3"},
{"__typename":"Warehouse","id":"4"}
]}}`,
federationConfigs: plan.FederationFieldConfigurations{
{
TypeName: "Warehouse",
SelectionSet: "id",
},
},
validate: func(t *testing.T, data map[string]interface{}) {
require.Empty(t, data)
},
validateError: func(t *testing.T, errorData []graphqlError) {
require.NotEmpty(t, errorData)
require.Equal(t, "entity type Warehouse received 3 entities in the subgraph response, but 4 are expected", errorData[0].Message)
},
},
}

Expand Down Expand Up @@ -3589,20 +3624,13 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) {
require.NoError(t, err)

// Parse the response
var resp struct {
Data map[string]interface{} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}
var resp graphqlResponse

err = json.Unmarshal(output.Bytes(), &resp)
require.NoError(t, err, "Failed to unmarshal response")
require.Empty(t, resp.Errors, "Response should not contain errors")
require.NotEmpty(t, resp.Data, "Response should contain data")

// Run the validation function
tc.validate(t, resp.Data)
tc.validateError(t, resp.Errors)
})
}
}
36 changes: 36 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/json_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,42 @@ func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder {
}
}

// validateFederatedResponse validates that the federated response is valid
// by checking that the number of entities per type is correct.
// For non-federated responses, this function is a no-op.
func (j *jsonBuilder) validateFederatedResponse(response *astjson.Value) error {
if j.indexMap == nil {
return nil
}

// Get the entities array from the response
// If we have an index map, we expect it to be a federated response
entities, err := response.Get(entityPath).Array()
if err != nil {
return err
}

// Count the number of entities per type
entitiyCountPerType := make(map[string]int)
for _, entity := range entities {
entityType := entity.Get("__typename").GetStringBytes()
entitiyCountPerType[string(entityType)]++
}

// Check that the number of entities per type is correct and exists in the index map.
for typeName, count := range entitiyCountPerType {
em, found := j.indexMap[typeName]
if !found {
return fmt.Errorf("entity type %s received in the subgraph response, but was not expected", typeName)
}

if len(em) != count {
return fmt.Errorf("entity type %s received %d entities in the subgraph response, but %d are expected", typeName, count, len(em))
}
}
return nil
}
Comment thread
Noroth marked this conversation as resolved.

// mergeValues combines two JSON values while preserving proper federation entity ordering.
// This is a critical function for GraphQL federation where multiple subgraphs may
// return entities that need to be merged in the correct order.
Expand Down
21 changes: 21 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ func testMapping() *GRPCMapping {
},
},
},
"Warehouse": {
{
Key: "id",
RPCConfig: RPCConfig{
RPC: "LookupWarehouseById",
Request: "LookupWarehouseByIdRequest",
Response: "LookupWarehouseByIdResponse",
},
},
},
},
EnumValues: map[string][]EnumValueMapping{
"CategoryKind": {
Expand Down Expand Up @@ -494,6 +504,17 @@ func testMapping() *GRPCMapping {
TargetName: "location",
},
},
"Warehouse": {
"id": {
TargetName: "id",
},
"name": {
TargetName: "name",
},
"location": {
TargetName: "location",
},
},
"User": {
"id": {
TargetName: "id",
Expand Down
40 changes: 38 additions & 2 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,35 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo {
return responseInfo
}

// batchStats represents an index map for batched items.
// It is used to ensure that the correct json values will be merged with the correct items from the batch.
//
// Example:
// [[0],[1],[0],[1]] We originally have 4 items, but we have 2 unique indexes (0 and 1).
// This means we are deduplicating 2 items by merging them from their response entity indexes.
// 0 -> 0, 1 -> 1, 2 -> 0, 3 -> 1
type batchStats [][]int

// getUniqueIndexes returns the number of unique indexes in the batchStats.
// This is used to ensure that we can provide a valid error message in case of differing array lengths.
func (b *batchStats) getUniqueIndexes() int {
uniqueIndexes := make(map[int]struct{})
for _, bi := range *b {
for _, index := range bi {
if index < 0 {
continue
}
uniqueIndexes[index] = struct{}{}
}
}

return len(uniqueIndexes)
}
Comment thread
Noroth marked this conversation as resolved.

type result struct {
postProcessing PostProcessingConfiguration
out *bytes.Buffer
batchStats [][]int
batchStats batchStats
fetchSkipped bool
nestedMergeItems []*result

Expand Down Expand Up @@ -601,7 +626,13 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson
if batch == nil {
return l.renderErrorsFailedToFetch(fetchItem, res, invalidGraphQLResponseShape)
}

if res.batchStats != nil {
uniqueIndexes := res.batchStats.getUniqueIndexes()
if uniqueIndexes != len(batch) {
return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, uniqueIndexes, len(batch)))
}

for i, stats := range res.batchStats {
for _, item := range stats {
if item == -1 {
Expand All @@ -618,6 +649,10 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson
}
}
} else {
if batchCount, itemCount := len(batch), len(items); batchCount != itemCount {
return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, itemCount, batchCount))
}

for i, item := range items {
_, _, err = astjson.MergeValuesWithPath(item, batch[i], res.postProcessing.MergePath...)
if err != nil {
Expand Down Expand Up @@ -953,6 +988,7 @@ const (
emptyGraphQLResponse = "empty response"
invalidGraphQLResponse = "invalid JSON"
invalidGraphQLResponseShape = "no data or errors in response"
invalidBatchItemCount = "returned entities count does not match the count of representation variables in the entities request. Expected %d, got %d"
)

func (l *Loader) renderAtPathErrorPart(path string) string {
Expand Down Expand Up @@ -1380,7 +1416,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem,
if err != nil {
return errors.WithStack(err)
}
res.batchStats = make([][]int, len(items))
res.batchStats = make(batchStats, len(items))
itemHashes := make([]uint64, 0, len(items))
batchItemIndex := 0
addSeparator := false
Expand Down
Loading
Loading