Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 37 additions & 31 deletions v2/pkg/engine/datasource/grpc_datasource/execution_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -779,11 +779,16 @@ func (r *rpcPlanningContext) buildFieldMessage(fieldTypeNode ast.Node, fieldRef
}

for _, fieldRef := range fieldRefs {
if r.isFieldResolver(fieldRef, false) {
fieldDefRef, found := r.definition.NodeFieldDefinitionByName(fieldTypeNode, r.operation.FieldNameBytes(fieldRef))
Comment thread
Noroth marked this conversation as resolved.
if !found {
return nil, fmt.Errorf("unable to build required field: field definition not found for field %s", r.operation.FieldNameString(fieldRef))
}

if r.isFieldResolver(fieldDefRef, false) {
continue
}

field, err := r.buildRequiredField(fieldTypeNode, fieldRef)
field, err := r.buildRequiredField(fieldTypeNode, fieldRef, fieldDefRef)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -853,12 +858,12 @@ func (r *rpcPlanningContext) enterResolverCompositeSelectionSet(oneOfType OneOfT
}

// isFieldResolver checks if a field is a field resolver.
func (r *rpcPlanningContext) isFieldResolver(fieldRef int, isRootField bool) bool {
if isRootField {
func (r *rpcPlanningContext) isFieldResolver(fieldDefRef int, isRootField bool) bool {
if isRootField || fieldDefRef == ast.InvalidRef {
return false
}

return len(r.operation.FieldArguments(fieldRef)) > 0
return r.definition.FieldDefinitionHasArgumentsDefinitions(fieldDefRef)
}

// getCompositeType checks whether the node is an interface or union type.
Expand Down Expand Up @@ -1097,15 +1102,20 @@ func (r *rpcPlanningContext) buildFieldResolverTypeMessage(typeName string, reso
message.Fields = make(RPCFields, 0, len(fieldRefs))

for _, fieldRef := range fieldRefs {
if r.isFieldResolver(fieldRef, false) {
fieldDefRef, found := r.definition.NodeFieldDefinitionByName(parentTypeNode, r.operation.FieldNameBytes(fieldRef))
if !found {
return nil, fmt.Errorf("unable to build required field: field definition not found for field %s", r.operation.FieldNameString(fieldRef))
}

if r.isFieldResolver(fieldDefRef, false) {
continue
}

if message.Fields.Exists(r.operation.FieldNameString(fieldRef), "") {
continue
}

field, err := r.buildRequiredField(parentTypeNode, fieldRef)
field, err := r.buildRequiredField(parentTypeNode, fieldRef, fieldDefRef)
if err != nil {
return nil, err
}
Expand All @@ -1117,23 +1127,17 @@ func (r *rpcPlanningContext) buildFieldResolverTypeMessage(typeName string, reso
return message, nil
}

func (r *rpcPlanningContext) buildRequiredField(typeNode ast.Node, fieldRef int) (RPCField, error) {
fieldName := r.operation.FieldNameString(fieldRef)
fieldDef, found := r.definition.NodeFieldDefinitionByName(typeNode, r.operation.FieldNameBytes(fieldRef))
if !found {
return RPCField{}, fmt.Errorf("unable to build required field: field definition not found for field %s", fieldName)
}

field, err := r.buildField(typeNode, fieldDef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef))
func (r *rpcPlanningContext) buildRequiredField(typeNode ast.Node, fieldRef, fieldDefinitionRef int) (RPCField, error) {
field, err := r.buildField(typeNode, fieldDefinitionRef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef))
if err != nil {
return RPCField{}, err
}

// If the field is a message type and has selections, we need to build a nested message.
if field.ProtoTypeName == DataTypeMessage && r.operation.FieldHasSelections(fieldRef) {
fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDef))
fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDefinitionRef))
if !found {
return RPCField{}, fmt.Errorf("unable to build required field: unable to resolve field type node for field %s", fieldName)
return RPCField{}, fmt.Errorf("unable to build required field: unable to resolve field type node for field %s", r.operation.FieldNameString(fieldRef))
}

message, err := r.buildFieldMessage(fieldTypeNode, fieldRef)
Expand All @@ -1154,22 +1158,22 @@ func (r *rpcPlanningContext) buildCompositeFields(inlineFragmentNode ast.Node, f
result := make([]RPCField, 0, len(fieldRefs))

for _, fieldRef := range fieldRefs {
if r.isFieldResolver(fieldRef, false) {
continue
fieldDefRef := r.fieldDefinitionRefForType(r.operation.FieldNameString(fieldRef), fragmentSelection.typeName)
if fieldDefRef == ast.InvalidRef {
return nil, fmt.Errorf("unable to build composite field: field definition not found for field %s", r.operation.FieldNameString(fieldRef))
}

fieldDef := r.fieldDefinitionRefForType(r.operation.FieldNameString(fieldRef), fragmentSelection.typeName)
if fieldDef == ast.InvalidRef {
return nil, fmt.Errorf("unable to build composite field: field definition not found for field %s", r.operation.FieldNameString(fieldRef))
if r.isFieldResolver(fieldDefRef, false) {
continue
}

field, err := r.buildField(inlineFragmentNode, fieldDef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef))
field, err := r.buildField(inlineFragmentNode, fieldDefRef, r.operation.FieldNameString(fieldRef), r.operation.FieldAliasString(fieldRef))
if err != nil {
return nil, err
}

if field.ProtoTypeName == DataTypeMessage && r.operation.FieldHasSelections(fieldRef) {
fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDef))
fieldTypeNode, found := r.definition.ResolveNodeFromTypeRef(r.definition.FieldDefinitionType(fieldDefRef))
if !found {
return nil, fmt.Errorf("unable to build composite field: unable to resolve field type node for field %s", r.operation.FieldNameString(fieldRef))
}
Expand Down Expand Up @@ -1256,15 +1260,17 @@ func (r *rpcPlanningContext) createResolverRPCCalls(subgraphName string, resolve
contextMessage.Fields[i] = field
}

fieldArgsMessage.Fields = make(RPCFields, len(resolvedField.fieldArguments))
for i := range resolvedField.fieldArguments {
field, err := r.createRPCFieldFromFieldArgument(resolvedField.fieldArguments[i])
if argLen := len(resolvedField.fieldArguments); argLen > 0 {
fieldArgsMessage.Fields = make(RPCFields, argLen)
for i := range resolvedField.fieldArguments {
field, err := r.createRPCFieldFromFieldArgument(resolvedField.fieldArguments[i])

if err != nil {
return nil, err
}
if err != nil {
return nil, err
}

fieldArgsMessage.Fields[i] = field
fieldArgsMessage.Fields[i] = field
}
}

calls = append(calls, call)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2618,6 +2618,114 @@ func TestExecutionPlanFieldResolvers_CustomSchemas(t *testing.T) {
},
},
},
{
name: "Should correctly create a resolve call for a resolver with an optional argument",
subgraphName: "Foo",
operation: `
query FooQuery {
foo {
fooResolverOptionalArgument {
__typename
}
}
}`,
schema: schemaWithNestedResolverAndCompositeType(t),
mapping: mappingWithNestedResolverAndCompositeType(t),
expectedPlan: &RPCExecutionPlan{
Calls: []RPCCall{
{
ServiceName: "Foo",
MethodName: "QueryFoo",
Request: RPCMessage{
Name: "QueryFooRequest",
},
Response: RPCMessage{
Name: "QueryFooResponse",
Fields: []RPCField{
{
Name: "foo",
ProtoTypeName: DataTypeMessage,
JSONPath: "foo",
Message: &RPCMessage{
Name: "Foo",
Fields: RPCFields{},
},
},
},
},
},
{
Kind: CallKindResolve,
DependentCalls: []int{0},
ResponsePath: buildPath("foo.fooResolverOptionalArgument"),
ServiceName: "Foo",
MethodName: "ResolveFooFooResolverOptionalArgument",
Request: RPCMessage{
Name: "ResolveFooFooResolverOptionalArgumentRequest",
Fields: []RPCField{
{
Name: "context",
ProtoTypeName: DataTypeMessage,
Repeated: true,
Message: &RPCMessage{
Name: "ResolveFooFooResolverOptionalArgumentContext",
Fields: []RPCField{
{
Name: "id",
ProtoTypeName: DataTypeString,
JSONPath: "id",
ResolvePath: buildPath("foo.id"),
},
},
},
},
{
Name: "field_args",
ProtoTypeName: DataTypeMessage,
Message: &RPCMessage{
Name: "ResolveFooFooResolverOptionalArgumentArgs",
},
},
},
},
Response: RPCMessage{
Name: "ResolveFooFooResolverOptionalArgumentResponse",
Fields: []RPCField{
{
Name: "result",
ProtoTypeName: DataTypeMessage,
Repeated: true,
JSONPath: "result",
Message: &RPCMessage{
Name: "ResolveFooFooResolverOptionalArgumentResult",
Fields: RPCFields{
{
Name: "foo_resolver_optional_argument",
ProtoTypeName: DataTypeMessage,
JSONPath: "fooResolverOptionalArgument",
Message: &RPCMessage{
Name: "Bar",
OneOfType: OneOfTypeInterface,
MemberTypes: []string{"Baz"},
Fields: RPCFields{
{
Name: "__typename",
ProtoTypeName: DataTypeString,
JSONPath: "__typename",
StaticValue: "Bar",
},
},
},
},
},
},
},
},
},
},
},
},
},
}

for _, tt := range tests {
Expand Down Expand Up @@ -2651,6 +2759,7 @@ schema {
type Foo {
id: ID!
fooResolver(foo: String!): Bar! @connect__fieldResolver(context: "id")
fooResolverOptionalArgument(foo: String): Bar! @connect__fieldResolver(context: "id")
}

interface Bar {
Expand Down Expand Up @@ -2727,6 +2836,17 @@ func mappingWithNestedResolverAndCompositeType(_ *testing.T) *GRPCMapping {
Request: "ResolveFooFooResolverRequest",
Response: "ResolveFooFooResolverResponse",
},
"fooResolverOptionalArgument": {
FieldMappingData: FieldMapData{
TargetName: "foo_resolver_optional_argument",
ArgumentMappings: FieldArgumentMap{
"foo": "foo",
},
},
RPC: "ResolveFooFooResolverOptionalArgument",
Request: "ResolveFooFooResolverOptionalArgumentRequest",
Response: "ResolveFooFooResolverOptionalArgumentResponse",
},
},
"Baz": {
"bazResolver": {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ func (r *rpcPlanVisitor) EnterField(ref int) {

// If the field is a field resolver, we need to handle it later in a separate resolver call.
// We only store the information about the field and create the call later.
if r.planCtx.isFieldResolver(ref, inRootField) {
if r.planCtx.isFieldResolver(fieldDefRef, inRootField) {
r.enterFieldResolver(ref, fieldDefRef)
return
}
Expand Down Expand Up @@ -418,7 +418,15 @@ func (r *rpcPlanVisitor) LeaveField(ref int) {
return
}

if r.planCtx.isFieldResolver(ref, inRootField) {
fieldDefRef, ok := r.walker.FieldDefinition(ref)
if !ok {
r.walker.Report.AddExternalError(operationreport.ExternalError{
Message: fmt.Sprintf("Field %s not found in definition %s", r.operation.FieldNameString(ref), r.walker.EnclosingTypeDefinition.NameString(r.definition)),
})
return
}

if r.planCtx.isFieldResolver(fieldDefRef, inRootField) {
// Pop the field resolver ancestor only when leaving a field resolver field.
r.fieldResolverAncestors.pop()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ func (r *rpcPlanVisitorFederation) EnterField(ref int) {

// If the field is a field resolver, we need to handle it later in a separate resolver call.
// We only store the information about the field and create the call later.
if r.planCtx.isFieldResolver(ref, inRootField) {
if r.planCtx.isFieldResolver(fieldDefRef, inRootField) {
r.enterFieldResolver(ref, fieldDefRef)
return
}
Expand Down Expand Up @@ -387,7 +387,15 @@ func (r *rpcPlanVisitorFederation) LeaveField(ref int) {
return
}

if r.planCtx.isFieldResolver(ref, inRootField) {
fieldDefRef, ok := r.walker.FieldDefinition(ref)
if !ok {
r.walker.Report.AddExternalError(operationreport.ExternalError{
Message: fmt.Sprintf("Field %s not found in definition %s", r.operation.FieldNameString(ref), r.walker.EnclosingTypeDefinition.NameString(r.definition)),
})
return
}

if r.planCtx.isFieldResolver(fieldDefRef, inRootField) {
// Pop the field resolver ancestor only when leaving a field resolver field.
r.fieldResolverAncestors.pop()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4539,6 +4539,33 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) {
require.Empty(t, errData)
},
},
{
name: "Query with field resolver without parentheses (nullable parameter)",
query: "query CategoriesWithFieldResolverNoParens { categories { id name popularityScore } }",
vars: `{"variables":{}}`,
validate: func(t *testing.T, data map[string]interface{}) {
require.NotEmpty(t, data)

categories, ok := data["categories"].([]interface{})
require.True(t, ok, "categories should be an array")
require.NotEmpty(t, categories, "categories should not be empty")
require.Len(t, categories, 4, "Should return 4 categories")

for _, category := range categories {
category, ok := category.(map[string]interface{})
require.True(t, ok, "category should be an object")
require.NotEmpty(t, category["id"])
require.NotEmpty(t, category["name"])
// popularityScore should return 50 when called without threshold
// Based on mockservice implementation: threshold defaults to 0, which is <= 50, so returns baseScore of 50
require.NotEmpty(t, category["popularityScore"], "popularityScore should not be empty when called without parameters")
require.Equal(t, float64(50), category["popularityScore"], "popularityScore should be 50 when threshold is not provided")
}
},
validateError: func(t *testing.T, errData []graphqlError) {
require.Empty(t, errData)
},
},
}

for _, tc := range testCases {
Expand Down