Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
634 changes: 632 additions & 2 deletions execution/engine/execution_engine_grpc_test.go

Large diffs are not rendered by default.

187 changes: 156 additions & 31 deletions v2/pkg/engine/datasource/grpc_datasource/compiler.go
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,14 @@ func (d *Document) ServiceByRef(ref int) Service {

// MessageByName returns a Message by its name.
// Returns an empty Message if no message with the given name exists.
func (d *Document) MessageByName(name string) Message {
func (d *Document) MessageByName(name string) (Message, bool) {
for _, m := range d.Messages {
if m.Name == name {
return m
return m, true
Comment thread
devsergiy marked this conversation as resolved.
}
}

return Message{}
return Message{}, false
}

// MessageRefByName returns the index of a Message in the Messages slice by its name.
Expand Down Expand Up @@ -336,13 +336,13 @@ func (p *RPCCompiler) Compile(executionPlan *RPCExecutionPlan, inputData gjson.R
invocations := make([]Invocation, 0, len(executionPlan.Calls))

for _, call := range executionPlan.Calls {
inputMessage := p.doc.MessageByName(call.Request.Name)
if inputMessage.Name == "" {
inputMessage, ok := p.doc.MessageByName(call.Request.Name)
if !ok {
return nil, fmt.Errorf("input message %s not found in document", call.Request.Name)
}

outputMessage := p.doc.MessageByName(call.Response.Name)
if outputMessage.Name == "" {
outputMessage, ok := p.doc.MessageByName(call.Response.Name)
if !ok {
return nil, fmt.Errorf("output message %s not found in document", call.Response.Name)
}

Expand Down Expand Up @@ -445,11 +445,33 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
if field.MessageRef >= 0 {
var fieldMsg *dynamicpb.Message

// If the field is optional, we are handling a scalar value that is wrapped in a message
// as protobuf scalar types are not nullable.
if rpcField.Optional {
// If we don't have a value for an optional field, we skip it to provide a null message.
switch {
case rpcField.IsListType:
// nested and nullable lists are wrapped in a message, therefore we need to handle them differently
Comment thread
devsergiy marked this conversation as resolved.
Outdated
// than repeated fields.
if !data.Get(rpcField.JSONPath).Exists() {
if !rpcField.Optional {
p.report.AddInternalError(fmt.Errorf("field %s is required but has no value", rpcField.JSONPath))
}

continue
}

if rpcField.ListMetadata == nil {
p.report.AddInternalError(fmt.Errorf("list metadata not found for field %s", rpcField.JSONPath))
continue
}

fieldMsg = p.buildListMessage(inputMessage.Desc, field, rpcField, data)
if fieldMsg == nil {
continue
}
case rpcField.IsOptionalScalar():
// If the field is optional, we are handling a scalar value that is wrapped in a message
// as protobuf scalar types are not nullable.

if !data.Get(rpcField.JSONPath).Exists() {
// If we don't have a value for an optional field, we skip it to provide a null message.
continue
}

Expand All @@ -459,7 +481,7 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
rpcField.ToOptionalTypeMessage(p.doc.Messages[field.MessageRef].Name),
data,
)
} else {
default:
fieldMsg = p.buildProtoMessage(p.doc.Messages[field.MessageRef], rpcField.Message, data.Get(rpcField.JSONPath))
}

Expand All @@ -468,21 +490,11 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
}

if field.Type == DataTypeEnum {
enum, ok := p.doc.EnumByName(rpcField.EnumName)
if !ok {
p.report.AddInternalError(fmt.Errorf("enum %s not found in document", rpcField.EnumName))
continue
}

for _, enumValue := range enum.Values {
if enumValue.GraphqlValue == data.Get(rpcField.JSONPath).String() {
message.Set(
fd.ByName(protoref.Name(field.Name)),
protoref.ValueOfEnum(protoref.EnumNumber(enumValue.Number)),
)

break
}
if val := p.getEnumValue(rpcField.EnumName, data.Get(rpcField.JSONPath)); val != nil {
message.Set(
fd.ByName(protoref.Name(field.Name)),
*val,
)
}

continue
Expand All @@ -497,6 +509,117 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes
return message
}

func (p *RPCCompiler) buildListMessage(desc protoref.MessageDescriptor, field Field, rpcField *RPCField, data gjson.Result) *dynamicpb.Message {
rootMsg := dynamicpb.NewMessage(desc.Fields().ByName(protoref.Name(field.Name)).Message())
p.traverseList(rootMsg, 1, field, rpcField, data.Get(rpcField.JSONPath))
return rootMsg
}

func (p *RPCCompiler) traverseList(rootMsg protoref.Message, level int, field Field, rpcField *RPCField, data gjson.Result) protoref.Message {
if level >= rpcField.ListMetadata.NestingLevel {
arr := data.Array()
if len(arr) == 0 {
if rpcField.ListMetadata.LevelInfo[level-1].Optional {
return nil
}

return rootMsg
}

// List wrappers always use field number 1
fieldDesc := rootMsg.Descriptor().Fields().ByNumber(1)
if fieldDesc == nil {
p.report.AddInternalError(fmt.Errorf("field with number %d not found in message %s", 1, rootMsg.Descriptor().Name()))
return nil
}

itemsField := rootMsg.Mutable(fieldDesc).List()

switch DataType(rpcField.TypeName) {
case DataTypeMessage:
itemsFieldMsg, ok := p.doc.MessageByName(rpcField.Message.Name)
if !ok {
p.report.AddInternalError(fmt.Errorf("message %s not found in document", rpcField.Message.Name))
return nil
}

for _, element := range arr {
if msg := p.buildProtoMessage(itemsFieldMsg, rpcField.Message, element); msg != nil {
itemsField.Append(protoref.ValueOfMessage(msg))
}
}
case DataTypeEnum:
for _, element := range arr {
if val := p.getEnumValue(rpcField.EnumName, element); val != nil {
itemsField.Append(*val)
}
}
default:
for _, element := range arr {
itemsField.Append(p.setValueForKind(DataType(fieldDesc.Kind().String()), element))
}
}

rootMsg.Set(fieldDesc, protoref.ValueOfList(itemsField))
return rootMsg
}

// For nested Lists we always expect a "list" field in the root message with field number 1
listFieldDesc := rootMsg.Descriptor().Fields().ByNumber(1)
if listFieldDesc == nil {
p.report.AddInternalError(fmt.Errorf("field with number %d not found in message %s", 1, rootMsg.Descriptor().Name()))
return nil
}

elements := data.Array()
newListField := rootMsg.NewField(listFieldDesc)
if len(elements) == 0 {
if rpcField.ListMetadata.LevelInfo[level-1].Optional {
return nil
}

rootMsg.Set(listFieldDesc, newListField)
return rootMsg
}

// Inside of a List message type we expect a repeated "items" field with field number 1
itemsFieldMsg := newListField.Message()
itemsFieldDesc := itemsFieldMsg.Descriptor().Fields().ByNumber(1)
if itemsFieldDesc == nil {
p.report.AddInternalError(fmt.Errorf("field with number %d not found in message %s", 1, itemsFieldMsg.Descriptor().Name()))
return nil
}

itemsField := itemsFieldMsg.Mutable(itemsFieldDesc)
for _, element := range elements {
newElement := itemsField.List().NewElement()
if val := p.traverseList(newElement.Message(), level+1, field, rpcField, element); val != nil {
itemsField.List().Append(protoref.ValueOfMessage(val))
}
}

rootMsg.Set(listFieldDesc, newListField)

return rootMsg
}

func (p *RPCCompiler) getEnumValue(enumName string, data gjson.Result) *protoref.Value {
enum, ok := p.doc.EnumByName(enumName)
if !ok {
p.report.AddInternalError(fmt.Errorf("enum %s not found in document", enumName))
return nil
}

for _, enumValue := range enum.Values {
if enumValue.GraphqlValue == data.String() {
v := protoref.ValueOfEnum(protoref.EnumNumber(enumValue.Number))
return &v
}
}

return nil
}

// setValueForKind converts a gjson.Result value to the appropriate protobuf value
// based on its kind/type.
func (p *RPCCompiler) setValueForKind(kind DataType, data gjson.Result) protoref.Value {
Expand Down Expand Up @@ -599,18 +722,20 @@ func (p *RPCCompiler) parseMethod(m protoref.MethodDescriptor) Method {
// parseMessageDefinitions extracts information from a protobuf message descriptor.
// It returns a slice of Message objects with the name and descriptor.
func (p *RPCCompiler) parseMessageDefinitions(messages protoref.MessageDescriptors) []Message {
extractedMessage := make([]Message, 0, messages.Len())
extractedMessages := make([]Message, 0, messages.Len())

for i := 0; i < messages.Len(); i++ {
protoMessage := messages.Get(i)

extractedMessage = append(extractedMessage, Message{
message := Message{
Name: string(protoMessage.Name()),
Desc: protoMessage,
})
}

extractedMessages = append(extractedMessages, message)
}

return extractedMessage
return extractedMessages
}

// enrichMessageData enriches the message data with the field information.
Expand Down
Loading
Loading