Skip to content
Merged
31 changes: 24 additions & 7 deletions v2/pkg/engine/datasource/grpc_datasource/execution_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -383,17 +383,19 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
}

type rpcPlanningContext struct {
operation *ast.Document
definition *ast.Document
mapping *GRPCMapping
operation *ast.Document
definition *ast.Document
mapping *GRPCMapping
visitedInputTypes map[string]*RPCMessage
}

// newRPCPlanningContext creates a new RPCPlanningContext.
func newRPCPlanningContext(operation *ast.Document, definition *ast.Document, mapping *GRPCMapping) *rpcPlanningContext {
return &rpcPlanningContext{
operation: operation,
definition: definition,
mapping: mapping,
operation: operation,
definition: definition,
mapping: mapping,
visitedInputTypes: make(map[string]*RPCMessage, len(definition.InputObjectTypeDefinitions)),
}
}

Expand Down Expand Up @@ -684,11 +686,26 @@ func (r *rpcPlanningContext) buildMessageFromInputObjectType(node *ast.Node) (*R
return nil, fmt.Errorf("unable to build message from input object type definition - incorrect type: %s", node.Kind)
}

typeName := node.NameString(r.definition)

// If we've already started building this type, return the in-progress message
// pointer to break the recursion cycle. The message's fields are populated by
// the caller that first entered this type, so the pointer will be complete once
// the top-level call returns.
if existing, ok := r.visitedInputTypes[typeName]; ok {
return existing, nil
}

inputObjectDefinition := r.definition.InputObjectTypeDefinitions[node.Ref]
message := &RPCMessage{
Name: node.NameString(r.definition),
Name: typeName,
Fields: make(RPCFields, 0, len(inputObjectDefinition.InputFieldsDefinition.Refs)),
}

// Register the message before recursing into fields so that recursive
// references resolve to this same pointer.
r.visitedInputTypes[typeName] = message
Comment thread
coderabbitai[bot] marked this conversation as resolved.

for _, inputFieldRef := range inputObjectDefinition.InputFieldsDefinition.Refs {
field, err := r.buildMessageFieldFromInputValueDefinition(inputFieldRef, node)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
package grpcdatasource

import (
"testing"

"github.com/stretchr/testify/require"

Check failure on line 6 in v2/pkg/engine/datasource/grpc_datasource/execution_plan_recursive_test.go

View workflow job for this annotation

GitHub Actions / Linters (1.25)

File is not properly formatted (gci)
"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func TestExecutionPlan_RecursiveInputTypes(t *testing.T) {
t.Parallel()

t.Run("Should not stack overflow on recursive input object with and/or fields", func(t *testing.T) {
t.Parallel()

schema := `
type Query {
search(conditions: ConditionsInput): [Result!]!
}

type Result {
id: ID!
name: String!
}

input ConditionsInput {
and: [ConditionsInput!]
or: [ConditionsInput!]
key: String
value: String
}`

mapping := &GRPCMapping{
Service: "Search",
QueryRPCs: map[string]RPCConfig{
"search": {
RPC: "Search",
Request: "SearchRequest",
Response: "SearchResponse",
},
},
}

query := `query SearchQuery($conditions: ConditionsInput) { search(conditions: $conditions) { id name } }`

plan := planRecursiveTest(t, query, schema, mapping)

require.Len(t, plan.Calls, 1)
call := plan.Calls[0]
require.Equal(t, "Search", call.MethodName)

// The request should have a conditions field with a recursive message.
require.Len(t, call.Request.Fields, 1)
conditionsField := call.Request.Fields[0]
require.Equal(t, "conditions", conditionsField.JSONPath)
require.NotNil(t, conditionsField.Message)
require.Equal(t, "ConditionsInput", conditionsField.Message.Name)
require.Len(t, conditionsField.Message.Fields, 4)

// The and/or fields should reference the same ConditionsInput message (cycle).
andField := conditionsField.Message.Fields[0]
orField := conditionsField.Message.Fields[1]
require.Equal(t, "and", andField.JSONPath)
require.Equal(t, "or", orField.JSONPath)
require.True(t, andField.Message == conditionsField.Message, "and field should reference the same ConditionsInput message")
require.True(t, orField.Message == conditionsField.Message, "or field should reference the same ConditionsInput message")
})

t.Run("Should not stack overflow on self-referencing input object", func(t *testing.T) {
t.Parallel()

schema := `
type Query {
filter(input: FilterInput): [Item!]!
}

type Item {
id: ID!
}

input FilterInput {
child: FilterInput
value: String
}`

mapping := &GRPCMapping{
Service: "Items",
QueryRPCs: map[string]RPCConfig{
"filter": {
RPC: "Filter",
Request: "FilterRequest",
Response: "FilterResponse",
},
},
}

query := `query FilterQuery($input: FilterInput) { filter(input: $input) { id } }`

plan := planRecursiveTest(t, query, schema, mapping)

require.Len(t, plan.Calls, 1)
call := plan.Calls[0]

Comment thread
coderabbitai[bot] marked this conversation as resolved.
require.Len(t, call.Request.Fields, 1)
inputField := call.Request.Fields[0]
require.Equal(t, "input", inputField.JSONPath)
require.NotNil(t, inputField.Message)
require.Equal(t, "FilterInput", inputField.Message.Name)
require.Len(t, inputField.Message.Fields, 2)

// The child field should reference the same FilterInput message.
childField := inputField.Message.Fields[0]
require.Equal(t, "child", childField.JSONPath)
require.True(t, childField.Message == inputField.Message, "child field should reference the same FilterInput message")
})

t.Run("Should not stack overflow on mutually recursive input objects", func(t *testing.T) {
t.Parallel()

schema := `
type Query {
evaluate(expr: ExprInput): Boolean!
}

input ExprInput {
not: NotExprInput
value: String
}

input NotExprInput {
expr: ExprInput
}`

mapping := &GRPCMapping{
Service: "Eval",
QueryRPCs: map[string]RPCConfig{
"evaluate": {
RPC: "Evaluate",
Request: "EvaluateRequest",
Response: "EvaluateResponse",
},
},
}

query := `query EvalQuery($expr: ExprInput) { evaluate(expr: $expr) }`

plan := planRecursiveTest(t, query, schema, mapping)

require.Len(t, plan.Calls, 1)
call := plan.Calls[0]

require.Len(t, call.Request.Fields, 1)
exprField := call.Request.Fields[0]
require.Equal(t, "expr", exprField.JSONPath)
require.NotNil(t, exprField.Message)
require.Equal(t, "ExprInput", exprField.Message.Name)
require.Len(t, exprField.Message.Fields, 2)

// ExprInput.not -> NotExprInput.expr -> ExprInput (cycle)
notField := exprField.Message.Fields[0]
require.Equal(t, "not", notField.JSONPath)
require.NotNil(t, notField.Message)
require.Equal(t, "NotExprInput", notField.Message.Name)
require.Len(t, notField.Message.Fields, 1)

backRef := notField.Message.Fields[0]
require.Equal(t, "expr", backRef.JSONPath)
require.True(t, backRef.Message == exprField.Message, "NotExprInput.expr should reference the same ExprInput message")
})
}

func planRecursiveTest(t *testing.T, query, schema string, mapping *GRPCMapping) *RPCExecutionPlan {
t.Helper()

schemaDoc := testSchema(t, schema)

queryDoc, report := astparser.ParseGraphqlDocumentString(query)
require.False(t, report.HasErrors())

rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{
subgraphName: mapping.Service,
mapping: mapping,
})

plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc)
require.NoError(t, err)

return plan
}
Loading