From 20e396e583f6f4abd9901247ab2785fa3fc9fd96 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Mon, 4 Aug 2025 13:57:26 +0200 Subject: [PATCH 01/24] feat: support multiple and nested keys --- .../engine/execution_engine_grpc_test.go | 4 +- v2/pkg/astvisitor/visitor.go | 24 + .../graphql_datasource/graphql_datasource.go | 11 +- .../grpc_datasource/configuration.go | 17 +- .../grpc_datasource/execution_plan.go | 247 +++- .../execution_plan_federation_test.go | 1048 ++++++++++++++++- .../grpc_datasource/execution_plan_test.go | 75 +- .../grpc_datasource/execution_plan_visitor.go | 520 +------- .../execution_plan_visitor_federation.go | 476 ++++++++ .../grpc_datasource/grpc_datasource.go | 43 +- .../grpc_datasource/grpc_datasource_test.go | 55 + .../grpc_datasource/mapping_test_helper.go | 34 +- .../required_fields_visitor.go | 164 +++ v2/pkg/grpctest/mapping/mapping.go | 26 +- v2/pkg/grpctest/testdata/products.graphqls | 9 +- 15 files changed, 2173 insertions(+), 580 deletions(-) create mode 100644 v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go create mode 100644 v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go diff --git a/execution/engine/execution_engine_grpc_test.go b/execution/engine/execution_engine_grpc_test.go index 3a58ba841a..3c506417b2 100644 --- a/execution/engine/execution_engine_grpc_test.go +++ b/execution/engine/execution_engine_grpc_test.go @@ -234,7 +234,7 @@ func TestGRPCSubgraphExecution(t *testing.T) { Query: "query UserQuery { users { id name } }", } - response, err := executeOperation(t, conn, operation) + response, err := executeOperation(t, conn, operation, withGRPCMapping(mapping.MustDefaultGRPCMapping(t))) require.NoError(t, err) require.Equal(t, `{"data":{"users":[{"id":"user-1","name":"User 1"},{"id":"user-2","name":"User 2"},{"id":"user-3","name":"User 3"}]}}`, response) }) @@ -255,7 +255,7 @@ func TestGRPCSubgraphExecution(t *testing.T) { `, } - response, err := executeOperation(t, conn, operation) + response, err := executeOperation(t, conn, operation, withGRPCMapping(mapping.MustDefaultGRPCMapping(t))) require.NoError(t, err) require.Equal(t, `{"data":{"user":{"id":"1","name":"User 1"}}}`, response) }) diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index f7e71d4174..9a48666e3d 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -3998,3 +3998,27 @@ func (w *Walker) FieldDefinitionDirectiveArgumentValueByName(field int, directiv return w.definition.DirectiveArgumentValueByName(directive, argumentName) } + +// InRootField returns true if the current field is a root field. +func (w *Walker) InRootField() bool { + return w.CurrentKind == ast.NodeKindField && + len(w.Ancestors) == 2 && + w.Ancestors[0].Kind == ast.NodeKindOperationDefinition +} + +// ResolveInlineFragment returns the inline fragment ref if the current field is inside of +// an inline fragment. +// It returns the inline fragment ref and true if the current field is inside of an inline fragment. +// It returns -1 and false if the current field is not inside of an inline fragment. +func (w *Walker) ResolveInlineFragment() (int, bool) { + if len(w.Ancestors) < 2 { + return -1, false + } + + node := w.Ancestors[len(w.Ancestors)-2] + if node.Kind != ast.NodeKindInlineFragment { + return -1, false + } + + return node.Ref, true +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 741022c1a6..1e7fadc0a0 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -363,11 +363,12 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { } dataSource, err = grpcdatasource.NewDataSource(p.grpcClient, grpcdatasource.DataSourceConfig{ - Operation: &opDocument, - Definition: p.config.schemaConfiguration.upstreamSchemaAst, - Mapping: p.config.grpc.Mapping, - Compiler: p.config.grpc.Compiler, - Disabled: p.config.grpc.Disabled, + Operation: &opDocument, + Definition: p.config.schemaConfiguration.upstreamSchemaAst, + Mapping: p.config.grpc.Mapping, + Compiler: p.config.grpc.Compiler, + Disabled: p.config.grpc.Disabled, + FederationConfigs: p.dataSourcePlannerConfig.RequiredFields, // TODO: remove fallback logic in visitor for subgraph name and // add proper error handling if the subgraph name is not set in the mapping SubgraphName: p.dataSourceConfig.Name(), diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration.go b/v2/pkg/engine/datasource/grpc_datasource/configuration.go index 6ee194004b..68ed82947e 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/configuration.go +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration.go @@ -17,7 +17,7 @@ type GRPCMapping struct { // SubscriptionRPCs maps GraphQL subscription fields to the corresponding gRPC RPC configurations SubscriptionRPCs RPCConfigMap // EntityRPCs defines how GraphQL types are resolved as entities using specific RPCs - EntityRPCs map[string]EntityRPCConfig + EntityRPCs map[string][]EntityRPCConfig // Fields defines the field mappings between GraphQL types and gRPC messages Fields map[string]FieldMap // EnumValues defines the enum values for each enum type @@ -121,3 +121,18 @@ func (g *GRPCMapping) ResolveEnumValue(enumName, enumValue string) (string, bool return "", false } + +func (g *GRPCMapping) ResolveEntityRPCConfig(typeName, key string) (RPCConfig, bool) { + rpcConfig, ok := g.EntityRPCs[typeName] + if !ok { + return RPCConfig{}, false + } + + for _, ei := range rpcConfig { + if ei.Key == key { + return ei.RPCConfig, true + } + } + + return RPCConfig{}, false +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index ad41e4baf8..4817e41bf3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -6,6 +6,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) @@ -56,8 +57,6 @@ type RPCExecutionPlan struct { // RPCCall represents a single call to a gRPC service method. // It contains all the information needed to make the call and process the response. type RPCCall struct { - // CallID is the unique identifier for the call - CallID int // DependentCalls is a list of calls that must be executed before this call DependentCalls []int // ServiceName is the name of the gRPC service to call @@ -259,7 +258,6 @@ func (r *RPCExecutionPlan) String() string { for j, call := range r.Calls { result.WriteString(fmt.Sprintf(" Call %d:\n", j)) - result.WriteString(fmt.Sprintf(" CallID: %d\n", call.CallID)) if len(call.DependentCalls) > 0 { result.WriteString(" DependentCalls: [") @@ -287,8 +285,12 @@ func (r *RPCExecutionPlan) String() string { return result.String() } +type PlanVisitor interface { + ExecutionPlan() *RPCExecutionPlan +} + type Planner struct { - visitor *rpcPlanVisitor + visitor PlanVisitor walker *astvisitor.Walker } @@ -297,19 +299,31 @@ type Planner struct { // The planner is responsible for creating an RPCExecutionPlan from a given // GraphQL operation. It is used by the engine to execute operations against // gRPC services. -func NewPlanner(subgraphName string, mapping *GRPCMapping) *Planner { +func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs plan.FederationFieldConfigurations) *Planner { walker := astvisitor.NewWalker(48) if mapping == nil { mapping = new(GRPCMapping) } + var visitor PlanVisitor + if len(federationConfigs) > 0 { + visitor = newRPCPlanVisitorFederation(&walker, rpcPlanVisitorConfig{ + subgraphName: subgraphName, + mapping: mapping, + federationConfigs: federationConfigs, + }) + } else { + visitor = newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + subgraphName: subgraphName, + mapping: mapping, + federationConfigs: federationConfigs, + }) + } + return &Planner{ - visitor: newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ - subgraphName: subgraphName, - mapping: mapping, - }), - walker: &walker, + visitor: visitor, + walker: &walker, } } @@ -320,7 +334,7 @@ func (p *Planner) PlanOperation(operation, definition *ast.Document) (*RPCExecut return nil, fmt.Errorf("unable to plan operation: %w", report) } - return p.visitor.plan, nil + return p.visitor.ExecutionPlan(), nil } // formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation @@ -342,3 +356,214 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) { } } } + +type rpcPlanningContext struct { + operation *ast.Document + definition *ast.Document + mapping *GRPCMapping +} + +func newRPCPlanningContext(operation *ast.Document, definition *ast.Document, mapping *GRPCMapping) *rpcPlanningContext { + return &rpcPlanningContext{ + operation: operation, + definition: definition, + mapping: mapping, + } +} + +// toDataType converts an ast.Type to a DataType. +// It handles the different type kinds and non-null types. +func (r *rpcPlanningContext) toDataType(t *ast.Type) DataType { + switch t.TypeKind { + case ast.TypeKindNamed: + return r.parseGraphQLType(t) + case ast.TypeKindList: + return r.toDataType(&r.definition.Types[t.OfType]) + case ast.TypeKindNonNull: + return r.toDataType(&r.definition.Types[t.OfType]) + } + + return DataTypeUnknown +} + +// parseGraphQLType parses an ast.Type and returns the corresponding DataType. +// It handles the different type kinds and non-null types. +func (r *rpcPlanningContext) parseGraphQLType(t *ast.Type) DataType { + dt := r.definition.Input.ByteSliceString(t.Name) + + // Retrieve the node to check the kind + node, found := r.definition.NodeByNameStr(dt) + if !found { + return DataTypeUnknown + } + + // For non-scalar types, return the corresponding DataType + switch node.Kind { + case ast.NodeKindInterfaceTypeDefinition: + fallthrough + case ast.NodeKindUnionTypeDefinition: + fallthrough + case ast.NodeKindObjectTypeDefinition, ast.NodeKindInputObjectTypeDefinition: + return DataTypeMessage + case ast.NodeKindEnumTypeDefinition: + return DataTypeEnum + default: + return fromGraphQLType(dt) + } +} + +func (r *rpcPlanningContext) resolveRPCMethodMapping(operationType ast.OperationType, operationFieldName string) (RPCConfig, error) { + if r.mapping == nil { + return RPCConfig{}, nil + } + + var rpcConfig RPCConfig + var ok bool + + switch operationType { + case ast.OperationTypeQuery: + rpcConfig, ok = r.mapping.QueryRPCs[operationFieldName] + case ast.OperationTypeMutation: + rpcConfig, ok = r.mapping.MutationRPCs[operationFieldName] + case ast.OperationTypeSubscription: + rpcConfig, ok = r.mapping.SubscriptionRPCs[operationFieldName] + } + + if !ok { + return RPCConfig{}, nil + } + + // We require all fields to be present when defining a mapping for the operation + if rpcConfig.RPC == "" { + return RPCConfig{}, fmt.Errorf("no rpc method name mapping found for operation %s", operationFieldName) + } + + if rpcConfig.Request == "" { + return RPCConfig{}, fmt.Errorf("no request message name mapping found for operation %s", operationFieldName) + } + + if rpcConfig.Response == "" { + return RPCConfig{}, fmt.Errorf("no response message name mapping found for operation %s", operationFieldName) + } + + return rpcConfig, nil +} + +// newMessageFromSelectionSet creates a new message from the enclosing type node and the selection set reference. +func (r *rpcPlanningContext) newMessageFromSelectionSet(enclosingTypeNode ast.Node, selectSetRef int) *RPCMessage { + message := &RPCMessage{ + Name: enclosingTypeNode.NameString(r.definition), + Fields: make(RPCFields, 0, len(r.operation.SelectionSets[selectSetRef].SelectionRefs)), + } + + return message +} + +// resolveFieldMapping resolves the field mapping for a field. +// This applies both for complex types in the input and for all fields in the response. +func (r *rpcPlanningContext) resolveFieldMapping(typeName, fieldName string) string { + grpcFieldName, ok := r.mapping.ResolveFieldMapping(typeName, fieldName) + if !ok { + return fieldName + } + + return grpcFieldName +} + +func (r *rpcPlanningContext) typeIsNullableOrNestedList(typeRef int) bool { + if !r.definition.TypeIsNonNull(typeRef) && r.definition.TypeIsList(typeRef) { + return true + } + + if r.definition.TypeNumberOfListWraps(typeRef) > 1 { + return true + } + + return false +} + +func (r *rpcPlanningContext) createListMetadata(typeRef int) (*ListMetadata, error) { + nestingLevel := r.definition.TypeNumberOfListWraps(typeRef) + + md := &ListMetadata{ + NestingLevel: nestingLevel, + LevelInfo: make([]LevelInfo, nestingLevel), + } + + for i := 0; i < nestingLevel; i++ { + md.LevelInfo[i] = LevelInfo{ + Optional: !r.definition.TypeIsNonNull(typeRef), + } + + typeRef = r.definition.ResolveNestedListOrListType(typeRef) + if typeRef == ast.InvalidRef { + return nil, fmt.Errorf("unable to resolve underlying list type for ref: %d", typeRef) + } + } + + return md, nil +} + +// buildField builds a field from a field definition. +// It handles lists, enums, and other types. +func (r *rpcPlanningContext) buildField(enclosingTypeNode ast.Node, fd int, fieldName, fieldAlias string) (RPCField, error) { + fdt := r.definition.FieldDefinitionType(fd) + typeName := r.toDataType(&r.definition.Types[fdt]) + parentTypeName := enclosingTypeNode.NameString(r.definition) + + field := RPCField{ + Name: r.resolveFieldMapping(parentTypeName, fieldName), + Alias: fieldAlias, + Optional: !r.definition.TypeIsNonNull(fdt), + JSONPath: fieldName, + TypeName: typeName.String(), + } + + if r.definition.TypeIsList(fdt) { + switch { + // for nullable or nested lists we need to build a wrapper message + // Nullability is handled by the datasource during the execution. + case r.typeIsNullableOrNestedList(fdt): + md, err := r.createListMetadata(fdt) + if err != nil { + return field, err + } + field.ListMetadata = md + field.IsListType = true + default: + // For non-nullable single lists we can directly use the repeated syntax in protobuf. + field.Repeated = true + } + } + + if typeName == DataTypeEnum { + field.EnumName = r.definition.FieldDefinitionTypeNameString(fd) + } + + if fieldName == "__typename" { + field.StaticValue = parentTypeName + } + + return field, nil +} + +func (r *rpcPlanningContext) ensureRequiredFields(message *RPCMessage, fc *federationConfigData) error { + // If the message is nil, we can't add any fields to it. + if message == nil { + return nil + } + + walker := astvisitor.WalkerFromPool() + defer walker.Release() + + requiredFieldsVisitor := newRequiredFieldsVisitor(walker, message, r) + return requiredFieldsVisitor.visitRequiredFields(r.definition, fc.entityTypeName, fc.requiredFields) +} + +func (r *rpcPlanningContext) resolveServiceName(subgraphName string) string { + if r.mapping == nil || r.mapping.Service == "" { + return subgraphName + } + + return r.mapping.Service +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index 5756d9d167..ee8170a2c3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -1,37 +1,53 @@ package grpcdatasource import ( + "fmt" + "reflect" + "strings" "testing" "github.com/google/go-cmp/cmp" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" grpctest "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafeparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) func TestEntityLookup(t *testing.T) { tests := []struct { - name string - query string - expectedPlan *RPCExecutionPlan - mapping *GRPCMapping + name string + query string + expectedPlan *RPCExecutionPlan + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations }{ { name: "Should create an execution plan for an entity lookup with one key field", query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on Product { __typename id name price } } }`, mapping: &GRPCMapping{ Service: "Products", - EntityRPCs: map[string]EntityRPCConfig{ + EntityRPCs: map[string][]EntityRPCConfig{ "Product": { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupProductById", - Request: "LookupProductByIdRequest", - Response: "LookupProductByIdResponse", + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupProductById", + Request: "LookupProductByIdRequest", + Response: "LookupProductByIdResponse", + }, }, }, }, }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "Product", + SelectionSet: "id", + }, + }, expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -101,6 +117,232 @@ func TestEntityLookup(t *testing.T) { }, }, }, + { + name: "Should create an execution plan for an entity lookup multiple types", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on Product { __typename id name price } ... on Storage { __typename id name location } } }`, + mapping: testMapping(), + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "Product", + SelectionSet: "id", + }, + { + TypeName: "Storage", + SelectionSet: "id", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupProductById", + Request: RPCMessage{ + Name: "LookupProductByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupProductByIdKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupProductByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "Product", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "Product", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + + { + Name: "price", + TypeName: string(DataTypeDouble), + JSONPath: "price", + }, + }, + }, + }, + }, + }, + }, + { + ServiceName: "Products", + MethodName: "LookupStorageById", + Request: RPCMessage{ + Name: "LookupStorageByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupStorageByIdKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupStorageByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "Storage", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "Storage", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + { + Name: "location", + TypeName: string(DataTypeString), + JSONPath: "location", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Should create an execution plan for an entity lookup with required fields", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on Product { __typename id name } } }`, + mapping: testMapping(), + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "Product", + SelectionSet: "id", + }, + { + TypeName: "Product", + FieldName: "name", // Field name requires price + SelectionSet: "price", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupProductById", + Request: RPCMessage{ + Name: "LookupProductByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupProductByIdKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupProductByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "Product", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "Product", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + { + Name: "price", + TypeName: string(DataTypeDouble), + JSONPath: "price", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, // TODO implement multiple entity lookup types // { @@ -304,23 +546,783 @@ func TestEntityLookup(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ - subgraphName: "Products", - mapping: tt.mapping, - }) - - walker.Walk(&queryDoc, &schemaDoc, &report) - - if report.HasErrors() { - t.Fatalf("failed to walk AST: %s", report.Error()) + planner := NewPlanner("Products", tt.mapping, tt.federationConfigs) + plan, err := planner.PlanOperation(&queryDoc, &schemaDoc) + if err != nil { + t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } }) } } + +func TestEntityKeys(t *testing.T) { + tests := []struct { + name string + query string + schema string + expectedPlan *RPCExecutionPlan + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations + }{ + { + name: "Should create an execution plan for an entity lookup with a key field", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + user(id: ID!): User + } + type User @key(fields: "id") { + id: ID! + name: String! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupUserById", + Request: "LookupUserByIdRequest", + Response: "LookupUserByIdResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id", + }, + }, + + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserById", + // Define the structure of the request message + Request: RPCMessage{ + Name: "LookupUserByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + // Define the structure of the response message + Response: RPCMessage{ + Name: "LookupUserByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Should create an execution plan for an entity lookup with a nested key", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + user(id: ID!): User + } + + type Address { + id: ID! + street: String! + city: String! + state: String! + zip: String! + } + + type User @key(fields: "id address { id }") { + id: ID! + name: String! + address: Address! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id address { id }", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndAddress", + Request: "LookupUserByIdAndAddressRequest", + Response: "LookupUserByIdAndAddressResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id address { id }", + }, + }, + + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndAddress", + // Define the structure of the request message + Request: RPCMessage{ + Name: "LookupUserByIdAndAddressRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndAddressKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "address", + TypeName: string(DataTypeMessage), + JSONPath: "address", + Message: &RPCMessage{ + Name: "Address", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + }, + }, + }, + // Define the structure of the response message + Response: RPCMessage{ + Name: "LookupUserByIdAndAddressResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Should create an execution plan for an entity lookup with a compound key", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + type User @key(fields: "id name") { + id: ID! + name: String! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id name", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndName", + Request: "LookupUserByIdAndNameRequest", + Response: "LookupUserByIdAndNameResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id name", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndName", + Request: RPCMessage{ + Name: "LookupUserByIdAndNameRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndNameKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdAndNameResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + runFederationTest(t, tt) + } +} + +func TestRequriedFields(t *testing.T) { + tests := []struct { + name string + query string + schema string + expectedPlan *RPCExecutionPlan + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations + }{ + { + name: "Should also require reviews field when name is selected ", + schema: testFederationSchemaString(` + type Query { + user(id: ID!): User + } + type User @key(fields: "id") { + id: ID! + name: String! + reviews: [String!]! + } + `, []string{"User"}), + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id", // no field name mean this is related to the key + }, + { + TypeName: "User", + SelectionSet: "reviews", + FieldName: "name", // name requires reviews + }, + }, + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupUserById", + Request: "LookupUserByIdRequest", + Response: "LookupUserByIdResponse", + }, + }, + }, + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserById", + Request: RPCMessage{ + Name: "LookupUserByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + JSONPath: "_entities", + Repeated: true, + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + { + Name: "reviews", + TypeName: string(DataTypeString), + JSONPath: "reviews", + Repeated: true, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Should require nested fields", + schema: testFederationSchemaString(` + type Query { + user(id: ID!): User + } + + type Review { + body: String! + title: String! + } + + type User @key(fields: "id") { + id: ID! + name: String! + reviews: [Review!]! + } + `, []string{"User"}), + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name reviews } } }`, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id", + }, + { + TypeName: "User", + SelectionSet: "reviews { body title }", + FieldName: "name", // name requires reviews { body title } + }, + }, + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupUserById", + Request: "LookupUserByIdRequest", + Response: "LookupUserByIdResponse", + }, + }, + }, + }, + Fields: map[string]FieldMap{ + "User": { + "id": { + TargetName: "id", + }, + }, + "Review": { + "body": { + TargetName: "body", + }, + "title": { + TargetName: "title", + }, + }, + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserById", + Request: RPCMessage{ + Name: "LookupUserByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + JSONPath: "_entities", + Repeated: true, + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + { + Name: "reviews", + TypeName: string(DataTypeMessage), + JSONPath: "reviews", + Repeated: true, + Message: &RPCMessage{ + Name: "Review", + Fields: []RPCField{ + { + Name: "body", + TypeName: string(DataTypeString), + JSONPath: "body", + }, + { + Name: "title", + TypeName: string(DataTypeString), + JSONPath: "title", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + runFederationTest(t, tt) + } +} + +func runFederationTest(t *testing.T, tt struct { + name string + query string + schema string + expectedPlan *RPCExecutionPlan + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations +}) { + + t.Helper() + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var operation, definition ast.Document + + definition = unsafeparser.ParseGraphqlDocumentStringWithBaseSchema(tt.schema) + + report := operationreport.Report{} + astvalidation.DefaultDefinitionValidator().Validate(&definition, &report) + if report.HasErrors() { + t.Fatalf("failed to validate schema: %s", report.Error()) + } + + operation, report = astparser.ParseGraphqlDocumentString(tt.query) + if report.HasErrors() { + t.Fatalf("failed to parse query: %s", report.Error()) + } + + operation, report = astparser.ParseGraphqlDocumentString(tt.query) + if report.HasErrors() { + t.Fatalf("failed to parse query: %s", report.Error()) + } + + planner := NewPlanner("Products", tt.mapping, tt.federationConfigs) + plan, err := planner.PlanOperation(&operation, &definition) + if err != nil { + t.Fatalf("failed to plan operation: %s", err) + } + + diff := cmp.Diff(tt.expectedPlan, plan) + if diff != "" { + t.Fatalf("execution plan mismatch: %s", diff) + } + }) + +} + +func testFederationSchemaString(schema string, entities []string) string { + entityUnion := strings.Join(entities, " | ") + return fmt.Sprintf(` + schema { + query: Query + } + %s + + union _Entity = %s + scalar _Any + `, schema, entityUnion) +} + +func TestRootKeyNames(t *testing.T) { + tests := []struct { + name string + keyFields string + expected []string + }{ + { + name: "Should return root key names with nested fields", + keyFields: "id address { id }", + expected: []string{"id", "address"}, + }, + { + name: "Should handle single field", + keyFields: "id", + expected: []string{"id"}, + }, + { + name: "Should handle multiple simple fields", + keyFields: "id name email status", + expected: []string{"id", "name", "email", "status"}, + }, + { + name: "Should handle mixed simple and complex fields", + keyFields: "id user { name email } status", + expected: []string{"id", "user", "status"}, + }, + { + name: "Should handle multiple nested objects", + keyFields: "id user { name email } address { street city } status", + expected: []string{"id", "user", "address", "status"}, + }, + { + name: "Should handle deeply nested fields", + keyFields: "id user { profile { personal { name age } } }", + expected: []string{"id", "user"}, + }, + { + name: "Should handle fields with underscores and numbers", + keyFields: "user_id item_2 product_variant { sku_code }", + expected: []string{"user_id", "item_2", "product_variant"}, + }, + { + name: "Should handle camelCase fields", + keyFields: "userId firstName lastName userProfile { emailAddress }", + expected: []string{"userId", "firstName", "lastName", "userProfile"}, + }, + { + name: "Should handle empty string", + keyFields: "", + expected: []string{}, + }, + { + name: "Should handle whitespace only", + keyFields: " \t\n ", + expected: []string{}, + }, + { + name: "Should handle extra whitespace", + keyFields: " id name address { street city } status ", + expected: []string{"id", "name", "address", "status"}, + }, + { + name: "Should handle nested braces", + keyFields: "id organization { department { team { lead { name } } } }", + expected: []string{"id", "organization"}, + }, + { + name: "Should handle multiple levels of nesting", + keyFields: "id user { contact { phone { primary secondary } email } } metadata { tags }", + expected: []string{"id", "user", "metadata"}, + }, + { + name: "Should handle fields with dashes", + keyFields: "user-id product-code shipping-address { street-name }", + expected: []string{"user-id", "product-code", "shipping-address"}, + }, + { + name: "Should handle single nested field", + keyFields: "user { id }", + expected: []string{"user"}, + }, + { + name: "Should handle adjacent nested fields", + keyFields: "user { name } address { street }", + expected: []string{"user", "address"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + fc := federationConfigData{ + keyFields: tt.keyFields, + } + t.Logf("keyFields: %s", tt.keyFields) + actual := fc.getRootKeyNames() + if !reflect.DeepEqual(actual, tt.expected) { + t.Fatalf("expected %v, got %v", tt.expected, actual) + } + }) + } +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go index ea9625552e..cc892e4679 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go @@ -153,7 +153,6 @@ func TestQueryExecutionPlans(t *testing.T) { { ServiceName: "Products", MethodName: "QueryUser", - CallID: 1, Request: RPCMessage{ Name: "QueryUserRequest", Fields: []RPCField{ @@ -359,8 +358,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a complex input type and variables", - query: `query ComplexFilterTypeQuery($filter: ComplexFilterTypeInput!) { complexFilterType(filter: $filter) { id name } }`, + name: "Should create an execution plan for a query with a complex input type and variables", + query: `query ComplexFilterTypeQuery($filter: ComplexFilterTypeInput!) { complexFilterType(filter: $filter) { id name } }`, + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -389,12 +389,12 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "filterField1", + Name: "filter_field_1", TypeName: string(DataTypeString), JSONPath: "filterField1", }, { - Name: "filterField2", + Name: "filter_field_2", TypeName: string(DataTypeString), JSONPath: "filterField2", }, @@ -412,7 +412,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "page", }, { - Name: "perPage", + Name: "per_page", TypeName: string(DataTypeInt32), JSONPath: "perPage", }, @@ -432,7 +432,7 @@ func TestQueryExecutionPlans(t *testing.T) { Fields: []RPCField{ { Repeated: true, - Name: "complexFilterType", + Name: "complex_filter_type", TypeName: string(DataTypeMessage), JSONPath: "complexFilterType", Message: &RPCMessage{ @@ -459,8 +459,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, { - name: "Should create an execution plan for a query with a complex input type and variables with different name", - query: `query ComplexFilterTypeQuery($foobar: ComplexFilterTypeInput!) { complexFilterType(filter: $foobar) { id name } }`, + name: "Should create an execution plan for a query with a complex input type and variables with different name", + query: `query ComplexFilterTypeQuery($foobar: ComplexFilterTypeInput!) { complexFilterType(filter: $foobar) { id name } }`, + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -489,12 +490,12 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "filterField1", + Name: "filter_field_1", TypeName: string(DataTypeString), JSONPath: "filterField1", }, { - Name: "filterField2", + Name: "filter_field_2", TypeName: string(DataTypeString), JSONPath: "filterField2", }, @@ -512,7 +513,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "page", }, { - Name: "perPage", + Name: "per_page", TypeName: string(DataTypeInt32), JSONPath: "perPage", }, @@ -532,7 +533,7 @@ func TestQueryExecutionPlans(t *testing.T) { Fields: []RPCField{ { Repeated: true, - Name: "complexFilterType", + Name: "complex_filter_type", TypeName: string(DataTypeMessage), JSONPath: "complexFilterType", Message: &RPCMessage{ @@ -558,8 +559,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a type filter with arguments and variables", - query: "query TypeWithMultipleFilterFieldsQuery($filter: FilterTypeInput!) { typeWithMultipleFilterFields(filter: $filter) { id name } }", + name: "Should create an execution plan for a query with a type filter with arguments and variables", + query: "query TypeWithMultipleFilterFieldsQuery($filter: FilterTypeInput!) { typeWithMultipleFilterFields(filter: $filter) { id name } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -577,13 +579,13 @@ func TestQueryExecutionPlans(t *testing.T) { Fields: []RPCField{ { Repeated: false, - Name: "filterField1", + Name: "filter_field_1", TypeName: string(DataTypeString), JSONPath: "filterField1", }, { Repeated: false, - Name: "filterField2", + Name: "filter_field_2", TypeName: string(DataTypeString), JSONPath: "filterField2", }, @@ -596,7 +598,7 @@ func TestQueryExecutionPlans(t *testing.T) { Name: "QueryTypeWithMultipleFilterFieldsResponse", Fields: []RPCField{ { - Name: "typeWithMultipleFilterFields", + Name: "type_with_multiple_filter_fields", TypeName: string(DataTypeMessage), Repeated: true, JSONPath: "typeWithMultipleFilterFields", @@ -623,8 +625,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query", - query: "query UserQuery { users { id name } }", + name: "Should create an execution plan for a query", + query: "query UserQuery { users { id name } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -678,8 +681,9 @@ func TestQueryExecutionPlans(t *testing.T) { expectedError: "no request message name mapping found for operation user", }, { - name: "Should create an execution plan for a query with a user", - query: `query UserQuery { user(id: "abc123") { id name } }`, + name: "Should create an execution plan for a query with a user", + query: `query UserQuery { user(id: "abc123") { id name } }`, + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -726,8 +730,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a nested type", - query: "query NestedTypeQuery { nestedType { id name b { id name c { id name } } } }", + name: "Should create an execution plan for a query with a nested type", + query: "query NestedTypeQuery { nestedType { id name b { id name c { id name } } } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -740,7 +745,7 @@ func TestQueryExecutionPlans(t *testing.T) { Name: "QueryNestedTypeResponse", Fields: []RPCField{ { - Name: "nestedType", + Name: "nested_type", TypeName: string(DataTypeMessage), Repeated: true, JSONPath: "nestedType", @@ -807,8 +812,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a recursive type", - query: "query RecursiveTypeQuery { recursiveType { id name recursiveType { id recursiveType { id name recursiveType { id name } } name } } }", + name: "Should create an execution plan for a query with a recursive type", + query: "query RecursiveTypeQuery { recursiveType { id name recursiveType { id recursiveType { id name recursiveType { id name } } name } } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -821,7 +827,7 @@ func TestQueryExecutionPlans(t *testing.T) { Name: "QueryRecursiveTypeResponse", Fields: []RPCField{ { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -838,7 +844,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -850,7 +856,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "id", }, { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -867,7 +873,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -1340,7 +1346,7 @@ func TestProductExecutionPlan(t *testing.T) { t.Fatalf("failed to validate query: %s", report.Error()) } - planner := NewPlanner("Products", testMapping()) + planner := NewPlanner("Products", testMapping(), nil) outPlan, err := planner.PlanOperation(&queryDoc, &schemaDoc) if tt.expectedError != "" { @@ -1507,7 +1513,6 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { { ServiceName: "Products", MethodName: "QueryCategories", - CallID: 1, Request: RPCMessage{ Name: "QueryCategoriesRequest", }, @@ -2026,7 +2031,6 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { { ServiceName: "Products", MethodName: "QueryUser", - CallID: 1, Request: RPCMessage{ Name: "QueryUserRequest", Fields: []RPCField{ @@ -2068,7 +2072,6 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { { ServiceName: "Products", MethodName: "QueryUser", - CallID: 2, Request: RPCMessage{ Name: "QueryUserRequest", Fields: []RPCField{ @@ -2446,7 +2449,7 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { t.Fatalf("failed to validate query: %s", report.Error()) } - planner := NewPlanner("Products", testMapping()) + planner := NewPlanner("Products", testMapping(), nil) outPlan, err := planner.PlanOperation(&queryDoc, &schemaDoc) if tt.expectedError != "" { diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go index edf2dd880b..a8c7946269 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go @@ -1,11 +1,12 @@ package grpcdatasource import ( + "errors" "fmt" - "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" "golang.org/x/text/cases" "golang.org/x/text/language" @@ -16,21 +17,10 @@ type keyField struct { fieldType string } -type entityInfo struct { - name string - keyFields []keyField - keyTypeName string - entityRootFieldRef int - entityInlineFragmentRef int -} - type planningInfo struct { - entityInfo entityInfo // resolvers []string operationType ast.OperationType operationFieldName string - isEntityLookup bool - methodName string requestMessageAncestors []*RPCMessage currentRequestMessage *RPCMessage @@ -43,64 +33,66 @@ type planningInfo struct { } type rpcPlanVisitor struct { - walker *astvisitor.Walker - operation *ast.Document - definition *ast.Document - planInfo planningInfo - - subgraphName string - mapping *GRPCMapping - plan *RPCExecutionPlan - operationDefinitionRef int - operationFieldRef int - operationFieldRefs []int - currentCall *RPCCall - currentCallID int + walker *astvisitor.Walker + operation *ast.Document + definition *ast.Document + planCtx *rpcPlanningContext + planInfo planningInfo + federationConfigs plan.FederationFieldConfigurations + + subgraphName string + mapping *GRPCMapping + plan *RPCExecutionPlan + operationFieldRef int + operationFieldRefs []int + currentCall *RPCCall + currentCallID int } type rpcPlanVisitorConfig struct { - subgraphName string - mapping *GRPCMapping + subgraphName string + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations } // newRPCPlanVisitor creates a new RPCPlanVisitor. // It registers the visitor with the walker and returns it. func newRPCPlanVisitor(walker *astvisitor.Walker, config rpcPlanVisitorConfig) *rpcPlanVisitor { visitor := &rpcPlanVisitor{ - walker: walker, - plan: &RPCExecutionPlan{}, - subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), - mapping: config.mapping, - operationDefinitionRef: -1, - operationFieldRef: -1, + walker: walker, + plan: &RPCExecutionPlan{}, + subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), + mapping: config.mapping, + operationFieldRef: -1, + federationConfigs: config.federationConfigs, } walker.RegisterEnterDocumentVisitor(visitor) walker.RegisterEnterOperationVisitor(visitor) walker.RegisterFieldVisitor(visitor) walker.RegisterSelectionSetVisitor(visitor) - walker.RegisterInlineFragmentVisitor(visitor) walker.RegisterEnterArgumentVisitor(visitor) return visitor } +func (r *rpcPlanVisitor) ExecutionPlan() *RPCExecutionPlan { + return r.plan +} + // EnterDocument implements astvisitor.EnterDocumentVisitor. func (r *rpcPlanVisitor) EnterDocument(operation *ast.Document, definition *ast.Document) { r.definition = definition r.operation = operation + + r.planCtx = newRPCPlanningContext(operation, definition, r.mapping) } // EnterOperationDefinition implements astvisitor.EnterOperationDefinitionVisitor. // This is called when entering the operation definition node. // It retrieves information about the operation // and creates a new group in the plan. -// -// The function also checks if this is an entity lookup operation, -// which requires special handling. func (r *rpcPlanVisitor) EnterOperationDefinition(ref int) { - r.operationDefinitionRef = ref - // Retrieves the fields from the root selection set. // These fields determine the names for the RPC functions to call. // TODO: handle fragments on root level `... on Query {}` @@ -117,10 +109,6 @@ func (r *rpcPlanVisitor) EnterOperationDefinition(ref int) { // // TODO handle field arguments to define resolvers func (r *rpcPlanVisitor) EnterArgument(ref int) { - if r.planInfo.isEntityLookup { - return - } - a := r.walker.Ancestor() if a.Kind != ast.NodeKindField && a.Ref != r.operationFieldRef { return @@ -137,9 +125,6 @@ func (r *rpcPlanVisitor) EnterArgument(ref int) { // EnterSelectionSet implements astvisitor.EnterSelectionSetVisitor. // Checks if this is in the root level below the operation definition. -// -// TODO handle multiple entity lookups in a single query. -// We need to create a new call for each entity lookup. func (r *rpcPlanVisitor) EnterSelectionSet(ref int) { if r.walker.Ancestor().Kind == ast.NodeKindOperationDefinition { return @@ -151,7 +136,7 @@ func (r *rpcPlanVisitor) EnterSelectionSet(ref int) { // In nested selection sets, a new message needs to be created, which will be added to the current response message. if r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message == nil { - r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message = r.newMessageFromSelectionSet(ref) + r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) } // Add the current response message to the ancestors and set the current response message to the current field message @@ -230,66 +215,34 @@ func (r *rpcPlanVisitor) LeaveSelectionSet(ref int) { } } -// EnterInlineFragment implements astvisitor.InlineFragmentVisitor. -func (r *rpcPlanVisitor) EnterInlineFragment(ref int) { - entityInfo := &r.planInfo.entityInfo - if entityInfo.entityRootFieldRef != -1 && entityInfo.entityInlineFragmentRef == -1 { - entityInfo.entityInlineFragmentRef = ref - r.resolveEntityInformation(ref) - r.scaffoldEntityLookup() - - return - } -} - -// LeaveInlineFragment implements astvisitor.InlineFragmentVisitor. -func (r *rpcPlanVisitor) LeaveInlineFragment(ref int) { - if ref == r.planInfo.entityInfo.entityInlineFragmentRef { - r.planInfo.entityInfo.entityInlineFragmentRef = -1 - } -} - -func (r *rpcPlanVisitor) IsRootField() bool { - return len(r.walker.Ancestors) == 2 && r.walker.Ancestors[0].Kind == ast.NodeKindOperationDefinition -} - -func (r *rpcPlanVisitor) IsInlineFragmentField() (int, bool) { - if len(r.walker.Ancestors) < 2 { - return -1, false - } - - node := r.walker.Ancestors[len(r.walker.Ancestors)-2] - if node.Kind != ast.NodeKindInlineFragment { - return -1, false - } - - return node.Ref, true -} - func (r *rpcPlanVisitor) handleRootField(ref int) error { r.operationFieldRef = ref r.planInfo.operationFieldName = r.operation.FieldNameString(ref) r.currentCall = &RPCCall{ - CallID: r.currentCallID, - ServiceName: r.resolveServiceName(), + ServiceName: r.planCtx.resolveServiceName(r.subgraphName), } r.planInfo.currentRequestMessage = &r.currentCall.Request r.planInfo.currentResponseMessage = &r.currentCall.Response // attempt to resolve the name from the mapping - if err := r.resolveRPCMethodMapping(); err != nil { + rpcConfig, err := r.planCtx.resolveRPCMethodMapping(r.planInfo.operationType, r.planInfo.operationFieldName) + if err != nil { return err } + r.currentCall.MethodName = rpcConfig.RPC + r.currentCall.Request.Name = rpcConfig.Request + r.currentCall.Response.Name = rpcConfig.Response + return nil } // EnterField implements astvisitor.EnterFieldVisitor. func (r *rpcPlanVisitor) EnterField(ref int) { fieldName := r.operation.FieldNameString(ref) - if r.IsRootField() { + if r.walker.InRootField() { if err := r.handleRootField(ref); err != nil { r.walker.StopWithInternalErr(err) return @@ -297,16 +250,7 @@ func (r *rpcPlanVisitor) EnterField(ref int) { } if fieldName == "_entities" { - // _entities is a special field that is used to look up entities - // Entity lookups are handled differently as we use special types for - // Providing variables (_Any) and the response type is a Union that needs to be - // determined from the first inline fragment. - r.planInfo.entityInfo = entityInfo{ - entityRootFieldRef: ref, - entityInlineFragmentRef: -1, - } - r.planInfo.isEntityLookup = true - r.planInfo.entityInfo.entityRootFieldRef = ref + r.walker.StopWithInternalErr(errors.New("entities field is not supported in this visitor")) return } @@ -324,9 +268,14 @@ func (r *rpcPlanVisitor) EnterField(ref int) { return } - field := r.buildField(fd, fieldName, fieldAlias) + field, err := r.planCtx.buildField(r.walker.EnclosingTypeDefinition, fd, fieldName, fieldAlias) + if err != nil { + r.walker.StopWithInternalErr(err) + return + } - if ref, ok := r.IsInlineFragmentField(); ok && !r.planInfo.isEntityLookup { + // check if we are inside of an inline fragment + if ref, ok := r.walker.ResolveInlineFragment(); ok { if r.planInfo.currentResponseMessage.FieldSelectionSet == nil { r.planInfo.currentResponseMessage.FieldSelectionSet = make(RPCFieldSelectionSet) } @@ -339,65 +288,14 @@ func (r *rpcPlanVisitor) EnterField(ref int) { r.planInfo.currentResponseMessage.Fields = append(r.planInfo.currentResponseMessage.Fields, field) } -// buildField builds a field from a field definition. -// It handles lists, enums, and other types. -func (r *rpcPlanVisitor) buildField(fd int, fieldName, fieldAlias string) RPCField { - fdt := r.definition.FieldDefinitionType(fd) - typeName := r.toDataType(&r.definition.Types[fdt]) - parentTypeName := r.walker.EnclosingTypeDefinition.NameString(r.definition) - - field := RPCField{ - Name: r.resolveFieldMapping(parentTypeName, fieldName), - Alias: fieldAlias, - Optional: !r.definition.TypeIsNonNull(fdt), - JSONPath: fieldName, - TypeName: typeName.String(), - } - - if r.definition.TypeIsList(fdt) { - switch { - // for nullable or nested lists we need to build a wrapper message - // Nullability is handled by the datasource during the execution. - case r.typeIsNullableOrNestedList(fdt): - field.ListMetadata = r.createListMetadata(fdt) - field.IsListType = true - default: - // For non-nullable single lists we can directly use the repeated syntax in protobuf. - field.Repeated = true - } - } - - if typeName == DataTypeEnum { - field.EnumName = r.definition.FieldDefinitionTypeNameString(fd) - } - - if fieldName == "__typename" { - field.StaticValue = parentTypeName - } - - return field -} - // LeaveField implements astvisitor.FieldVisitor. func (r *rpcPlanVisitor) LeaveField(ref int) { - if ref == r.planInfo.entityInfo.entityRootFieldRef { - r.planInfo.entityInfo.entityRootFieldRef = -1 - } - // If we are not in the operation field, we can increment the response field index. - if !r.IsRootField() { + if !r.walker.InRootField() { r.planInfo.currentResponseFieldIndex++ return } - // If we left the operation field, we need to finalize the current call and prepare the next one. - if r.currentCall.MethodName == "" { - methodName := r.rpcMethodName() - r.currentCall.MethodName = methodName - r.currentCall.Request.Name = methodName + "Request" - r.currentCall.Response.Name = methodName + "Response" - } - r.plan.Calls[r.currentCallID] = *r.currentCall r.currentCall = &RPCCall{} @@ -409,16 +307,6 @@ func (r *rpcPlanVisitor) LeaveField(ref int) { r.planInfo.currentResponseFieldIndex = 0 } -// newMessageFromSelectionSet creates a new message from a selection set. -func (r *rpcPlanVisitor) newMessageFromSelectionSet(ref int) *RPCMessage { - message := &RPCMessage{ - Name: r.walker.EnclosingTypeDefinition.NameString(r.definition), - Fields: make(RPCFields, 0, len(r.operation.SelectionSets[ref].SelectionRefs)), - } - - return message -} - // enrichRequestMessageFromInputArgument constructs a request message from an input argument based on its type. // It retrieves the underlying type and builds the request message from the underlying type. // If the underlying type is an input object type, it creates a new message and adds it to the current request message. @@ -466,7 +354,7 @@ func (r *rpcPlanVisitor) enrichRequestMessageFromInputArgument(argRef, typeRef i r.planInfo.requestMessageAncestors = r.planInfo.requestMessageAncestors[:len(r.planInfo.requestMessageAncestors)-1] case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: - dt := r.toDataType(&r.definition.Types[typeRef]) + dt := r.planCtx.toDataType(&r.definition.Types[typeRef]) r.planInfo.currentRequestMessage.Fields = append(r.planInfo.currentRequestMessage.Fields, r.buildInputMessageField(typeRef, mappedInputName, jsonPath, dt)) @@ -507,7 +395,7 @@ func (r *rpcPlanVisitor) buildMessageField(fieldName string, typeRef, parentType // If the type is not an object, directly add the field to the request message if underlyingTypeNode.Kind != ast.NodeKindInputObjectTypeDefinition { - dt := r.toDataType(&inputValueDefinitionType) + dt := r.planCtx.toDataType(&inputValueDefinitionType) r.planInfo.currentRequestMessage.Fields = append(r.planInfo.currentRequestMessage.Fields, r.buildInputMessageField(typeRef, mappedName, fieldName, dt)) @@ -545,8 +433,13 @@ func (r *rpcPlanVisitor) buildInputMessageField(typeRef int, fieldName, jsonPath switch { // for nullable or nested lists we need to build a wrapper message // Nullability is handled by the datasource during the execution. - case r.typeIsNullableOrNestedList(typeRef): - field.ListMetadata = r.createListMetadata(typeRef) + case r.planCtx.typeIsNullableOrNestedList(typeRef): + md, err := r.planCtx.createListMetadata(typeRef) + if err != nil { + r.walker.StopWithInternalErr(err) + return field + } + field.ListMetadata = md field.IsListType = true default: // For non-nullable single lists we can directly use the repeated syntax in protobuf. @@ -561,156 +454,6 @@ func (r *rpcPlanVisitor) buildInputMessageField(typeRef int, fieldName, jsonPath return field } -func (r *rpcPlanVisitor) createListMetadata(typeRef int) *ListMetadata { - nestingLevel := r.definition.TypeNumberOfListWraps(typeRef) - - md := &ListMetadata{ - NestingLevel: nestingLevel, - LevelInfo: make([]LevelInfo, nestingLevel), - } - - for i := 0; i < nestingLevel; i++ { - md.LevelInfo[i] = LevelInfo{ - Optional: !r.definition.TypeIsNonNull(typeRef), - } - - typeRef = r.definition.ResolveNestedListOrListType(typeRef) - if typeRef == ast.InvalidRef { - r.walker.StopWithInternalErr(fmt.Errorf("unable to resolve underlying list type for ref: %d", typeRef)) - return nil - } - } - - return md -} - -func (r *rpcPlanVisitor) resolveEntityInformation(inlineFragmentRef int) { - // TODO support multiple entities in a single query - if !r.planInfo.isEntityLookup || r.planInfo.entityInfo.name != "" { - return - } - - fragmentName := r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef) - node, found := r.definition.NodeByNameStr(fragmentName) - if !found { - return - } - - // Only process object type definitions - // TODO: handle interfaces - if node.Kind != ast.NodeKindObjectTypeDefinition { - return - } - - // An entity must at least have a key directive - def := r.definition.ObjectTypeDefinitions[node.Ref] - if !def.HasDirectives { - return - } - - // TODO: We currently only support one key directive per entity - // We need to get the used key from the graphql datasource. - for _, directiveRef := range def.Directives.Refs { - if r.definition.DirectiveNameString(directiveRef) != federationKeyDirectiveName { - continue - } - - r.planInfo.entityInfo.name = fragmentName - - directive := r.definition.Directives[directiveRef] - for _, argRef := range directive.Arguments.Refs { - if r.definition.ArgumentNameString(argRef) != "fields" { - continue - } - argument := r.definition.Arguments[argRef] - keyFieldName := r.definition.ValueContentString(argument.Value) - - fieldDef, ok := r.definition.NodeFieldDefinitionByName(node, ast.ByteSlice(keyFieldName)) - if !ok { - r.walker.Report.AddExternalError(operationreport.ExternalError{ - Message: fmt.Sprintf("Field %s not found in definition", keyFieldName), - }) - return - } - - fdt := r.definition.FieldDefinitionType(fieldDef) - ft := r.definition.Types[fdt] - - r.planInfo.entityInfo.keyFields = - append(r.planInfo.entityInfo.keyFields, keyField{ - fieldName: keyFieldName, - fieldType: r.toDataType(&ft).String(), - }) - } - - break - } - - keyFields := make([]string, 0, len(r.planInfo.entityInfo.keyFields)) - for _, key := range r.planInfo.entityInfo.keyFields { - keyFields = append(keyFields, key.fieldName) - } - - if ei, exists := r.mapping.EntityRPCs[r.planInfo.entityInfo.name]; exists { - r.currentCall.Request.Name = ei.RPCConfig.Request - r.currentCall.Response.Name = ei.RPCConfig.Response - r.planInfo.methodName = ei.RPCConfig.RPC - } - - r.planInfo.entityInfo.keyTypeName = r.planInfo.entityInfo.name + "By" + strings.Join(titleSlice(keyFields), "And") -} - -// scaffoldEntityLookup creates the entity lookup call structure -// by creating the key field message and adding it to the current request message. -// It also adds the results message to the current response message. -func (r *rpcPlanVisitor) scaffoldEntityLookup() { - if !r.planInfo.isEntityLookup { - return - } - - entityInfo := &r.planInfo.entityInfo - keyFieldMessage := &RPCMessage{ - Name: r.rpcMethodName() + "Key", - } - for _, key := range entityInfo.keyFields { - keyFieldMessage.Fields = append(keyFieldMessage.Fields, RPCField{ - Name: key.fieldName, - TypeName: key.fieldType, - JSONPath: key.fieldName, - }) - } - - r.planInfo.currentRequestMessage.Fields = []RPCField{ - { - Name: "keys", - TypeName: DataTypeMessage.String(), - Repeated: true, // The inputs are always a list of objects - JSONPath: "representations", - Message: keyFieldMessage, - }, - } - - // The proto response message has a field `result` which is a list of entities. - // As this is a special case we directly map it to _entities. - r.planInfo.currentResponseMessage.Fields = []RPCField{ - { - Name: "result", - TypeName: DataTypeMessage.String(), - JSONPath: "_entities", - Repeated: true, - }, - } -} - -func (r *rpcPlanVisitor) resolveServiceName() string { - if r.mapping == nil || r.mapping.Service == "" { - return r.subgraphName - } - - return r.mapping.Service -} - -// resolveFieldMapping resolves the field mapping for a field. // This applies both for complex types in the input and for all fields in the response. func (r *rpcPlanVisitor) resolveFieldMapping(typeName, fieldName string) string { grpcFieldName, ok := r.mapping.ResolveFieldMapping(typeName, fieldName) @@ -733,150 +476,3 @@ func (r *rpcPlanVisitor) resolveInputArgument(baseType string, fieldRef int, arg return grpcFieldName } - -func (r *rpcPlanVisitor) resolveRPCMethodMapping() error { - if r.mapping == nil { - return nil - } - - if r.planInfo.isEntityLookup && r.planInfo.entityInfo.name != "" { - // Resolving the entity lookup method name is done differently - return nil - } - - var rpcConfig RPCConfig - var ok bool - - switch r.planInfo.operationType { - case ast.OperationTypeQuery: - rpcConfig, ok = r.mapping.QueryRPCs[r.planInfo.operationFieldName] - case ast.OperationTypeMutation: - rpcConfig, ok = r.mapping.MutationRPCs[r.planInfo.operationFieldName] - case ast.OperationTypeSubscription: - rpcConfig, ok = r.mapping.SubscriptionRPCs[r.planInfo.operationFieldName] - } - - // if we don't have a mapping, we can skip the operation - if !ok { - return nil - } - - // We require all fields to be present when defining a mapping for the operation - if rpcConfig.RPC == "" { - return fmt.Errorf("no rpc method name mapping found for operation %s", r.planInfo.operationFieldName) - } - - if rpcConfig.Request == "" { - return fmt.Errorf("no request message name mapping found for operation %s", r.planInfo.operationFieldName) - } - - if rpcConfig.Response == "" { - return fmt.Errorf("no response message name mapping found for operation %s", r.planInfo.operationFieldName) - } - - r.currentCall.MethodName = rpcConfig.RPC - r.currentCall.Request.Name = rpcConfig.Request - r.currentCall.Response.Name = rpcConfig.Response - - return nil -} - -// rpcMethodName determines the appropriate method name based on operation type. -func (r *rpcPlanVisitor) rpcMethodName() string { - if r.planInfo.methodName != "" { - return r.planInfo.methodName - } - - switch r.planInfo.operationType { - case ast.OperationTypeQuery: - r.planInfo.methodName = r.buildQueryMethodName() - case ast.OperationTypeMutation: - r.planInfo.methodName = r.buildMutationMethodName() - case ast.OperationTypeSubscription: - r.planInfo.methodName = r.buildSubscriptionMethodName() - } - - return r.planInfo.methodName -} - -// buildQueryMethodName constructs a method name for query operations. -func (r *rpcPlanVisitor) buildQueryMethodName() string { - if r.planInfo.isEntityLookup && r.planInfo.entityInfo.name != "" { - return "Lookup" + r.planInfo.entityInfo.keyTypeName - } - - return "Query" + cases.Title(language.Und, cases.NoLower).String(r.planInfo.operationFieldName) -} - -// buildMutationMethodName constructs a method name for mutation operations. -func (r *rpcPlanVisitor) buildMutationMethodName() string { - // TODO implement mutation method name handling - return "Mutation" + cases.Title(language.Und, cases.NoLower).String(r.planInfo.operationFieldName) -} - -// buildSubscriptionMethodName constructs a method name for subscription operations. -func (r *rpcPlanVisitor) buildSubscriptionMethodName() string { - // TODO implement subscription method name handling - return "Subscription" + cases.Title(language.Und, cases.NoLower).String(r.planInfo.operationFieldName) -} - -// toDataType converts an ast.Type to a DataType. -// It handles the different type kinds and non-null types. -func (r *rpcPlanVisitor) toDataType(t *ast.Type) DataType { - switch t.TypeKind { - case ast.TypeKindNamed: - return r.parseGraphQLType(t) - case ast.TypeKindList: - return r.toDataType(&r.definition.Types[t.OfType]) - case ast.TypeKindNonNull: - return r.toDataType(&r.definition.Types[t.OfType]) - } - - return DataTypeUnknown -} - -// parseGraphQLType parses an ast.Type and returns the corresponding DataType. -// It handles the different type kinds and non-null types. -func (r *rpcPlanVisitor) parseGraphQLType(t *ast.Type) DataType { - dt := r.definition.Input.ByteSliceString(t.Name) - - // Retrieve the node to check the kind - node, found := r.definition.NodeByNameStr(dt) - if !found { - return DataTypeUnknown - } - - // For non-scalar types, return the corresponding DataType - switch node.Kind { - case ast.NodeKindInterfaceTypeDefinition: - fallthrough - case ast.NodeKindUnionTypeDefinition: - fallthrough - case ast.NodeKindObjectTypeDefinition, ast.NodeKindInputObjectTypeDefinition: - return DataTypeMessage - case ast.NodeKindEnumTypeDefinition: - return DataTypeEnum - default: - return fromGraphQLType(dt) - } -} - -func (r *rpcPlanVisitor) typeIsNullableOrNestedList(typeRef int) bool { - if !r.definition.TypeIsNonNull(typeRef) && r.definition.TypeIsList(typeRef) { - return true - } - - if r.definition.TypeNumberOfListWraps(typeRef) > 1 { - return true - } - - return false -} - -// titleSlice capitalizes the first letter of each string in a slice. -func titleSlice(s []string) []string { - for i, v := range s { - s[i] = cases.Title(language.Und, cases.NoLower).String(v) - } - return s -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go new file mode 100644 index 0000000000..541d30fe69 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -0,0 +1,476 @@ +package grpcdatasource + +import ( + "errors" + "fmt" + "strings" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/runes" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// entityInfo contains the information about the entity that is being looked up. +type entityInfo struct { + typeName string + entityRootFieldRef int + entityInlineFragmentRef int +} + +type federationConfigData struct { + entityTypeName string + keyFields string + requiredFields string +} + +func newFederationConfigData(entityTypeName string) federationConfigData { + return federationConfigData{ + entityTypeName: entityTypeName, + keyFields: "", + requiredFields: "", + } +} + +func (f *federationConfigData) getRootKeyNames() []string { + keys := make([]string, 0) + + selectionSetQueue := []struct{}{} + currentIndex := 0 + + for i := range f.keyFields { + switch f.keyFields[i] { + case runes.LBRACE: + selectionSetQueue = append(selectionSetQueue, struct{}{}) + case runes.RBRACE: + if len(selectionSetQueue) == 0 { + continue + } + selectionSetQueue = selectionSetQueue[:len(selectionSetQueue)-1] + currentIndex = i + 1 + case runes.SPACE: + if len(selectionSetQueue) > 0 { + continue + } + + key := strings.TrimSpace(f.keyFields[currentIndex:i]) + currentIndex = i + 1 + + if key == "" { + continue + } + + keys = append(keys, key) + } + } + + if currentIndex < len(f.keyFields) { + key := strings.TrimSpace(f.keyFields[currentIndex:]) + if key != "" { + keys = append(keys, key) + } + } + + return keys +} + +type rpcPlanVisitorFederation struct { + walker *astvisitor.Walker + operation *ast.Document + definition *ast.Document + planCtx *rpcPlanningContext + mapping *GRPCMapping + + planInfo planningInfo + entityInfo entityInfo + federationConfigData []federationConfigData + + plan *RPCExecutionPlan + subgraphName string + currentCall *RPCCall + currentCallIndex int +} + +func newRPCPlanVisitorFederation(walker *astvisitor.Walker, config rpcPlanVisitorConfig) *rpcPlanVisitorFederation { + visitor := &rpcPlanVisitorFederation{ + walker: walker, + plan: &RPCExecutionPlan{}, + subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), + mapping: config.mapping, + entityInfo: entityInfo{ + entityRootFieldRef: -1, + entityInlineFragmentRef: -1, + }, + federationConfigData: parseFederationConfigData(config.federationConfigs), + } + + walker.RegisterEnterDocumentVisitor(visitor) + walker.RegisterEnterOperationVisitor(visitor) + walker.RegisterInlineFragmentVisitor(visitor) + walker.RegisterSelectionSetVisitor(visitor) + walker.RegisterFieldVisitor(visitor) + + return visitor +} + +func (r *rpcPlanVisitorFederation) ExecutionPlan() *RPCExecutionPlan { + return r.plan +} + +// EnterDocument implements astvisitor.EnterDocumentVisitor. +func (r *rpcPlanVisitorFederation) EnterDocument(operation *ast.Document, definition *ast.Document) { + r.operation = operation + r.definition = definition + + r.planCtx = newRPCPlanningContext(operation, definition, r.mapping) +} + +// EnterOperationDefinition implements astvisitor.EnterOperationDefinitionVisitor. +func (r *rpcPlanVisitorFederation) EnterOperationDefinition(ref int) { + if r.operation.OperationDefinitions[ref].OperationType != ast.OperationTypeQuery { + r.walker.StopWithInternalErr(errors.New("only query operations are supported for the federation plan visitor")) + return + } + + r.planInfo.operationType = r.operation.OperationDefinitions[ref].OperationType +} + +// EnterInlineFragment implements astvisitor.InlineFragmentVisitor. +func (r *rpcPlanVisitorFederation) EnterInlineFragment(ref int) { + // if !r.IsEntityInlineFragment(ref) { + // return + // } + + fragmentName := r.operation.InlineFragmentTypeConditionNameString(ref) + fc, ok := r.FederationConfigDataByEntityTypeName(fragmentName) + if !ok { + return + } + + r.currentCall = &RPCCall{ + ServiceName: r.planCtx.resolveServiceName(r.subgraphName), + } + + r.planInfo.currentRequestMessage = &r.currentCall.Request + r.planInfo.currentResponseMessage = &r.currentCall.Response + + r.entityInfo.entityInlineFragmentRef = ref + r.entityInfo.typeName = fragmentName + r.resolveEntityInformation(ref, fc) + r.scaffoldEntityLookup(fc) +} + +// LeaveInlineFragment implements astvisitor.InlineFragmentVisitor. +func (r *rpcPlanVisitorFederation) LeaveInlineFragment(ref int) { + if r.entityInfo.entityInlineFragmentRef != ref { + // We only handle the entity inline fragment + return + } + + // We need to ensure that all the fields that are in the required fields are also present in the response message. + fc, found := r.FederationConfigDataByEntityTypeName(r.entityInfo.typeName) + if !found { + r.walker.StopWithInternalErr(errors.New("federation config data not found for entity type name: " + r.entityInfo.typeName)) + return + } + + if fc.requiredFields != "" { + r.planCtx.ensureRequiredFields(r.planInfo.currentResponseMessage, &fc) + } + + r.plan.Calls = append(r.plan.Calls, *r.currentCall) + r.currentCall = &RPCCall{} + r.currentCallIndex++ + + r.planInfo = planningInfo{ + operationType: r.planInfo.operationType, + operationFieldName: r.planInfo.operationFieldName, + currentRequestMessage: &RPCMessage{}, + currentResponseMessage: &RPCMessage{}, + currentResponseFieldIndex: 0, + responseMessageAncestors: []*RPCMessage{}, + responseFieldIndexAncestors: []int{}, + } + + r.entityInfo.entityInlineFragmentRef = ast.InvalidRef +} + +// EnterSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *rpcPlanVisitorFederation) EnterSelectionSet(ref int) { + if r.walker.Ancestor().Kind == ast.NodeKindOperationDefinition { + return + } + + if r.planInfo.currentRequestMessage == nil || len(r.planInfo.currentResponseMessage.Fields) == 0 || len(r.planInfo.currentResponseMessage.Fields) <= r.planInfo.currentResponseFieldIndex { + return + } + + // In nested selection sets, a new message needs to be created, which will be added to the current response message. + if r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message == nil { + r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) + } + + // Add the current response message to the ancestors and set the current response message to the current field message + r.planInfo.responseMessageAncestors = append(r.planInfo.responseMessageAncestors, r.planInfo.currentResponseMessage) + r.planInfo.currentResponseMessage = r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message + + // Check if the ancestor type is a composite type (interface or union) + // and set the oneof type and member types. + if err := r.handleCompositeType(r.walker.Ancestor()); err != nil { + // If the ancestor is a composite type, but we were unable to resolve the member types, + // we stop the walker and return an internal error. + r.walker.StopWithInternalErr(err) + return + } + + // Keep track of the field indices for the current response message. + // This is used to set the correct field index for the current response message + // when leaving the selection set. + r.planInfo.responseFieldIndexAncestors = append(r.planInfo.responseFieldIndexAncestors, r.planInfo.currentResponseFieldIndex) + + r.planInfo.currentResponseFieldIndex = 0 // reset the field index for the current selection set +} + +func (r *rpcPlanVisitorFederation) handleCompositeType(node ast.Node) error { + if node.Ref < 0 { + return nil + } + + var ( + ok bool + oneOfType OneOfType + memberTypes []string + ) + + switch node.Kind { + case ast.NodeKindField: + return r.handleCompositeType(r.walker.EnclosingTypeDefinition) + case ast.NodeKindInterfaceTypeDefinition: + oneOfType = OneOfTypeInterface + memberTypes, ok = r.definition.InterfaceTypeDefinitionImplementedByObjectWithNames(node.Ref) + if !ok { + return fmt.Errorf("interface type %s is not implemented by any object", r.definition.InterfaceTypeDefinitionNameString(node.Ref)) + } + case ast.NodeKindUnionTypeDefinition: + oneOfType = OneOfTypeUnion + memberTypes, ok = r.definition.UnionTypeDefinitionMemberTypeNames(node.Ref) + if !ok { + return fmt.Errorf("union type %s is not defined", r.definition.UnionTypeDefinitionNameString(node.Ref)) + } + default: + return nil + } + + r.planInfo.currentResponseMessage.OneOfType = oneOfType + r.planInfo.currentResponseMessage.MemberTypes = memberTypes + + return nil +} + +// LeaveSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *rpcPlanVisitorFederation) LeaveSelectionSet(ref int) { + if r.walker.Ancestor().Kind == ast.NodeKindInlineFragment { + return + } + + if len(r.planInfo.responseFieldIndexAncestors) > 0 { + r.planInfo.currentResponseFieldIndex = r.planInfo.responseFieldIndexAncestors[len(r.planInfo.responseFieldIndexAncestors)-1] + r.planInfo.responseFieldIndexAncestors = r.planInfo.responseFieldIndexAncestors[:len(r.planInfo.responseFieldIndexAncestors)-1] + } + + if len(r.planInfo.responseMessageAncestors) > 0 { + r.planInfo.currentResponseMessage = r.planInfo.responseMessageAncestors[len(r.planInfo.responseMessageAncestors)-1] + r.planInfo.responseMessageAncestors = r.planInfo.responseMessageAncestors[:len(r.planInfo.responseMessageAncestors)-1] + } +} + +// EnterField implements astvisitor.FieldVisitor. +func (r *rpcPlanVisitorFederation) EnterField(ref int) { + fieldName := r.operation.FieldNameString(ref) + if r.walker.InRootField() { + r.planInfo.operationFieldName = r.operation.FieldNameString(ref) + } + + if fieldName == "_entities" { + // _entities is a special field that is used to look up entities + // Entity lookups are handled differently as we use special types for + // Providing variables (_Any) and the response type is a Union that needs to be + // determined from the first inline fragment. + r.entityInfo = entityInfo{ + entityRootFieldRef: ref, + entityInlineFragmentRef: -1, + } + + r.entityInfo.entityRootFieldRef = ref + return + } + + // prevent duplicate fields + fieldAlias := r.operation.FieldAliasString(ref) + if r.planInfo.currentResponseMessage.Fields.Exists(fieldName, fieldAlias) { + return + } + + fd, ok := r.walker.FieldDefinition(ref) + if !ok { + r.walker.Report.AddExternalError(operationreport.ExternalError{ + Message: fmt.Sprintf("Field %s not found in definition %s", r.operation.FieldNameString(ref), r.walker.EnclosingTypeDefinition.NameString(r.definition)), + }) + return + } + + field, err := r.planCtx.buildField(r.walker.EnclosingTypeDefinition, fd, fieldName, fieldAlias) + if err != nil { + r.walker.StopWithInternalErr(err) + return + } + + // check if we are inside of an inline fragment and not the entity inline fragment + if ref, ok := r.walker.ResolveInlineFragment(); ok && r.entityInfo.entityInlineFragmentRef != ref { + if r.planInfo.currentResponseMessage.FieldSelectionSet == nil { + r.planInfo.currentResponseMessage.FieldSelectionSet = make(RPCFieldSelectionSet) + } + + inlineFragmentName := r.operation.InlineFragmentTypeConditionNameString(ref) + r.planInfo.currentResponseMessage.FieldSelectionSet.Add(inlineFragmentName, field) + return + } + + r.planInfo.currentResponseMessage.Fields = append(r.planInfo.currentResponseMessage.Fields, field) +} + +// LeaveField implements astvisitor.FieldVisitor. +func (r *rpcPlanVisitorFederation) LeaveField(ref int) { + // If we are not in the operation field, we can increment the response field index. + if !r.walker.InRootField() { + r.planInfo.currentResponseFieldIndex++ + return + } + + r.planInfo.currentResponseFieldIndex = 0 +} + +func (r *rpcPlanVisitorFederation) resolveEntityInformation(inlineFragmentRef int, fc federationConfigData) { + fragmentName := r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef) + node, found := r.definition.NodeByNameStr(r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef)) + if !found { + r.walker.StopWithInternalErr(errors.New("definition node not found for inline fragment: " + fragmentName)) + return + } + + // Only process object type definitions + // TODO: handle interfaces + if node.Kind != ast.NodeKindObjectTypeDefinition { + return + } + + rpcConfig, exists := r.mapping.ResolveEntityRPCConfig(fc.entityTypeName, fc.keyFields) + if !exists { + return + } + + r.currentCall.Request.Name = rpcConfig.Request + r.currentCall.Response.Name = rpcConfig.Response + r.currentCall.MethodName = rpcConfig.RPC +} + +// scaffoldEntityLookup creates the entity lookup call structure +// by creating the key field message and adding it to the current request message. +// It also adds the results message to the current response message. +func (r *rpcPlanVisitorFederation) scaffoldEntityLookup(fc federationConfigData) { + keyFieldMessage := &RPCMessage{ + Name: r.currentCall.MethodName + "Key", + } + + walker := astvisitor.WalkerFromPool() + defer walker.Release() + + requiredFieldsVisitor := newRequiredFieldsVisitor(walker, keyFieldMessage, r.planCtx) + err := requiredFieldsVisitor.visitRequiredFields(r.definition, fc.entityTypeName, fc.keyFields) + if err != nil { + r.walker.StopWithInternalErr(err) + return + } + + r.planInfo.currentRequestMessage.Fields = []RPCField{ + { + Name: "keys", + TypeName: DataTypeMessage.String(), + Repeated: true, // The inputs are always a list of objects + JSONPath: "representations", + Message: keyFieldMessage, + }, + } + + // The proto response message has a field `result` which is a list of entities. + // As this is a special case we directly map it to _entities. + r.planInfo.currentResponseMessage.Fields = []RPCField{ + { + Name: "result", + TypeName: DataTypeMessage.String(), + JSONPath: "_entities", + Repeated: true, + }, + } +} + +// FederationConfigDataByEntityTypeName returns the entity config data for the given entity type name. +func (r *rpcPlanVisitorFederation) FederationConfigDataByEntityTypeName(entityTypeName string) (federationConfigData, bool) { + for _, fc := range r.federationConfigData { + if fc.entityTypeName == entityTypeName { + return fc, true + } + } + return federationConfigData{}, false +} + +func (r *rpcPlanVisitorFederation) IsEntityInlineFragment(ref int) bool { + if r.entityInfo.entityInlineFragmentRef == ast.InvalidRef { + return false + } + + return r.entityInfo.entityInlineFragmentRef == ref +} + +func parseFederationConfigData(federationConfigs plan.FederationFieldConfigurations) []federationConfigData { + var out []federationConfigData + + typeNameIndexSet := map[string]int{} + typeNameIndex := 0 + + for _, fc := range federationConfigs { + // If the entity is not resolvable, we skip it + // TODO: Is this needed? + if fc.DisableEntityResolver { + continue + } + + // Create a new entity type if it doesn't exist + if _, ok := typeNameIndexSet[fc.TypeName]; !ok { + out = append(out, newFederationConfigData(fc.TypeName)) + typeNameIndexSet[fc.TypeName] = typeNameIndex + typeNameIndex++ + } + + data := &out[typeNameIndexSet[fc.TypeName]] + + // Selection set determines whether we have key fields or additional required fields + if fc.SelectionSet == "" { + continue + } + + // This is a required field, so we add it to the required fields + if fc.FieldName != "" { + data.requiredFields = fc.SelectionSet + continue + } + + // This is a key field, so we add it to the key fields + data.keyFields = fc.SelectionSet + } + + return out +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index a718b71aac..d17c4fcc23 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -17,10 +17,12 @@ import ( "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" protoref "google.golang.org/protobuf/reflect/protoreflect" ) @@ -32,11 +34,12 @@ var _ resolve.DataSource = (*DataSource)(nil) // transforms the responses back to GraphQL format. type DataSource struct { // Invocations is a list of gRPC invocations to be executed - plan *RPCExecutionPlan - cc grpc.ClientConnInterface - rc *RPCCompiler - mapping *GRPCMapping - disabled bool + plan *RPCExecutionPlan + cc grpc.ClientConnInterface + rc *RPCCompiler + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations + disabled bool } type ProtoConfig struct { @@ -44,28 +47,30 @@ type ProtoConfig struct { } type DataSourceConfig struct { - Operation *ast.Document - Definition *ast.Document - Compiler *RPCCompiler - SubgraphName string - Mapping *GRPCMapping - Disabled bool + Operation *ast.Document + Definition *ast.Document + Compiler *RPCCompiler + SubgraphName string + Mapping *GRPCMapping + FederationConfigs plan.FederationFieldConfigurations + Disabled bool } // NewDataSource creates a new gRPC datasource func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*DataSource, error) { - planner := NewPlanner(config.SubgraphName, config.Mapping) + planner := NewPlanner(config.SubgraphName, config.Mapping, config.FederationConfigs) plan, err := planner.PlanOperation(config.Operation, config.Definition) if err != nil { return nil, err } return &DataSource{ - plan: plan, - cc: client, - rc: config.Compiler, - mapping: config.Mapping, - disabled: config.Disabled, + plan: plan, + cc: client, + rc: config.Compiler, + mapping: config.Mapping, + federationConfigs: config.FederationConfigs, + disabled: config.Disabled, }, nil } @@ -97,6 +102,10 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) for _, invocation := range invocations { // Invoke the gRPC method - this will populate invocation.Output methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) + + b, _ := protojson.Marshal(invocation.Input) + fmt.Println(string(b)) + err := d.cc.Invoke(ctx, methodName, invocation.Input, invocation.Output) if err != nil { out.Write(writeErrorBytes(err)) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index aae0711585..627611a5c6 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -125,6 +125,14 @@ func Test_DataSource_Load(t *testing.T) { SubgraphName: "Products", Compiler: compiler, Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "complexFilterType": { + RPC: "QueryComplexFilterType", + Request: "QueryComplexFilterTypeRequest", + Response: "QueryComplexFilterTypeResponse", + }, + }, Fields: map[string]FieldMap{ "Query": { "complexFilterType": { @@ -176,10 +184,21 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { SubgraphName: "Products", Compiler: compiler, Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "complexFilterType": { + RPC: "QueryComplexFilterType", + Request: "QueryComplexFilterTypeRequest", + Response: "QueryComplexFilterTypeResponse", + }, + }, Fields: map[string]FieldMap{ "Query": { "complexFilterType": { TargetName: "complex_filter_type", + ArgumentMappings: map[string]string{ + "filter": "filter", + }, }, }, "FilterType": { @@ -255,10 +274,21 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { SubgraphName: "Products", Compiler: compiler, Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "complexFilterType": { + RPC: "QueryComplexFilterType", + Request: "QueryComplexFilterTypeRequest", + Response: "QueryComplexFilterTypeResponse", + }, + }, Fields: map[string]FieldMap{ "Query": { "complexFilterType": { TargetName: "complex_filter_type", + ArgumentMappings: map[string]string{ + "filter": "filter", + }, }, }, "FilterType": { @@ -346,6 +376,31 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { Definition: &schemaDoc, SubgraphName: "Products", Compiler: compiler, + Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "user": { + RPC: "QueryUser", + Request: "QueryUserRequest", + Response: "QueryUserResponse", + }, + }, + Fields: map[string]FieldMap{ + "Query": { + "user": { + TargetName: "user", + }, + }, + "User": { + "id": { + TargetName: "id", + }, + "name": { + TargetName: "name", + }, + }, + }, + }, }) require.NoError(t, err) diff --git a/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go b/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go index 146f371778..bb52f42954 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go +++ b/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go @@ -218,21 +218,25 @@ func testMapping() *GRPCMapping { }, }, SubscriptionRPCs: RPCConfigMap{}, - EntityRPCs: map[string]EntityRPCConfig{ + EntityRPCs: map[string][]EntityRPCConfig{ "Product": { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupProductById", - Request: "LookupProductByIdRequest", - Response: "LookupProductByIdResponse", + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupProductById", + Request: "LookupProductByIdRequest", + Response: "LookupProductByIdResponse", + }, }, }, "Storage": { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupStorageById", - Request: "LookupStorageByIdRequest", - Response: "LookupStorageByIdResponse", + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupStorageById", + Request: "LookupStorageByIdRequest", + Response: "LookupStorageByIdResponse", + }, }, }, }, @@ -624,6 +628,14 @@ func testMapping() *GRPCMapping { TargetName: "filter", }, }, + "FilterTypeInput": { + "filterField1": { + TargetName: "filter_field_1", + }, + "filterField2": { + TargetName: "filter_field_2", + }, + }, "Category": { "id": { TargetName: "id", diff --git a/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go new file mode 100644 index 0000000000..19f3de23c0 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go @@ -0,0 +1,164 @@ +package grpcdatasource + +import ( + "errors" + "fmt" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" +) + +// requiredFieldsVisitor is a visitor that visits the required fields of a message. +type requiredFieldsVisitor struct { + operation *ast.Document + definition *ast.Document + + walker *astvisitor.Walker + message *RPCMessage + + planCtx *rpcPlanningContext + + messageAncestors []*RPCMessage +} + +// newRequiredFieldsVisitor creates a new requiredFieldsVisitor. +// It registers the visitor with the walker and returns it. +func newRequiredFieldsVisitor(walker *astvisitor.Walker, message *RPCMessage, planCtx *rpcPlanningContext) *requiredFieldsVisitor { + visitor := &requiredFieldsVisitor{ + walker: walker, + message: message, + planCtx: planCtx, + messageAncestors: []*RPCMessage{}, + } + + walker.RegisterEnterDocumentVisitor(visitor) + walker.RegisterSelectionSetVisitor(visitor) + walker.RegisterEnterFieldVisitor(visitor) + + return visitor +} + +// visitRequiredFields visits the required fields of a message. +// It creates a new document with the required fields and walks it. +// To achieve that we create a fragment with the required fields and walk it. +func (r *requiredFieldsVisitor) visitRequiredFields(definition *ast.Document, typeName, requiredFields string) error { + doc, report := plan.RequiredFieldsFragment(typeName, requiredFields, false) + if report.HasErrors() { + return report + } + + r.walker.Walk(doc, definition, report) + if report.HasErrors() { + return report + } + + return nil +} + +// EnterDocument implements astvisitor.EnterDocumentVisitor. +func (r *requiredFieldsVisitor) EnterDocument(operation *ast.Document, definition *ast.Document) { + if r.message == nil { + r.walker.StopWithInternalErr(errors.New("unable to visit required fields. Message is required")) + return + } + + r.operation = operation + r.definition = definition +} + +// EnterSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *requiredFieldsVisitor) EnterSelectionSet(ref int) { + // Ignore the root selection set + if r.walker.Ancestor().Kind == ast.NodeKindFragmentDefinition { + return + } + + lastField := &r.message.Fields[len(r.message.Fields)-1] + if lastField.Message == nil { + lastField.Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) + } + + r.messageAncestors = append(r.messageAncestors, r.message) + r.message = lastField.Message + + if err := r.handleCompositeType(r.walker.EnclosingTypeDefinition); err != nil { + r.walker.StopWithInternalErr(err) + return + } +} + +// LeaveSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *requiredFieldsVisitor) LeaveSelectionSet(ref int) { + if r.walker.Ancestor().Kind == ast.NodeKindFragmentDefinition { + return + } + + if len(r.messageAncestors) > 0 { + r.message = r.messageAncestors[len(r.messageAncestors)-1] + r.messageAncestors = r.messageAncestors[:len(r.messageAncestors)-1] + } +} + +// EnterField implements astvisitor.EnterFieldVisitor. +func (r *requiredFieldsVisitor) EnterField(ref int) { + fieldName := r.operation.FieldNameString(ref) + + // prevent duplicate fields + if r.message.Fields.Exists(fieldName, "") { + return + } + + fd, ok := r.walker.FieldDefinition(ref) + if !ok { + r.walker.Report.AddExternalError(operationreport.ExternalError{ + Message: fmt.Sprintf("Field %s not found in definition %s", fieldName, r.walker.EnclosingTypeDefinition.NameString(r.definition)), + }) + return + } + + field, err := r.planCtx.buildField(r.walker.EnclosingTypeDefinition, fd, fieldName, "") + if err != nil { + r.walker.StopWithInternalErr(err) + return + } + + r.message.Fields = append(r.message.Fields, field) +} + +func (r *requiredFieldsVisitor) handleCompositeType(node ast.Node) error { + if node.Ref < 0 { + return nil + } + + var ( + ok bool + oneOfType OneOfType + memberTypes []string + ) + + switch node.Kind { + case ast.NodeKindField: + return r.handleCompositeType(r.walker.EnclosingTypeDefinition) + case ast.NodeKindInterfaceTypeDefinition: + oneOfType = OneOfTypeInterface + memberTypes, ok = r.definition.InterfaceTypeDefinitionImplementedByObjectWithNames(node.Ref) + if !ok { + return fmt.Errorf("interface type %s is not implemented by any object", r.definition.InterfaceTypeDefinitionNameString(node.Ref)) + } + case ast.NodeKindUnionTypeDefinition: + oneOfType = OneOfTypeUnion + memberTypes, ok = r.definition.UnionTypeDefinitionMemberTypeNames(node.Ref) + if !ok { + return fmt.Errorf("union type %s is not defined", r.definition.UnionTypeDefinitionNameString(node.Ref)) + } + default: + return nil + } + + r.message.OneOfType = oneOfType + r.message.MemberTypes = memberTypes + + return nil +} diff --git a/v2/pkg/grpctest/mapping/mapping.go b/v2/pkg/grpctest/mapping/mapping.go index 842a029997..1476d084ab 100644 --- a/v2/pkg/grpctest/mapping/mapping.go +++ b/v2/pkg/grpctest/mapping/mapping.go @@ -225,21 +225,25 @@ func DefaultGRPCMapping() *grpcdatasource.GRPCMapping { }, }, SubscriptionRPCs: grpcdatasource.RPCConfigMap{}, - EntityRPCs: map[string]grpcdatasource.EntityRPCConfig{ + EntityRPCs: map[string][]grpcdatasource.EntityRPCConfig{ "Product": { - Key: "id", - RPCConfig: grpcdatasource.RPCConfig{ - RPC: "LookupProductById", - Request: "LookupProductByIdRequest", - Response: "LookupProductByIdResponse", + { + Key: "id", + RPCConfig: grpcdatasource.RPCConfig{ + RPC: "LookupProductById", + Request: "LookupProductByIdRequest", + Response: "LookupProductByIdResponse", + }, }, }, "Storage": { - Key: "id", - RPCConfig: grpcdatasource.RPCConfig{ - RPC: "LookupStorageById", - Request: "LookupStorageByIdRequest", - Response: "LookupStorageByIdResponse", + { + Key: "id", + RPCConfig: grpcdatasource.RPCConfig{ + RPC: "LookupStorageById", + Request: "LookupStorageByIdRequest", + Response: "LookupStorageByIdResponse", + }, }, }, }, diff --git a/v2/pkg/grpctest/testdata/products.graphqls b/v2/pkg/grpctest/testdata/products.graphqls index 18197a3320..ac009e83e6 100644 --- a/v2/pkg/grpctest/testdata/products.graphqls +++ b/v2/pkg/grpctest/testdata/products.graphqls @@ -11,6 +11,13 @@ type Storage @key(fields: "id") { location: String! } +type Warehouse @key(fields: "id") @key(fields: "slug") { + id: ID! + slug: String! + name: String! + location: String! +} + type User { id: ID! name: String! @@ -393,5 +400,5 @@ type Mutation { -union _Entity = Product | Storage +union _Entity = Product | Storage | Warehouse scalar _Any From 5b22725f350e2cbab5c5715b361ae2f30bad117b Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Mon, 4 Aug 2025 18:31:32 +0200 Subject: [PATCH 02/24] chore: remove debug statement --- v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go | 4 ---- 1 file changed, 4 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index d17c4fcc23..3e09581197 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -22,7 +22,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" - "google.golang.org/protobuf/encoding/protojson" protoref "google.golang.org/protobuf/reflect/protoreflect" ) @@ -103,9 +102,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // Invoke the gRPC method - this will populate invocation.Output methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) - b, _ := protojson.Marshal(invocation.Input) - fmt.Println(string(b)) - err := d.cc.Invoke(ctx, methodName, invocation.Input, invocation.Output) if err != nil { out.Write(writeErrorBytes(err)) From 96991d5ba83df362995f31b4dda56133946d0465 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Tue, 5 Aug 2025 15:37:09 +0200 Subject: [PATCH 03/24] chore: check error --- .../grpc_datasource/execution_plan_visitor_federation.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index 541d30fe69..eddc89981f 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -178,7 +178,11 @@ func (r *rpcPlanVisitorFederation) LeaveInlineFragment(ref int) { } if fc.requiredFields != "" { - r.planCtx.ensureRequiredFields(r.planInfo.currentResponseMessage, &fc) + if err := r.planCtx.ensureRequiredFields(r.planInfo.currentResponseMessage, &fc); err != nil { + r.walker.StopWithInternalErr(err) + return + } + } r.plan.Calls = append(r.plan.Calls, *r.currentCall) From 6b23a531c19336f1c733add12de34c4906e6e3a5 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 6 Aug 2025 10:51:50 +0200 Subject: [PATCH 04/24] chore: improve mapping logic for key fields --- .../grpc_datasource/configuration.go | 101 +++- .../grpc_datasource/configuration_test.go | 432 ++++++++++++++++++ .../grpc_datasource/execution_plan.go | 1 - .../execution_plan_federation_test.go | 325 ++++++++----- .../grpc_datasource/execution_plan_visitor.go | 6 - .../execution_plan_visitor_federation.go | 44 -- 6 files changed, 753 insertions(+), 156 deletions(-) create mode 100644 v2/pkg/engine/datasource/grpc_datasource/configuration_test.go diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration.go b/v2/pkg/engine/datasource/grpc_datasource/configuration.go index 68ed82947e..12375abe76 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/configuration.go +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration.go @@ -1,5 +1,11 @@ package grpcdatasource +import ( + "strings" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/runes" +) + type ( // RPCConfigMap is a map of RPC names to RPC configurations RPCConfigMap map[string]RPCConfig @@ -129,10 +135,103 @@ func (g *GRPCMapping) ResolveEntityRPCConfig(typeName, key string) (RPCConfig, b } for _, ei := range rpcConfig { - if ei.Key == key { + if compareKeyFields(ei.Key, key) { return ei.RPCConfig, true } + } return RPCConfig{}, false } + +type keySet map[string]struct{} + +func (k keySet) add(keys ...string) { + for _, key := range keys { + trimmedKey := strings.TrimSpace(key) + if trimmedKey != "" { + k[trimmedKey] = struct{}{} + } + } +} + +func (k keySet) equals(other keySet) bool { + if len(k) != len(other) { + return false + } + + for key := range k { + if _, ok := other[key]; !ok { + return false + } + } + + return true +} + +// We compare only top level key +func compareKeyFields(left, right string) bool { + if left == right { + return true + } + + left = stripSelectionSets(left) + right = stripSelectionSets(right) + + leftKeys := strings.Split(left, " ") + rightKeys := strings.Split(right, " ") + + leftSet := make(keySet) + leftSet.add(leftKeys...) + + rightSet := make(keySet) + rightSet.add(rightKeys...) + + return leftSet.equals(rightSet) +} + +func stripSelectionSets(keyString string) string { + + selectionSetQueue := []struct{}{} + currentIndex := 0 + + keyString = strings.ReplaceAll(keyString, ",", " ") + + var sb strings.Builder + + for i := range keyString { + switch keyString[i] { + case runes.LBRACE: + selectionSetQueue = append(selectionSetQueue, struct{}{}) + case runes.RBRACE: + currentIndex = i + 1 + if len(selectionSetQueue) == 0 { + continue + } + selectionSetQueue = selectionSetQueue[:len(selectionSetQueue)-1] + case runes.SPACE: + if len(selectionSetQueue) > 0 { + continue + } + + key := strings.TrimSpace(keyString[currentIndex:i]) + currentIndex = i + 1 + + if key == "" { + continue + } + + sb.WriteString(key) + sb.WriteRune(runes.SPACE) + } + } + + if currentIndex < len(keyString) && len(selectionSetQueue) == 0 { + key := strings.TrimSpace(keyString[currentIndex:]) + if key != "" { + sb.WriteString(key) + } + } + + return strings.TrimSpace(sb.String()) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go b/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go new file mode 100644 index 0000000000..9ec743a1f3 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go @@ -0,0 +1,432 @@ +package grpcdatasource + +import "testing" + +func TestCompareKeyFields(t *testing.T) { + tests := []struct { + name string + left string + right string + expected bool + }{ + { + name: "identical strings", + left: "id name", + right: "id name", + expected: true, + }, + { + name: "empty strings", + left: "", + right: "", + expected: true, + }, + { + name: "single field same", + left: "id", + right: "id", + expected: true, + }, + { + name: "single field different", + left: "id", + right: "name", + expected: false, + }, + { + name: "different order same fields", + left: "id name email", + right: "name email id", + expected: true, + }, + { + name: "comma separated same fields", + left: "id,name,email", + right: "name,email,id", + expected: true, + }, + { + name: "mixed separators same fields", + left: "id,name email", + right: "name email,id", + expected: true, + }, + { + name: "different number of fields", + left: "id name", + right: "id name email", + expected: false, + }, + { + name: "completely different fields", + left: "id name", + right: "email phone", + expected: false, + }, + { + name: "one field missing", + left: "id name email", + right: "id name", + expected: false, + }, + { + name: "extra whitespace handling", + left: " id name ", + right: "name id", + expected: true, + }, + { + name: "extra whitespace with commas", + left: " id , name , email ", + right: "email,name,id", + expected: true, + }, + { + name: "multiple consecutive spaces", + left: "id name email", + right: "email name id", + expected: true, + }, + { + name: "mixed spaces and commas with whitespace", + left: "id, name email", + right: "email, name id", + expected: true, + }, + { + name: "empty fields filtered out", + left: "id , , name", + right: "name id", + expected: true, + }, + { + name: "only spaces and commas", + left: " , , ", + right: "", + expected: true, + }, + { + name: "one empty one with fields", + left: "", + right: "id name", + expected: false, + }, + { + name: "duplicate fields in left", + left: "id name id", + right: "id name", + expected: true, + }, + { + name: "duplicate fields in both", + left: "id name id email", + right: "name email name id", + expected: true, + }, + { + name: "case sensitive comparison", + left: "ID name", + right: "id name", + expected: false, + }, + { + name: "single character fields", + left: "a b c", + right: "c a b", + expected: true, + }, + { + name: "complex field names", + left: "user_id created_at updated_at", + right: "updated_at user_id created_at", + expected: true, + }, + { + name: "simple selection set", + left: "id name { firstName lastName }", + right: "name id", + expected: true, + }, + { + name: "nested selection sets", + left: "id user { profile { name email } }", + right: "user id", + expected: true, + }, + { + name: "deeply nested selection sets", + left: "id foo { bar { baz { qux } } }", + right: "foo id", + expected: true, + }, + { + name: "multiple fields with selection sets", + left: "id name user { email } address { street city }", + right: "address user name id", + expected: true, + }, + { + name: "selection sets with different order", + left: "user { name email } id address { street }", + right: "id address user", + expected: true, + }, + { + name: "mixed fields and selection sets", + left: "id name user { profile { personal { firstName lastName } work { company } } } status", + right: "status user name id", + expected: true, + }, + { + name: "selection sets with spaces inside braces", + left: "id user { name email } address", + right: "address user id", + expected: true, + }, + { + name: "selection sets with commas and spaces", + left: "id, user { name, email }, address { street, city }", + right: "address, user, id", + expected: true, + }, + { + name: "empty selection sets", + left: "id user { } address { }", + right: "address user id", + expected: true, + }, + { + name: "selection sets with nested empty braces", + left: "id user { profile { } contact { email } }", + right: "user id", + expected: true, + }, + { + name: "different nested structures same top level", + left: "user { name email } product { title price { amount currency } }", + right: "product { id sku } user { firstName lastName }", + expected: true, + }, + { + name: "selection sets vs simple fields different", + left: "id user { name }", + right: "id name", + expected: false, + }, + { + name: "only selection set fields", + left: "user { name email } address { street city }", + right: "address user", + expected: true, + }, + { + name: "complex real-world example", + left: "id user { profile { personal { name age } professional { company role } } } orders { items { product { name price } quantity } total }", + right: "orders user id", + expected: true, + }, + { + name: "unbalanced braces handled gracefully", + left: "id user { name email", + right: "id user { name email", + expected: true, + }, + { + name: "extra closing braces", + left: "id user { name } } extra", + right: "extra user id", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := compareKeyFields(tt.left, tt.right) + if result != tt.expected { + t.Errorf("compareKeyFields(%q, %q) = %v, expected %v", + tt.left, tt.right, result, tt.expected) + } + }) + } +} + +func TestKeySet(t *testing.T) { + t.Run("add method", func(t *testing.T) { + ks := make(keySet) + ks.add("id", "name", "", " ", "email") + + expected := keySet{ + "id": struct{}{}, + "name": struct{}{}, + "email": struct{}{}, + } + + if !ks.equals(expected) { + t.Errorf("keySet.add() did not produce expected result. Got %v, expected %v", ks, expected) + } + }) + + t.Run("equals method same sets", func(t *testing.T) { + ks1 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + ks2 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + + if !ks1.equals(ks2) { + t.Error("keySet.equals() should return true for identical sets") + } + }) + + t.Run("equals method different sizes", func(t *testing.T) { + ks1 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + ks2 := keySet{ + "id": struct{}{}, + } + + if ks1.equals(ks2) { + t.Error("keySet.equals() should return false for sets of different sizes") + } + }) + + t.Run("equals method different keys", func(t *testing.T) { + ks1 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + ks2 := keySet{ + "id": struct{}{}, + "email": struct{}{}, + } + + if ks1.equals(ks2) { + t.Error("keySet.equals() should return false for sets with different keys") + } + }) + + t.Run("equals method empty sets", func(t *testing.T) { + ks1 := make(keySet) + ks2 := make(keySet) + + if !ks1.equals(ks2) { + t.Error("keySet.equals() should return true for empty sets") + } + }) +} + +func TestStripSelectionSets(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "no selection sets", + input: "id name email", + expected: "id name email", + }, + { + name: "simple selection set", + input: "id name { firstName lastName }", + expected: "id name", + }, + { + name: "nested selection sets", + input: "id user { profile { name email } }", + expected: "id user", + }, + { + name: "deeply nested selection sets", + input: "id foo { bar { baz { qux } } }", + expected: "id foo", + }, + { + name: "multiple selection sets", + input: "id user { name } address { street }", + expected: "id user address", + }, + { + name: "empty selection sets", + input: "id user { } address { }", + expected: "id user address", + }, + { + name: "selection sets with commas", + input: "id, user { name, email }, address", + expected: "id user address", + }, + { + name: "complex nested example", + input: "id user { profile { personal { name age } work { company } } } orders { items { product } }", + expected: "id user orders", + }, + { + name: "spaces inside selection sets", + input: "id user { name email } address", + expected: "id user address", + }, + { + name: "empty input", + input: "", + expected: "", + }, + { + name: "only selection sets", + input: "user { name } address { street }", + expected: "user address", + }, + { + name: "unbalanced opening brace", + input: "id user { name email", + expected: "id user", + }, + { + name: "extra closing braces", + input: "id user { name } } extra", + expected: "id user extra", + }, + { + name: "consecutive spaces", + input: "id name user { email }", + expected: "id name user", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := stripSelectionSets(tt.input) + if result != tt.expected { + t.Errorf("stripSelectionSets(%q) = %q, expected %q", + tt.input, result, tt.expected) + } + }) + } +} + +// Benchmark tests for performance +func BenchmarkCompareKeyFields(b *testing.B) { + testCases := []struct { + name string + left string + right string + }{ + {"simple", "id name", "name id"}, + {"complex", "id,name,email,phone,address", "address phone email name id"}, + {"long", "field1 field2 field3 field4 field5 field6 field7 field8 field9 field10", "field10 field9 field8 field7 field6 field5 field4 field3 field2 field1"}, + {"long and nested", "field1 field2 field3 field4 field5 field6 field7 field8 field9 field10 { field11 field12 field13 field14 field15 field16 field17 field18 field19 field20 }", "field20 field19 field18 field17 field16 field15 field14 field13 field12 field11 field10 field9 field8 field7 field6 field5 field4 field3 field2 field1"}, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + compareKeyFields(tc.left, tc.right) + } + }) + } +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index 4817e41bf3..9dca2a6ed9 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -11,7 +11,6 @@ import ( ) const ( - federationKeyDirectiveName = "key" // knownTypeOptionalFieldValueName is the name of the field that is used to wrap optional scalar values // in a message as protobuf scalar types are not nullable. knownTypeOptionalFieldValueName = "value" diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index ee8170a2c3..f002833063 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -2,7 +2,6 @@ package grpcdatasource import ( "fmt" - "reflect" "strings" "testing" @@ -890,6 +889,227 @@ func TestEntityKeys(t *testing.T) { }, }, }, + { + name: "Order in a compound key should not matter", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + type User @key(fields: "id name") { + id: ID! + name: String! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "name id", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndName", + Request: "LookupUserByIdAndNameRequest", + Response: "LookupUserByIdAndNameResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id name", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndName", + Request: RPCMessage{ + Name: "LookupUserByIdAndNameRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndNameKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdAndNameResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Nested fields in a compound key should be ignored", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + + type Address { + id: ID! + street: String! + } + type User @key(fields: "id name address { id }") { + id: ID! + name: String! + address: Address! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "name id address", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndNameAndAddress", + Request: "LookupUserByIdAndNameAndAddressRequest", + Response: "LookupUserByIdAndNameAndAddressResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id name address { id }", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndNameAndAddress", + Request: RPCMessage{ + Name: "LookupUserByIdAndNameAndAddressRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndNameAndAddressKey", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + { + Name: "address", + TypeName: string(DataTypeMessage), + JSONPath: "address", + Message: &RPCMessage{ + Name: "Address", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdAndNameAndAddressResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, } for _, tt := range tests { @@ -1223,106 +1443,3 @@ func testFederationSchemaString(schema string, entities []string) string { scalar _Any `, schema, entityUnion) } - -func TestRootKeyNames(t *testing.T) { - tests := []struct { - name string - keyFields string - expected []string - }{ - { - name: "Should return root key names with nested fields", - keyFields: "id address { id }", - expected: []string{"id", "address"}, - }, - { - name: "Should handle single field", - keyFields: "id", - expected: []string{"id"}, - }, - { - name: "Should handle multiple simple fields", - keyFields: "id name email status", - expected: []string{"id", "name", "email", "status"}, - }, - { - name: "Should handle mixed simple and complex fields", - keyFields: "id user { name email } status", - expected: []string{"id", "user", "status"}, - }, - { - name: "Should handle multiple nested objects", - keyFields: "id user { name email } address { street city } status", - expected: []string{"id", "user", "address", "status"}, - }, - { - name: "Should handle deeply nested fields", - keyFields: "id user { profile { personal { name age } } }", - expected: []string{"id", "user"}, - }, - { - name: "Should handle fields with underscores and numbers", - keyFields: "user_id item_2 product_variant { sku_code }", - expected: []string{"user_id", "item_2", "product_variant"}, - }, - { - name: "Should handle camelCase fields", - keyFields: "userId firstName lastName userProfile { emailAddress }", - expected: []string{"userId", "firstName", "lastName", "userProfile"}, - }, - { - name: "Should handle empty string", - keyFields: "", - expected: []string{}, - }, - { - name: "Should handle whitespace only", - keyFields: " \t\n ", - expected: []string{}, - }, - { - name: "Should handle extra whitespace", - keyFields: " id name address { street city } status ", - expected: []string{"id", "name", "address", "status"}, - }, - { - name: "Should handle nested braces", - keyFields: "id organization { department { team { lead { name } } } }", - expected: []string{"id", "organization"}, - }, - { - name: "Should handle multiple levels of nesting", - keyFields: "id user { contact { phone { primary secondary } email } } metadata { tags }", - expected: []string{"id", "user", "metadata"}, - }, - { - name: "Should handle fields with dashes", - keyFields: "user-id product-code shipping-address { street-name }", - expected: []string{"user-id", "product-code", "shipping-address"}, - }, - { - name: "Should handle single nested field", - keyFields: "user { id }", - expected: []string{"user"}, - }, - { - name: "Should handle adjacent nested fields", - keyFields: "user { name } address { street }", - expected: []string{"user", "address"}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - fc := federationConfigData{ - keyFields: tt.keyFields, - } - t.Logf("keyFields: %s", tt.keyFields) - actual := fc.getRootKeyNames() - if !reflect.DeepEqual(actual, tt.expected) { - t.Fatalf("expected %v, got %v", tt.expected, actual) - } - }) - } -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go index a8c7946269..d806172d20 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go @@ -12,13 +12,7 @@ import ( "golang.org/x/text/language" ) -type keyField struct { - fieldName string - fieldType string -} - type planningInfo struct { - // resolvers []string operationType ast.OperationType operationFieldName string diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index eddc89981f..f05989b453 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -3,12 +3,10 @@ package grpcdatasource import ( "errors" "fmt" - "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/runes" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" "golang.org/x/text/cases" "golang.org/x/text/language" @@ -35,48 +33,6 @@ func newFederationConfigData(entityTypeName string) federationConfigData { } } -func (f *federationConfigData) getRootKeyNames() []string { - keys := make([]string, 0) - - selectionSetQueue := []struct{}{} - currentIndex := 0 - - for i := range f.keyFields { - switch f.keyFields[i] { - case runes.LBRACE: - selectionSetQueue = append(selectionSetQueue, struct{}{}) - case runes.RBRACE: - if len(selectionSetQueue) == 0 { - continue - } - selectionSetQueue = selectionSetQueue[:len(selectionSetQueue)-1] - currentIndex = i + 1 - case runes.SPACE: - if len(selectionSetQueue) > 0 { - continue - } - - key := strings.TrimSpace(f.keyFields[currentIndex:i]) - currentIndex = i + 1 - - if key == "" { - continue - } - - keys = append(keys, key) - } - } - - if currentIndex < len(f.keyFields) { - key := strings.TrimSpace(f.keyFields[currentIndex:]) - if key != "" { - keys = append(keys, key) - } - } - - return keys -} - type rpcPlanVisitorFederation struct { walker *astvisitor.Walker operation *ast.Document From 8042e65f60557fe1ec15817a1fcab39d99450a21 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 6 Aug 2025 10:59:10 +0200 Subject: [PATCH 05/24] chore: fix alias behavior for nullable fields --- .../grpc_datasource/grpc_datasource.go | 2 +- .../grpc_datasource/grpc_datasource_test.go | 41 +++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 3e09581197..5ee811dcd2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -235,7 +235,7 @@ func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessa } if field.IsOptionalScalar() { - err := d.resolveOptionalField(arena, root, field.JSONPath, msg) + err := d.resolveOptionalField(arena, root, field.AliasOrPath(), msg) if err != nil { return nil, err } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 627611a5c6..ce279a11b3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -1858,6 +1858,47 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } }, }, + { + name: "Query nullable fields type with all fields and aliases", + query: `query { nullableFieldsType { id name optionalString1: optionalString optionalInt1: optionalInt optionalFloat1: optionalFloat optionalBoolean1: optionalBoolean requiredString1: requiredString requiredInt1: requiredInt } }`, + vars: "{}", + validate: func(t *testing.T, data map[string]interface{}) { + nullableFieldsType, ok := data["nullableFieldsType"].(map[string]interface{}) + require.True(t, ok, "nullableFieldsType should be an object") + require.NotEmpty(t, nullableFieldsType, "nullableFieldsType should not be empty") + + // Check required fields are present + require.Contains(t, nullableFieldsType, "id") + require.Contains(t, nullableFieldsType, "name") + require.Contains(t, nullableFieldsType, "requiredString1") + require.Contains(t, nullableFieldsType, "requiredInt1") + + require.NotEmpty(t, nullableFieldsType["id"], "id should not be empty") + require.NotEmpty(t, nullableFieldsType["name"], "name should not be empty") + require.NotEmpty(t, nullableFieldsType["requiredString1"], "requiredString1 should not be empty") + require.NotEmpty(t, nullableFieldsType["requiredInt1"], "requiredInt1 should not be empty") + + // Check optional fields are present (but may be null) + require.Contains(t, nullableFieldsType, "optionalString1") + require.Contains(t, nullableFieldsType, "optionalInt1") + require.Contains(t, nullableFieldsType, "optionalFloat1") + require.Contains(t, nullableFieldsType, "optionalBoolean1") + + // Verify types of non-null optional fields + if nullableFieldsType["optionalString1"] != nil { + require.IsType(t, "", nullableFieldsType["optionalString1"]) + } + if nullableFieldsType["optionalInt1"] != nil { + require.IsType(t, float64(0), nullableFieldsType["optionalInt1"]) // JSON numbers are float64 + } + if nullableFieldsType["optionalFloat1"] != nil { + require.IsType(t, float64(0), nullableFieldsType["optionalFloat1"]) + } + if nullableFieldsType["optionalBoolean1"] != nil { + require.IsType(t, false, nullableFieldsType["optionalBoolean1"]) + } + }, + }, { name: "Query nullable fields type by ID", query: `query($id: ID!) { nullableFieldsTypeById(id: $id) { id name optionalString requiredString } }`, From 04af7d693a8a4f89284367db97acf1b34c5e3dd0 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 6 Aug 2025 11:01:32 +0200 Subject: [PATCH 06/24] chore: remove invalid test scenarios --- .../datasource/grpc_datasource/configuration_test.go | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go b/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go index 9ec743a1f3..b34c11bea3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go @@ -225,18 +225,6 @@ func TestCompareKeyFields(t *testing.T) { right: "orders user id", expected: true, }, - { - name: "unbalanced braces handled gracefully", - left: "id user { name email", - right: "id user { name email", - expected: true, - }, - { - name: "extra closing braces", - left: "id user { name } } extra", - right: "extra user id", - expected: false, - }, } for _, tt := range tests { From d7cc0e846d52df230af6662503dd1b87e6562691 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 6 Aug 2025 11:16:22 +0200 Subject: [PATCH 07/24] chore: simplify selection set stripping logic --- .../grpc_datasource/configuration.go | 43 ++++++++----------- 1 file changed, 17 insertions(+), 26 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration.go b/v2/pkg/engine/datasource/grpc_datasource/configuration.go index 12375abe76..f054d2b3bc 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/configuration.go +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration.go @@ -191,45 +191,36 @@ func compareKeyFields(left, right string) bool { } func stripSelectionSets(keyString string) string { + depth := 0 - selectionSetQueue := []struct{}{} - currentIndex := 0 + lastIndex := len(keyString) - 1 + var prev rune keyString = strings.ReplaceAll(keyString, ",", " ") var sb strings.Builder - for i := range keyString { - switch keyString[i] { + for i, r := range keyString { + switch r { case runes.LBRACE: - selectionSetQueue = append(selectionSetQueue, struct{}{}) + depth++ case runes.RBRACE: - currentIndex = i + 1 - if len(selectionSetQueue) == 0 { - continue - } - selectionSetQueue = selectionSetQueue[:len(selectionSetQueue)-1] - case runes.SPACE: - if len(selectionSetQueue) > 0 { + if depth == 0 { continue } - key := strings.TrimSpace(keyString[currentIndex:i]) - currentIndex = i + 1 - - if key == "" { - continue + depth-- + case runes.COMMA: + if i < lastIndex && keyString[i+1] != runes.SPACE { + sb.WriteRune(runes.SPACE) + } + default: + if depth != 0 || (r == runes.SPACE && prev == runes.SPACE) { + break } - sb.WriteString(key) - sb.WriteRune(runes.SPACE) - } - } - - if currentIndex < len(keyString) && len(selectionSetQueue) == 0 { - key := strings.TrimSpace(keyString[currentIndex:]) - if key != "" { - sb.WriteString(key) + sb.WriteRune(r) + prev = r } } From ce4d4452549f611f6deb595ceb6a27cefcb6af7f Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 6 Aug 2025 11:26:44 +0200 Subject: [PATCH 08/24] chore: remove comma switch --- v2/pkg/engine/datasource/grpc_datasource/configuration.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration.go b/v2/pkg/engine/datasource/grpc_datasource/configuration.go index f054d2b3bc..18860026d2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/configuration.go +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration.go @@ -193,14 +193,13 @@ func compareKeyFields(left, right string) bool { func stripSelectionSets(keyString string) string { depth := 0 - lastIndex := len(keyString) - 1 var prev rune keyString = strings.ReplaceAll(keyString, ",", " ") var sb strings.Builder - for i, r := range keyString { + for _, r := range keyString { switch r { case runes.LBRACE: depth++ @@ -210,10 +209,6 @@ func stripSelectionSets(keyString string) string { } depth-- - case runes.COMMA: - if i < lastIndex && keyString[i+1] != runes.SPACE { - sb.WriteRune(runes.SPACE) - } default: if depth != 0 || (r == runes.SPACE && prev == runes.SPACE) { break From 45227e2ba108a369c0fdeba30fa726010e70f62d Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Wed, 6 Aug 2025 11:56:04 +0200 Subject: [PATCH 09/24] chore: address review comments --- .../grpc_datasource/execution_plan_federation_test.go | 7 +------ .../datasource/grpc_datasource/required_fields_visitor.go | 5 +++++ 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index f002833063..8684f34870 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -1117,7 +1117,7 @@ func TestEntityKeys(t *testing.T) { } } -func TestRequriedFields(t *testing.T) { +func TestRequiredFields(t *testing.T) { tests := []struct { name string query string @@ -1412,11 +1412,6 @@ func runFederationTest(t *testing.T, tt struct { t.Fatalf("failed to parse query: %s", report.Error()) } - operation, report = astparser.ParseGraphqlDocumentString(tt.query) - if report.HasErrors() { - t.Fatalf("failed to parse query: %s", report.Error()) - } - planner := NewPlanner("Products", tt.mapping, tt.federationConfigs) plan, err := planner.PlanOperation(&operation, &definition) if err != nil { diff --git a/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go index 19f3de23c0..7a6809f513 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go +++ b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go @@ -75,6 +75,11 @@ func (r *requiredFieldsVisitor) EnterSelectionSet(ref int) { return } + if len(r.message.Fields) == 0 { + r.walker.StopWithInternalErr(errors.New("cannot access last field: message has no fields")) + return + } + lastField := &r.message.Fields[len(r.message.Fields)-1] if lastField.Message == nil { lastField.Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) From add12415814b08149ec8508eba35ffd032d8a214 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Thu, 7 Aug 2025 11:02:04 +0200 Subject: [PATCH 10/24] chore: address review comments --- .../grpc_datasource/execution_plan.go | 47 +-- .../execution_plan_composite_test.go | 13 +- .../execution_plan_federation_test.go | 357 +----------------- .../grpc_datasource/execution_plan_test.go | 19 +- .../grpc_datasource/execution_plan_visitor.go | 29 +- .../execution_plan_visitor_federation.go | 44 +-- .../grpc_datasource/grpc_datasource_test.go | 2 +- 7 files changed, 55 insertions(+), 456 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index 9dca2a6ed9..ab63824613 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -5,9 +5,7 @@ import ( "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" - "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) const ( @@ -285,55 +283,35 @@ func (r *RPCExecutionPlan) String() string { } type PlanVisitor interface { - ExecutionPlan() *RPCExecutionPlan + PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) } -type Planner struct { - visitor PlanVisitor - walker *astvisitor.Walker -} - -// NewPlanner returns a new Planner instance. +// NewPlanner returns a new PlanVisitor instance. // // The planner is responsible for creating an RPCExecutionPlan from a given // GraphQL operation. It is used by the engine to execute operations against // gRPC services. -func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs plan.FederationFieldConfigurations) *Planner { - walker := astvisitor.NewWalker(48) - +func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs plan.FederationFieldConfigurations) PlanVisitor { if mapping == nil { mapping = new(GRPCMapping) } var visitor PlanVisitor if len(federationConfigs) > 0 { - visitor = newRPCPlanVisitorFederation(&walker, rpcPlanVisitorConfig{ + visitor = newRPCPlanVisitorFederation(rpcPlanVisitorConfig{ subgraphName: subgraphName, mapping: mapping, federationConfigs: federationConfigs, }) } else { - visitor = newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + visitor = newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: subgraphName, mapping: mapping, federationConfigs: federationConfigs, }) } - return &Planner{ - visitor: visitor, - walker: &walker, - } -} - -func (p *Planner) PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) { - report := &operationreport.Report{} - p.walker.Walk(operation, definition, report) - if report.HasErrors() { - return nil, fmt.Errorf("unable to plan operation: %w", report) - } - - return p.visitor.ExecutionPlan(), nil + return visitor } // formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation @@ -546,19 +524,6 @@ func (r *rpcPlanningContext) buildField(enclosingTypeNode ast.Node, fd int, fiel return field, nil } -func (r *rpcPlanningContext) ensureRequiredFields(message *RPCMessage, fc *federationConfigData) error { - // If the message is nil, we can't add any fields to it. - if message == nil { - return nil - } - - walker := astvisitor.WalkerFromPool() - defer walker.Release() - - requiredFieldsVisitor := newRequiredFieldsVisitor(walker, message, r) - return requiredFieldsVisitor.visitRequiredFields(r.definition, fc.entityTypeName, fc.requiredFields) -} - func (r *rpcPlanningContext) resolveServiceName(subgraphName string) string { if r.mapping == nil || r.mapping.Service == "" { return subgraphName diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go index 30ae95163e..6e784a42bd 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go @@ -6,7 +6,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" grpctest "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" ) @@ -604,14 +603,12 @@ func TestCompositeTypeExecutionPlan(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: testMapping(), }) - walker.Walk(&queryDoc, &schemaDoc, &report) + rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) if report.HasErrors() { require.NotEmpty(t, tt.expectedError) @@ -871,14 +868,12 @@ func TestMutationUnionExecutionPlan(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: testMapping(), }) - walker.Walk(&queryDoc, &schemaDoc, &report) + rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) if report.HasErrors() { require.NotEmpty(t, tt.expectedError) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index 8684f34870..247eeec653 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -260,88 +260,6 @@ func TestEntityLookup(t *testing.T) { }, }, }, - { - name: "Should create an execution plan for an entity lookup with required fields", - query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on Product { __typename id name } } }`, - mapping: testMapping(), - federationConfigs: plan.FederationFieldConfigurations{ - { - TypeName: "Product", - SelectionSet: "id", - }, - { - TypeName: "Product", - FieldName: "name", // Field name requires price - SelectionSet: "price", - }, - }, - expectedPlan: &RPCExecutionPlan{ - Calls: []RPCCall{ - { - ServiceName: "Products", - MethodName: "LookupProductById", - Request: RPCMessage{ - Name: "LookupProductByIdRequest", - Fields: []RPCField{ - { - Name: "keys", - TypeName: string(DataTypeMessage), - Repeated: true, - JSONPath: "representations", - Message: &RPCMessage{ - Name: "LookupProductByIdKey", - Fields: []RPCField{ - { - Name: "id", - TypeName: string(DataTypeString), - JSONPath: "id", - }, - }, - }, - }, - }, - }, - Response: RPCMessage{ - Name: "LookupProductByIdResponse", - Fields: []RPCField{ - { - Name: "result", - TypeName: string(DataTypeMessage), - Repeated: true, - JSONPath: "_entities", - Message: &RPCMessage{ - Name: "Product", - Fields: []RPCField{ - { - Name: "__typename", - TypeName: string(DataTypeString), - JSONPath: "__typename", - StaticValue: "Product", - }, - { - Name: "id", - TypeName: string(DataTypeString), - JSONPath: "id", - }, - { - Name: "name", - TypeName: string(DataTypeString), - JSONPath: "name", - }, - { - Name: "price", - TypeName: string(DataTypeDouble), - JSONPath: "price", - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, // TODO implement multiple entity lookup types // { @@ -573,7 +491,7 @@ func TestEntityKeys(t *testing.T) { query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, schema: testFederationSchemaString(` type Query { - user(id: ID!): User + _entities(representations: [_Any!]!): [_Entity]! } type User @key(fields: "id") { id: ID! @@ -671,7 +589,7 @@ func TestEntityKeys(t *testing.T) { query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, schema: testFederationSchemaString(` type Query { - user(id: ID!): User + _entities(representations: [_Any!]!): [_Entity]! } type Address { @@ -1117,272 +1035,6 @@ func TestEntityKeys(t *testing.T) { } } -func TestRequiredFields(t *testing.T) { - tests := []struct { - name string - query string - schema string - expectedPlan *RPCExecutionPlan - mapping *GRPCMapping - federationConfigs plan.FederationFieldConfigurations - }{ - { - name: "Should also require reviews field when name is selected ", - schema: testFederationSchemaString(` - type Query { - user(id: ID!): User - } - type User @key(fields: "id") { - id: ID! - name: String! - reviews: [String!]! - } - `, []string{"User"}), - query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, - federationConfigs: plan.FederationFieldConfigurations{ - { - TypeName: "User", - SelectionSet: "id", // no field name mean this is related to the key - }, - { - TypeName: "User", - SelectionSet: "reviews", - FieldName: "name", // name requires reviews - }, - }, - mapping: &GRPCMapping{ - Service: "Products", - EntityRPCs: map[string][]EntityRPCConfig{ - "User": { - { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupUserById", - Request: "LookupUserByIdRequest", - Response: "LookupUserByIdResponse", - }, - }, - }, - }, - }, - expectedPlan: &RPCExecutionPlan{ - Calls: []RPCCall{ - { - ServiceName: "Products", - MethodName: "LookupUserById", - Request: RPCMessage{ - Name: "LookupUserByIdRequest", - Fields: []RPCField{ - { - Name: "keys", - TypeName: string(DataTypeMessage), - Repeated: true, - JSONPath: "representations", - Message: &RPCMessage{ - Name: "LookupUserByIdKey", - Fields: []RPCField{ - { - Name: "id", - TypeName: string(DataTypeString), - JSONPath: "id", - }, - }, - }, - }, - }, - }, - Response: RPCMessage{ - Name: "LookupUserByIdResponse", - Fields: []RPCField{ - { - Name: "result", - TypeName: string(DataTypeMessage), - JSONPath: "_entities", - Repeated: true, - Message: &RPCMessage{ - Name: "User", - Fields: []RPCField{ - { - Name: "__typename", - TypeName: string(DataTypeString), - JSONPath: "__typename", - StaticValue: "User", - }, - { - Name: "id", - TypeName: string(DataTypeString), - JSONPath: "id", - }, - { - Name: "name", - TypeName: string(DataTypeString), - JSONPath: "name", - }, - { - Name: "reviews", - TypeName: string(DataTypeString), - JSONPath: "reviews", - Repeated: true, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - { - name: "Should require nested fields", - schema: testFederationSchemaString(` - type Query { - user(id: ID!): User - } - - type Review { - body: String! - title: String! - } - - type User @key(fields: "id") { - id: ID! - name: String! - reviews: [Review!]! - } - `, []string{"User"}), - query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name reviews } } }`, - federationConfigs: plan.FederationFieldConfigurations{ - { - TypeName: "User", - SelectionSet: "id", - }, - { - TypeName: "User", - SelectionSet: "reviews { body title }", - FieldName: "name", // name requires reviews { body title } - }, - }, - mapping: &GRPCMapping{ - Service: "Products", - EntityRPCs: map[string][]EntityRPCConfig{ - "User": { - { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupUserById", - Request: "LookupUserByIdRequest", - Response: "LookupUserByIdResponse", - }, - }, - }, - }, - Fields: map[string]FieldMap{ - "User": { - "id": { - TargetName: "id", - }, - }, - "Review": { - "body": { - TargetName: "body", - }, - "title": { - TargetName: "title", - }, - }, - }, - }, - expectedPlan: &RPCExecutionPlan{ - Calls: []RPCCall{ - { - ServiceName: "Products", - MethodName: "LookupUserById", - Request: RPCMessage{ - Name: "LookupUserByIdRequest", - Fields: []RPCField{ - { - Name: "keys", - TypeName: string(DataTypeMessage), - Repeated: true, - JSONPath: "representations", - Message: &RPCMessage{ - Name: "LookupUserByIdKey", - Fields: []RPCField{ - { - Name: "id", - TypeName: string(DataTypeString), - JSONPath: "id", - }, - }, - }, - }, - }, - }, - Response: RPCMessage{ - Name: "LookupUserByIdResponse", - Fields: []RPCField{ - { - Name: "result", - TypeName: string(DataTypeMessage), - JSONPath: "_entities", - Repeated: true, - Message: &RPCMessage{ - Name: "User", - Fields: []RPCField{ - { - Name: "__typename", - TypeName: string(DataTypeString), - JSONPath: "__typename", - StaticValue: "User", - }, - { - Name: "id", - TypeName: string(DataTypeString), - JSONPath: "id", - }, - { - Name: "name", - TypeName: string(DataTypeString), - JSONPath: "name", - }, - { - Name: "reviews", - TypeName: string(DataTypeMessage), - JSONPath: "reviews", - Repeated: true, - Message: &RPCMessage{ - Name: "Review", - Fields: []RPCField{ - { - Name: "body", - TypeName: string(DataTypeString), - JSONPath: "body", - }, - { - Name: "title", - TypeName: string(DataTypeString), - JSONPath: "title", - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - }, - } - - for _, tt := range tests { - runFederationTest(t, tt) - } -} - func runFederationTest(t *testing.T, tt struct { name string query string @@ -1412,6 +1064,11 @@ func runFederationTest(t *testing.T, tt struct { t.Fatalf("failed to parse query: %s", report.Error()) } + astvalidation.DefaultOperationValidator().Validate(&operation, &definition, &report) + if report.HasErrors() { + t.Fatalf("failed to validate query: %s", report.Error()) + } + planner := NewPlanner("Products", tt.mapping, tt.federationConfigs) plan, err := planner.PlanOperation(&operation, &definition) if err != nil { diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go index cc892e4679..79afa9ab50 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go @@ -11,7 +11,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" ) type testCase struct { @@ -30,14 +29,12 @@ func runTest(t *testing.T, testCase testCase) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: testMapping(), }) - walker.Walk(&queryDoc, &schemaDoc, &report) + rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) if report.HasErrors() { require.NotEmpty(t, testCase.expectedError, "expected error to be empty, got: %s", report.Error()) @@ -1093,23 +1090,21 @@ func TestQueryExecutionPlans(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: tt.mapping, }) - walker.Walk(&queryDoc, &schemaDoc, &report) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { - require.Contains(t, report.Error(), tt.expectedError) + if err != nil { + require.Contains(t, err.Error(), tt.expectedError) require.NotEmpty(t, tt.expectedError) return } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go index d806172d20..431357af14 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go @@ -27,12 +27,11 @@ type planningInfo struct { } type rpcPlanVisitor struct { - walker *astvisitor.Walker - operation *ast.Document - definition *ast.Document - planCtx *rpcPlanningContext - planInfo planningInfo - federationConfigs plan.FederationFieldConfigurations + walker *astvisitor.Walker + operation *ast.Document + definition *ast.Document + planCtx *rpcPlanningContext + planInfo planningInfo subgraphName string mapping *GRPCMapping @@ -51,14 +50,14 @@ type rpcPlanVisitorConfig struct { // newRPCPlanVisitor creates a new RPCPlanVisitor. // It registers the visitor with the walker and returns it. -func newRPCPlanVisitor(walker *astvisitor.Walker, config rpcPlanVisitorConfig) *rpcPlanVisitor { +func newRPCPlanVisitor(config rpcPlanVisitorConfig) *rpcPlanVisitor { + walker := astvisitor.NewWalker(48) visitor := &rpcPlanVisitor{ - walker: walker, + walker: &walker, plan: &RPCExecutionPlan{}, subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), mapping: config.mapping, operationFieldRef: -1, - federationConfigs: config.federationConfigs, } walker.RegisterEnterDocumentVisitor(visitor) @@ -70,8 +69,14 @@ func newRPCPlanVisitor(walker *astvisitor.Walker, config rpcPlanVisitorConfig) * return visitor } -func (r *rpcPlanVisitor) ExecutionPlan() *RPCExecutionPlan { - return r.plan +func (r *rpcPlanVisitor) PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) { + report := &operationreport.Report{} + r.walker.Walk(operation, definition, report) + if report.HasErrors() { + return nil, fmt.Errorf("unable to plan operation: %w", report) + } + + return r.plan, nil } // EnterDocument implements astvisitor.EnterDocumentVisitor. @@ -155,7 +160,7 @@ func (r *rpcPlanVisitor) EnterSelectionSet(ref int) { } func (r *rpcPlanVisitor) handleCompositeType(node ast.Node) error { - if node.Ref < 0 { + if node.Ref == ast.InvalidRef { return nil } diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index f05989b453..b768f0efaa 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -50,9 +50,10 @@ type rpcPlanVisitorFederation struct { currentCallIndex int } -func newRPCPlanVisitorFederation(walker *astvisitor.Walker, config rpcPlanVisitorConfig) *rpcPlanVisitorFederation { +func newRPCPlanVisitorFederation(config rpcPlanVisitorConfig) *rpcPlanVisitorFederation { + walker := astvisitor.NewWalker(48) visitor := &rpcPlanVisitorFederation{ - walker: walker, + walker: &walker, plan: &RPCExecutionPlan{}, subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), mapping: config.mapping, @@ -72,8 +73,14 @@ func newRPCPlanVisitorFederation(walker *astvisitor.Walker, config rpcPlanVisito return visitor } -func (r *rpcPlanVisitorFederation) ExecutionPlan() *RPCExecutionPlan { - return r.plan +func (r *rpcPlanVisitorFederation) PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) { + report := &operationreport.Report{} + r.walker.Walk(operation, definition, report) + if report.HasErrors() { + return nil, fmt.Errorf("unable to plan operation: %w", report) + } + + return r.plan, nil } // EnterDocument implements astvisitor.EnterDocumentVisitor. @@ -96,10 +103,6 @@ func (r *rpcPlanVisitorFederation) EnterOperationDefinition(ref int) { // EnterInlineFragment implements astvisitor.InlineFragmentVisitor. func (r *rpcPlanVisitorFederation) EnterInlineFragment(ref int) { - // if !r.IsEntityInlineFragment(ref) { - // return - // } - fragmentName := r.operation.InlineFragmentTypeConditionNameString(ref) fc, ok := r.FederationConfigDataByEntityTypeName(fragmentName) if !ok { @@ -126,21 +129,6 @@ func (r *rpcPlanVisitorFederation) LeaveInlineFragment(ref int) { return } - // We need to ensure that all the fields that are in the required fields are also present in the response message. - fc, found := r.FederationConfigDataByEntityTypeName(r.entityInfo.typeName) - if !found { - r.walker.StopWithInternalErr(errors.New("federation config data not found for entity type name: " + r.entityInfo.typeName)) - return - } - - if fc.requiredFields != "" { - if err := r.planCtx.ensureRequiredFields(r.planInfo.currentResponseMessage, &fc); err != nil { - r.walker.StopWithInternalErr(err) - return - } - - } - r.plan.Calls = append(r.plan.Calls, *r.currentCall) r.currentCall = &RPCCall{} r.currentCallIndex++ @@ -195,7 +183,7 @@ func (r *rpcPlanVisitorFederation) EnterSelectionSet(ref int) { } func (r *rpcPlanVisitorFederation) handleCompositeType(node ast.Node) error { - if node.Ref < 0 { + if node.Ref == ast.InvalidRef { return nil } @@ -261,7 +249,7 @@ func (r *rpcPlanVisitorFederation) EnterField(ref int) { // determined from the first inline fragment. r.entityInfo = entityInfo{ entityRootFieldRef: ref, - entityInlineFragmentRef: -1, + entityInlineFragmentRef: ast.InvalidRef, } r.entityInfo.entityRootFieldRef = ref @@ -402,12 +390,6 @@ func parseFederationConfigData(federationConfigs plan.FederationFieldConfigurati typeNameIndex := 0 for _, fc := range federationConfigs { - // If the entity is not resolvable, we skip it - // TODO: Is this needed? - if fc.DisableEntityResolver { - continue - } - // Create a new entity type if it doesn't exist if _, ok := typeNameIndexSet[fc.TypeName]; !ok { out = append(out, newFederationConfigData(fc.TypeName)) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index ce279a11b3..85381fb426 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -1859,7 +1859,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { }, }, { - name: "Query nullable fields type with all fields and aliases", + name: "Query nullable fields type with all aliased fields", query: `query { nullableFieldsType { id name optionalString1: optionalString optionalInt1: optionalInt optionalFloat1: optionalFloat optionalBoolean1: optionalBoolean requiredString1: requiredString requiredInt1: requiredInt } }`, vars: "{}", validate: func(t *testing.T, data map[string]interface{}) { From 681f2ae57c8b4d97e2ca1f46d160ba9170c6cf8b Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Thu, 7 Aug 2025 11:05:32 +0200 Subject: [PATCH 11/24] chore: proper error checks --- .../execution_plan_composite_test.go | 16 ++++++++-------- .../grpc_datasource/execution_plan_test.go | 10 +++++----- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go index 6e784a42bd..f6a81aec3c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go @@ -608,16 +608,16 @@ func TestCompositeTypeExecutionPlan(t *testing.T) { mapping: testMapping(), }) - rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { + if err != nil { require.NotEmpty(t, tt.expectedError) - require.Contains(t, report.Error(), tt.expectedError) + require.Contains(t, err.Error(), tt.expectedError) return } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } @@ -873,16 +873,16 @@ func TestMutationUnionExecutionPlan(t *testing.T) { mapping: testMapping(), }) - rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { + if err != nil { require.NotEmpty(t, tt.expectedError) - require.Contains(t, report.Error(), tt.expectedError) + require.Contains(t, err.Error(), tt.expectedError) return } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go index 79afa9ab50..a19ef4a8ff 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go @@ -34,16 +34,16 @@ func runTest(t *testing.T, testCase testCase) { mapping: testMapping(), }) - rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { - require.NotEmpty(t, testCase.expectedError, "expected error to be empty, got: %s", report.Error()) - require.Contains(t, report.Error(), testCase.expectedError, "expected error to contain: %s, got: %s", testCase.expectedError, report.Error()) + if err != nil { + require.NotEmpty(t, testCase.expectedError, "expected error to be empty, got: %s", err.Error()) + require.Contains(t, err.Error(), testCase.expectedError, "expected error to contain: %s, got: %s", testCase.expectedError, err.Error()) return } require.Empty(t, testCase.expectedError) - diff := cmp.Diff(testCase.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(testCase.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } From 6bd3025b856b5f17381b0f6801157c63c175555e Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Thu, 7 Aug 2025 14:32:23 +0200 Subject: [PATCH 12/24] chore: use InvalidRef --- v2/pkg/astvisitor/visitor.go | 4 ++-- .../grpc_datasource/execution_plan_visitor_federation.go | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index 9a48666e3d..a2cbb102da 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -4012,12 +4012,12 @@ func (w *Walker) InRootField() bool { // It returns -1 and false if the current field is not inside of an inline fragment. func (w *Walker) ResolveInlineFragment() (int, bool) { if len(w.Ancestors) < 2 { - return -1, false + return ast.InvalidRef, false } node := w.Ancestors[len(w.Ancestors)-2] if node.Kind != ast.NodeKindInlineFragment { - return -1, false + return ast.InvalidRef, false } return node.Ref, true diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index b768f0efaa..a65bcb2f9b 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -58,8 +58,8 @@ func newRPCPlanVisitorFederation(config rpcPlanVisitorConfig) *rpcPlanVisitorFed subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), mapping: config.mapping, entityInfo: entityInfo{ - entityRootFieldRef: -1, - entityInlineFragmentRef: -1, + entityRootFieldRef: ast.InvalidRef, + entityInlineFragmentRef: ast.InvalidRef, }, federationConfigData: parseFederationConfigData(config.federationConfigs), } From 4e36ffb5b8236b75fa39f07a4888f6e5b0a8338e Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 09:14:51 +0200 Subject: [PATCH 13/24] chore: handle entity order correctly --- .../datasource/grpc_datasource/compiler.go | 30 ++ .../grpc_datasource/execution_plan.go | 9 + .../execution_plan_federation_test.go | 24 +- .../execution_plan_visitor_federation.go | 12 +- .../grpc_datasource/grpc_datasource.go | 404 ++------------- .../grpc_datasource/grpc_datasource_test.go | 125 ++++- .../grpc_datasource/json_builder.go | 483 ++++++++++++++++++ .../required_fields_visitor.go | 1 + 8 files changed, 716 insertions(+), 372 deletions(-) create mode 100644 v2/pkg/engine/datasource/grpc_datasource/json_builder.go diff --git a/v2/pkg/engine/datasource/grpc_datasource/compiler.go b/v2/pkg/engine/datasource/grpc_datasource/compiler.go index ee49e8176c..959cb43c05 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/compiler.go +++ b/v2/pkg/engine/datasource/grpc_datasource/compiler.go @@ -3,6 +3,7 @@ package grpcdatasource import ( "context" "fmt" + "slices" "github.com/bufbuild/protocompile" "github.com/tidwall/gjson" @@ -433,6 +434,12 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes for _, element := range elements { switch field.Type { case DataTypeMessage: + // If we handle entity lookups, we get a list of representation variables that need to + // be applied to the correct type names. + if !isAllowedForTypename(rpcField.Message, element) { + continue + } + fieldMsg := p.buildProtoMessage(p.doc.Messages[field.MessageRef], rpcField.Message, element) list.Append(protoref.ValueOfMessage(fieldMsg)) default: @@ -785,3 +792,26 @@ func (p *RPCCompiler) parseField(f protoref.FieldDescriptor) Field { MessageRef: -1, } } + +func isAllowedForTypename(message *RPCMessage, element gjson.Result) bool { + if message == nil { + // We assume that having a nil message expects a null value. + return true + } + + // If we don't have a member types, we assume that the message is allowed for all types. + if message.MemberTypes == nil { + return true + } + + typeName := element.Get("__typename") + if !typeName.Exists() { + // If we don't have a type name, we assume that the message is allowed for all types. + return true + } + + typeString := typeName.String() + + // If we have a type name, we need to check if the message is restricted to a specific type. + return slices.Contains(message.MemberTypes, typeString) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index ab63824613..166ec3166a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -104,6 +104,15 @@ func (r *RPCMessage) SelectValidTypes(typeName string) []string { return []string{r.Name, typeName} } +func (r *RPCMessage) AppendTypeNameField(typeName string) { + r.Fields = append(r.Fields, RPCField{ + Name: "__typename", + TypeName: DataTypeString.String(), + StaticValue: typeName, + JSONPath: "__typename", + }) +} + // RPCFieldSelectionSet is a map of field selections based on inline fragments type RPCFieldSelectionSet map[string]RPCFields diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index 247eeec653..f0184b93cd 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -62,7 +62,8 @@ func TestEntityLookup(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupProductByIdKey", + Name: "LookupProductByIdKey", + MemberTypes: []string{"Product"}, Fields: []RPCField{ { Name: "id", @@ -144,7 +145,8 @@ func TestEntityLookup(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupProductByIdKey", + Name: "LookupProductByIdKey", + MemberTypes: []string{"Product"}, Fields: []RPCField{ { Name: "id", @@ -207,7 +209,8 @@ func TestEntityLookup(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupStorageByIdKey", + Name: "LookupStorageByIdKey", + MemberTypes: []string{"Storage"}, Fields: []RPCField{ { Name: "id", @@ -535,7 +538,8 @@ func TestEntityKeys(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupUserByIdKey", + Name: "LookupUserByIdKey", + MemberTypes: []string{"User"}, Fields: []RPCField{ { Name: "id", @@ -643,7 +647,8 @@ func TestEntityKeys(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupUserByIdAndAddressKey", + Name: "LookupUserByIdAndAddressKey", + MemberTypes: []string{"User"}, Fields: []RPCField{ { Name: "id", @@ -754,7 +759,8 @@ func TestEntityKeys(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupUserByIdAndNameKey", + Name: "LookupUserByIdAndNameKey", + MemberTypes: []string{"User"}, Fields: []RPCField{ { Name: "id", @@ -854,7 +860,8 @@ func TestEntityKeys(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupUserByIdAndNameKey", + Name: "LookupUserByIdAndNameKey", + MemberTypes: []string{"User"}, Fields: []RPCField{ { Name: "id", @@ -960,7 +967,8 @@ func TestEntityKeys(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupUserByIdAndNameAndAddressKey", + Name: "LookupUserByIdAndNameAndAddressKey", + MemberTypes: []string{"User"}, Fields: []RPCField{ { Name: "id", diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index a65bcb2f9b..1e33f52118 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -161,6 +161,10 @@ func (r *rpcPlanVisitorFederation) EnterSelectionSet(ref int) { r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) } + if r.IsEntityInlineFragment(r.walker.Ancestor()) { + r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message.AppendTypeNameField(r.entityInfo.typeName) + } + // Add the current response message to the ancestors and set the current response message to the current field message r.planInfo.responseMessageAncestors = append(r.planInfo.responseMessageAncestors, r.planInfo.currentResponseMessage) r.planInfo.currentResponseMessage = r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message @@ -375,12 +379,16 @@ func (r *rpcPlanVisitorFederation) FederationConfigDataByEntityTypeName(entityTy return federationConfigData{}, false } -func (r *rpcPlanVisitorFederation) IsEntityInlineFragment(ref int) bool { +func (r *rpcPlanVisitorFederation) IsEntityInlineFragment(node ast.Node) bool { + if node.Kind != ast.NodeKindInlineFragment { + return false + } + if r.entityInfo.entityInlineFragmentRef == ast.InvalidRef { return false } - return r.entityInfo.entityInlineFragmentRef == ref + return r.entityInfo.entityInlineFragmentRef == node.Ref } func parseFederationConfigData(federationConfigs plan.FederationFieldConfigurations) []federationConfigData { diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 5ee811dcd2..665ff8c817 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -9,9 +9,8 @@ package grpcdatasource import ( "bytes" "context" - "errors" "fmt" - "strconv" + "sync" "github.com/tidwall/gjson" "github.com/wundergraph/astjson" @@ -19,10 +18,8 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - protoref "google.golang.org/protobuf/reflect/protoreflect" ) // Verify DataSource implements the resolve.DataSource interface @@ -80,14 +77,15 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { + // get variables from input + variables := gjson.Parse(string(input)).Get("body.variables") + builder := newJSONBuilder(d.mapping, variables) + if d.disabled { - out.Write(writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used"))) + out.Write(builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used"))) return nil } - // get variables from input - variables := gjson.Parse(string(input)).Get("body.variables") - // get invocations from plan invocations, err := d.rc.Compile(d.plan, variables) if err != nil { @@ -95,38 +93,59 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) } a := astjson.Arena{} - root := a.NewObject() + responses := make([]*astjson.Value, len(invocations)) + errGrp, errGrpCtx := errgroup.WithContext(ctx) + + mu := sync.Mutex{} // make gRPC calls - for _, invocation := range invocations { - // Invoke the gRPC method - this will populate invocation.Output - methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) + for index, invocation := range invocations { + errGrp.Go(func() error { + // Invoke the gRPC method - this will populate invocation.Output + methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) - err := d.cc.Invoke(ctx, methodName, invocation.Input, invocation.Output) - if err != nil { - out.Write(writeErrorBytes(err)) + err := d.cc.Invoke(errGrpCtx, methodName, invocation.Input, invocation.Output) + if err != nil { + return err + } + + response, err := builder.marshalResponseJSON(&a, &invocation.Call.Response, invocation.Output) + if err != nil { + return err + } + + d.synchronizedSetResponse(&mu, responses, index, response) return nil - } + }) + } - responseJSON, err := d.marshalResponseJSON(&a, &invocation.Call.Response, invocation.Output) - if err != nil { - return err - } + if err := errGrp.Wait(); err != nil { + out.Write(builder.writeErrorBytes(err)) + return nil + } - root, _, err = astjson.MergeValues(root, responseJSON) + root := a.NewObject() + for _, response := range responses { + root, err = builder.mergeValues(root, response) if err != nil { - out.Write(writeErrorBytes(err)) - return nil + out.Write(builder.writeErrorBytes(err)) + return err } } - data := a.NewObject() - data.Set("data", root) + data := builder.toDataObject(root) out.Write(data.MarshalTo(nil)) return nil } +func (d *DataSource) synchronizedSetResponse(mu *sync.Mutex, responses []*astjson.Value, index int, response *astjson.Value) { + mu.Lock() + defer mu.Unlock() + + responses[index] = response +} + // LoadWithFiles implements resolve.DataSource interface. // Similar to Load, but handles file uploads if needed. // @@ -137,336 +156,3 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { panic("unimplemented") } - -func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { - if message == nil { - return arena.NewNull(), nil - } - - root := arena.NewObject() - - if message.IsOneOf() { - oneof := data.Descriptor().Oneofs().ByName(protoref.Name(message.OneOfType.FieldName())) - if oneof == nil { - return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) - } - - oneofDescriptor := data.WhichOneof(oneof) - if oneofDescriptor == nil { - return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) - } - - if oneofDescriptor.Kind() == protoref.MessageKind { - data = data.Get(oneofDescriptor).Message() - } - } - - validFields := message.Fields - if message.IsOneOf() { - validFields = append(validFields, message.FieldSelectionSet.SelectFieldsForTypes(message.SelectValidTypes(string(data.Type().Descriptor().Name())))...) - } - - for _, field := range validFields { - if field.StaticValue != "" { - if len(message.MemberTypes) == 0 { - root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) - continue - } - - for _, memberTypes := range message.MemberTypes { - if memberTypes == string(data.Type().Descriptor().Name()) { - root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) - break - } - } - - continue - } - - fd := data.Descriptor().Fields().ByName(protoref.Name(field.Name)) - if fd == nil { - continue - } - - if fd.IsList() { - list := data.Get(fd).List() - arr := arena.NewArray() - root.Set(field.AliasOrPath(), arr) - - if !list.IsValid() { - continue - } - - for i := 0; i < list.Len(); i++ { - - switch fd.Kind() { - case protoref.MessageKind: - message := list.Get(i).Message() - value, err := d.marshalResponseJSON(arena, field.Message, message) - if err != nil { - return nil, err - } - - arr.SetArrayItem(i, value) - default: - d.setArrayItem(i, arena, arr, list.Get(i), fd) - } - - } - - continue - } - - if fd.Kind() == protoref.MessageKind { - msg := data.Get(fd).Message() - if !msg.IsValid() { - root.Set(field.AliasOrPath(), arena.NewNull()) - continue - } - - if field.IsListType { - arr, err := d.flattenListStructure(arena, field.ListMetadata, msg, field.Message) - if err != nil { - return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) - } - - root.Set(field.AliasOrPath(), arr) - continue - } - - if field.IsOptionalScalar() { - err := d.resolveOptionalField(arena, root, field.AliasOrPath(), msg) - if err != nil { - return nil, err - } - - continue - } - - value, err := d.marshalResponseJSON(arena, field.Message, msg) - if err != nil { - return nil, err - } - - if field.JSONPath == "" { - root, _, err = astjson.MergeValues(root, value) - if err != nil { - return nil, err - } - } else { - root.Set(field.AliasOrPath(), value) - } - - continue - } - - d.setJSONValue(arena, root, field.AliasOrPath(), data, fd) - } - - return root, nil -} - -func (d *DataSource) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { - if md == nil { - return arena.NewNull(), errors.New("list metadata not found") - } - - if len(md.LevelInfo) < md.NestingLevel { - return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") - } - - if !data.IsValid() { - if md.LevelInfo[0].Optional { - return arena.NewNull(), nil - } - - return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") - } - - root := arena.NewArray() - return d.traverseList(0, arena, root, md, data, message) -} - -func (d *DataSource) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { - if level > md.NestingLevel { - return current, nil - } - - // List wrappers always use field number 1 - fd := data.Descriptor().Fields().ByNumber(1) - if fd == nil { - return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) - } - - if fd.Kind() != protoref.MessageKind { - return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) - } - - msg := data.Get(fd).Message() - if !msg.IsValid() { - // If the message is not valid we can either return null if the list is nullable or an error if it is non nullable. - if md.LevelInfo[level].Optional { - return arena.NewNull(), nil - } - - return arena.NewArray(), fmt.Errorf("cannot add null item to response for non nullable list") - } - - fd = msg.Descriptor().Fields().ByNumber(1) - if !fd.IsList() { - return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) - } - - if level < md.NestingLevel-1 { - list := msg.Get(fd).List() - for i := 0; i < list.Len(); i++ { - next := arena.NewArray() - val, err := d.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) - if err != nil { - return nil, err - } - - current.SetArrayItem(i, val) - } - - return current, nil - } - - list := msg.Get(fd).List() - if !list.IsValid() { - // If the list is not valid, we return an empty array here as the - // nullabilty is checked on the outer List wrapper type. - return arena.NewArray(), nil - } - - for i := 0; i < list.Len(); i++ { - if message != nil { - val, err := d.marshalResponseJSON(arena, message, list.Get(i).Message()) - if err != nil { - return nil, err - } - - current.SetArrayItem(i, val) - } else { - d.setArrayItem(i, arena, current, list.Get(i), fd) - } - } - - return current, nil -} - -func (d *DataSource) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { - fd := data.Descriptor().Fields().ByName(protoref.Name("value")) - if fd == nil { - return fmt.Errorf("unable to resolve optional field: field %q not found in message %s", "value", data.Descriptor().Name()) - } - - d.setJSONValue(arena, root, name, data, fd) - return nil -} - -func (d *DataSource) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { - if !data.IsValid() { - root.Set(name, arena.NewNull()) - return - } - - switch fd.Kind() { - case protoref.BoolKind: - boolValue := data.Get(fd).Bool() - if boolValue { - root.Set(name, arena.NewTrue()) - } else { - root.Set(name, arena.NewFalse()) - } - case protoref.StringKind: - root.Set(name, arena.NewString(data.Get(fd).String())) - case protoref.Int32Kind, protoref.Int64Kind: - root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) - case protoref.Uint32Kind, protoref.Uint64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) - case protoref.FloatKind, protoref.DoubleKind: - root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) - case protoref.BytesKind: - root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) - case protoref.EnumKind: - enumDesc := fd.Enum() - enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) - if enumValueDesc == nil { - root.Set(name, arena.NewNull()) - return - } - - graphqlValue, ok := d.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) - if !ok { - root.Set(name, arena.NewNull()) - return - } - - root.Set(name, arena.NewString(graphqlValue)) - } -} - -func (d *DataSource) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { - if !data.IsValid() { - array.SetArrayItem(index, arena.NewNull()) - return - } - - switch fd.Kind() { - case protoref.BoolKind: - boolValue := data.Bool() - if boolValue { - array.SetArrayItem(index, arena.NewTrue()) - } else { - array.SetArrayItem(index, arena.NewFalse()) - } - case protoref.StringKind: - array.SetArrayItem(index, arena.NewString(data.String())) - case protoref.Int32Kind, protoref.Int64Kind: - array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) - case protoref.Uint32Kind, protoref.Uint64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) - case protoref.FloatKind, protoref.DoubleKind: - array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) - case protoref.BytesKind: - array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) - case protoref.EnumKind: - enumDesc := fd.Enum() - enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) - if enumValueDesc == nil { - array.SetArrayItem(index, arena.NewNull()) - return - } - - graphqlValue, ok := d.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) - if !ok { - array.SetArrayItem(index, arena.NewNull()) - return - } - - array.SetArrayItem(index, arena.NewString(graphqlValue)) - } -} - -func writeErrorBytes(err error) []byte { - a := astjson.Arena{} - errorRoot := a.NewObject() - errorArray := a.NewArray() - errorRoot.Set("errors", errorArray) - - errorItem := a.NewObject() - errorItem.Set("message", a.NewString(err.Error())) - - extensions := a.NewObject() - if st, ok := status.FromError(err); ok { - extensions.Set("code", a.NewString(st.Code().String())) - } else { - extensions.Set("code", a.NewString(codes.Internal.String())) - } - - errorItem.Set("extensions", extensions) - errorArray.SetArrayItem(0, errorItem) - - return errorRoot.MarshalTo(nil) -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 85381fb426..68964e7a06 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -11,8 +11,10 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest/productv1" "google.golang.org/grpc" @@ -495,10 +497,9 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - ds := &DataSource{} - arena := astjson.Arena{} - responseJSON, err := ds.marshalResponseJSON(&arena, &response, responseMessage) + jsonBuilder := newJSONBuilder(nil, gjson.Result{}) + responseJSON, err := jsonBuilder.marshalResponseJSON(&arena, &response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) } @@ -3487,3 +3488,121 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { }) } } + +func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { + conn, cleanup := setupTestGRPCServer(t) + t.Cleanup(cleanup) + + testCases := []struct { + name string + query string + vars string + federationConfigs plan.FederationFieldConfigurations + validate func(t *testing.T, data map[string]interface{}) + }{ + { + name: "Query nullable fields type with all fields", + query: `query { _entities(representations: $representations) { ...on Product { id name } ...on Storage { id name } } }`, + vars: `{"variables":{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Storage","id":"3"}, + {"__typename":"Product","id":"2"}, + {"__typename":"Storage","id":"4"} + ]}}`, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "Product", + SelectionSet: "id", + }, + { + TypeName: "Storage", + SelectionSet: "id", + }, + }, + validate: func(t *testing.T, data map[string]interface{}) { + entities, ok := data["_entities"].([]interface{}) + require.True(t, ok, "_entities should be an array") + require.NotEmpty(t, entities, "_entities should not be empty") + + // Check required fields are present + require.Contains(t, entities[0], "id") + require.Contains(t, entities[0], "name") + require.Contains(t, entities[1], "id") + require.Contains(t, entities[1], "name") + + require.Len(t, entities, 4, "Should return 4 entities") + + product, ok := entities[0].(map[string]interface{}) + require.True(t, ok, "product should be an object") + require.Equal(t, "1", product["id"]) + require.Equal(t, "Product 1", product["name"]) + + storage, ok := entities[1].(map[string]interface{}) + require.True(t, ok, "storage should be an object") + require.Equal(t, "3", storage["id"]) + require.Equal(t, "Storage 3", storage["name"]) + + product2, ok := entities[2].(map[string]interface{}) + require.True(t, ok, "product2 should be an object") + require.Equal(t, "2", product2["id"]) + require.Equal(t, "Product 2", product2["name"]) + + storage2, ok := entities[3].(map[string]interface{}) + require.True(t, ok, "storage2 should be an object") + require.Equal(t, "4", storage2["id"]) + require.Equal(t, "Storage 4", storage2["name"]) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Parse the GraphQL schema + schemaDoc := grpctest.MustGraphQLSchema(t) + + // Parse the GraphQL query + queryDoc, report := astparser.ParseGraphqlDocumentString(tc.query) + if report.HasErrors() { + t.Fatalf("failed to parse query: %s", report.Error()) + } + + compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) + if err != nil { + t.Fatalf("failed to compile proto: %v", err) + } + + // Create the datasource + ds, err := NewDataSource(conn, DataSourceConfig{ + Operation: &queryDoc, + Definition: &schemaDoc, + SubgraphName: "Products", + Mapping: testMapping(), + Compiler: compiler, + FederationConfigs: tc.federationConfigs, + }) + require.NoError(t, err) + + // Execute the query through our datasource + output := new(bytes.Buffer) + input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) + err = ds.Load(context.Background(), []byte(input), output) + require.NoError(t, err) + + // Parse the response + var resp struct { + Data map[string]interface{} `json:"data"` + Errors []struct { + Message string `json:"message"` + } `json:"errors,omitempty"` + } + + err = json.Unmarshal(output.Bytes(), &resp) + require.NoError(t, err, "Failed to unmarshal response") + require.Empty(t, resp.Errors, "Response should not contain errors") + require.NotEmpty(t, resp.Data, "Response should contain data") + + // Run the validation function + tc.validate(t, resp.Data) + }) + } +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go new file mode 100644 index 0000000000..839f33e7a3 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -0,0 +1,483 @@ +package grpcdatasource + +import ( + "errors" + "fmt" + "strconv" + + "github.com/tidwall/gjson" + "github.com/wundergraph/astjson" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + protoref "google.golang.org/protobuf/reflect/protoreflect" +) + +var ( + entityPath = "_entities" + dataPath = "data" + errorsPath = "errors" +) + +type entityIndex struct { + representationIndex int + resultIndex int +} + +type indexMap map[string][]entityIndex + +func (i indexMap) getResultIndex(val *astjson.Value, representationIndex int) int { + if i == nil { + return representationIndex + } + + if val == nil { + return representationIndex + } + + typeName := val.Get("__typename").GetStringBytes() + + for _, entityIndex := range i[string(typeName)] { + if entityIndex.representationIndex == representationIndex { + return entityIndex.resultIndex + } + } + + return representationIndex +} + +func createRepresentationIndexMap(variables gjson.Result) indexMap { + var representations []gjson.Result + r := variables.Get("representations") + if !r.Exists() { + return nil + } + + representations = r.Array() + im := make(indexMap) + indexSet := make(map[string]int) + for i, representation := range representations { + typeName := representation.Get("__typename").String() + if _, ok := indexSet[typeName]; !ok { + indexSet[typeName] = -1 + } + + indexSet[typeName]++ + + im[typeName] = append(im[typeName], entityIndex{ + representationIndex: indexSet[typeName], + resultIndex: i, + }) + } + return im +} + +type jsonBuilder struct { + mapping *GRPCMapping + variables gjson.Result + indexMap indexMap +} + +func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { + return &jsonBuilder{ + mapping: mapping, + variables: variables, + indexMap: createRepresentationIndexMap(variables), + } +} + +func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { + if len(j.indexMap) == 0 { + // We don't have a representation index map, so we can just merge the values. + root, _, err := astjson.MergeValues(left, right) + if err != nil { + return nil, err + } + return root, nil + } + + // When we have an index map, we need to ensure to keep the order of the representations. + leftObject, err := left.Object() + if err != nil { + return nil, err + } + + if leftObject.Len() == 0 { + return right, nil + } + + return j.mergeEntities(left, right) +} + +func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { + + root := astjson.Arena{} + defer root.Reset() + + entities := root.NewObject() + entities.Set(entityPath, root.NewArray()) + + arr := entities.Get(entityPath) + + leftRepresentations, err := left.Get(entityPath).Array() + if err != nil { + return nil, err + } + + rightRepresentations, err := right.Get(entityPath).Array() + if err != nil { + return nil, err + } + + for index, lr := range leftRepresentations { + resultIndex := j.indexMap.getResultIndex(lr, index) + arr.SetArrayItem(resultIndex, lr) + } + + for index, rr := range rightRepresentations { + resultIndex := j.indexMap.getResultIndex(rr, index) + arr.SetArrayItem(resultIndex, rr) + } + + return entities, nil +} + +func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { + if message == nil { + return arena.NewNull(), nil + } + + root := arena.NewObject() + + if message.IsOneOf() { + oneof := data.Descriptor().Oneofs().ByName(protoref.Name(message.OneOfType.FieldName())) + if oneof == nil { + return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) + } + + oneofDescriptor := data.WhichOneof(oneof) + if oneofDescriptor == nil { + return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) + } + + if oneofDescriptor.Kind() == protoref.MessageKind { + data = data.Get(oneofDescriptor).Message() + } + } + + validFields := message.Fields + if message.IsOneOf() { + validFields = append(validFields, message.FieldSelectionSet.SelectFieldsForTypes(message.SelectValidTypes(string(data.Type().Descriptor().Name())))...) + } + + for _, field := range validFields { + if field.StaticValue != "" { + if len(message.MemberTypes) == 0 { + root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) + continue + } + + for _, memberTypes := range message.MemberTypes { + if memberTypes == string(data.Type().Descriptor().Name()) { + root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) + break + } + } + + continue + } + + fd := data.Descriptor().Fields().ByName(protoref.Name(field.Name)) + if fd == nil { + continue + } + + if fd.IsList() { + list := data.Get(fd).List() + arr := arena.NewArray() + root.Set(field.AliasOrPath(), arr) + + if !list.IsValid() { + continue + } + + for i := 0; i < list.Len(); i++ { + switch fd.Kind() { + case protoref.MessageKind: + message := list.Get(i).Message() + value, err := j.marshalResponseJSON(arena, field.Message, message) + if err != nil { + return nil, err + } + + arr.SetArrayItem(i, value) + default: + j.setArrayItem(i, arena, arr, list.Get(i), fd) + } + + } + + continue + } + + if fd.Kind() == protoref.MessageKind { + msg := data.Get(fd).Message() + if !msg.IsValid() { + root.Set(field.AliasOrPath(), arena.NewNull()) + continue + } + + if field.IsListType { + arr, err := j.flattenListStructure(arena, field.ListMetadata, msg, field.Message) + if err != nil { + return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) + } + + root.Set(field.AliasOrPath(), arr) + continue + } + + if field.IsOptionalScalar() { + err := j.resolveOptionalField(arena, root, field.AliasOrPath(), msg) + if err != nil { + return nil, err + } + + continue + } + + value, err := j.marshalResponseJSON(arena, field.Message, msg) + if err != nil { + return nil, err + } + + if field.JSONPath == "" { + root, _, err = astjson.MergeValues(root, value) + if err != nil { + return nil, err + } + } else { + root.Set(field.AliasOrPath(), value) + } + + continue + } + + j.setJSONValue(arena, root, field.AliasOrPath(), data, fd) + } + + return root, nil +} + +func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { + if md == nil { + return arena.NewNull(), errors.New("list metadata not found") + } + + if len(md.LevelInfo) < md.NestingLevel { + return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") + } + + if !data.IsValid() { + if md.LevelInfo[0].Optional { + return arena.NewNull(), nil + } + + return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") + } + + root := arena.NewArray() + return j.traverseList(0, arena, root, md, data, message) +} + +func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { + if level > md.NestingLevel { + return current, nil + } + + // List wrappers always use field number 1 + fd := data.Descriptor().Fields().ByNumber(1) + if fd == nil { + return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) + } + + if fd.Kind() != protoref.MessageKind { + return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) + } + + msg := data.Get(fd).Message() + if !msg.IsValid() { + // If the message is not valid we can either return null if the list is nullable or an error if it is non nullable. + if md.LevelInfo[level].Optional { + return arena.NewNull(), nil + } + + return arena.NewArray(), fmt.Errorf("cannot add null item to response for non nullable list") + } + + fd = msg.Descriptor().Fields().ByNumber(1) + if !fd.IsList() { + return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) + } + + if level < md.NestingLevel-1 { + list := msg.Get(fd).List() + for i := 0; i < list.Len(); i++ { + next := arena.NewArray() + val, err := j.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) + if err != nil { + return nil, err + } + + current.SetArrayItem(i, val) + } + + return current, nil + } + + list := msg.Get(fd).List() + if !list.IsValid() { + // If the list is not valid, we return an empty array here as the + // nullabilty is checked on the outer List wrapper type. + return arena.NewArray(), nil + } + + for i := 0; i < list.Len(); i++ { + if message != nil { + val, err := j.marshalResponseJSON(arena, message, list.Get(i).Message()) + if err != nil { + return nil, err + } + + current.SetArrayItem(i, val) + } else { + j.setArrayItem(i, arena, current, list.Get(i), fd) + } + } + + return current, nil +} + +func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { + fd := data.Descriptor().Fields().ByName(protoref.Name("value")) + if fd == nil { + return fmt.Errorf("unable to resolve optional field: field %q not found in message %s", "value", data.Descriptor().Name()) + } + + j.setJSONValue(arena, root, name, data, fd) + return nil +} + +func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { + if !data.IsValid() { + root.Set(name, arena.NewNull()) + return + } + + switch fd.Kind() { + case protoref.BoolKind: + boolValue := data.Get(fd).Bool() + if boolValue { + root.Set(name, arena.NewTrue()) + } else { + root.Set(name, arena.NewFalse()) + } + case protoref.StringKind: + root.Set(name, arena.NewString(data.Get(fd).String())) + case protoref.Int32Kind, protoref.Int64Kind: + root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) + case protoref.Uint32Kind, protoref.Uint64Kind: + root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) + case protoref.FloatKind, protoref.DoubleKind: + root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) + case protoref.BytesKind: + root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) + case protoref.EnumKind: + enumDesc := fd.Enum() + enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) + if enumValueDesc == nil { + root.Set(name, arena.NewNull()) + return + } + + graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) + if !ok { + root.Set(name, arena.NewNull()) + return + } + + root.Set(name, arena.NewString(graphqlValue)) + } +} + +func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { + if !data.IsValid() { + array.SetArrayItem(index, arena.NewNull()) + return + } + + switch fd.Kind() { + case protoref.BoolKind: + boolValue := data.Bool() + if boolValue { + array.SetArrayItem(index, arena.NewTrue()) + } else { + array.SetArrayItem(index, arena.NewFalse()) + } + case protoref.StringKind: + array.SetArrayItem(index, arena.NewString(data.String())) + case protoref.Int32Kind, protoref.Int64Kind: + array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) + case protoref.Uint32Kind, protoref.Uint64Kind: + array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) + case protoref.FloatKind, protoref.DoubleKind: + array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) + case protoref.BytesKind: + array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) + case protoref.EnumKind: + enumDesc := fd.Enum() + enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) + if enumValueDesc == nil { + array.SetArrayItem(index, arena.NewNull()) + return + } + + graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) + if !ok { + array.SetArrayItem(index, arena.NewNull()) + return + } + + array.SetArrayItem(index, arena.NewString(graphqlValue)) + } +} + +func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { + a := astjson.Arena{} + defer a.Reset() + data := a.NewObject() + data.Set(dataPath, root) + return data +} + +func (j *jsonBuilder) writeErrorBytes(err error) []byte { + a := astjson.Arena{} + defer a.Reset() + errorRoot := a.NewObject() + errorArray := a.NewArray() + errorRoot.Set(errorsPath, errorArray) + + errorItem := a.NewObject() + errorItem.Set("message", a.NewString(err.Error())) + + extensions := a.NewObject() + if st, ok := status.FromError(err); ok { + extensions.Set("code", a.NewString(st.Code().String())) + } else { + extensions.Set("code", a.NewString(codes.Internal.String())) + } + + errorItem.Set("extensions", extensions) + errorArray.SetArrayItem(0, errorItem) + + return errorRoot.MarshalTo(nil) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go index 7a6809f513..cd54e7ccf4 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go +++ b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go @@ -49,6 +49,7 @@ func (r *requiredFieldsVisitor) visitRequiredFields(definition *ast.Document, ty return report } + r.message.MemberTypes = []string{typeName} r.walker.Walk(doc, definition, report) if report.HasErrors() { return report From 6fcc131195bedb7f3ed466600cd4a96a897d3157 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 09:17:14 +0200 Subject: [PATCH 14/24] chore: add error handling when mapping was not found --- .../grpc_datasource/execution_plan_visitor_federation.go | 1 + 1 file changed, 1 insertion(+) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index 1e33f52118..d8c50ba9f5 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -321,6 +321,7 @@ func (r *rpcPlanVisitorFederation) resolveEntityInformation(inlineFragmentRef in rpcConfig, exists := r.mapping.ResolveEntityRPCConfig(fc.entityTypeName, fc.keyFields) if !exists { + r.walker.StopWithInternalErr(fmt.Errorf("entity type %s not found in mapping", fc.entityTypeName)) return } From 94457e95aa6da6ec399ce38c637db257265cb8a1 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 09:19:20 +0200 Subject: [PATCH 15/24] chore: improve error handling --- .../execution_plan_visitor_federation.go | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index d8c50ba9f5..7f2aab2615 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -118,7 +118,11 @@ func (r *rpcPlanVisitorFederation) EnterInlineFragment(ref int) { r.entityInfo.entityInlineFragmentRef = ref r.entityInfo.typeName = fragmentName - r.resolveEntityInformation(ref, fc) + if err := r.resolveEntityInformation(ref, fc); err != nil { + r.walker.StopWithInternalErr(err) + return + } + r.scaffoldEntityLookup(fc) } @@ -305,29 +309,29 @@ func (r *rpcPlanVisitorFederation) LeaveField(ref int) { r.planInfo.currentResponseFieldIndex = 0 } -func (r *rpcPlanVisitorFederation) resolveEntityInformation(inlineFragmentRef int, fc federationConfigData) { +func (r *rpcPlanVisitorFederation) resolveEntityInformation(inlineFragmentRef int, fc federationConfigData) error { fragmentName := r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef) node, found := r.definition.NodeByNameStr(r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef)) if !found { - r.walker.StopWithInternalErr(errors.New("definition node not found for inline fragment: " + fragmentName)) - return + return errors.New("definition node not found for inline fragment: " + fragmentName) } // Only process object type definitions // TODO: handle interfaces if node.Kind != ast.NodeKindObjectTypeDefinition { - return + return nil } rpcConfig, exists := r.mapping.ResolveEntityRPCConfig(fc.entityTypeName, fc.keyFields) if !exists { - r.walker.StopWithInternalErr(fmt.Errorf("entity type %s not found in mapping", fc.entityTypeName)) - return + return fmt.Errorf("entity type %s not found in mapping", fc.entityTypeName) } r.currentCall.Request.Name = rpcConfig.Request r.currentCall.Response.Name = rpcConfig.Response r.currentCall.MethodName = rpcConfig.RPC + + return nil } // scaffoldEntityLookup creates the entity lookup call structure From 3aaa45c03ba9a003654e8d4bd34952b871bf0736 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 09:24:26 +0200 Subject: [PATCH 16/24] chore: improve handling for shared objects --- .../grpc_datasource/grpc_datasource.go | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index 665ff8c817..935c764363 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -92,8 +92,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return err } - a := astjson.Arena{} - responses := make([]*astjson.Value, len(invocations)) errGrp, errGrpCtx := errgroup.WithContext(ctx) @@ -101,6 +99,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) // make gRPC calls for index, invocation := range invocations { errGrp.Go(func() error { + a := astjson.Arena{} // Invoke the gRPC method - this will populate invocation.Output methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) @@ -109,12 +108,15 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return err } + mu.Lock() + defer mu.Unlock() + response, err := builder.marshalResponseJSON(&a, &invocation.Call.Response, invocation.Output) if err != nil { return err } - d.synchronizedSetResponse(&mu, responses, index, response) + responses[index] = response return nil }) } @@ -124,6 +126,7 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return nil } + a := astjson.Arena{} root := a.NewObject() for _, response := range responses { root, err = builder.mergeValues(root, response) @@ -139,13 +142,6 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) return nil } -func (d *DataSource) synchronizedSetResponse(mu *sync.Mutex, responses []*astjson.Value, index int, response *astjson.Value) { - mu.Lock() - defer mu.Unlock() - - responses[index] = response -} - // LoadWithFiles implements resolve.DataSource interface. // Similar to Load, but handles file uploads if needed. // From ae0493bc6cddf90b40ea5719e17eff75a4bffd4d Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 09:32:26 +0200 Subject: [PATCH 17/24] chore: remove type from schema --- v2/pkg/grpctest/testdata/products.graphqls | 7 ------- 1 file changed, 7 deletions(-) diff --git a/v2/pkg/grpctest/testdata/products.graphqls b/v2/pkg/grpctest/testdata/products.graphqls index ac009e83e6..6db5cfd402 100644 --- a/v2/pkg/grpctest/testdata/products.graphqls +++ b/v2/pkg/grpctest/testdata/products.graphqls @@ -11,13 +11,6 @@ type Storage @key(fields: "id") { location: String! } -type Warehouse @key(fields: "id") @key(fields: "slug") { - id: ID! - slug: String! - name: String! - location: String! -} - type User { id: ID! name: String! From 7c51e396f0011bd3f3a728100447ff2663c11483 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 09:33:06 +0200 Subject: [PATCH 18/24] chore: remove type from entity union --- v2/pkg/grpctest/testdata/products.graphqls | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/v2/pkg/grpctest/testdata/products.graphqls b/v2/pkg/grpctest/testdata/products.graphqls index 6db5cfd402..18197a3320 100644 --- a/v2/pkg/grpctest/testdata/products.graphqls +++ b/v2/pkg/grpctest/testdata/products.graphqls @@ -393,5 +393,5 @@ type Mutation { -union _Entity = Product | Storage | Warehouse +union _Entity = Product | Storage scalar _Any From 4c9efdb23933efed494452da8f815bc337a3266d Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 10:33:22 +0200 Subject: [PATCH 19/24] chore: address some review comments --- .../engine/datasource/grpc_datasource/grpc_datasource_test.go | 2 +- v2/pkg/engine/datasource/grpc_datasource/json_builder.go | 3 --- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 68964e7a06..bd393ddf14 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -3502,7 +3502,7 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { }{ { name: "Query nullable fields type with all fields", - query: `query { _entities(representations: $representations) { ...on Product { id name } ...on Storage { id name } } }`, + query: `query($representations: [_Any!]!) { _entities(representations: $representations) { ...on Product { id name } ...on Storage { id name } } }`, vars: `{"variables":{"representations":[ {"__typename":"Product","id":"1"}, {"__typename":"Storage","id":"3"}, diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 839f33e7a3..172a710ec2 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -109,9 +109,7 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a } func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { - root := astjson.Arena{} - defer root.Reset() entities := root.NewObject() entities.Set(entityPath, root.NewArray()) @@ -453,7 +451,6 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { a := astjson.Arena{} - defer a.Reset() data := a.NewObject() data.Set(dataPath, root) return data From ca89a2ac4405ca19be5ef7359c938640b68ae25a Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 10:35:52 +0200 Subject: [PATCH 20/24] chore: check if typename already exists --- v2/pkg/engine/datasource/grpc_datasource/execution_plan.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index 166ec3166a..7c7290fef9 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -105,6 +105,10 @@ func (r *RPCMessage) SelectValidTypes(typeName string) []string { } func (r *RPCMessage) AppendTypeNameField(typeName string) { + if r.Fields != nil && r.Fields.Exists("__typename", "") { + return + } + r.Fields = append(r.Fields, RPCField{ Name: "__typename", TypeName: DataTypeString.String(), From 6cb861726d6dbe9fe3acd2a392d4e0da1c76fdb3 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 11:12:41 +0200 Subject: [PATCH 21/24] chore: add proper comments for json builder --- .../grpc_datasource/json_builder.go | 151 +++++++++++++++--- 1 file changed, 132 insertions(+), 19 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 172a710ec2..110bf179a3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -12,19 +12,28 @@ import ( protoref "google.golang.org/protobuf/reflect/protoreflect" ) +// Standard GraphQL response paths var ( - entityPath = "_entities" - dataPath = "data" - errorsPath = "errors" + entityPath = "_entities" // Path for federated entities in response + dataPath = "data" // Standard GraphQL data wrapper + errorsPath = "errors" // Standard GraphQL errors array ) +// entityIndex represents the mapping between representation order and result order +// for GraphQL federation entities. This is crucial for maintaining correct entity +// order when multiple subgraphs return entities in different orders. type entityIndex struct { - representationIndex int - resultIndex int + representationIndex int // Index in the original representation array + resultIndex int // Index where this entity should appear in the final result } +// indexMap maps GraphQL type names to their corresponding entity indices +// This allows proper ordering of federated entities by type type indexMap map[string][]entityIndex +// getResultIndex returns the correct result index for an entity based on its type +// and representation index. This ensures federated entities maintain proper ordering +// across multiple subgraph responses. func (i indexMap) getResultIndex(val *astjson.Value, representationIndex int) int { if i == nil { return representationIndex @@ -34,17 +43,24 @@ func (i indexMap) getResultIndex(val *astjson.Value, representationIndex int) in return representationIndex } + // Extract the __typename field to determine entity type typeName := val.Get("__typename").GetStringBytes() + // Find the correct result index for this type and representation index for _, entityIndex := range i[string(typeName)] { if entityIndex.representationIndex == representationIndex { return entityIndex.resultIndex } } + // Fallback to representation index if no mapping found return representationIndex } +// createRepresentationIndexMap builds an index mapping for GraphQL federation entities +// from the variables containing entity representations. This map is used to ensure +// that entities are returned in the correct order when merging responses from multiple +// subgraphs, which is critical for GraphQL federation correctness. func createRepresentationIndexMap(variables gjson.Result) indexMap { var representations []gjson.Result r := variables.Get("representations") @@ -54,29 +70,44 @@ func createRepresentationIndexMap(variables gjson.Result) indexMap { representations = r.Array() im := make(indexMap) - indexSet := make(map[string]int) + indexSet := make(map[string]int) // Track count per type name + + // Build mapping for each representation for i, representation := range representations { typeName := representation.Get("__typename").String() + + // Initialize counter for new type names if _, ok := indexSet[typeName]; !ok { indexSet[typeName] = -1 } + // Increment index for this type indexSet[typeName]++ + // Create mapping entry for this entity im[typeName] = append(im[typeName], entityIndex{ - representationIndex: indexSet[typeName], - resultIndex: i, + representationIndex: indexSet[typeName], // Position within entities of this type + resultIndex: i, // Position in the overall result array }) } return im } +// jsonBuilder is the core component responsible for converting gRPC protobuf responses +// into GraphQL-compatible JSON format. It handles complex scenarios including: +// - GraphQL federation entity merging and ordering +// - Nested list structures with proper nullability handling +// - Protobuf to GraphQL type conversion +// - Error response formatting type jsonBuilder struct { - mapping *GRPCMapping - variables gjson.Result - indexMap indexMap + mapping *GRPCMapping // Mapping configuration for GraphQL to gRPC translation + variables gjson.Result // GraphQL variables containing entity representations + indexMap indexMap // Entity index mapping for federation ordering } +// newJSONBuilder creates a new JSON builder instance with the provided mapping +// and variables. The builder automatically creates an index map for proper +// federation entity ordering if representations are present in the variables. func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { return &jsonBuilder{ mapping: mapping, @@ -85,9 +116,13 @@ func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { } } +// mergeValues combines two JSON values while preserving proper federation entity ordering. +// This is a critical function for GraphQL federation where multiple subgraphs may +// return entities that need to be merged in the correct order. func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { if len(j.indexMap) == 0 { - // We don't have a representation index map, so we can just merge the values. + // No federation index map available - use simple merge + // This path is taken for non-federated queries root, _, err := astjson.MergeValues(left, right) if err != nil { return nil, err @@ -95,27 +130,33 @@ func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*a return root, nil } - // When we have an index map, we need to ensure to keep the order of the representations. + // Federation entities present - must preserve representation order leftObject, err := left.Object() if err != nil { return nil, err } + // If left side is empty, just return right side if leftObject.Len() == 0 { return right, nil } + // Perform federation-aware entity merging return j.mergeEntities(left, right) } +// mergeEntities performs federation-aware merging of entity arrays from multiple subgraph responses. +// This function ensures that entities are placed in the correct positions in the final response +// array based on their original representation order, which is critical for GraphQL federation. func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { root := astjson.Arena{} + // Create the response structure with _entities array entities := root.NewObject() entities.Set(entityPath, root.NewArray()) - arr := entities.Get(entityPath) + // Extract entity arrays from both responses leftRepresentations, err := left.Get(entityPath).Array() if err != nil { return nil, err @@ -126,11 +167,13 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( return nil, err } + // Merge left entities using index mapping to preserve order for index, lr := range leftRepresentations { resultIndex := j.indexMap.getResultIndex(lr, index) arr.SetArrayItem(resultIndex, lr) } + // Merge right entities using index mapping to preserve order for index, rr := range rightRepresentations { resultIndex := j.indexMap.getResultIndex(rr, index) arr.SetArrayItem(resultIndex, rr) @@ -139,6 +182,9 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( return entities, nil } +// marshalResponseJSON converts a protobuf message into a GraphQL-compatible JSON response. +// This is the core marshaling function that handles all the complex type conversions, +// including oneOf types, nested messages, lists, and scalar values. func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { if message == nil { return arena.NewNull(), nil @@ -146,34 +192,43 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess root := arena.NewObject() + // Handle protobuf oneOf types - these represent GraphQL union/interface types if message.IsOneOf() { oneof := data.Descriptor().Oneofs().ByName(protoref.Name(message.OneOfType.FieldName())) if oneof == nil { return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) } + // Determine which oneOf field is actually set oneofDescriptor := data.WhichOneof(oneof) if oneofDescriptor == nil { return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) } + // Extract the actual message data from the oneOf wrapper if oneofDescriptor.Kind() == protoref.MessageKind { data = data.Get(oneofDescriptor).Message() } } + // Determine which fields to include in the response validFields := message.Fields if message.IsOneOf() { + // For oneOf types, add type-specific fields based on the actual concrete type validFields = append(validFields, message.FieldSelectionSet.SelectFieldsForTypes(message.SelectValidTypes(string(data.Type().Descriptor().Name())))...) } + // Process each field in the message for _, field := range validFields { + // Handle static values (like __typename fields) if field.StaticValue != "" { if len(message.MemberTypes) == 0 { + // Simple static value - use as-is root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) continue } + // Type-specific static value - match against member types for _, memberTypes := range message.MemberTypes { if memberTypes == string(data.Type().Descriptor().Name()) { root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) @@ -184,23 +239,29 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess continue } + // Get the protobuf field descriptor for this GraphQL field fd := data.Descriptor().Fields().ByName(protoref.Name(field.Name)) if fd == nil { + // Field not found in protobuf message - skip it continue } + // Handle list fields (repeated in protobuf) if fd.IsList() { list := data.Get(fd).List() arr := arena.NewArray() root.Set(field.AliasOrPath(), arr) if !list.IsValid() { + // Invalid list - leave as empty array continue } + // Process each list item for i := 0; i < list.Len(); i++ { switch fd.Kind() { case protoref.MessageKind: + // List of messages - recursively marshal each message message := list.Get(i).Message() value, err := j.marshalResponseJSON(arena, field.Message, message) if err != nil { @@ -209,21 +270,24 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess arr.SetArrayItem(i, value) default: + // List of scalar values - convert directly j.setArrayItem(i, arena, arr, list.Get(i), fd) } - } continue } + // Handle message fields (nested objects) if fd.Kind() == protoref.MessageKind { msg := data.Get(fd).Message() if !msg.IsValid() { + // Invalid message - set to null root.Set(field.AliasOrPath(), arena.NewNull()) continue } + // Handle special list wrapper types for complex nested lists if field.IsListType { arr, err := j.flattenListStructure(arena, field.ListMetadata, msg, field.Message) if err != nil { @@ -234,6 +298,7 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess continue } + // Handle optional scalar wrapper types (e.g., google.protobuf.StringValue) if field.IsOptionalScalar() { err := j.resolveOptionalField(arena, root, field.AliasOrPath(), msg) if err != nil { @@ -243,38 +308,48 @@ func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMess continue } + // Regular nested message - recursively marshal value, err := j.marshalResponseJSON(arena, field.Message, msg) if err != nil { return nil, err } if field.JSONPath == "" { + // Field should be merged into parent object (flattened) root, _, err = astjson.MergeValues(root, value) if err != nil { return nil, err } } else { + // Field should be nested under its own key root.Set(field.AliasOrPath(), value) } continue } + // Handle scalar fields (string, int, bool, etc.) j.setJSONValue(arena, root, field.AliasOrPath(), data, fd) } return root, nil } +// flattenListStructure handles complex nested list structures that are wrapped in protobuf +// messages to support nullable and multi-dimensional lists. This is necessary because +// protobuf doesn't directly support nullable list items or complex nesting scenarios +// that GraphQL allows. func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if md == nil { return arena.NewNull(), errors.New("list metadata not found") } + // Validate metadata consistency if len(md.LevelInfo) < md.NestingLevel { return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") } + // Handle null data with proper nullability checking if !data.IsValid() { if md.LevelInfo[0].Optional { return arena.NewNull(), nil @@ -283,16 +358,20 @@ func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadat return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") } + // Start recursive traversal of the nested list structure root := arena.NewArray() return j.traverseList(0, arena, root, md, data, message) } +// traverseList recursively traverses nested list wrapper structures to extract the actual +// list data. This handles multi-dimensional lists like [[String]] or [[[User]]] by +// unwrapping the protobuf message wrappers at each level. func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { if level > md.NestingLevel { return current, nil } - // List wrappers always use field number 1 + // List wrappers always use field number 1 in the generated protobuf fd := data.Descriptor().Fields().ByNumber(1) if fd == nil { return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) @@ -302,9 +381,10 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) } + // Get the wrapper message containing the list msg := data.Get(fd).Message() if !msg.IsValid() { - // If the message is not valid we can either return null if the list is nullable or an error if it is non nullable. + // Handle null wrapper based on nullability rules if md.LevelInfo[level].Optional { return arena.NewNull(), nil } @@ -312,14 +392,17 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast return arena.NewArray(), fmt.Errorf("cannot add null item to response for non nullable list") } + // The actual list is always at field number 1 in the wrapper fd = msg.Descriptor().Fields().ByNumber(1) if !fd.IsList() { return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) } + // Handle intermediate nesting levels (not the final level) if level < md.NestingLevel-1 { list := msg.Get(fd).List() for i := 0; i < list.Len(); i++ { + // Create nested array for next level next := arena.NewArray() val, err := j.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) if err != nil { @@ -332,15 +415,18 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast return current, nil } + // Handle the final nesting level - extract actual data list := msg.Get(fd).List() if !list.IsValid() { - // If the list is not valid, we return an empty array here as the - // nullabilty is checked on the outer List wrapper type. + // Invalid list at final level - return empty array + // Nullability is checked at the wrapper level, not the list level return arena.NewArray(), nil } + // Process each item in the final list for i := 0; i < list.Len(); i++ { if message != nil { + // List of complex objects - recursively marshal each item val, err := j.marshalResponseJSON(arena, message, list.Get(i).Message()) if err != nil { return nil, err @@ -348,6 +434,7 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast current.SetArrayItem(i, val) } else { + // List of scalar values - convert directly j.setArrayItem(i, arena, current, list.Get(i), fd) } } @@ -355,16 +442,24 @@ func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *ast return current, nil } +// resolveOptionalField extracts the value from optional scalar wrapper types like +// google.protobuf.StringValue, google.protobuf.Int32Value, etc. These wrappers +// are used to represent nullable scalar values in protobuf. func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { + // Optional scalar wrappers always have a "value" field fd := data.Descriptor().Fields().ByName(protoref.Name("value")) if fd == nil { return fmt.Errorf("unable to resolve optional field: field %q not found in message %s", "value", data.Descriptor().Name()) } + // Extract and set the wrapped value j.setJSONValue(arena, root, name, data, fd) return nil } +// setJSONValue converts a protobuf field value to the appropriate JSON representation +// and sets it on the provided JSON object. This handles all protobuf scalar types +// and enum values with proper GraphQL mapping. func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { if !data.IsValid() { root.Set(name, arena.NewNull()) @@ -397,8 +492,10 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na return } + // Look up the GraphQL enum value mapping graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { + // No mapping found - set to null root.Set(name, arena.NewNull()) return } @@ -407,6 +504,9 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na } } +// setArrayItem converts a protobuf list item value to JSON and sets it at the specified +// array index. This is similar to setJSONValue but operates on array elements rather +// than object properties, and works with protobuf Value types rather than Message types. func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { if !data.IsValid() { array.SetArrayItem(index, arena.NewNull()) @@ -439,8 +539,10 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs return } + // Look up GraphQL enum mapping graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) if !ok { + // No mapping found - use null array.SetArrayItem(index, arena.NewNull()) return } @@ -449,6 +551,8 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs } } +// toDataObject wraps a response value in the standard GraphQL data envelope. +// This creates the top-level structure { "data": ... } that GraphQL clients expect. func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { a := astjson.Arena{} data := a.NewObject() @@ -456,20 +560,29 @@ func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { return data } +// writeErrorBytes creates a properly formatted GraphQL error response in JSON format. +// This includes the error message and gRPC status code information in the extensions +// field, following GraphQL error specification standards. func (j *jsonBuilder) writeErrorBytes(err error) []byte { a := astjson.Arena{} defer a.Reset() + + // Create standard GraphQL error structure errorRoot := a.NewObject() errorArray := a.NewArray() errorRoot.Set(errorsPath, errorArray) + // Create individual error object errorItem := a.NewObject() errorItem.Set("message", a.NewString(err.Error())) + // Add gRPC status code information to extensions extensions := a.NewObject() if st, ok := status.FromError(err); ok { + // gRPC error - include the specific status code extensions.Set("code", a.NewString(st.Code().String())) } else { + // Generic error - default to INTERNAL status extensions.Set("code", a.NewString(codes.Internal.String())) } From f28d67fb986b5b328599f6c7e4bf9bd67ec03e4d Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 11:19:36 +0200 Subject: [PATCH 22/24] chore: shorten function --- v2/pkg/engine/datasource/grpc_datasource/json_builder.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index 110bf179a3..ae6c2e55d8 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -169,14 +169,12 @@ func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) ( // Merge left entities using index mapping to preserve order for index, lr := range leftRepresentations { - resultIndex := j.indexMap.getResultIndex(lr, index) - arr.SetArrayItem(resultIndex, lr) + arr.SetArrayItem(j.indexMap.getResultIndex(lr, index), lr) } // Merge right entities using index mapping to preserve order for index, rr := range rightRepresentations { - resultIndex := j.indexMap.getResultIndex(rr, index) - arr.SetArrayItem(resultIndex, rr) + arr.SetArrayItem(j.indexMap.getResultIndex(rr, index), rr) } return entities, nil From 42a4532af5624e2e1e1fafe98bec4ffcdb1fbb59 Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 11:47:08 +0200 Subject: [PATCH 23/24] chore: prevent protential int overflow --- v2/pkg/engine/datasource/grpc_datasource/json_builder.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go index ae6c2e55d8..1166fc9b92 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -474,8 +474,10 @@ func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, na } case protoref.StringKind: root.Set(name, arena.NewString(data.Get(fd).String())) - case protoref.Int32Kind, protoref.Int64Kind: + case protoref.Int32Kind: root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) + case protoref.Int64Kind: + root.Set(name, arena.NewNumberString(strconv.FormatInt(data.Get(fd).Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: @@ -521,8 +523,10 @@ func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjs } case protoref.StringKind: array.SetArrayItem(index, arena.NewString(data.String())) - case protoref.Int32Kind, protoref.Int64Kind: + case protoref.Int32Kind: array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) + case protoref.Int64Kind: + array.SetArrayItem(index, arena.NewNumberString(strconv.FormatInt(data.Int(), 10))) case protoref.Uint32Kind, protoref.Uint64Kind: array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) case protoref.FloatKind, protoref.DoubleKind: From b454dd17f2558259db9b68c191ec7559b7bfabca Mon Sep 17 00:00:00 2001 From: Ludwig Bedacht Date: Fri, 8 Aug 2025 11:55:49 +0200 Subject: [PATCH 24/24] chore: improve comments --- .../grpc_datasource/configuration.go | 24 +++++++++++++------ 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration.go b/v2/pkg/engine/datasource/grpc_datasource/configuration.go index 18860026d2..9bc74cb31a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/configuration.go +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration.go @@ -8,8 +8,12 @@ import ( type ( // RPCConfigMap is a map of RPC names to RPC configurations + // The key is the field name in the GraphQL operation type (query, mutation, subscription). + // The value is the RPC configuration for that field. RPCConfigMap map[string]RPCConfig // FieldMap defines the mapping between a GraphQL field and a gRPC field + // The key is the field name in the GraphQL type. + // The value is the FieldMapData for that field which contains the target name and the argument mappings. FieldMap map[string]FieldMapData ) @@ -23,22 +27,27 @@ type GRPCMapping struct { // SubscriptionRPCs maps GraphQL subscription fields to the corresponding gRPC RPC configurations SubscriptionRPCs RPCConfigMap // EntityRPCs defines how GraphQL types are resolved as entities using specific RPCs + // The key is the type name and the value is a list of EntityRPCConfig for that type EntityRPCs map[string][]EntityRPCConfig // Fields defines the field mappings between GraphQL types and gRPC messages + // The key is the type name and the value is a map of field name to FieldMapData for that type Fields map[string]FieldMap // EnumValues defines the enum values for each enum type + // The key is the enum type name and the value is a list of EnumValueMapping for that enum type EnumValues map[string][]EnumValueMapping } +// EnumValueMapping defines the mapping between a GraphQL enum value and a gRPC enum value type EnumValueMapping struct { - Value string - TargetValue string + Value string // The GraphQL enum value + TargetValue string // The gRPC enum value } +// GRPCConfiguration defines the configuration for a gRPC datasource type GRPCConfiguration struct { - Disabled bool - Mapping *GRPCMapping - Compiler *RPCCompiler + Disabled bool // Whether the RPC is disabled + Mapping *GRPCMapping // The mapping between GraphQL types and gRPC messages + Compiler *RPCCompiler // The compiler for the RPC } // RPCConfig defines the configuration for a specific RPC operation @@ -59,9 +68,10 @@ type EntityRPCConfig struct { RPCConfig } +// FieldMapData defines the mapping between a GraphQL field and a gRPC field type FieldMapData struct { - TargetName string - ArgumentMappings FieldArgumentMap + TargetName string // The name of the gRPC field + ArgumentMappings FieldArgumentMap // The mapping between GraphQL field arguments and gRPC request arguments } // FieldArgumentMap defines the mapping between a GraphQL field argument and a gRPC field