Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions v2/pkg/engine/datasource/grpc_datasource/execution_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,10 +323,10 @@ func (r *RPCExecutionPlan) String() string {
fmt.Fprintf(&result, " Method: %s\n", call.MethodName)

result.WriteString(" Request:\n")
formatRPCMessage(&result, call.Request, 8)
formatRPCMessage(&result, &call.Request, 8, map[*RPCMessage]struct{}{})

result.WriteString(" Response:\n")
formatRPCMessage(&result, call.Response, 8)
formatRPCMessage(&result, &call.Response, 8, map[*RPCMessage]struct{}{})
}

return result.String()
Expand Down Expand Up @@ -361,8 +361,20 @@ func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs pla
}), nil
}

// formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation
func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
// formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation.
// The seen parameter tracks visited message pointers to prevent infinite recursion on cyclic message graphs.
func formatRPCMessage(sb *strings.Builder, message *RPCMessage, indent int, seen map[*RPCMessage]struct{}) {
if message == nil {
return
}
if _, ok := seen[message]; ok {
indentStr := strings.Repeat(" ", indent)
fmt.Fprintf(sb, "%s<recursive message: %s>\n", indentStr, message.Name)
return
}
seen[message] = struct{}{}
defer delete(seen, message)

indentStr := strings.Repeat(" ", indent)

fmt.Fprintf(sb, "%sName: %s\n", indentStr, message.Name)
Expand All @@ -377,23 +389,25 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {

if field.Message != nil {
fmt.Fprintf(sb, "%s Message:\n", indentStr)
formatRPCMessage(sb, *field.Message, indent+6)
formatRPCMessage(sb, field.Message, indent+6, seen)
}
}
}

type rpcPlanningContext struct {
operation *ast.Document
definition *ast.Document
mapping *GRPCMapping
operation *ast.Document
definition *ast.Document
mapping *GRPCMapping
requestMessageCache map[string]*RPCMessage
}

// newRPCPlanningContext creates a new RPCPlanningContext.
func newRPCPlanningContext(operation *ast.Document, definition *ast.Document, mapping *GRPCMapping) *rpcPlanningContext {
return &rpcPlanningContext{
operation: operation,
definition: definition,
mapping: mapping,
operation: operation,
definition: definition,
mapping: mapping,
requestMessageCache: make(map[string]*RPCMessage),
}
}

Expand Down Expand Up @@ -684,11 +698,18 @@ func (r *rpcPlanningContext) buildMessageFromInputObjectType(node *ast.Node) (*R
return nil, fmt.Errorf("unable to build message from input object type definition - incorrect type: %s", node.Kind)
}

typeName := node.NameString(r.definition)
if message, ok := r.requestMessageCache[typeName]; ok {
return message, nil
}

inputObjectDefinition := r.definition.InputObjectTypeDefinitions[node.Ref]
message := &RPCMessage{
Name: node.NameString(r.definition),
Name: typeName,
Fields: make(RPCFields, 0, len(inputObjectDefinition.InputFieldsDefinition.Refs)),
}
r.requestMessageCache[typeName] = message

Comment thread
coderabbitai[bot] marked this conversation as resolved.
for _, inputFieldRef := range inputObjectDefinition.InputFieldsDefinition.Refs {
field, err := r.buildMessageFieldFromInputValueDefinition(inputFieldRef, node)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
package grpcdatasource

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/wundergraph/graphql-go-tools/v2/pkg/ast"
"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
"github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform"
"github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation"
"github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport"
)

func TestMutationExecutionPlanWithRecursiveInputType(t *testing.T) {
schemaDoc := mustParseRecursiveInputSchema(t, `
scalar JSON

type Query {
noop: Boolean
}

type Mutation {
updateNode(input: UpdateNodeInput!): Node!
}

type Node {
id: ID!
}

input UpdateNodeInput {
id: ID!
conditions: RecursiveFilterInput
}

input RecursiveFilterInput {
and: [RecursiveFilterInput!]
or: [RecursiveFilterInput!]
key: String
value: JSON
}
`)

queryDoc, report := astparser.ParseGraphqlDocumentString(`
mutation UpdateNode($input: UpdateNodeInput!) {
updateNode(input: $input) {
id
}
}
`)
require.False(t, report.HasErrors(), report.Error())

plan, err := newRPCPlanVisitor(rpcPlanVisitorConfig{
subgraphName: "Products",
mapping: &GRPCMapping{
Service: "Products",
MutationRPCs: RPCConfigMap[RPCConfig]{
"updateNode": {
RPC: "MutationUpdateNode",
Request: "MutationUpdateNodeRequest",
Response: "MutationUpdateNodeResponse",
},
},
},
}).PlanOperation(&queryDoc, &schemaDoc)
require.NoError(t, err)
require.NotNil(t, plan)
require.NotPanics(t, func() { _ = plan.String() })
require.Len(t, plan.Calls, 1)

inputField := lookupField(plan.Calls[0].Request.Fields, "input")
require.NotNil(t, inputField)
require.NotNil(t, inputField.Message)

conditionsField := lookupField(inputField.Message.Fields, "conditions")
require.NotNil(t, conditionsField)
require.NotNil(t, conditionsField.Message)

andField := lookupField(conditionsField.Message.Fields, "and")
require.NotNil(t, andField)
require.True(t, andField.Repeated || andField.IsListType)
require.Same(t, conditionsField.Message, andField.Message)

orField := lookupField(conditionsField.Message.Fields, "or")
require.NotNil(t, orField)
require.True(t, orField.Repeated || orField.IsListType)
require.Same(t, conditionsField.Message, orField.Message)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When looking at the test, you haven't checked if key and value are also part of message fields.
Also can you add more test cases that handle multiple arguments and nested arguments?

something like this:

input A {
  b: B  
}

input B {
  c: C
}

input C {
  a: A
  b: B
}


keyField := lookupField(conditionsField.Message.Fields, "key")
require.NotNil(t, keyField)

valueField := lookupField(conditionsField.Message.Fields, "value")
require.NotNil(t, valueField)
}

func TestMutationExecutionPlanWithNestedRecursiveInputTypes(t *testing.T) {
schemaDoc := mustParseRecursiveInputSchema(t, `
type Query {
noop: Boolean
}

type Mutation {
processA(input: A!): Boolean!
}

input A {
b: B
}

input B {
c: C
}

input C {
a: A
b: B
}
`)

queryDoc, report := astparser.ParseGraphqlDocumentString(`
mutation ProcessA($input: A!) {
processA(input: $input)
}
`)
require.False(t, report.HasErrors(), report.Error())

plan, err := newRPCPlanVisitor(rpcPlanVisitorConfig{
subgraphName: "Products",
mapping: &GRPCMapping{
Service: "Products",
MutationRPCs: RPCConfigMap[RPCConfig]{
"processA": {
RPC: "MutationProcessA",
Request: "MutationProcessARequest",
Response: "MutationProcessAResponse",
},
},
},
}).PlanOperation(&queryDoc, &schemaDoc)
require.NoError(t, err)
require.NotNil(t, plan)
require.NotPanics(t, func() { _ = plan.String() })
require.Len(t, plan.Calls, 1)

inputField := lookupField(plan.Calls[0].Request.Fields, "input")
require.NotNil(t, inputField)
require.NotNil(t, inputField.Message)

// A.b -> B
aMessage := inputField.Message
require.Equal(t, "A", aMessage.Name)
bFieldInA := lookupField(aMessage.Fields, "b")
require.NotNil(t, bFieldInA)
require.NotNil(t, bFieldInA.Message)

// B.c -> C
bMessage := bFieldInA.Message
require.Equal(t, "B", bMessage.Name)
cFieldInB := lookupField(bMessage.Fields, "c")
require.NotNil(t, cFieldInB)
require.NotNil(t, cFieldInB.Message)

// C.a -> A (same pointer as top-level A)
cMessage := cFieldInB.Message
require.Equal(t, "C", cMessage.Name)
aFieldInC := lookupField(cMessage.Fields, "a")
require.NotNil(t, aFieldInC)
require.NotNil(t, aFieldInC.Message)
require.Same(t, aMessage, aFieldInC.Message)

// C.b -> B (same pointer as A.b's message)
bFieldInC := lookupField(cMessage.Fields, "b")
require.NotNil(t, bFieldInC)
require.NotNil(t, bFieldInC.Message)
require.Same(t, bMessage, bFieldInC.Message)
}

func TestMutationExecutionPlanWithMultipleRecursiveArguments(t *testing.T) {
schemaDoc := mustParseRecursiveInputSchema(t, `
type Query {
noop: Boolean
}

type Mutation {
processFilters(filter: RecursiveFilter!, exclude: RecursiveFilter!): Boolean!
}

input RecursiveFilter {
and: [RecursiveFilter!]
or: [RecursiveFilter!]
key: String
}
`)

queryDoc, report := astparser.ParseGraphqlDocumentString(`
mutation ProcessFilters($filter: RecursiveFilter!, $exclude: RecursiveFilter!) {
processFilters(filter: $filter, exclude: $exclude)
}
`)
require.False(t, report.HasErrors(), report.Error())

plan, err := newRPCPlanVisitor(rpcPlanVisitorConfig{
subgraphName: "Products",
mapping: &GRPCMapping{
Service: "Products",
MutationRPCs: RPCConfigMap[RPCConfig]{
"processFilters": {
RPC: "MutationProcessFilters",
Request: "MutationProcessFiltersRequest",
Response: "MutationProcessFiltersResponse",
},
},
},
}).PlanOperation(&queryDoc, &schemaDoc)
require.NoError(t, err)
require.NotNil(t, plan)
require.NotPanics(t, func() { _ = plan.String() })
require.Len(t, plan.Calls, 1)

filterField := lookupField(plan.Calls[0].Request.Fields, "filter")
require.NotNil(t, filterField)
require.NotNil(t, filterField.Message)

excludeField := lookupField(plan.Calls[0].Request.Fields, "exclude")
require.NotNil(t, excludeField)
require.NotNil(t, excludeField.Message)

// Both arguments share the same RecursiveFilter message via cache
require.Same(t, filterField.Message, excludeField.Message)

// Verify self-referencing fields
andField := lookupField(filterField.Message.Fields, "and")
require.NotNil(t, andField)
require.True(t, andField.Repeated || andField.IsListType)
require.Same(t, filterField.Message, andField.Message)

orField := lookupField(filterField.Message.Fields, "or")
require.NotNil(t, orField)
require.True(t, orField.Repeated || orField.IsListType)
require.Same(t, filterField.Message, orField.Message)

keyField := lookupField(filterField.Message.Fields, "key")
require.NotNil(t, keyField)
}

func lookupField(fields RPCFields, name string) *RPCField {
for i := range fields {
if fields[i].Name == name {
return &fields[i]
}
}

return nil
}

func mustParseRecursiveInputSchema(t *testing.T, schema string) ast.Document {
t.Helper()

doc, report := astparser.ParseGraphqlDocumentString(schema)
require.False(t, report.HasErrors(), report.Error())

err := asttransform.MergeDefinitionWithBaseSchema(&doc)
require.NoError(t, err)

report = operationreport.Report{}
astvalidation.DefaultDefinitionValidator().Validate(&doc, &report)
require.False(t, report.HasErrors(), report.Error())

return doc
}