Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions v2/pkg/engine/datasource/grpc_datasource/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/bufbuild/protocompile"
"github.com/tidwall/gjson"
"google.golang.org/protobuf/proto"
protoref "google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/dynamicpb"

Expand Down Expand Up @@ -389,6 +390,7 @@ func (p *RPCCompiler) processFile(f protoref.FileDescriptor, mapping *GRPCMappin

// ServiceCall represents a single gRPC service call with its input and output messages.
type ServiceCall struct {
ID int
// ServiceName is the name of the gRPC service to call
ServiceName string
// MethodName is the name of the method on the service to call
Expand All @@ -401,6 +403,10 @@ type ServiceCall struct {
RPC *RPCCall
}

func (s *ServiceCall) CloneOutput() protoref.Message {
return proto.Clone(s.Output.Interface()).ProtoReflect()
}

func (s *ServiceCall) MethodFullName() string {
var builder strings.Builder

Expand Down Expand Up @@ -464,7 +470,7 @@ func (p *RPCCompiler) CompileFetches(graph *DependencyGraph, fetches []FetchItem
return nil, err
}

graph.SetFetchData(node.ID, &serviceCall)
serviceCall.ID = node.ID
serviceCalls = append(serviceCalls, serviceCall)
}

Expand Down Expand Up @@ -600,10 +606,10 @@ func (p *RPCCompiler) buildProtoMessageWithContext(inputMessage Message, rpcMess
contextList := p.newEmptyListMessageByName(rootMessage, contextSchemaField.Name)
contextData := p.resolveContextData(context[0], contextRPCField) // TODO handle multiple contexts (resolver requires another resolver)

for _, data := range contextData {
for _, contextElement := range contextData {
val := contextList.NewElement()
valMsg := val.Message()
for fieldName, value := range data {
for fieldName, value := range contextElement {
if err := p.setMessageValue(valMsg, fieldName, value); err != nil {
return nil, err
}
Expand Down Expand Up @@ -636,13 +642,13 @@ func (p *RPCCompiler) buildProtoMessageWithContext(inputMessage Message, rpcMess
}

func (p *RPCCompiler) resolveContextData(context FetchItem, contextField *RPCField) []map[string]protoref.Value {
if context.ServiceCall == nil || context.ServiceCall.Output == nil {
if context.Output == nil {
return []map[string]protoref.Value{}
}

contextValues := make([]map[string]protoref.Value, 0)
for _, field := range contextField.Message.Fields {
values := p.resolveContextDataForPath(context.ServiceCall.Output, field.ResolvePath)
values := p.resolveContextDataForPath(context.Output, field.ResolvePath)

for index, value := range values {
if index >= len(contextValues) {
Expand Down Expand Up @@ -676,7 +682,6 @@ func (p *RPCCompiler) resolveContextDataForPath(message protoref.Message, path a
}

return p.resolveDataForPath(msg.Message(), path)

}

// resolveListDataForPath resolves the data for a given path in a list message.
Expand All @@ -696,13 +701,14 @@ func (p *RPCCompiler) resolveListDataForPath(message protoref.List, fd protoref.

for _, val := range values {
if list, isList := val.Interface().(protoref.List); isList {
values := p.resolveListDataForPath(list, fd, path)
values := p.resolveListDataForPath(list, fd, path[1:])
result = append(result, values...)
continue
} else {
result = append(result, val)
}
}

result = append(result, values...)
default:
result = append(result, item)
}
Expand Down
88 changes: 52 additions & 36 deletions v2/pkg/engine/datasource/grpc_datasource/execution_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ const (
// RPCCall represents a single call to a gRPC service method.
// It contains all the information needed to make the call and process the response.
type RPCCall struct {
// ID indicates the expected index of the call in the execution plan
ID int
// Kind of call, used to decide how to execute the call
// This is used to identify the call type and execution behaviour.
Kind CallKind
Expand Down Expand Up @@ -131,20 +133,6 @@ func (r *RPCMessage) SelectValidTypes(typeName string) []string {
return []string{r.Name, typeName}
}

// AppendTypeNameField appends a typename field to the message.
func (r *RPCMessage) AppendTypeNameField(typeName string) {
if r.Fields != nil && r.Fields.Exists(typenameFieldName, "") {
return
}

r.Fields = append(r.Fields, RPCField{
Name: typenameFieldName,
ProtoTypeName: DataTypeString,
StaticValue: typeName,
JSONPath: typenameFieldName,
})
}

// RPCFieldSelectionSet is a map of field selections based on inline fragments
type RPCFieldSelectionSet map[string]RPCFields

Expand Down Expand Up @@ -292,8 +280,8 @@ func (r *RPCExecutionPlan) String() string {

result.WriteString("RPCExecutionPlan:\n")

for j, call := range r.Calls {
result.WriteString(fmt.Sprintf(" Call %d:\n", j))
for _, call := range r.Calls {
result.WriteString(fmt.Sprintf(" Call %d:\n", call.ID))

if len(call.DependentCalls) > 0 {
result.WriteString(" DependentCalls: [")
Expand Down Expand Up @@ -362,6 +350,7 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
fmt.Fprintf(sb, "%s TypeName: %s\n", indentStr, field.ProtoTypeName)
fmt.Fprintf(sb, "%s Repeated: %v\n", indentStr, field.Repeated)
fmt.Fprintf(sb, "%s JSONPath: %s\n", indentStr, field.JSONPath)
fmt.Fprintf(sb, "%s ResolvePath: %s\n", indentStr, field.ResolvePath.String())

if field.Message != nil {
fmt.Fprintf(sb, "%s Message:\n", indentStr)
Expand Down Expand Up @@ -474,6 +463,15 @@ func (r *rpcPlanningContext) newMessageFromSelectionSet(enclosingTypeNode ast.No
return message
}

func (r *rpcPlanningContext) findResolverFieldMapping(typeName, fieldName string) string {
resolveConfig := r.mapping.FindResolveTypeFieldMapping(typeName, fieldName)
if resolveConfig == nil {
return fieldName
}

return resolveConfig.FieldMappingData.TargetName
}

// resolveFieldMapping resolves the field mapping for a field.
// This applies both for complex types in the input and for all fields in the response.
func (r *rpcPlanningContext) resolveFieldMapping(typeName, fieldName string) string {
Expand Down Expand Up @@ -531,25 +529,25 @@ func (r *rpcPlanningContext) createListMetadata(typeRef int) (*ListMetadata, err

// buildField builds a field from a field definition.
// It handles lists, enums, and other types.
func (r *rpcPlanningContext) buildField(enclosingTypeNode ast.Node, fd int, fieldName, fieldAlias string) (RPCField, error) {
fdt := r.definition.FieldDefinitionType(fd)
typeName := r.toDataType(&r.definition.Types[fdt])
func (r *rpcPlanningContext) buildField(enclosingTypeNode ast.Node, fieldDef int, fieldName, fieldAlias string) (RPCField, error) {
fieldDefType := r.definition.FieldDefinitionType(fieldDef)
typeName := r.toDataType(&r.definition.Types[fieldDefType])
parentTypeName := enclosingTypeNode.NameString(r.definition)

field := RPCField{
Name: r.resolveFieldMapping(parentTypeName, fieldName),
Alias: fieldAlias,
Optional: !r.definition.TypeIsNonNull(fdt),
Optional: !r.definition.TypeIsNonNull(fieldDefType),
JSONPath: fieldName,
ProtoTypeName: typeName,
}

if r.definition.TypeIsList(fdt) {
if r.definition.TypeIsList(fieldDefType) {
switch {
// for nullable or nested lists we need to build a wrapper message
// Nullability is handled by the datasource during the execution.
case r.typeIsNullableOrNestedList(fdt):
md, err := r.createListMetadata(fdt)
case r.typeIsNullableOrNestedList(fieldDefType):
md, err := r.createListMetadata(fieldDefType)
if err != nil {
return field, err
}
Expand All @@ -562,7 +560,7 @@ func (r *rpcPlanningContext) buildField(enclosingTypeNode ast.Node, fd int, fiel
}

if typeName == DataTypeEnum {
field.EnumName = r.definition.FieldDefinitionTypeNameString(fd)
field.EnumName = r.definition.FieldDefinitionTypeNameString(fieldDef)
}

if fieldName == typenameFieldName {
Expand Down Expand Up @@ -808,17 +806,20 @@ func (r *rpcPlanningContext) resolveServiceName(subgraphName string) string {
}

type resolverField struct {
id int
callerRef int
parentTypeNode ast.Node
fieldRef int
fieldDefinitionTypeRef int
fieldsSelectionSetRef int
responsePath ast.Path
contextPath ast.Path

contextFields []contextField
fieldArguments []fieldArgument
fragmentSelections []fragmentSelection
fragmentType OneOfType
listNestingLevel int
memberTypes []string
}

Expand Down Expand Up @@ -907,8 +908,11 @@ func (r *rpcPlanningContext) setResolvedField(walker *astvisitor.Walker, fieldDe
}

for _, contextFieldRef := range contextFields {
contextFieldName := r.definition.FieldDefinitionNameBytes(contextFieldRef)
resolvedPath := fieldPath.WithFieldNameItem(contextFieldName)
mapping := r.resolveFieldMapping(
walker.EnclosingTypeDefinition.NameString(r.definition),
r.definition.FieldDefinitionNameString(contextFieldRef),
)
resolvedPath := fieldPath.WithFieldNameItem([]byte(mapping))

resolvedField.contextFields = append(resolvedField.contextFields, contextField{
fieldRef: contextFieldRef,
Expand All @@ -922,6 +926,12 @@ func (r *rpcPlanningContext) setResolvedField(walker *astvisitor.Walker, fieldDe
}

resolvedField.fieldArguments = fieldArguments

fieldDefType := r.definition.FieldDefinitionType(fieldDefRef)
if r.typeIsNullableOrNestedList(fieldDefType) {
resolvedField.listNestingLevel = r.definition.TypeNumberOfListWraps(fieldDefType)
}

return nil
}

Expand Down Expand Up @@ -1306,6 +1316,19 @@ func (r *rpcPlanningContext) newResolveRPCCall(config *resolveRPCCallConfig) (RP
}
}

fd := r.fieldDefinitionRefForType(r.operation.FieldNameString(resolvedField.fieldRef), resolvedField.parentTypeNode.NameString(r.definition))
if fd == ast.InvalidRef {
return RPCCall{}, fmt.Errorf("unable to build response field: field definition not found for field %s", r.operation.FieldNameString(resolvedField.fieldRef))
}

field, err := r.buildField(resolvedField.parentTypeNode, fd, r.operation.FieldNameString(resolvedField.fieldRef), r.operation.FieldAliasString(resolvedField.fieldRef))
if err != nil {
return RPCCall{}, err
}

field.Name = resolveConfig.FieldMappingData.TargetName
field.Message = responseFieldsMessage

response := RPCMessage{
Name: resolveConfig.Response,
Fields: RPCFields{
Expand All @@ -1315,22 +1338,15 @@ func (r *rpcPlanningContext) newResolveRPCCall(config *resolveRPCCallConfig) (RP
JSONPath: resultFieldName,
Repeated: true,
Message: &RPCMessage{
Name: resolveConfig.RPC + "Result",
Fields: RPCFields{
{
Name: resolveConfig.FieldMappingData.TargetName,
ProtoTypeName: dataType,
JSONPath: r.operation.FieldAliasOrNameString(resolvedField.fieldRef),
Message: responseFieldsMessage,
Optional: !r.definition.TypeIsNonNull(resolvedField.fieldDefinitionTypeRef),
},
},
Name: resolveConfig.RPC + "Result",
Fields: RPCFields{field},
},
},
},
}

return RPCCall{
ID: resolvedField.id,
DependentCalls: []int{resolvedField.callerRef},
ResponsePath: resolvedField.responsePath,
MethodName: resolveConfig.RPC,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ func TestEntityLookup(t *testing.T) {
},
},
{
ID: 1,
ServiceName: "Products",
MethodName: "LookupStorageById",
Kind: CallKindEntity,
Expand Down Expand Up @@ -1078,6 +1079,7 @@ func TestEntityLookupWithFieldResolvers(t *testing.T) {
},
},
{
ID: 1,
ServiceName: "Products",
MethodName: "ResolveProductShippingEstimate",
Kind: CallKindResolve,
Expand Down Expand Up @@ -1254,6 +1256,7 @@ func TestEntityLookupWithFieldResolvers(t *testing.T) {
},
},
{
ID: 1,
ServiceName: "Products",
MethodName: "LookupProductById",
Kind: CallKindEntity,
Expand Down Expand Up @@ -1318,6 +1321,7 @@ func TestEntityLookupWithFieldResolvers(t *testing.T) {
},
},
{
ID: 2,
ServiceName: "Products",
MethodName: "ResolveProductShippingEstimate",
Kind: CallKindResolve,
Expand Down Expand Up @@ -1525,6 +1529,7 @@ func TestEntityLookupWithFieldResolvers_WithCompositeTypes(t *testing.T) {
},
},
{
ID: 1,
ServiceName: "Products",
MethodName: "ResolveProductMascotRecommendation",
Kind: CallKindResolve,
Expand Down Expand Up @@ -1702,6 +1707,7 @@ func TestEntityLookupWithFieldResolvers_WithCompositeTypes(t *testing.T) {
},
},
{
ID: 1,
ServiceName: "Products",
MethodName: "ResolveProductStockStatus",
Kind: CallKindResolve,
Expand Down Expand Up @@ -1889,6 +1895,7 @@ func TestEntityLookupWithFieldResolvers_WithCompositeTypes(t *testing.T) {
},
},
{
ID: 1,
ServiceName: "Products",
MethodName: "ResolveProductProductDetails",
Kind: CallKindResolve,
Expand Down
Loading
Loading