Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
635 changes: 629 additions & 6 deletions execution/engine/execution_engine_grpc_test.go

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions v2/pkg/ast/ast_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,16 @@ func (d *Document) TypeIsList(ref int) bool {
}
}

// TypeIsNonNullList checks if the type is a non-nullable list.
// e.g.:
// * [String!]! -> true
// * [String]! -> true
// * [String!] -> false
// * [String] -> false
func (d *Document) TypeIsNonNullList(ref int) bool {
return d.Types[ref].TypeKind == TypeKindNonNull && d.TypeIsList(d.Types[ref].OfType)
}

func (d *Document) TypeNumberOfListWraps(ref int) int {
count := 0
for {
Expand Down Expand Up @@ -260,3 +270,27 @@ func (d *Document) ResolveListOrNameType(ref int) (typeRef int) {
}
return
}

// ResolveNestedListOrListType returns the underlying type of a list.
// In contrast to ResolveListOrNameType, this function does not unwrap a non-null type.
// e.g.:
// * [[String]] -> [String]
// * [[String]!] -> [String]!
// * [String!]! -> String!
// * [String]! -> String
func (d *Document) ResolveNestedListOrListType(ref int) int {
if !d.TypeIsList(ref) {
return InvalidRef
}

graphqlType := d.Types[ref]
if graphqlType.TypeKind == TypeKindNonNull {
graphqlType = d.Types[graphqlType.OfType]
}

if graphqlType.TypeKind == TypeKindList {
return graphqlType.OfType
}

return ref
}
197 changes: 165 additions & 32 deletions v2/pkg/engine/datasource/grpc_datasource/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,16 @@ func (d *Document) ServiceByRef(ref int) Service {

// MessageByName returns a Message by its name.
// Returns an empty Message if no message with the given name exists.
func (d *Document) MessageByName(name string) Message {
// We only expect this function to return false if either the message name was provided incorrectly,
// or the schema and mapping was not properly configured.
func (d *Document) MessageByName(name string) (Message, bool) {
for _, m := range d.Messages {
if m.Name == name {
return m
return m, true
Comment thread
devsergiy marked this conversation as resolved.
}
}

return Message{}
return Message{}, false
}

// MessageRefByName returns the index of a Message in the Messages slice by its name.
Expand Down Expand Up @@ -336,13 +338,13 @@ func (p *RPCCompiler) Compile(executionPlan *RPCExecutionPlan, inputData gjson.R
invocations := make([]Invocation, 0, len(executionPlan.Calls))

for _, call := range executionPlan.Calls {
inputMessage := p.doc.MessageByName(call.Request.Name)
if inputMessage.Name == "" {
inputMessage, ok := p.doc.MessageByName(call.Request.Name)
if !ok {
return nil, fmt.Errorf("input message %s not found in document", call.Request.Name)
}

outputMessage := p.doc.MessageByName(call.Response.Name)
if outputMessage.Name == "" {
outputMessage, ok := p.doc.MessageByName(call.Response.Name)
if !ok {
return nil, fmt.Errorf("output message %s not found in document", call.Response.Name)
}

Expand Down Expand Up @@ -445,11 +447,39 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
if field.MessageRef >= 0 {
var fieldMsg *dynamicpb.Message

// If the field is optional, we are handling a scalar value that is wrapped in a message
// as protobuf scalar types are not nullable.
if rpcField.Optional {
// If we don't have a value for an optional field, we skip it to provide a null message.
switch {

case rpcField.IsListType:
// Nested and nullable lists are wrapped in a message, therefore we need to handle them differently
// than repeated fields. We need to do this because protobuf repeated fields are not nullable and cannot be nested.
//
// message BlogPost {
// ListOfBoolean is_published = 1;
// ListOfListOfString related_topics = 2;
// }
if !data.Get(rpcField.JSONPath).Exists() {
if !rpcField.Optional {
p.report.AddInternalError(fmt.Errorf("field %s is required but has no value", rpcField.JSONPath))
}

continue
}

if rpcField.ListMetadata == nil {
p.report.AddInternalError(fmt.Errorf("list metadata not found for field %s", rpcField.JSONPath))
continue
}

fieldMsg = p.buildListMessage(inputMessage.Desc, field, rpcField, data)
if fieldMsg == nil {
continue
}
case rpcField.IsOptionalScalar():
// If the field is optional, we are handling a scalar value that is wrapped in a message
// as protobuf scalar types are not nullable.

if !data.Get(rpcField.JSONPath).Exists() {
// If we don't have a value for an optional field, we skip it to provide a null message.
continue
}

Expand All @@ -459,7 +489,7 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
rpcField.ToOptionalTypeMessage(p.doc.Messages[field.MessageRef].Name),
data,
)
} else {
default:
fieldMsg = p.buildProtoMessage(p.doc.Messages[field.MessageRef], rpcField.Message, data.Get(rpcField.JSONPath))
}

Expand All @@ -468,35 +498,136 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
}

if field.Type == DataTypeEnum {
enum, ok := p.doc.EnumByName(rpcField.EnumName)
if !ok {
p.report.AddInternalError(fmt.Errorf("enum %s not found in document", rpcField.EnumName))
continue
}

for _, enumValue := range enum.Values {
if enumValue.GraphqlValue == data.Get(rpcField.JSONPath).String() {
message.Set(
fd.ByName(protoref.Name(field.Name)),
protoref.ValueOfEnum(protoref.EnumNumber(enumValue.Number)),
)

break
}
if val := p.getEnumValue(rpcField.EnumName, data.Get(rpcField.JSONPath)); val != nil {
message.Set(
fd.ByName(protoref.Name(field.Name)),
*val,
)
}

continue
}

// Handle scalar fields
// TODO handle optional fields
value := data.Get(rpcField.JSONPath)
message.Set(fd.ByName(protoref.Name(field.Name)), p.setValueForKind(field.Type, value))
}

return message
}

// buildListMessage creates a new protobuf message, which reflects a wrapper type to work with a list in GraphQL.
// A list wrapper type has an inner message type, which contains a repeated field.
// We need this to make sure we can differentiate between a null list and an empty list, as repeated fields are not nullable.
//
// message ListOfFloat {
// message List {
// repeated double items = 1;
// }
// List list = 1;
// }
func (p *RPCCompiler) buildListMessage(desc protoref.MessageDescriptor, field Field, rpcField *RPCField, data gjson.Result) *dynamicpb.Message {
rootMsg := dynamicpb.NewMessage(desc.Fields().ByName(protoref.Name(field.Name)).Message())
p.traverseList(rootMsg, 1, field, rpcField, data.Get(rpcField.JSONPath))
return rootMsg
}

// traverseList makes sure we can handle nested lists properly.
// A nested list follows the same structure as a regular list, but references the lower nested message list wrapper.
//
// message ListOfListOfString {
// message List {
// repeated ListOfString items = 1;
// }
// List list = 1;
// }
func (p *RPCCompiler) traverseList(rootMsg protoref.Message, level int, field Field, rpcField *RPCField, data gjson.Result) protoref.Message {
listFieldDesc := rootMsg.Descriptor().Fields().ByNumber(1)
if listFieldDesc == nil {
p.report.AddInternalError(fmt.Errorf("field with number %d not found in message %s", 1, rootMsg.Descriptor().Name()))
return nil
}

elements := data.Array()
newListField := rootMsg.NewField(listFieldDesc)
if len(elements) == 0 {
if rpcField.ListMetadata.LevelInfo[level-1].Optional {
return nil
}

rootMsg.Set(listFieldDesc, newListField)
return rootMsg
}

// Inside of a List message type we expect a repeated "items" field with field number 1
Comment thread
devsergiy marked this conversation as resolved.
itemsFieldMsg := newListField.Message()
itemsFieldDesc := itemsFieldMsg.Descriptor().Fields().ByNumber(1)
if itemsFieldDesc == nil {
p.report.AddInternalError(fmt.Errorf("field with number %d not found in message %s", 1, itemsFieldMsg.Descriptor().Name()))
return nil
}

itemsField := itemsFieldMsg.Mutable(itemsFieldDesc).List()

if level >= rpcField.ListMetadata.NestingLevel {
switch DataType(rpcField.TypeName) {
case DataTypeMessage:
itemsFieldMsg, ok := p.doc.MessageByName(rpcField.Message.Name)
if !ok {
p.report.AddInternalError(fmt.Errorf("message %s not found in document", rpcField.Message.Name))
return nil
}

for _, element := range elements {
if msg := p.buildProtoMessage(itemsFieldMsg, rpcField.Message, element); msg != nil {
itemsField.Append(protoref.ValueOfMessage(msg))
}
}
case DataTypeEnum:
for _, element := range elements {
if val := p.getEnumValue(rpcField.EnumName, element); val != nil {
itemsField.Append(*val)
}
}
default:
for _, element := range elements {
itemsField.Append(p.setValueForKind(DataType(itemsFieldDesc.Kind().String()), element))
}
}

itemsFieldMsg.Set(itemsFieldDesc, protoref.ValueOfList(itemsField))
rootMsg.Set(listFieldDesc, newListField)
return rootMsg
}

for _, element := range elements {
newElement := itemsField.NewElement()
if val := p.traverseList(newElement.Message(), level+1, field, rpcField, element); val != nil {
itemsField.Append(protoref.ValueOfMessage(val))
}
}

rootMsg.Set(listFieldDesc, newListField)
return rootMsg
}

func (p *RPCCompiler) getEnumValue(enumName string, data gjson.Result) *protoref.Value {
enum, ok := p.doc.EnumByName(enumName)
if !ok {
p.report.AddInternalError(fmt.Errorf("enum %s not found in document", enumName))
return nil
}

for _, enumValue := range enum.Values {
if enumValue.GraphqlValue == data.String() {
v := protoref.ValueOfEnum(protoref.EnumNumber(enumValue.Number))
return &v
}
}

return nil
}

// setValueForKind converts a gjson.Result value to the appropriate protobuf value
// based on its kind/type.
func (p *RPCCompiler) setValueForKind(kind DataType, data gjson.Result) protoref.Value {
Expand Down Expand Up @@ -599,18 +730,20 @@ func (p *RPCCompiler) parseMethod(m protoref.MethodDescriptor) Method {
// parseMessageDefinitions extracts information from a protobuf message descriptor.
// It returns a slice of Message objects with the name and descriptor.
func (p *RPCCompiler) parseMessageDefinitions(messages protoref.MessageDescriptors) []Message {
extractedMessage := make([]Message, 0, messages.Len())
extractedMessages := make([]Message, 0, messages.Len())

for i := 0; i < messages.Len(); i++ {
protoMessage := messages.Get(i)

extractedMessage = append(extractedMessage, Message{
message := Message{
Name: string(protoMessage.Name()),
Desc: protoMessage,
})
}

extractedMessages = append(extractedMessages, message)
}

return extractedMessage
return extractedMessages
}

// enrichMessageData enriches the message data with the field information.
Expand Down
Loading
Loading