Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,13 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer)
return err
}

// In case of a federated response, we need to ensure that the response is valid.
// The number of entities per type must match the number of lookup keys in the variables.
err = builder.validateFederatedResponse(response)
Comment thread
StarpTech marked this conversation as resolved.
if err != nil {
return err
}

responses[index] = response
return nil
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3493,12 +3493,21 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) {
conn, cleanup := setupTestGRPCServer(t)
t.Cleanup(cleanup)

type graphqlError struct {
Message string `json:"message"`
}
type graphqlResponse struct {
Data map[string]interface{} `json:"data"`
Errors []graphqlError `json:"errors,omitempty"`
}

testCases := []struct {
name string
query string
vars string
federationConfigs plan.FederationFieldConfigurations
validate func(t *testing.T, data map[string]interface{})
validateError func(t *testing.T, errData []graphqlError)
}{
{
name: "Query nullable fields type with all fields",
Expand Down Expand Up @@ -3552,6 +3561,32 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) {
require.Equal(t, "4", storage2["id"])
require.Equal(t, "Storage 4", storage2["name"])
},
validateError: func(t *testing.T, errorData []graphqlError) {
require.Empty(t, errorData)
},
},
{
name: "Query warehouse and expect an error",
query: `query($representations: [_Any!]!) { _entities(representations: $representations) { ...on Warehouse { id name } } }`,
vars: `{"variables":{"representations":[
{"__typename":"Warehouse","id":"1"},
{"__typename":"Warehouse","id":"2"},
{"__typename":"Warehouse","id":"3"},
{"__typename":"Warehouse","id":"4"}
]}}`,
federationConfigs: plan.FederationFieldConfigurations{
{
TypeName: "Warehouse",
SelectionSet: "id",
},
},
validate: func(t *testing.T, data map[string]interface{}) {
require.Empty(t, data)
},
validateError: func(t *testing.T, errorData []graphqlError) {
require.NotEmpty(t, errorData)
require.Equal(t, "entity type Warehouse received 3 entities in the subgraph response, but 4 are expected", errorData[0].Message)
},
},
}

Expand Down Expand Up @@ -3589,20 +3624,13 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) {
require.NoError(t, err)

// Parse the response
var resp struct {
Data map[string]interface{} `json:"data"`
Errors []struct {
Message string `json:"message"`
} `json:"errors,omitempty"`
}
var resp graphqlResponse

err = json.Unmarshal(output.Bytes(), &resp)
require.NoError(t, err, "Failed to unmarshal response")
require.Empty(t, resp.Errors, "Response should not contain errors")
require.NotEmpty(t, resp.Data, "Response should contain data")

// Run the validation function
tc.validate(t, resp.Data)
tc.validateError(t, resp.Errors)
})
}
}
36 changes: 36 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/json_builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,42 @@ func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder {
}
}

// validateFederatedResponse validates that the federated response is valid
// by checking that the number of entities per type is correct.
// For non-federated responses, this function is a no-op.
func (j *jsonBuilder) validateFederatedResponse(response *astjson.Value) error {
if j.indexMap == nil {
return nil
}

// Get the entities array from the response
// If we have an index map, we expect it to be a federated response
entities, err := response.Get(entityPath).Array()
if err != nil {
return err
}

// Count the number of entities per type
entitiyCountPerType := make(map[string]int)
for _, entity := range entities {
entityType := entity.Get("__typename").GetStringBytes()
entitiyCountPerType[string(entityType)]++
}

// Check that the number of entities per type is correct and exists in the index map.
for typeName, count := range entitiyCountPerType {
em, found := j.indexMap[typeName]
if !found {
return fmt.Errorf("entity type %s received in the subgraph response, but was not expected", typeName)
}

if len(em) != count {
return fmt.Errorf("entity type %s received %d entities in the subgraph response, but %d are expected", typeName, count, len(em))
}
}
return nil
}
Comment thread
Noroth marked this conversation as resolved.

// mergeValues combines two JSON values while preserving proper federation entity ordering.
// This is a critical function for GraphQL federation where multiple subgraphs may
// return entities that need to be merged in the correct order.
Expand Down
21 changes: 21 additions & 0 deletions v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,16 @@ func testMapping() *GRPCMapping {
},
},
},
"Warehouse": {
{
Key: "id",
RPCConfig: RPCConfig{
RPC: "LookupWarehouseById",
Request: "LookupWarehouseByIdRequest",
Response: "LookupWarehouseByIdResponse",
},
},
},
},
EnumValues: map[string][]EnumValueMapping{
"CategoryKind": {
Expand Down Expand Up @@ -494,6 +504,17 @@ func testMapping() *GRPCMapping {
TargetName: "location",
},
},
"Warehouse": {
"id": {
TargetName: "id",
},
"name": {
TargetName: "name",
},
"location": {
TargetName: "location",
},
},
"User": {
"id": {
TargetName: "id",
Expand Down
40 changes: 38 additions & 2 deletions v2/pkg/engine/resolve/loader.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,35 @@ func newResponseInfo(res *result, subgraphError error) *ResponseInfo {
return responseInfo
}

// batchStats represents an index map for batched items.
// It is used to ensure that the correct json values will be merged with the correct items from the batch.
//
// Example:
// [[0],[1],[0],[1]] We originally have 4 items, but we have 2 unique indexes (0 and 1).
// This means we are deduplicating 2 items by merging them from their response entity indexes.
// 0 -> 0, 1 -> 1, 2 -> 0, 3 -> 1
type batchStats [][]int

// getUniqueIndexes returns the number of unique indexes in the batchStats.
// This is used to ensure that we can provide a valid error message in case of differing array lengths.
func (b *batchStats) getUniqueIndexes() int {
uniqueIndexes := make(map[int]struct{})
for _, bi := range *b {
for _, index := range bi {
if index < 0 {
continue
}
uniqueIndexes[index] = struct{}{}
}
}

return len(uniqueIndexes)
}
Comment thread
Noroth marked this conversation as resolved.

type result struct {
postProcessing PostProcessingConfiguration
out *bytes.Buffer
batchStats [][]int
batchStats batchStats
fetchSkipped bool
nestedMergeItems []*result

Expand Down Expand Up @@ -601,7 +626,13 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson
if batch == nil {
return l.renderErrorsFailedToFetch(fetchItem, res, invalidGraphQLResponseShape)
}

if res.batchStats != nil {
uniqueIndexes := res.batchStats.getUniqueIndexes()
if uniqueIndexes != len(batch) {
return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, uniqueIndexes, len(batch)))
}

for i, stats := range res.batchStats {
for _, item := range stats {
if item == -1 {
Expand All @@ -618,6 +649,10 @@ func (l *Loader) mergeResult(fetchItem *FetchItem, res *result, items []*astjson
}
}
} else {
if batchCount, itemCount := len(batch), len(items); batchCount != itemCount {
return l.renderErrorsFailedToFetch(fetchItem, res, fmt.Sprintf(invalidBatchItemCount, itemCount, batchCount))
}

for i, item := range items {
_, _, err = astjson.MergeValuesWithPath(item, batch[i], res.postProcessing.MergePath...)
if err != nil {
Expand Down Expand Up @@ -953,6 +988,7 @@ const (
emptyGraphQLResponse = "empty response"
invalidGraphQLResponse = "invalid JSON"
invalidGraphQLResponseShape = "no data or errors in response"
invalidBatchItemCount = "returned items from batch do not match the number of items in the request. Expected %d, got %d"
Comment thread
StarpTech marked this conversation as resolved.
Outdated
)

func (l *Loader) renderAtPathErrorPart(path string) string {
Expand Down Expand Up @@ -1380,7 +1416,7 @@ func (l *Loader) loadBatchEntityFetch(ctx context.Context, fetchItem *FetchItem,
if err != nil {
return errors.WithStack(err)
}
res.batchStats = make([][]int, len(items))
res.batchStats = make(batchStats, len(items))
itemHashes := make([]uint64, 0, len(items))
batchItemIndex := 0
addSeparator := false
Expand Down
21 changes: 21 additions & 0 deletions v2/pkg/grpctest/mapping/mapping.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,16 @@ func DefaultGRPCMapping() *grpcdatasource.GRPCMapping {
},
},
},
"Warehouse": {
{
Key: "id",
RPCConfig: grpcdatasource.RPCConfig{
RPC: "LookupWarehouseById",
Request: "LookupWarehouseByIdRequest",
Response: "LookupWarehouseByIdResponse",
},
},
},
},
EnumValues: map[string][]grpcdatasource.EnumValueMapping{
"CategoryKind": {
Expand Down Expand Up @@ -501,6 +511,17 @@ func DefaultGRPCMapping() *grpcdatasource.GRPCMapping {
TargetName: "location",
},
},
"Warehouse": {
"id": {
TargetName: "id",
},
"name": {
TargetName: "name",
},
"location": {
TargetName: "location",
},
},
"User": {
"id": {
TargetName: "id",
Expand Down
33 changes: 33 additions & 0 deletions v2/pkg/grpctest/mockservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,39 @@ type MockService struct {
productv1.UnimplementedProductServiceServer
}

// LookupWarehouseById implements productv1.ProductServiceServer.
func (s *MockService) LookupWarehouseById(ctx context.Context, in *productv1.LookupWarehouseByIdRequest) (*productv1.LookupWarehouseByIdResponse, error) {
var results []*productv1.Warehouse

// Special requirement: return one less item than requested to test error handling
// This deliberately breaks the normal pattern of returning the same number of items as keys
keys := in.GetKeys()
if len(keys) == 0 {
return &productv1.LookupWarehouseByIdResponse{
Result: results,
}, nil
}

// Return all items except the last one to test error scenarios
for i, input := range keys {
// Skip the last item to create an intentional mismatch
if i == len(keys)-1 {
break
}

warehouseId := input.GetId()
results = append(results, &productv1.Warehouse{
Id: warehouseId,
Name: fmt.Sprintf("Warehouse %s", warehouseId),
Location: fmt.Sprintf("Location %d", rand.Intn(100)),
})
}

return &productv1.LookupWarehouseByIdResponse{
Result: results,
}, nil
}

// Helper functions to convert input types to output types
func convertCategoryInputsToCategories(inputs []*productv1.CategoryInput) []*productv1.Category {
if inputs == nil {
Expand Down
42 changes: 42 additions & 0 deletions v2/pkg/grpctest/product.proto
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ service ProductService {
rpc LookupProductById(LookupProductByIdRequest) returns (LookupProductByIdResponse) {}
// Lookup Storage entity by id
rpc LookupStorageById(LookupStorageByIdRequest) returns (LookupStorageByIdResponse) {}
// Lookup Warehouse entity by id
rpc LookupWarehouseById(LookupWarehouseByIdRequest) returns (LookupWarehouseByIdResponse) {}
Comment thread
Noroth marked this conversation as resolved.
rpc MutationBulkCreateAuthors(MutationBulkCreateAuthorsRequest) returns (MutationBulkCreateAuthorsResponse) {}
rpc MutationBulkCreateBlogPosts(MutationBulkCreateBlogPostsRequest) returns (MutationBulkCreateBlogPostsResponse) {}
rpc MutationBulkUpdateAuthors(MutationBulkUpdateAuthorsRequest) returns (MutationBulkUpdateAuthorsResponse) {}
Expand Down Expand Up @@ -256,6 +258,40 @@ message LookupStorageByIdResponse {
repeated Storage result = 1;
}

// Key message for Warehouse entity lookup
message LookupWarehouseByIdRequestKey {
// Key field for Warehouse entity lookup.
string id = 1;
}

// Request message for Warehouse entity lookup.
message LookupWarehouseByIdRequest {
/*
* List of keys to look up Warehouse entities.
* Order matters - each key maps to one entity in LookupWarehouseByIdResponse.
*/
repeated LookupWarehouseByIdRequestKey keys = 1;
}

// Response message for Warehouse entity lookup.
message LookupWarehouseByIdResponse {
/*
* List of Warehouse entities in the same order as the keys in LookupWarehouseByIdRequest.
* Always return the same number of entities as keys. Use null for entities that cannot be found.
*
* Example:
* LookupUserByIdRequest:
* keys:
* - id: 1
* - id: 2
* LookupUserByIdResponse:
* result:
* - id: 1 # User with id 1 found
* - null # User with id 2 not found
*/
repeated Warehouse result = 1;
}

// Request message for users operation.
message QueryUsersRequest {
}
Expand Down Expand Up @@ -596,6 +632,12 @@ message Storage {
string location = 3;
}

message Warehouse {
string id = 1;
string name = 2;
string location = 3;
}

message User {
string id = 1;
string name = 2;
Expand Down
Loading
Loading