Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions v2/pkg/engine/datasource/grpc_datasource/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -549,8 +549,8 @@ func (p *RPCCompiler) buildProtoMessageWithContext(inputMessage Message, rpcMess

rootMessage := dynamicpb.NewMessage(inputMessage.Desc)

if len(inputMessage.Fields) != 2 {
return nil, fmt.Errorf("message %s must have exactly two fields: context and field_args", inputMessage.Name)
if len(inputMessage.Fields) < 1 {
return nil, fmt.Errorf("message %s must have at least the context field", inputMessage.Name)
}

contextSchemaField := inputMessage.GetField("context")
Expand Down Expand Up @@ -587,24 +587,22 @@ func (p *RPCCompiler) buildProtoMessageWithContext(inputMessage Message, rpcMess
contextList.Append(val)
}

argsSchemaField := inputMessage.GetField("field_args")
if argsSchemaField == nil {
return nil, fmt.Errorf("field_args field not found in message %s", inputMessage.Name)
}

argsMessage := p.doc.Messages[argsSchemaField.MessageRef]
argsRPCField := rpcMessage.Fields.ByName("field_args")
if argsRPCField == nil {
return nil, fmt.Errorf("field_args field not found in message %s", rpcMessage.Name)
}
if argsRPCField != nil {
argsSchemaField := inputMessage.GetField("field_args")
if argsSchemaField == nil {
return nil, fmt.Errorf("field_args field not found in message %s", inputMessage.Name)
}

args, err := p.buildProtoMessage(argsMessage, argsRPCField.Message, data)
if err != nil {
return nil, err
}
// Set the args field
if err := p.setMessageValue(rootMessage, argsRPCField.Name, protoref.ValueOfMessage(args)); err != nil {
return nil, err
argsMessage := p.doc.Messages[argsSchemaField.MessageRef]
args, err := p.buildProtoMessage(argsMessage, argsRPCField.Message, data)
if err != nil {
return nil, err
}
// Set the args field
if err := p.setMessageValue(rootMessage, argsRPCField.Name, protoref.ValueOfMessage(args)); err != nil {
return nil, err
}
}

return rootMessage, nil
Expand Down
61 changes: 36 additions & 25 deletions v2/pkg/engine/datasource/grpc_datasource/execution_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -925,7 +925,12 @@ func (r *rpcPlanningContext) isFieldResolver(fieldDefRef int, isRootField bool)
return false
}

return r.definition.FieldDefinitionHasArgumentsDefinitions(fieldDefRef)
return r.definition.FieldDefinitionHasArgumentsDefinitions(fieldDefRef) || r.hasFieldResolverDirective(fieldDefRef)
}

// hasFieldResolverDirective checks if a field has a field resolver directive.
func (r *rpcPlanningContext) hasFieldResolverDirective(fieldDefRef int) bool {
return r.definition.FieldDefinitionHasNamedDirective(fieldDefRef, fieldResolverDirectiveName)
}

// isRequiredField checks if a field is a required field.
Expand Down Expand Up @@ -986,12 +991,14 @@ func (r *rpcPlanningContext) setResolvedField(walker *astvisitor.Walker, fieldDe
})
}

fieldArguments, err := r.parseFieldArguments(walker, fieldDefRef, fieldArgs)
if err != nil {
return err
}
if len(fieldArgs) > 0 {
fieldArguments, err := r.parseFieldArguments(walker, fieldDefRef, fieldArgs)
if err != nil {
return err
}

resolvedField.fieldArguments = fieldArguments
resolvedField.fieldArguments = fieldArguments
}

fieldDefType := r.definition.FieldDefinitionType(fieldDefRef)
if r.typeIsNullableOrNestedList(fieldDefType) {
Expand Down Expand Up @@ -1446,8 +1453,11 @@ func (r *rpcPlanningContext) createResolverRPCCalls(subgraphName string, resolve
Name: resolveConfig.RPC + "Context",
}

fieldArgsMessage := &RPCMessage{
Name: resolveConfig.RPC + "Args",
var fieldArgsMessage *RPCMessage
if len(resolvedField.fieldArguments) > 0 {
fieldArgsMessage = &RPCMessage{
Name: resolveConfig.RPC + "Args",
}
}
Comment thread
Noroth marked this conversation as resolved.

call, err := r.newResolveRPCCall(&resolveRPCCallConfig{
Expand Down Expand Up @@ -1555,28 +1565,29 @@ func (r *rpcPlanningContext) newResolveRPCCall(config *resolveRPCCallConfig) (RP
},
}

requestFields := RPCFields{
{
Name: contextFieldName,
ProtoTypeName: DataTypeMessage,
Repeated: true,
Message: config.contextMessage,
},
}
if config.fieldArgsMessage != nil {
requestFields = append(requestFields, RPCField{
Name: fieldArgsFieldName,
ProtoTypeName: DataTypeMessage,
Message: config.fieldArgsMessage,
})
}

return RPCCall{
ID: resolvedField.id,
DependentCalls: []int{resolvedField.callerRef},
ResponsePath: resolvedField.responsePath,
MethodName: resolveConfig.RPC,
Kind: CallKindResolve,
Request: RPCMessage{
Name: resolveConfig.Request,
Fields: RPCFields{
{
Name: contextFieldName,
ProtoTypeName: DataTypeMessage,
Repeated: true,
Message: config.contextMessage,
},
{
Name: fieldArgsFieldName,
ProtoTypeName: DataTypeMessage,
Message: config.fieldArgsMessage,
},
},
},
Response: response,
Request: RPCMessage{Name: resolveConfig.Request, Fields: requestFields},
Response: response,
}, nil
}
Loading
Loading