diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index 1354584f27..3a1c12b00f 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -779,11 +779,16 @@ func (r *rpcPlanningContext) buildFieldMessage(fieldTypeNode ast.Node, fieldRef } for _, fieldRef := range fieldRefs { - if r.isFieldResolver(fieldRef, false) { + fieldDefRef, found := r.definition.NodeFieldDefinitionByName(fieldTypeNode, r.operation.FieldNameBytes(fieldRef)) + if !found { + return nil, fmt.Errorf("unable to build required field: field definition not found for field %s", r.operation.FieldNameString(fieldRef)) + } + + if r.isFieldResolver(fieldDefRef, false) { continue } - field, err := r.buildRequiredField(fieldTypeNode, fieldRef) + field, err := r.buildRequiredField(fieldTypeNode, fieldRef, fieldDefRef) if err != nil { return nil, err } @@ -853,12 +858,12 @@ func (r *rpcPlanningContext) enterResolverCompositeSelectionSet(oneOfType OneOfT } // isFieldResolver checks if a field is a field resolver. -func (r *rpcPlanningContext) isFieldResolver(fieldRef int, isRootField bool) bool { - if isRootField { +func (r *rpcPlanningContext) isFieldResolver(fieldDefRef int, isRootField bool) bool { + if isRootField || fieldDefRef == ast.InvalidRef { return false } - return len(r.operation.FieldArguments(fieldRef)) > 0 + return r.definition.FieldDefinitionHasArgumentsDefinitions(fieldDefRef) } // getCompositeType checks whether the node is an interface or union type. @@ -1097,7 +1102,12 @@ func (r *rpcPlanningContext) buildFieldResolverTypeMessage(typeName string, reso message.Fields = make(RPCFields, 0, len(fieldRefs)) for _, fieldRef := range fieldRefs { - if r.isFieldResolver(fieldRef, false) { + fieldDefRef, found := r.definition.NodeFieldDefinitionByName(parentTypeNode, r.operation.FieldNameBytes(fieldRef)) + if !found { + return nil, fmt.Errorf("unable to build required field: field definition not found for field %s", r.operation.FieldNameString(fieldRef)) + } + + if r.isFieldResolver(fieldDefRef, false) { continue } @@ -1105,7 +1115,7 @@ func (r *rpcPlanningContext) buildFieldResolverTypeMessage(typeName string, reso continue } - field, err := r.buildRequiredField(parentTypeNode, fieldRef) + field, err := r.buildRequiredField(parentTypeNode, fieldRef, fieldDefRef) if err != nil { return nil, err } @@ -1117,23 +1127,17 @@ func (r *rpcPlanningContext) buildFieldResolverTypeMessage(typeName string, reso return message, nil } -func (r *rpcPlanningContext) buildRequiredField(typeNode ast.Node, fieldRef int) (RPCField, error) { - fieldName := r.operation.FieldNameString(fieldRef) - fieldDef, found := r.definition.NodeFieldDefinitionByName(typeNode, r.operation.FieldNameBytes(fieldRef)) - if !found { - return RPCField{}, fmt.Errorf("unable to build required field: field definition not found for field %s", fieldName) - } - - field, err := r.buildField(typeNode, fieldDef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef)) +func (r *rpcPlanningContext) buildRequiredField(typeNode ast.Node, fieldRef, fieldDefinitionRef int) (RPCField, error) { + field, err := r.buildField(typeNode, fieldDefinitionRef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef)) if err != nil { return RPCField{}, err } // If the field is a message type and has selections, we need to build a nested message. if field.ProtoTypeName == DataTypeMessage && r.operation.FieldHasSelections(fieldRef) { - fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDef)) + fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDefinitionRef)) if !found { - return RPCField{}, fmt.Errorf("unable to build required field: unable to resolve field type node for field %s", fieldName) + return RPCField{}, fmt.Errorf("unable to build required field: unable to resolve field type node for field %s", r.operation.FieldNameString(fieldRef)) } message, err := r.buildFieldMessage(fieldTypeNode, fieldRef) @@ -1154,22 +1158,22 @@ func (r *rpcPlanningContext) buildCompositeFields(inlineFragmentNode ast.Node, f result := make([]RPCField, 0, len(fieldRefs)) for _, fieldRef := range fieldRefs { - if r.isFieldResolver(fieldRef, false) { - continue + fieldDefRef := r.fieldDefinitionRefForType(r.operation.FieldNameString(fieldRef), fragmentSelection.typeName) + if fieldDefRef == ast.InvalidRef { + return nil, fmt.Errorf("unable to build composite field: field definition not found for field %s", r.operation.FieldNameString(fieldRef)) } - fieldDef := r.fieldDefinitionRefForType(r.operation.FieldNameString(fieldRef), fragmentSelection.typeName) - if fieldDef == ast.InvalidRef { - return nil, fmt.Errorf("unable to build composite field: field definition not found for field %s", r.operation.FieldNameString(fieldRef)) + if r.isFieldResolver(fieldDefRef, false) { + continue } - field, err := r.buildField(inlineFragmentNode, fieldDef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef)) + field, err := r.buildField(inlineFragmentNode, fieldDefRef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef)) if err != nil { return nil, err } if field.ProtoTypeName == DataTypeMessage && r.operation.FieldHasSelections(fieldRef) { - fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDef)) + fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDefRef)) if !found { return nil, fmt.Errorf("unable to build composite field: unable to resolve field type node for field %s", r.operation.FieldNameString(fieldRef)) } @@ -1256,15 +1260,17 @@ func (r *rpcPlanningContext) createResolverRPCCalls(subgraphName string, resolve contextMessage.Fields[i] = field } - fieldArgsMessage.Fields = make(RPCFields, len(resolvedField.fieldArguments)) - for i := range resolvedField.fieldArguments { - field, err := r.createRPCFieldFromFieldArgument(resolvedField.fieldArguments[i]) + if argLen := len(resolvedField.fieldArguments); argLen > 0 { + fieldArgsMessage.Fields = make(RPCFields, argLen) + for i := range resolvedField.fieldArguments { + field, err := r.createRPCFieldFromFieldArgument(resolvedField.fieldArguments[i]) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - fieldArgsMessage.Fields[i] = field + fieldArgsMessage.Fields[i] = field + } } calls = append(calls, call) diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go index 41b2d0c10e..b9d534368e 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_field_resolvers_test.go @@ -2618,6 +2618,114 @@ func TestExecutionPlanFieldResolvers_CustomSchemas(t *testing.T) { }, }, }, + { + name: "Should correctly create a resolve call for a resolver with an optional argument", + subgraphName: "Foo", + operation: ` + query FooQuery { + foo { + fooResolverOptionalArgument { + __typename + } + } + }`, + schema: schemaWithNestedResolverAndCompositeType(t), + mapping: mappingWithNestedResolverAndCompositeType(t), + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Foo", + MethodName: "QueryFoo", + Request: RPCMessage{ + Name: "QueryFooRequest", + }, + Response: RPCMessage{ + Name: "QueryFooResponse", + Fields: []RPCField{ + { + Name: "foo", + ProtoTypeName: DataTypeMessage, + JSONPath: "foo", + Message: &RPCMessage{ + Name: "Foo", + Fields: RPCFields{}, + }, + }, + }, + }, + }, + { + Kind: CallKindResolve, + DependentCalls: []int{0}, + ResponsePath: buildPath("foo.fooResolverOptionalArgument"), + ServiceName: "Foo", + MethodName: "ResolveFooFooResolverOptionalArgument", + Request: RPCMessage{ + Name: "ResolveFooFooResolverOptionalArgumentRequest", + Fields: []RPCField{ + { + Name: "context", + ProtoTypeName: DataTypeMessage, + Repeated: true, + Message: &RPCMessage{ + Name: "ResolveFooFooResolverOptionalArgumentContext", + Fields: []RPCField{ + { + Name: "id", + ProtoTypeName: DataTypeString, + JSONPath: "id", + ResolvePath: buildPath("foo.id"), + }, + }, + }, + }, + { + Name: "field_args", + ProtoTypeName: DataTypeMessage, + Message: &RPCMessage{ + Name: "ResolveFooFooResolverOptionalArgumentArgs", + }, + }, + }, + }, + Response: RPCMessage{ + Name: "ResolveFooFooResolverOptionalArgumentResponse", + Fields: []RPCField{ + { + Name: "result", + ProtoTypeName: DataTypeMessage, + Repeated: true, + JSONPath: "result", + Message: &RPCMessage{ + Name: "ResolveFooFooResolverOptionalArgumentResult", + Fields: RPCFields{ + { + Name: "foo_resolver_optional_argument", + ProtoTypeName: DataTypeMessage, + JSONPath: "fooResolverOptionalArgument", + Message: &RPCMessage{ + Name: "Bar", + OneOfType: OneOfTypeInterface, + MemberTypes: []string{"Baz"}, + Fields: RPCFields{ + { + Name: "__typename", + ProtoTypeName: DataTypeString, + JSONPath: "__typename", + StaticValue: "Bar", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, } for _, tt := range tests { @@ -2651,6 +2759,7 @@ schema { type Foo { id: ID! fooResolver(foo: String!): Bar! @connect__fieldResolver(context: "id") + fooResolverOptionalArgument(foo: String): Bar! @connect__fieldResolver(context: "id") } interface Bar { @@ -2727,6 +2836,17 @@ func mappingWithNestedResolverAndCompositeType(_ *testing.T) *GRPCMapping { Request: "ResolveFooFooResolverRequest", Response: "ResolveFooFooResolverResponse", }, + "fooResolverOptionalArgument": { + FieldMappingData: FieldMapData{ + TargetName: "foo_resolver_optional_argument", + ArgumentMappings: FieldArgumentMap{ + "foo": "foo", + }, + }, + RPC: "ResolveFooFooResolverOptionalArgument", + Request: "ResolveFooFooResolverOptionalArgumentRequest", + Response: "ResolveFooFooResolverOptionalArgumentResponse", + }, }, "Baz": { "bazResolver": { diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go index 307bba52b9..5e9391c1df 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go @@ -358,7 +358,7 @@ func (r *rpcPlanVisitor) EnterField(ref int) { // If the field is a field resolver, we need to handle it later in a separate resolver call. // We only store the information about the field and create the call later. - if r.planCtx.isFieldResolver(ref, inRootField) { + if r.planCtx.isFieldResolver(fieldDefRef, inRootField) { r.enterFieldResolver(ref, fieldDefRef) return } @@ -418,7 +418,15 @@ func (r *rpcPlanVisitor) LeaveField(ref int) { return } - if r.planCtx.isFieldResolver(ref, inRootField) { + fieldDefRef, ok := r.walker.FieldDefinition(ref) + if !ok { + r.walker.Report.AddExternalError(operationreport.ExternalError{ + Message: fmt.Sprintf("Field %s not found in definition %s", r.operation.FieldNameString(ref), r.walker.EnclosingTypeDefinition.NameString(r.definition)), + }) + return + } + + if r.planCtx.isFieldResolver(fieldDefRef, inRootField) { // Pop the field resolver ancestor only when leaving a field resolver field. r.fieldResolverAncestors.pop() diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go index d3dad5f348..c64f50fc55 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -331,7 +331,7 @@ func (r *rpcPlanVisitorFederation) EnterField(ref int) { // If the field is a field resolver, we need to handle it later in a separate resolver call. // We only store the information about the field and create the call later. - if r.planCtx.isFieldResolver(ref, inRootField) { + if r.planCtx.isFieldResolver(fieldDefRef, inRootField) { r.enterFieldResolver(ref, fieldDefRef) return } @@ -387,7 +387,15 @@ func (r *rpcPlanVisitorFederation) LeaveField(ref int) { return } - if r.planCtx.isFieldResolver(ref, inRootField) { + fieldDefRef, ok := r.walker.FieldDefinition(ref) + if !ok { + r.walker.Report.AddExternalError(operationreport.ExternalError{ + Message: fmt.Sprintf("Field %s not found in definition %s", r.operation.FieldNameString(ref), r.walker.EnclosingTypeDefinition.NameString(r.definition)), + }) + return + } + + if r.planCtx.isFieldResolver(fieldDefRef, inRootField) { // Pop the field resolver ancestor only when leaving a field resolver field. r.fieldResolverAncestors.pop() diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index 76da7be5de..b4e4ef5cb3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -4539,6 +4539,33 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) { require.Empty(t, errData) }, }, + { + name: "Query with field resolver without parentheses (nullable parameter)", + query: "query CategoriesWithFieldResolverNoParens { categories { id name popularityScore } }", + vars: `{"variables":{}}`, + validate: func(t *testing.T, data map[string]interface{}) { + require.NotEmpty(t, data) + + categories, ok := data["categories"].([]interface{}) + require.True(t, ok, "categories should be an array") + require.NotEmpty(t, categories, "categories should not be empty") + require.Len(t, categories, 4, "Should return 4 categories") + + for _, category := range categories { + category, ok := category.(map[string]interface{}) + require.True(t, ok, "category should be an object") + require.NotEmpty(t, category["id"]) + require.NotEmpty(t, category["name"]) + // popularityScore should return 50 when called without threshold + // Based on mockservice implementation: threshold defaults to 0, which is <= 50, so returns baseScore of 50 + require.NotEmpty(t, category["popularityScore"], "popularityScore should not be empty when called without parameters") + require.Equal(t, float64(50), category["popularityScore"], "popularityScore should be 50 when threshold is not provided") + } + }, + validateError: func(t *testing.T, errData []graphqlError) { + require.Empty(t, errData) + }, + }, } for _, tc := range testCases {