Skip to content
Merged
49 changes: 39 additions & 10 deletions v2/pkg/engine/datasource/grpc_datasource/execution_plan.go
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,11 @@ func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs pla

// formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation
func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
visited := make(map[*RPCMessage]struct{})
formatRPCMessageVisited(sb, message, indent, visited)
}

func formatRPCMessageVisited(sb *strings.Builder, message RPCMessage, indent int, visited map[*RPCMessage]struct{}) {
indentStr := strings.Repeat(" ", indent)

fmt.Fprintf(sb, "%sName: %s\n", indentStr, message.Name)
Expand All @@ -375,25 +380,34 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) {
fmt.Fprintf(sb, "%s JSONPath: %s\n", indentStr, field.JSONPath)
fmt.Fprintf(sb, "%s ResolvePath: %s\n", indentStr, field.ResolvePath.String())

if field.Message != nil {
fmt.Fprintf(sb, "%s Message:\n", indentStr)
formatRPCMessage(sb, *field.Message, indent+6)
if field.Message == nil {
return
}

fmt.Fprintf(sb, "%s Message:\n", indentStr)
if _, seen := visited[field.Message]; seen {
fmt.Fprintf(sb, "%s <recursive: %s>\n", indentStr, field.Message.Name)
continue
}
visited[field.Message] = struct{}{}
formatRPCMessageVisited(sb, *field.Message, indent+6, visited)
}
}

type rpcPlanningContext struct {
operation *ast.Document
definition *ast.Document
mapping *GRPCMapping
operation *ast.Document
definition *ast.Document
mapping *GRPCMapping
visitedInputTypes map[string]*RPCMessage
}

// newRPCPlanningContext creates a new RPCPlanningContext.
func newRPCPlanningContext(operation *ast.Document, definition *ast.Document, mapping *GRPCMapping) *rpcPlanningContext {
return &rpcPlanningContext{
operation: operation,
definition: definition,
mapping: mapping,
operation: operation,
definition: definition,
mapping: mapping,
visitedInputTypes: make(map[string]*RPCMessage, len(definition.InputObjectTypeDefinitions)),
}
}

Expand Down Expand Up @@ -684,11 +698,26 @@ func (r *rpcPlanningContext) buildMessageFromInputObjectType(node *ast.Node) (*R
return nil, fmt.Errorf("unable to build message from input object type definition - incorrect type: %s", node.Kind)
}

typeName := node.NameString(r.definition)

// If we've already started building this type, return the in-progress message
// pointer to break the recursion cycle. The message's fields are populated by
// the caller that first entered this type, so the pointer will be complete once
// the top-level call returns.
if existing, ok := r.visitedInputTypes[typeName]; ok {
return existing, nil
}

inputObjectDefinition := r.definition.InputObjectTypeDefinitions[node.Ref]
message := &RPCMessage{
Name: node.NameString(r.definition),
Name: typeName,
Fields: make(RPCFields, 0, len(inputObjectDefinition.InputFieldsDefinition.Refs)),
}

// Register the message before recursing into fields so that recursive
// references resolve to this same pointer.
r.visitedInputTypes[typeName] = message
Comment thread
coderabbitai[bot] marked this conversation as resolved.

for _, inputFieldRef := range inputObjectDefinition.InputFieldsDefinition.Refs {
field, err := r.buildMessageFieldFromInputValueDefinition(inputFieldRef, node)
if err != nil {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
package grpcdatasource

import (
"testing"

"github.com/stretchr/testify/require"

"github.com/wundergraph/graphql-go-tools/v2/pkg/astparser"
)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func TestExecutionPlan_RecursiveInputTypes_String(t *testing.T) {
// Verify stringer method does not overflow on recursive inputs
t.Parallel()

schema := `
type Query {
search(conditions: ConditionsInput): [Result!]!
}

type Result {
id: ID!
name: String!
}

input ConditionsInput {
and: [ConditionsInput!]
or: [ConditionsInput!]
key: String
value: String
}`

mapping := &GRPCMapping{
Service: "Search",
QueryRPCs: map[string]RPCConfig{
"search": {
RPC: "Search",
Request: "SearchRequest",
Response: "SearchResponse",
},
},
}

query := `query SearchQuery($conditions: ConditionsInput) { search(conditions: $conditions) { id name } }`

plan := planRecursiveTest(t, query, schema, mapping)

result := plan.String()
// formatRPCMessage must emit the recursive placeholder instead of overflowing.
require.Contains(t, result, "ConditionsInput")
require.Contains(t, result, "<recursive: ConditionsInput>")
}

func TestExecutionPlan_RecursiveInputTypes(t *testing.T) {
t.Parallel()

t.Run("Should not stack overflow on recursive input object with and/or fields", func(t *testing.T) {
t.Parallel()

schema := `
type Query {
search(conditions: ConditionsInput): [Result!]!
}

type Result {
id: ID!
name: String!
}

input ConditionsInput {
and: [ConditionsInput!]
or: [ConditionsInput!]
key: String
value: String
}`

mapping := &GRPCMapping{
Service: "Search",
QueryRPCs: map[string]RPCConfig{
"search": {
RPC: "Search",
Request: "SearchRequest",
Response: "SearchResponse",
},
},
}

query := `query SearchQuery($conditions: ConditionsInput) { search(conditions: $conditions) { id name } }`

plan := planRecursiveTest(t, query, schema, mapping)

require.Len(t, plan.Calls, 1)
call := plan.Calls[0]
require.Equal(t, "Search", call.MethodName)

// The request should have a conditions field with a recursive message.
require.Len(t, call.Request.Fields, 1)
conditionsField := call.Request.Fields[0]
require.Equal(t, "conditions", conditionsField.JSONPath)
require.NotNil(t, conditionsField.Message)
require.Equal(t, "ConditionsInput", conditionsField.Message.Name)
require.Len(t, conditionsField.Message.Fields, 4)

// The and/or fields should reference the same ConditionsInput message (cycle).
andField := findField(t, conditionsField.Message.Fields, "and")
orField := findField(t, conditionsField.Message.Fields, "or")
require.True(t, andField.Message == conditionsField.Message, "and field should reference the same ConditionsInput message")
require.True(t, orField.Message == conditionsField.Message, "or field should reference the same ConditionsInput message")
})

t.Run("Should not stack overflow on self-referencing input object", func(t *testing.T) {
t.Parallel()

schema := `
type Query {
filter(input: FilterInput): [Item!]!
}

type Item {
id: ID!
}

input FilterInput {
child: FilterInput
value: String
}`

mapping := &GRPCMapping{
Service: "Items",
QueryRPCs: map[string]RPCConfig{
"filter": {
RPC: "Filter",
Request: "FilterRequest",
Response: "FilterResponse",
},
},
}

query := `query FilterQuery($input: FilterInput) { filter(input: $input) { id } }`

plan := planRecursiveTest(t, query, schema, mapping)

require.Len(t, plan.Calls, 1)
call := plan.Calls[0]
require.Equal(t, "Filter", call.MethodName)

Comment thread
coderabbitai[bot] marked this conversation as resolved.
require.Len(t, call.Request.Fields, 1)
inputField := call.Request.Fields[0]
require.Equal(t, "input", inputField.JSONPath)
require.NotNil(t, inputField.Message)
require.Equal(t, "FilterInput", inputField.Message.Name)
require.Len(t, inputField.Message.Fields, 2)

// The child field should reference the same FilterInput message.
childField := findField(t, inputField.Message.Fields, "child")
require.True(t, childField.Message == inputField.Message, "child field should reference the same FilterInput message")
})

t.Run("Should not stack overflow on mutually recursive input objects", func(t *testing.T) {
t.Parallel()

schema := `
type Query {
evaluate(expr: ExprInput): Boolean!
}

input ExprInput {
not: NotExprInput
value: String
}

input NotExprInput {
expr: ExprInput
}`

mapping := &GRPCMapping{
Service: "Eval",
QueryRPCs: map[string]RPCConfig{
"evaluate": {
RPC: "Evaluate",
Request: "EvaluateRequest",
Response: "EvaluateResponse",
},
},
}

query := `query EvalQuery($expr: ExprInput) { evaluate(expr: $expr) }`

plan := planRecursiveTest(t, query, schema, mapping)

require.Len(t, plan.Calls, 1)
call := plan.Calls[0]
require.Equal(t, "Evaluate", call.MethodName)

require.Len(t, call.Request.Fields, 1)
exprField := call.Request.Fields[0]
require.Equal(t, "expr", exprField.JSONPath)
require.NotNil(t, exprField.Message)
require.Equal(t, "ExprInput", exprField.Message.Name)
require.Len(t, exprField.Message.Fields, 2)

// ExprInput.not -> NotExprInput.expr -> ExprInput (cycle)
notField := findField(t, exprField.Message.Fields, "not")
require.NotNil(t, notField.Message)
require.Equal(t, "NotExprInput", notField.Message.Name)
require.Len(t, notField.Message.Fields, 1)

backRef := findField(t, notField.Message.Fields, "expr")
require.True(t, backRef.Message == exprField.Message, "NotExprInput.expr should reference the same ExprInput message")
})
}

func findField(t *testing.T, fields RPCFields, jsonPath string) RPCField {
t.Helper()

for _, f := range fields {
if f.JSONPath == jsonPath {
return f
}
}

t.Fatalf("field with JSONPath %q not found", jsonPath)
return RPCField{}
}

func planRecursiveTest(t *testing.T, query, schema string, mapping *GRPCMapping) *RPCExecutionPlan {
t.Helper()

schemaDoc := testSchema(t, schema)

queryDoc, report := astparser.ParseGraphqlDocumentString(query)
require.False(t, report.HasErrors())

rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{
subgraphName: mapping.Service,
mapping: mapping,
})

plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc)
require.NoError(t, err)

return plan
}
6 changes: 2 additions & 4 deletions v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,15 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D
// It processes the input JSON data to make gRPC calls and returns
// the response data.
//
// Headers are converted to gRPC metadata and part of gRPC calls.
// Headers are converted to gRPC metadata and are part of gRPC calls.
//
// The input is expected to contain the necessary information to make
// a gRPC call, including service name, method name, and request data.
func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte) (data []byte, err error) {
// get variables from input
variables := gjson.Parse(unsafebytes.BytesToString(input)).Get("body.variables")

var (
poolItems []*arena.PoolItem
)
var poolItems []*arena.PoolItem
defer func() {
d.pool.ReleaseMany(poolItems)
}()
Expand Down
Loading
Loading