diff --git a/execution/engine/execution_engine_grpc_test.go b/execution/engine/execution_engine_grpc_test.go index 3a58ba841a..3c506417b2 100644 --- a/execution/engine/execution_engine_grpc_test.go +++ b/execution/engine/execution_engine_grpc_test.go @@ -234,7 +234,7 @@ func TestGRPCSubgraphExecution(t *testing.T) { Query: "query UserQuery { users { id name } }", } - response, err := executeOperation(t, conn, operation) + response, err := executeOperation(t, conn, operation, withGRPCMapping(mapping.MustDefaultGRPCMapping(t))) require.NoError(t, err) require.Equal(t, `{"data":{"users":[{"id":"user-1","name":"User 1"},{"id":"user-2","name":"User 2"},{"id":"user-3","name":"User 3"}]}}`, response) }) @@ -255,7 +255,7 @@ func TestGRPCSubgraphExecution(t *testing.T) { `, } - response, err := executeOperation(t, conn, operation) + response, err := executeOperation(t, conn, operation, withGRPCMapping(mapping.MustDefaultGRPCMapping(t))) require.NoError(t, err) require.Equal(t, `{"data":{"user":{"id":"1","name":"User 1"}}}`, response) }) diff --git a/v2/pkg/astvisitor/visitor.go b/v2/pkg/astvisitor/visitor.go index f7e71d4174..a2cbb102da 100644 --- a/v2/pkg/astvisitor/visitor.go +++ b/v2/pkg/astvisitor/visitor.go @@ -3998,3 +3998,27 @@ func (w *Walker) FieldDefinitionDirectiveArgumentValueByName(field int, directiv return w.definition.DirectiveArgumentValueByName(directive, argumentName) } + +// InRootField returns true if the current field is a root field. +func (w *Walker) InRootField() bool { + return w.CurrentKind == ast.NodeKindField && + len(w.Ancestors) == 2 && + w.Ancestors[0].Kind == ast.NodeKindOperationDefinition +} + +// ResolveInlineFragment returns the inline fragment ref if the current field is inside of +// an inline fragment. +// It returns the inline fragment ref and true if the current field is inside of an inline fragment. +// It returns -1 and false if the current field is not inside of an inline fragment. +func (w *Walker) ResolveInlineFragment() (int, bool) { + if len(w.Ancestors) < 2 { + return ast.InvalidRef, false + } + + node := w.Ancestors[len(w.Ancestors)-2] + if node.Kind != ast.NodeKindInlineFragment { + return ast.InvalidRef, false + } + + return node.Ref, true +} diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index 6f35dbee3c..316e2fcdad 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -368,11 +368,12 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { } dataSource, err = grpcdatasource.NewDataSource(p.grpcClient, grpcdatasource.DataSourceConfig{ - Operation: &opDocument, - Definition: p.config.schemaConfiguration.upstreamSchemaAst, - Mapping: p.config.grpc.Mapping, - Compiler: p.config.grpc.Compiler, - Disabled: p.config.grpc.Disabled, + Operation: &opDocument, + Definition: p.config.schemaConfiguration.upstreamSchemaAst, + Mapping: p.config.grpc.Mapping, + Compiler: p.config.grpc.Compiler, + Disabled: p.config.grpc.Disabled, + FederationConfigs: p.dataSourcePlannerConfig.RequiredFields, // TODO: remove fallback logic in visitor for subgraph name and // add proper error handling if the subgraph name is not set in the mapping SubgraphName: p.dataSourceConfig.Name(), diff --git a/v2/pkg/engine/datasource/grpc_datasource/compiler.go b/v2/pkg/engine/datasource/grpc_datasource/compiler.go index ee49e8176c..959cb43c05 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/compiler.go +++ b/v2/pkg/engine/datasource/grpc_datasource/compiler.go @@ -3,6 +3,7 @@ package grpcdatasource import ( "context" "fmt" + "slices" "github.com/bufbuild/protocompile" "github.com/tidwall/gjson" @@ -433,6 +434,12 @@ func (p *RPCCompiler) buildProtoMessage(inputMessage Message, rpcMessage *RPCMes for _, element := range elements { switch field.Type { case DataTypeMessage: + // If we handle entity lookups, we get a list of representation variables that need to + // be applied to the correct type names. + if !isAllowedForTypename(rpcField.Message, element) { + continue + } + fieldMsg := p.buildProtoMessage(p.doc.Messages[field.MessageRef], rpcField.Message, element) list.Append(protoref.ValueOfMessage(fieldMsg)) default: @@ -785,3 +792,26 @@ func (p *RPCCompiler) parseField(f protoref.FieldDescriptor) Field { MessageRef: -1, } } + +func isAllowedForTypename(message *RPCMessage, element gjson.Result) bool { + if message == nil { + // We assume that having a nil message expects a null value. + return true + } + + // If we don't have a member types, we assume that the message is allowed for all types. + if message.MemberTypes == nil { + return true + } + + typeName := element.Get("__typename") + if !typeName.Exists() { + // If we don't have a type name, we assume that the message is allowed for all types. + return true + } + + typeString := typeName.String() + + // If we have a type name, we need to check if the message is restricted to a specific type. + return slices.Contains(message.MemberTypes, typeString) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration.go b/v2/pkg/engine/datasource/grpc_datasource/configuration.go index 6ee194004b..9bc74cb31a 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/configuration.go +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration.go @@ -1,9 +1,19 @@ package grpcdatasource +import ( + "strings" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/lexer/runes" +) + type ( // RPCConfigMap is a map of RPC names to RPC configurations + // The key is the field name in the GraphQL operation type (query, mutation, subscription). + // The value is the RPC configuration for that field. RPCConfigMap map[string]RPCConfig // FieldMap defines the mapping between a GraphQL field and a gRPC field + // The key is the field name in the GraphQL type. + // The value is the FieldMapData for that field which contains the target name and the argument mappings. FieldMap map[string]FieldMapData ) @@ -17,22 +27,27 @@ type GRPCMapping struct { // SubscriptionRPCs maps GraphQL subscription fields to the corresponding gRPC RPC configurations SubscriptionRPCs RPCConfigMap // EntityRPCs defines how GraphQL types are resolved as entities using specific RPCs - EntityRPCs map[string]EntityRPCConfig + // The key is the type name and the value is a list of EntityRPCConfig for that type + EntityRPCs map[string][]EntityRPCConfig // Fields defines the field mappings between GraphQL types and gRPC messages + // The key is the type name and the value is a map of field name to FieldMapData for that type Fields map[string]FieldMap // EnumValues defines the enum values for each enum type + // The key is the enum type name and the value is a list of EnumValueMapping for that enum type EnumValues map[string][]EnumValueMapping } +// EnumValueMapping defines the mapping between a GraphQL enum value and a gRPC enum value type EnumValueMapping struct { - Value string - TargetValue string + Value string // The GraphQL enum value + TargetValue string // The gRPC enum value } +// GRPCConfiguration defines the configuration for a gRPC datasource type GRPCConfiguration struct { - Disabled bool - Mapping *GRPCMapping - Compiler *RPCCompiler + Disabled bool // Whether the RPC is disabled + Mapping *GRPCMapping // The mapping between GraphQL types and gRPC messages + Compiler *RPCCompiler // The compiler for the RPC } // RPCConfig defines the configuration for a specific RPC operation @@ -53,9 +68,10 @@ type EntityRPCConfig struct { RPCConfig } +// FieldMapData defines the mapping between a GraphQL field and a gRPC field type FieldMapData struct { - TargetName string - ArgumentMappings FieldArgumentMap + TargetName string // The name of the gRPC field + ArgumentMappings FieldArgumentMap // The mapping between GraphQL field arguments and gRPC request arguments } // FieldArgumentMap defines the mapping between a GraphQL field argument and a gRPC field @@ -121,3 +137,97 @@ func (g *GRPCMapping) ResolveEnumValue(enumName, enumValue string) (string, bool return "", false } + +func (g *GRPCMapping) ResolveEntityRPCConfig(typeName, key string) (RPCConfig, bool) { + rpcConfig, ok := g.EntityRPCs[typeName] + if !ok { + return RPCConfig{}, false + } + + for _, ei := range rpcConfig { + if compareKeyFields(ei.Key, key) { + return ei.RPCConfig, true + } + + } + + return RPCConfig{}, false +} + +type keySet map[string]struct{} + +func (k keySet) add(keys ...string) { + for _, key := range keys { + trimmedKey := strings.TrimSpace(key) + if trimmedKey != "" { + k[trimmedKey] = struct{}{} + } + } +} + +func (k keySet) equals(other keySet) bool { + if len(k) != len(other) { + return false + } + + for key := range k { + if _, ok := other[key]; !ok { + return false + } + } + + return true +} + +// We compare only top level key +func compareKeyFields(left, right string) bool { + if left == right { + return true + } + + left = stripSelectionSets(left) + right = stripSelectionSets(right) + + leftKeys := strings.Split(left, " ") + rightKeys := strings.Split(right, " ") + + leftSet := make(keySet) + leftSet.add(leftKeys...) + + rightSet := make(keySet) + rightSet.add(rightKeys...) + + return leftSet.equals(rightSet) +} + +func stripSelectionSets(keyString string) string { + depth := 0 + + var prev rune + + keyString = strings.ReplaceAll(keyString, ",", " ") + + var sb strings.Builder + + for _, r := range keyString { + switch r { + case runes.LBRACE: + depth++ + case runes.RBRACE: + if depth == 0 { + continue + } + + depth-- + default: + if depth != 0 || (r == runes.SPACE && prev == runes.SPACE) { + break + } + + sb.WriteRune(r) + prev = r + } + } + + return strings.TrimSpace(sb.String()) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go b/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go new file mode 100644 index 0000000000..b34c11bea3 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/configuration_test.go @@ -0,0 +1,420 @@ +package grpcdatasource + +import "testing" + +func TestCompareKeyFields(t *testing.T) { + tests := []struct { + name string + left string + right string + expected bool + }{ + { + name: "identical strings", + left: "id name", + right: "id name", + expected: true, + }, + { + name: "empty strings", + left: "", + right: "", + expected: true, + }, + { + name: "single field same", + left: "id", + right: "id", + expected: true, + }, + { + name: "single field different", + left: "id", + right: "name", + expected: false, + }, + { + name: "different order same fields", + left: "id name email", + right: "name email id", + expected: true, + }, + { + name: "comma separated same fields", + left: "id,name,email", + right: "name,email,id", + expected: true, + }, + { + name: "mixed separators same fields", + left: "id,name email", + right: "name email,id", + expected: true, + }, + { + name: "different number of fields", + left: "id name", + right: "id name email", + expected: false, + }, + { + name: "completely different fields", + left: "id name", + right: "email phone", + expected: false, + }, + { + name: "one field missing", + left: "id name email", + right: "id name", + expected: false, + }, + { + name: "extra whitespace handling", + left: " id name ", + right: "name id", + expected: true, + }, + { + name: "extra whitespace with commas", + left: " id , name , email ", + right: "email,name,id", + expected: true, + }, + { + name: "multiple consecutive spaces", + left: "id name email", + right: "email name id", + expected: true, + }, + { + name: "mixed spaces and commas with whitespace", + left: "id, name email", + right: "email, name id", + expected: true, + }, + { + name: "empty fields filtered out", + left: "id , , name", + right: "name id", + expected: true, + }, + { + name: "only spaces and commas", + left: " , , ", + right: "", + expected: true, + }, + { + name: "one empty one with fields", + left: "", + right: "id name", + expected: false, + }, + { + name: "duplicate fields in left", + left: "id name id", + right: "id name", + expected: true, + }, + { + name: "duplicate fields in both", + left: "id name id email", + right: "name email name id", + expected: true, + }, + { + name: "case sensitive comparison", + left: "ID name", + right: "id name", + expected: false, + }, + { + name: "single character fields", + left: "a b c", + right: "c a b", + expected: true, + }, + { + name: "complex field names", + left: "user_id created_at updated_at", + right: "updated_at user_id created_at", + expected: true, + }, + { + name: "simple selection set", + left: "id name { firstName lastName }", + right: "name id", + expected: true, + }, + { + name: "nested selection sets", + left: "id user { profile { name email } }", + right: "user id", + expected: true, + }, + { + name: "deeply nested selection sets", + left: "id foo { bar { baz { qux } } }", + right: "foo id", + expected: true, + }, + { + name: "multiple fields with selection sets", + left: "id name user { email } address { street city }", + right: "address user name id", + expected: true, + }, + { + name: "selection sets with different order", + left: "user { name email } id address { street }", + right: "id address user", + expected: true, + }, + { + name: "mixed fields and selection sets", + left: "id name user { profile { personal { firstName lastName } work { company } } } status", + right: "status user name id", + expected: true, + }, + { + name: "selection sets with spaces inside braces", + left: "id user { name email } address", + right: "address user id", + expected: true, + }, + { + name: "selection sets with commas and spaces", + left: "id, user { name, email }, address { street, city }", + right: "address, user, id", + expected: true, + }, + { + name: "empty selection sets", + left: "id user { } address { }", + right: "address user id", + expected: true, + }, + { + name: "selection sets with nested empty braces", + left: "id user { profile { } contact { email } }", + right: "user id", + expected: true, + }, + { + name: "different nested structures same top level", + left: "user { name email } product { title price { amount currency } }", + right: "product { id sku } user { firstName lastName }", + expected: true, + }, + { + name: "selection sets vs simple fields different", + left: "id user { name }", + right: "id name", + expected: false, + }, + { + name: "only selection set fields", + left: "user { name email } address { street city }", + right: "address user", + expected: true, + }, + { + name: "complex real-world example", + left: "id user { profile { personal { name age } professional { company role } } } orders { items { product { name price } quantity } total }", + right: "orders user id", + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := compareKeyFields(tt.left, tt.right) + if result != tt.expected { + t.Errorf("compareKeyFields(%q, %q) = %v, expected %v", + tt.left, tt.right, result, tt.expected) + } + }) + } +} + +func TestKeySet(t *testing.T) { + t.Run("add method", func(t *testing.T) { + ks := make(keySet) + ks.add("id", "name", "", " ", "email") + + expected := keySet{ + "id": struct{}{}, + "name": struct{}{}, + "email": struct{}{}, + } + + if !ks.equals(expected) { + t.Errorf("keySet.add() did not produce expected result. Got %v, expected %v", ks, expected) + } + }) + + t.Run("equals method same sets", func(t *testing.T) { + ks1 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + ks2 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + + if !ks1.equals(ks2) { + t.Error("keySet.equals() should return true for identical sets") + } + }) + + t.Run("equals method different sizes", func(t *testing.T) { + ks1 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + ks2 := keySet{ + "id": struct{}{}, + } + + if ks1.equals(ks2) { + t.Error("keySet.equals() should return false for sets of different sizes") + } + }) + + t.Run("equals method different keys", func(t *testing.T) { + ks1 := keySet{ + "id": struct{}{}, + "name": struct{}{}, + } + ks2 := keySet{ + "id": struct{}{}, + "email": struct{}{}, + } + + if ks1.equals(ks2) { + t.Error("keySet.equals() should return false for sets with different keys") + } + }) + + t.Run("equals method empty sets", func(t *testing.T) { + ks1 := make(keySet) + ks2 := make(keySet) + + if !ks1.equals(ks2) { + t.Error("keySet.equals() should return true for empty sets") + } + }) +} + +func TestStripSelectionSets(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "no selection sets", + input: "id name email", + expected: "id name email", + }, + { + name: "simple selection set", + input: "id name { firstName lastName }", + expected: "id name", + }, + { + name: "nested selection sets", + input: "id user { profile { name email } }", + expected: "id user", + }, + { + name: "deeply nested selection sets", + input: "id foo { bar { baz { qux } } }", + expected: "id foo", + }, + { + name: "multiple selection sets", + input: "id user { name } address { street }", + expected: "id user address", + }, + { + name: "empty selection sets", + input: "id user { } address { }", + expected: "id user address", + }, + { + name: "selection sets with commas", + input: "id, user { name, email }, address", + expected: "id user address", + }, + { + name: "complex nested example", + input: "id user { profile { personal { name age } work { company } } } orders { items { product } }", + expected: "id user orders", + }, + { + name: "spaces inside selection sets", + input: "id user { name email } address", + expected: "id user address", + }, + { + name: "empty input", + input: "", + expected: "", + }, + { + name: "only selection sets", + input: "user { name } address { street }", + expected: "user address", + }, + { + name: "unbalanced opening brace", + input: "id user { name email", + expected: "id user", + }, + { + name: "extra closing braces", + input: "id user { name } } extra", + expected: "id user extra", + }, + { + name: "consecutive spaces", + input: "id name user { email }", + expected: "id name user", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := stripSelectionSets(tt.input) + if result != tt.expected { + t.Errorf("stripSelectionSets(%q) = %q, expected %q", + tt.input, result, tt.expected) + } + }) + } +} + +// Benchmark tests for performance +func BenchmarkCompareKeyFields(b *testing.B) { + testCases := []struct { + name string + left string + right string + }{ + {"simple", "id name", "name id"}, + {"complex", "id,name,email,phone,address", "address phone email name id"}, + {"long", "field1 field2 field3 field4 field5 field6 field7 field8 field9 field10", "field10 field9 field8 field7 field6 field5 field4 field3 field2 field1"}, + {"long and nested", "field1 field2 field3 field4 field5 field6 field7 field8 field9 field10 { field11 field12 field13 field14 field15 field16 field17 field18 field19 field20 }", "field20 field19 field18 field17 field16 field15 field14 field13 field12 field11 field10 field9 field8 field7 field6 field5 field4 field3 field2 field1"}, + } + + for _, tc := range testCases { + b.Run(tc.name, func(b *testing.B) { + for i := 0; i < b.N; i++ { + compareKeyFields(tc.left, tc.right) + } + }) + } +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go index ad41e4baf8..7c7290fef9 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan.go @@ -5,12 +5,10 @@ import ( "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" - "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" ) const ( - federationKeyDirectiveName = "key" // knownTypeOptionalFieldValueName is the name of the field that is used to wrap optional scalar values // in a message as protobuf scalar types are not nullable. knownTypeOptionalFieldValueName = "value" @@ -56,8 +54,6 @@ type RPCExecutionPlan struct { // RPCCall represents a single call to a gRPC service method. // It contains all the information needed to make the call and process the response. type RPCCall struct { - // CallID is the unique identifier for the call - CallID int // DependentCalls is a list of calls that must be executed before this call DependentCalls []int // ServiceName is the name of the gRPC service to call @@ -108,6 +104,19 @@ func (r *RPCMessage) SelectValidTypes(typeName string) []string { return []string{r.Name, typeName} } +func (r *RPCMessage) AppendTypeNameField(typeName string) { + if r.Fields != nil && r.Fields.Exists("__typename", "") { + return + } + + r.Fields = append(r.Fields, RPCField{ + Name: "__typename", + TypeName: DataTypeString.String(), + StaticValue: typeName, + JSONPath: "__typename", + }) +} + // RPCFieldSelectionSet is a map of field selections based on inline fragments type RPCFieldSelectionSet map[string]RPCFields @@ -259,7 +268,6 @@ func (r *RPCExecutionPlan) String() string { for j, call := range r.Calls { result.WriteString(fmt.Sprintf(" Call %d:\n", j)) - result.WriteString(fmt.Sprintf(" CallID: %d\n", call.CallID)) if len(call.DependentCalls) > 0 { result.WriteString(" DependentCalls: [") @@ -287,40 +295,36 @@ func (r *RPCExecutionPlan) String() string { return result.String() } -type Planner struct { - visitor *rpcPlanVisitor - walker *astvisitor.Walker +type PlanVisitor interface { + PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) } -// NewPlanner returns a new Planner instance. +// NewPlanner returns a new PlanVisitor instance. // // The planner is responsible for creating an RPCExecutionPlan from a given // GraphQL operation. It is used by the engine to execute operations against // gRPC services. -func NewPlanner(subgraphName string, mapping *GRPCMapping) *Planner { - walker := astvisitor.NewWalker(48) - +func NewPlanner(subgraphName string, mapping *GRPCMapping, federationConfigs plan.FederationFieldConfigurations) PlanVisitor { if mapping == nil { mapping = new(GRPCMapping) } - return &Planner{ - visitor: newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ - subgraphName: subgraphName, - mapping: mapping, - }), - walker: &walker, + var visitor PlanVisitor + if len(federationConfigs) > 0 { + visitor = newRPCPlanVisitorFederation(rpcPlanVisitorConfig{ + subgraphName: subgraphName, + mapping: mapping, + federationConfigs: federationConfigs, + }) + } else { + visitor = newRPCPlanVisitor(rpcPlanVisitorConfig{ + subgraphName: subgraphName, + mapping: mapping, + federationConfigs: federationConfigs, + }) } -} -func (p *Planner) PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) { - report := &operationreport.Report{} - p.walker.Walk(operation, definition, report) - if report.HasErrors() { - return nil, fmt.Errorf("unable to plan operation: %w", report) - } - - return p.visitor.plan, nil + return visitor } // formatRPCMessage formats an RPCMessage and adds it to the string builder with the specified indentation @@ -342,3 +346,201 @@ func formatRPCMessage(sb *strings.Builder, message RPCMessage, indent int) { } } } + +type rpcPlanningContext struct { + operation *ast.Document + definition *ast.Document + mapping *GRPCMapping +} + +func newRPCPlanningContext(operation *ast.Document, definition *ast.Document, mapping *GRPCMapping) *rpcPlanningContext { + return &rpcPlanningContext{ + operation: operation, + definition: definition, + mapping: mapping, + } +} + +// toDataType converts an ast.Type to a DataType. +// It handles the different type kinds and non-null types. +func (r *rpcPlanningContext) toDataType(t *ast.Type) DataType { + switch t.TypeKind { + case ast.TypeKindNamed: + return r.parseGraphQLType(t) + case ast.TypeKindList: + return r.toDataType(&r.definition.Types[t.OfType]) + case ast.TypeKindNonNull: + return r.toDataType(&r.definition.Types[t.OfType]) + } + + return DataTypeUnknown +} + +// parseGraphQLType parses an ast.Type and returns the corresponding DataType. +// It handles the different type kinds and non-null types. +func (r *rpcPlanningContext) parseGraphQLType(t *ast.Type) DataType { + dt := r.definition.Input.ByteSliceString(t.Name) + + // Retrieve the node to check the kind + node, found := r.definition.NodeByNameStr(dt) + if !found { + return DataTypeUnknown + } + + // For non-scalar types, return the corresponding DataType + switch node.Kind { + case ast.NodeKindInterfaceTypeDefinition: + fallthrough + case ast.NodeKindUnionTypeDefinition: + fallthrough + case ast.NodeKindObjectTypeDefinition, ast.NodeKindInputObjectTypeDefinition: + return DataTypeMessage + case ast.NodeKindEnumTypeDefinition: + return DataTypeEnum + default: + return fromGraphQLType(dt) + } +} + +func (r *rpcPlanningContext) resolveRPCMethodMapping(operationType ast.OperationType, operationFieldName string) (RPCConfig, error) { + if r.mapping == nil { + return RPCConfig{}, nil + } + + var rpcConfig RPCConfig + var ok bool + + switch operationType { + case ast.OperationTypeQuery: + rpcConfig, ok = r.mapping.QueryRPCs[operationFieldName] + case ast.OperationTypeMutation: + rpcConfig, ok = r.mapping.MutationRPCs[operationFieldName] + case ast.OperationTypeSubscription: + rpcConfig, ok = r.mapping.SubscriptionRPCs[operationFieldName] + } + + if !ok { + return RPCConfig{}, nil + } + + // We require all fields to be present when defining a mapping for the operation + if rpcConfig.RPC == "" { + return RPCConfig{}, fmt.Errorf("no rpc method name mapping found for operation %s", operationFieldName) + } + + if rpcConfig.Request == "" { + return RPCConfig{}, fmt.Errorf("no request message name mapping found for operation %s", operationFieldName) + } + + if rpcConfig.Response == "" { + return RPCConfig{}, fmt.Errorf("no response message name mapping found for operation %s", operationFieldName) + } + + return rpcConfig, nil +} + +// newMessageFromSelectionSet creates a new message from the enclosing type node and the selection set reference. +func (r *rpcPlanningContext) newMessageFromSelectionSet(enclosingTypeNode ast.Node, selectSetRef int) *RPCMessage { + message := &RPCMessage{ + Name: enclosingTypeNode.NameString(r.definition), + Fields: make(RPCFields, 0, len(r.operation.SelectionSets[selectSetRef].SelectionRefs)), + } + + return message +} + +// resolveFieldMapping resolves the field mapping for a field. +// This applies both for complex types in the input and for all fields in the response. +func (r *rpcPlanningContext) resolveFieldMapping(typeName, fieldName string) string { + grpcFieldName, ok := r.mapping.ResolveFieldMapping(typeName, fieldName) + if !ok { + return fieldName + } + + return grpcFieldName +} + +func (r *rpcPlanningContext) typeIsNullableOrNestedList(typeRef int) bool { + if !r.definition.TypeIsNonNull(typeRef) && r.definition.TypeIsList(typeRef) { + return true + } + + if r.definition.TypeNumberOfListWraps(typeRef) > 1 { + return true + } + + return false +} + +func (r *rpcPlanningContext) createListMetadata(typeRef int) (*ListMetadata, error) { + nestingLevel := r.definition.TypeNumberOfListWraps(typeRef) + + md := &ListMetadata{ + NestingLevel: nestingLevel, + LevelInfo: make([]LevelInfo, nestingLevel), + } + + for i := 0; i < nestingLevel; i++ { + md.LevelInfo[i] = LevelInfo{ + Optional: !r.definition.TypeIsNonNull(typeRef), + } + + typeRef = r.definition.ResolveNestedListOrListType(typeRef) + if typeRef == ast.InvalidRef { + return nil, fmt.Errorf("unable to resolve underlying list type for ref: %d", typeRef) + } + } + + return md, nil +} + +// buildField builds a field from a field definition. +// It handles lists, enums, and other types. +func (r *rpcPlanningContext) buildField(enclosingTypeNode ast.Node, fd int, fieldName, fieldAlias string) (RPCField, error) { + fdt := r.definition.FieldDefinitionType(fd) + typeName := r.toDataType(&r.definition.Types[fdt]) + parentTypeName := enclosingTypeNode.NameString(r.definition) + + field := RPCField{ + Name: r.resolveFieldMapping(parentTypeName, fieldName), + Alias: fieldAlias, + Optional: !r.definition.TypeIsNonNull(fdt), + JSONPath: fieldName, + TypeName: typeName.String(), + } + + if r.definition.TypeIsList(fdt) { + switch { + // for nullable or nested lists we need to build a wrapper message + // Nullability is handled by the datasource during the execution. + case r.typeIsNullableOrNestedList(fdt): + md, err := r.createListMetadata(fdt) + if err != nil { + return field, err + } + field.ListMetadata = md + field.IsListType = true + default: + // For non-nullable single lists we can directly use the repeated syntax in protobuf. + field.Repeated = true + } + } + + if typeName == DataTypeEnum { + field.EnumName = r.definition.FieldDefinitionTypeNameString(fd) + } + + if fieldName == "__typename" { + field.StaticValue = parentTypeName + } + + return field, nil +} + +func (r *rpcPlanningContext) resolveServiceName(subgraphName string) string { + if r.mapping == nil || r.mapping.Service == "" { + return subgraphName + } + + return r.mapping.Service +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go index 30ae95163e..f6a81aec3c 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_composite_test.go @@ -6,7 +6,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" grpctest "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" ) @@ -604,23 +603,21 @@ func TestCompositeTypeExecutionPlan(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: testMapping(), }) - walker.Walk(&queryDoc, &schemaDoc, &report) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { + if err != nil { require.NotEmpty(t, tt.expectedError) - require.Contains(t, report.Error(), tt.expectedError) + require.Contains(t, err.Error(), tt.expectedError) return } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } @@ -871,23 +868,21 @@ func TestMutationUnionExecutionPlan(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: testMapping(), }) - walker.Walk(&queryDoc, &schemaDoc, &report) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { + if err != nil { require.NotEmpty(t, tt.expectedError) - require.Contains(t, report.Error(), tt.expectedError) + require.Contains(t, err.Error(), tt.expectedError) return } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go index 5756d9d167..f0184b93cd 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_federation_test.go @@ -1,37 +1,52 @@ package grpcdatasource import ( + "fmt" + "strings" "testing" "github.com/google/go-cmp/cmp" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" grpctest "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/internal/unsafeparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" ) func TestEntityLookup(t *testing.T) { tests := []struct { - name string - query string - expectedPlan *RPCExecutionPlan - mapping *GRPCMapping + name string + query string + expectedPlan *RPCExecutionPlan + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations }{ { name: "Should create an execution plan for an entity lookup with one key field", query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on Product { __typename id name price } } }`, mapping: &GRPCMapping{ Service: "Products", - EntityRPCs: map[string]EntityRPCConfig{ + EntityRPCs: map[string][]EntityRPCConfig{ "Product": { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupProductById", - Request: "LookupProductByIdRequest", - Response: "LookupProductByIdResponse", + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupProductById", + Request: "LookupProductByIdRequest", + Response: "LookupProductByIdResponse", + }, }, }, }, }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "Product", + SelectionSet: "id", + }, + }, expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -47,7 +62,8 @@ func TestEntityLookup(t *testing.T) { Repeated: true, JSONPath: "representations", Message: &RPCMessage{ - Name: "LookupProductByIdKey", + Name: "LookupProductByIdKey", + MemberTypes: []string{"Product"}, Fields: []RPCField{ { Name: "id", @@ -101,6 +117,152 @@ func TestEntityLookup(t *testing.T) { }, }, }, + { + name: "Should create an execution plan for an entity lookup multiple types", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on Product { __typename id name price } ... on Storage { __typename id name location } } }`, + mapping: testMapping(), + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "Product", + SelectionSet: "id", + }, + { + TypeName: "Storage", + SelectionSet: "id", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupProductById", + Request: RPCMessage{ + Name: "LookupProductByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupProductByIdKey", + MemberTypes: []string{"Product"}, + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupProductByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "Product", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "Product", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + + { + Name: "price", + TypeName: string(DataTypeDouble), + JSONPath: "price", + }, + }, + }, + }, + }, + }, + }, + { + ServiceName: "Products", + MethodName: "LookupStorageById", + Request: RPCMessage{ + Name: "LookupStorageByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupStorageByIdKey", + MemberTypes: []string{"Storage"}, + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupStorageByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "Storage", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "Storage", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + { + Name: "location", + TypeName: string(DataTypeString), + JSONPath: "location", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, // TODO implement multiple entity lookup types // { @@ -304,23 +466,640 @@ func TestEntityLookup(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ - subgraphName: "Products", - mapping: tt.mapping, - }) - - walker.Walk(&queryDoc, &schemaDoc, &report) - - if report.HasErrors() { - t.Fatalf("failed to walk AST: %s", report.Error()) + planner := NewPlanner("Products", tt.mapping, tt.federationConfigs) + plan, err := planner.PlanOperation(&queryDoc, &schemaDoc) + if err != nil { + t.Fatalf("failed to plan operation: %s", err) } - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } }) } } + +func TestEntityKeys(t *testing.T) { + tests := []struct { + name string + query string + schema string + expectedPlan *RPCExecutionPlan + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations + }{ + { + name: "Should create an execution plan for an entity lookup with a key field", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + type User @key(fields: "id") { + id: ID! + name: String! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupUserById", + Request: "LookupUserByIdRequest", + Response: "LookupUserByIdResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id", + }, + }, + + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserById", + // Define the structure of the request message + Request: RPCMessage{ + Name: "LookupUserByIdRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdKey", + MemberTypes: []string{"User"}, + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + // Define the structure of the response message + Response: RPCMessage{ + Name: "LookupUserByIdResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Should create an execution plan for an entity lookup with a nested key", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + + type Address { + id: ID! + street: String! + city: String! + state: String! + zip: String! + } + + type User @key(fields: "id address { id }") { + id: ID! + name: String! + address: Address! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id address { id }", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndAddress", + Request: "LookupUserByIdAndAddressRequest", + Response: "LookupUserByIdAndAddressResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id address { id }", + }, + }, + + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndAddress", + // Define the structure of the request message + Request: RPCMessage{ + Name: "LookupUserByIdAndAddressRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndAddressKey", + MemberTypes: []string{"User"}, + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "address", + TypeName: string(DataTypeMessage), + JSONPath: "address", + Message: &RPCMessage{ + Name: "Address", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + }, + }, + }, + // Define the structure of the response message + Response: RPCMessage{ + Name: "LookupUserByIdAndAddressResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Should create an execution plan for an entity lookup with a compound key", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + type User @key(fields: "id name") { + id: ID! + name: String! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "id name", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndName", + Request: "LookupUserByIdAndNameRequest", + Response: "LookupUserByIdAndNameResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id name", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndName", + Request: RPCMessage{ + Name: "LookupUserByIdAndNameRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndNameKey", + MemberTypes: []string{"User"}, + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdAndNameResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Order in a compound key should not matter", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + type User @key(fields: "id name") { + id: ID! + name: String! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "name id", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndName", + Request: "LookupUserByIdAndNameRequest", + Response: "LookupUserByIdAndNameResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id name", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndName", + Request: RPCMessage{ + Name: "LookupUserByIdAndNameRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndNameKey", + MemberTypes: []string{"User"}, + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdAndNameResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + { + name: "Nested fields in a compound key should be ignored", + query: `query EntityLookup($representations: [_Any!]!) { _entities(representations: $representations) { ... on User { __typename id name } } }`, + schema: testFederationSchemaString(` + type Query { + _entities(representations: [_Any!]!): [_Entity]! + } + + type Address { + id: ID! + street: String! + } + type User @key(fields: "id name address { id }") { + id: ID! + name: String! + address: Address! + } + `, []string{"User"}), + mapping: &GRPCMapping{ + Service: "Products", + EntityRPCs: map[string][]EntityRPCConfig{ + "User": { + { + Key: "name id address", + RPCConfig: RPCConfig{ + RPC: "LookupUserByIdAndNameAndAddress", + Request: "LookupUserByIdAndNameAndAddressRequest", + Response: "LookupUserByIdAndNameAndAddressResponse", + }, + }, + }, + }, + }, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "User", + SelectionSet: "id name address { id }", + }, + }, + expectedPlan: &RPCExecutionPlan{ + Calls: []RPCCall{ + { + ServiceName: "Products", + MethodName: "LookupUserByIdAndNameAndAddress", + Request: RPCMessage{ + Name: "LookupUserByIdAndNameAndAddressRequest", + Fields: []RPCField{ + { + Name: "keys", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "representations", + Message: &RPCMessage{ + Name: "LookupUserByIdAndNameAndAddressKey", + MemberTypes: []string{"User"}, + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + { + Name: "address", + TypeName: string(DataTypeMessage), + JSONPath: "address", + Message: &RPCMessage{ + Name: "Address", + Fields: []RPCField{ + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + }, + }, + }, + }, + }, + }, + }, + }, + Response: RPCMessage{ + Name: "LookupUserByIdAndNameAndAddressResponse", + Fields: []RPCField{ + { + Name: "result", + TypeName: string(DataTypeMessage), + Repeated: true, + JSONPath: "_entities", + Message: &RPCMessage{ + Name: "User", + Fields: []RPCField{ + { + Name: "__typename", + TypeName: string(DataTypeString), + JSONPath: "__typename", + StaticValue: "User", + }, + { + Name: "id", + TypeName: string(DataTypeString), + JSONPath: "id", + }, + { + Name: "name", + TypeName: string(DataTypeString), + JSONPath: "name", + }, + }, + }, + }, + }, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + runFederationTest(t, tt) + } +} + +func runFederationTest(t *testing.T, tt struct { + name string + query string + schema string + expectedPlan *RPCExecutionPlan + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations +}) { + + t.Helper() + + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + var operation, definition ast.Document + + definition = unsafeparser.ParseGraphqlDocumentStringWithBaseSchema(tt.schema) + + report := operationreport.Report{} + astvalidation.DefaultDefinitionValidator().Validate(&definition, &report) + if report.HasErrors() { + t.Fatalf("failed to validate schema: %s", report.Error()) + } + + operation, report = astparser.ParseGraphqlDocumentString(tt.query) + if report.HasErrors() { + t.Fatalf("failed to parse query: %s", report.Error()) + } + + astvalidation.DefaultOperationValidator().Validate(&operation, &definition, &report) + if report.HasErrors() { + t.Fatalf("failed to validate query: %s", report.Error()) + } + + planner := NewPlanner("Products", tt.mapping, tt.federationConfigs) + plan, err := planner.PlanOperation(&operation, &definition) + if err != nil { + t.Fatalf("failed to plan operation: %s", err) + } + + diff := cmp.Diff(tt.expectedPlan, plan) + if diff != "" { + t.Fatalf("execution plan mismatch: %s", diff) + } + }) + +} + +func testFederationSchemaString(schema string, entities []string) string { + entityUnion := strings.Join(entities, " | ") + return fmt.Sprintf(` + schema { + query: Query + } + %s + + union _Entity = %s + scalar _Any + `, schema, entityUnion) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go index ea9625552e..a19ef4a8ff 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_test.go @@ -11,7 +11,6 @@ import ( "github.com/google/go-cmp/cmp" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" - "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" ) type testCase struct { @@ -30,23 +29,21 @@ func runTest(t *testing.T, testCase testCase) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: testMapping(), }) - walker.Walk(&queryDoc, &schemaDoc, &report) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { - require.NotEmpty(t, testCase.expectedError, "expected error to be empty, got: %s", report.Error()) - require.Contains(t, report.Error(), testCase.expectedError, "expected error to contain: %s, got: %s", testCase.expectedError, report.Error()) + if err != nil { + require.NotEmpty(t, testCase.expectedError, "expected error to be empty, got: %s", err.Error()) + require.Contains(t, err.Error(), testCase.expectedError, "expected error to contain: %s, got: %s", testCase.expectedError, err.Error()) return } require.Empty(t, testCase.expectedError) - diff := cmp.Diff(testCase.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(testCase.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } @@ -153,7 +150,6 @@ func TestQueryExecutionPlans(t *testing.T) { { ServiceName: "Products", MethodName: "QueryUser", - CallID: 1, Request: RPCMessage{ Name: "QueryUserRequest", Fields: []RPCField{ @@ -359,8 +355,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a complex input type and variables", - query: `query ComplexFilterTypeQuery($filter: ComplexFilterTypeInput!) { complexFilterType(filter: $filter) { id name } }`, + name: "Should create an execution plan for a query with a complex input type and variables", + query: `query ComplexFilterTypeQuery($filter: ComplexFilterTypeInput!) { complexFilterType(filter: $filter) { id name } }`, + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -389,12 +386,12 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "filterField1", + Name: "filter_field_1", TypeName: string(DataTypeString), JSONPath: "filterField1", }, { - Name: "filterField2", + Name: "filter_field_2", TypeName: string(DataTypeString), JSONPath: "filterField2", }, @@ -412,7 +409,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "page", }, { - Name: "perPage", + Name: "per_page", TypeName: string(DataTypeInt32), JSONPath: "perPage", }, @@ -432,7 +429,7 @@ func TestQueryExecutionPlans(t *testing.T) { Fields: []RPCField{ { Repeated: true, - Name: "complexFilterType", + Name: "complex_filter_type", TypeName: string(DataTypeMessage), JSONPath: "complexFilterType", Message: &RPCMessage{ @@ -459,8 +456,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, { - name: "Should create an execution plan for a query with a complex input type and variables with different name", - query: `query ComplexFilterTypeQuery($foobar: ComplexFilterTypeInput!) { complexFilterType(filter: $foobar) { id name } }`, + name: "Should create an execution plan for a query with a complex input type and variables with different name", + query: `query ComplexFilterTypeQuery($foobar: ComplexFilterTypeInput!) { complexFilterType(filter: $foobar) { id name } }`, + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -489,12 +487,12 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "filterField1", + Name: "filter_field_1", TypeName: string(DataTypeString), JSONPath: "filterField1", }, { - Name: "filterField2", + Name: "filter_field_2", TypeName: string(DataTypeString), JSONPath: "filterField2", }, @@ -512,7 +510,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "page", }, { - Name: "perPage", + Name: "per_page", TypeName: string(DataTypeInt32), JSONPath: "perPage", }, @@ -532,7 +530,7 @@ func TestQueryExecutionPlans(t *testing.T) { Fields: []RPCField{ { Repeated: true, - Name: "complexFilterType", + Name: "complex_filter_type", TypeName: string(DataTypeMessage), JSONPath: "complexFilterType", Message: &RPCMessage{ @@ -558,8 +556,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a type filter with arguments and variables", - query: "query TypeWithMultipleFilterFieldsQuery($filter: FilterTypeInput!) { typeWithMultipleFilterFields(filter: $filter) { id name } }", + name: "Should create an execution plan for a query with a type filter with arguments and variables", + query: "query TypeWithMultipleFilterFieldsQuery($filter: FilterTypeInput!) { typeWithMultipleFilterFields(filter: $filter) { id name } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -577,13 +576,13 @@ func TestQueryExecutionPlans(t *testing.T) { Fields: []RPCField{ { Repeated: false, - Name: "filterField1", + Name: "filter_field_1", TypeName: string(DataTypeString), JSONPath: "filterField1", }, { Repeated: false, - Name: "filterField2", + Name: "filter_field_2", TypeName: string(DataTypeString), JSONPath: "filterField2", }, @@ -596,7 +595,7 @@ func TestQueryExecutionPlans(t *testing.T) { Name: "QueryTypeWithMultipleFilterFieldsResponse", Fields: []RPCField{ { - Name: "typeWithMultipleFilterFields", + Name: "type_with_multiple_filter_fields", TypeName: string(DataTypeMessage), Repeated: true, JSONPath: "typeWithMultipleFilterFields", @@ -623,8 +622,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query", - query: "query UserQuery { users { id name } }", + name: "Should create an execution plan for a query", + query: "query UserQuery { users { id name } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -678,8 +678,9 @@ func TestQueryExecutionPlans(t *testing.T) { expectedError: "no request message name mapping found for operation user", }, { - name: "Should create an execution plan for a query with a user", - query: `query UserQuery { user(id: "abc123") { id name } }`, + name: "Should create an execution plan for a query with a user", + query: `query UserQuery { user(id: "abc123") { id name } }`, + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -726,8 +727,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a nested type", - query: "query NestedTypeQuery { nestedType { id name b { id name c { id name } } } }", + name: "Should create an execution plan for a query with a nested type", + query: "query NestedTypeQuery { nestedType { id name b { id name c { id name } } } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -740,7 +742,7 @@ func TestQueryExecutionPlans(t *testing.T) { Name: "QueryNestedTypeResponse", Fields: []RPCField{ { - Name: "nestedType", + Name: "nested_type", TypeName: string(DataTypeMessage), Repeated: true, JSONPath: "nestedType", @@ -807,8 +809,9 @@ func TestQueryExecutionPlans(t *testing.T) { }, }, { - name: "Should create an execution plan for a query with a recursive type", - query: "query RecursiveTypeQuery { recursiveType { id name recursiveType { id recursiveType { id name recursiveType { id name } } name } } }", + name: "Should create an execution plan for a query with a recursive type", + query: "query RecursiveTypeQuery { recursiveType { id name recursiveType { id recursiveType { id name recursiveType { id name } } name } } }", + mapping: testMapping(), expectedPlan: &RPCExecutionPlan{ Calls: []RPCCall{ { @@ -821,7 +824,7 @@ func TestQueryExecutionPlans(t *testing.T) { Name: "QueryRecursiveTypeResponse", Fields: []RPCField{ { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -838,7 +841,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -850,7 +853,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "id", }, { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -867,7 +870,7 @@ func TestQueryExecutionPlans(t *testing.T) { JSONPath: "name", }, { - Name: "recursiveType", + Name: "recursive_type", TypeName: string(DataTypeMessage), JSONPath: "recursiveType", Message: &RPCMessage{ @@ -1087,23 +1090,21 @@ func TestQueryExecutionPlans(t *testing.T) { t.Fatalf("failed to parse query: %s", report.Error()) } - walker := astvisitor.NewWalker(48) - - rpcPlanVisitor := newRPCPlanVisitor(&walker, rpcPlanVisitorConfig{ + rpcPlanVisitor := newRPCPlanVisitor(rpcPlanVisitorConfig{ subgraphName: "Products", mapping: tt.mapping, }) - walker.Walk(&queryDoc, &schemaDoc, &report) + plan, err := rpcPlanVisitor.PlanOperation(&queryDoc, &schemaDoc) - if report.HasErrors() { - require.Contains(t, report.Error(), tt.expectedError) + if err != nil { + require.Contains(t, err.Error(), tt.expectedError) require.NotEmpty(t, tt.expectedError) return } require.Empty(t, tt.expectedError) - diff := cmp.Diff(tt.expectedPlan, rpcPlanVisitor.plan) + diff := cmp.Diff(tt.expectedPlan, plan) if diff != "" { t.Fatalf("execution plan mismatch: %s", diff) } @@ -1340,7 +1341,7 @@ func TestProductExecutionPlan(t *testing.T) { t.Fatalf("failed to validate query: %s", report.Error()) } - planner := NewPlanner("Products", testMapping()) + planner := NewPlanner("Products", testMapping(), nil) outPlan, err := planner.PlanOperation(&queryDoc, &schemaDoc) if tt.expectedError != "" { @@ -1507,7 +1508,6 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { { ServiceName: "Products", MethodName: "QueryCategories", - CallID: 1, Request: RPCMessage{ Name: "QueryCategoriesRequest", }, @@ -2026,7 +2026,6 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { { ServiceName: "Products", MethodName: "QueryUser", - CallID: 1, Request: RPCMessage{ Name: "QueryUserRequest", Fields: []RPCField{ @@ -2068,7 +2067,6 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { { ServiceName: "Products", MethodName: "QueryUser", - CallID: 2, Request: RPCMessage{ Name: "QueryUserRequest", Fields: []RPCField{ @@ -2446,7 +2444,7 @@ func TestProductExecutionPlanWithAliases(t *testing.T) { t.Fatalf("failed to validate query: %s", report.Error()) } - planner := NewPlanner("Products", testMapping()) + planner := NewPlanner("Products", testMapping(), nil) outPlan, err := planner.PlanOperation(&queryDoc, &schemaDoc) if tt.expectedError != "" { diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go index edf2dd880b..431357af14 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor.go @@ -1,36 +1,20 @@ package grpcdatasource import ( + "errors" "fmt" - "strings" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" "golang.org/x/text/cases" "golang.org/x/text/language" ) -type keyField struct { - fieldName string - fieldType string -} - -type entityInfo struct { - name string - keyFields []keyField - keyTypeName string - entityRootFieldRef int - entityInlineFragmentRef int -} - type planningInfo struct { - entityInfo entityInfo - // resolvers []string operationType ast.OperationType operationFieldName string - isEntityLookup bool - methodName string requestMessageAncestors []*RPCMessage currentRequestMessage *RPCMessage @@ -46,61 +30,68 @@ type rpcPlanVisitor struct { walker *astvisitor.Walker operation *ast.Document definition *ast.Document + planCtx *rpcPlanningContext planInfo planningInfo - subgraphName string - mapping *GRPCMapping - plan *RPCExecutionPlan - operationDefinitionRef int - operationFieldRef int - operationFieldRefs []int - currentCall *RPCCall - currentCallID int + subgraphName string + mapping *GRPCMapping + plan *RPCExecutionPlan + operationFieldRef int + operationFieldRefs []int + currentCall *RPCCall + currentCallID int } type rpcPlanVisitorConfig struct { - subgraphName string - mapping *GRPCMapping + subgraphName string + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations } // newRPCPlanVisitor creates a new RPCPlanVisitor. // It registers the visitor with the walker and returns it. -func newRPCPlanVisitor(walker *astvisitor.Walker, config rpcPlanVisitorConfig) *rpcPlanVisitor { +func newRPCPlanVisitor(config rpcPlanVisitorConfig) *rpcPlanVisitor { + walker := astvisitor.NewWalker(48) visitor := &rpcPlanVisitor{ - walker: walker, - plan: &RPCExecutionPlan{}, - subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), - mapping: config.mapping, - operationDefinitionRef: -1, - operationFieldRef: -1, + walker: &walker, + plan: &RPCExecutionPlan{}, + subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), + mapping: config.mapping, + operationFieldRef: -1, } walker.RegisterEnterDocumentVisitor(visitor) walker.RegisterEnterOperationVisitor(visitor) walker.RegisterFieldVisitor(visitor) walker.RegisterSelectionSetVisitor(visitor) - walker.RegisterInlineFragmentVisitor(visitor) walker.RegisterEnterArgumentVisitor(visitor) return visitor } +func (r *rpcPlanVisitor) PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) { + report := &operationreport.Report{} + r.walker.Walk(operation, definition, report) + if report.HasErrors() { + return nil, fmt.Errorf("unable to plan operation: %w", report) + } + + return r.plan, nil +} + // EnterDocument implements astvisitor.EnterDocumentVisitor. func (r *rpcPlanVisitor) EnterDocument(operation *ast.Document, definition *ast.Document) { r.definition = definition r.operation = operation + + r.planCtx = newRPCPlanningContext(operation, definition, r.mapping) } // EnterOperationDefinition implements astvisitor.EnterOperationDefinitionVisitor. // This is called when entering the operation definition node. // It retrieves information about the operation // and creates a new group in the plan. -// -// The function also checks if this is an entity lookup operation, -// which requires special handling. func (r *rpcPlanVisitor) EnterOperationDefinition(ref int) { - r.operationDefinitionRef = ref - // Retrieves the fields from the root selection set. // These fields determine the names for the RPC functions to call. // TODO: handle fragments on root level `... on Query {}` @@ -117,10 +108,6 @@ func (r *rpcPlanVisitor) EnterOperationDefinition(ref int) { // // TODO handle field arguments to define resolvers func (r *rpcPlanVisitor) EnterArgument(ref int) { - if r.planInfo.isEntityLookup { - return - } - a := r.walker.Ancestor() if a.Kind != ast.NodeKindField && a.Ref != r.operationFieldRef { return @@ -137,9 +124,6 @@ func (r *rpcPlanVisitor) EnterArgument(ref int) { // EnterSelectionSet implements astvisitor.EnterSelectionSetVisitor. // Checks if this is in the root level below the operation definition. -// -// TODO handle multiple entity lookups in a single query. -// We need to create a new call for each entity lookup. func (r *rpcPlanVisitor) EnterSelectionSet(ref int) { if r.walker.Ancestor().Kind == ast.NodeKindOperationDefinition { return @@ -151,7 +135,7 @@ func (r *rpcPlanVisitor) EnterSelectionSet(ref int) { // In nested selection sets, a new message needs to be created, which will be added to the current response message. if r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message == nil { - r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message = r.newMessageFromSelectionSet(ref) + r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) } // Add the current response message to the ancestors and set the current response message to the current field message @@ -176,7 +160,7 @@ func (r *rpcPlanVisitor) EnterSelectionSet(ref int) { } func (r *rpcPlanVisitor) handleCompositeType(node ast.Node) error { - if node.Ref < 0 { + if node.Ref == ast.InvalidRef { return nil } @@ -230,66 +214,34 @@ func (r *rpcPlanVisitor) LeaveSelectionSet(ref int) { } } -// EnterInlineFragment implements astvisitor.InlineFragmentVisitor. -func (r *rpcPlanVisitor) EnterInlineFragment(ref int) { - entityInfo := &r.planInfo.entityInfo - if entityInfo.entityRootFieldRef != -1 && entityInfo.entityInlineFragmentRef == -1 { - entityInfo.entityInlineFragmentRef = ref - r.resolveEntityInformation(ref) - r.scaffoldEntityLookup() - - return - } -} - -// LeaveInlineFragment implements astvisitor.InlineFragmentVisitor. -func (r *rpcPlanVisitor) LeaveInlineFragment(ref int) { - if ref == r.planInfo.entityInfo.entityInlineFragmentRef { - r.planInfo.entityInfo.entityInlineFragmentRef = -1 - } -} - -func (r *rpcPlanVisitor) IsRootField() bool { - return len(r.walker.Ancestors) == 2 && r.walker.Ancestors[0].Kind == ast.NodeKindOperationDefinition -} - -func (r *rpcPlanVisitor) IsInlineFragmentField() (int, bool) { - if len(r.walker.Ancestors) < 2 { - return -1, false - } - - node := r.walker.Ancestors[len(r.walker.Ancestors)-2] - if node.Kind != ast.NodeKindInlineFragment { - return -1, false - } - - return node.Ref, true -} - func (r *rpcPlanVisitor) handleRootField(ref int) error { r.operationFieldRef = ref r.planInfo.operationFieldName = r.operation.FieldNameString(ref) r.currentCall = &RPCCall{ - CallID: r.currentCallID, - ServiceName: r.resolveServiceName(), + ServiceName: r.planCtx.resolveServiceName(r.subgraphName), } r.planInfo.currentRequestMessage = &r.currentCall.Request r.planInfo.currentResponseMessage = &r.currentCall.Response // attempt to resolve the name from the mapping - if err := r.resolveRPCMethodMapping(); err != nil { + rpcConfig, err := r.planCtx.resolveRPCMethodMapping(r.planInfo.operationType, r.planInfo.operationFieldName) + if err != nil { return err } + r.currentCall.MethodName = rpcConfig.RPC + r.currentCall.Request.Name = rpcConfig.Request + r.currentCall.Response.Name = rpcConfig.Response + return nil } // EnterField implements astvisitor.EnterFieldVisitor. func (r *rpcPlanVisitor) EnterField(ref int) { fieldName := r.operation.FieldNameString(ref) - if r.IsRootField() { + if r.walker.InRootField() { if err := r.handleRootField(ref); err != nil { r.walker.StopWithInternalErr(err) return @@ -297,16 +249,7 @@ func (r *rpcPlanVisitor) EnterField(ref int) { } if fieldName == "_entities" { - // _entities is a special field that is used to look up entities - // Entity lookups are handled differently as we use special types for - // Providing variables (_Any) and the response type is a Union that needs to be - // determined from the first inline fragment. - r.planInfo.entityInfo = entityInfo{ - entityRootFieldRef: ref, - entityInlineFragmentRef: -1, - } - r.planInfo.isEntityLookup = true - r.planInfo.entityInfo.entityRootFieldRef = ref + r.walker.StopWithInternalErr(errors.New("entities field is not supported in this visitor")) return } @@ -324,9 +267,14 @@ func (r *rpcPlanVisitor) EnterField(ref int) { return } - field := r.buildField(fd, fieldName, fieldAlias) + field, err := r.planCtx.buildField(r.walker.EnclosingTypeDefinition, fd, fieldName, fieldAlias) + if err != nil { + r.walker.StopWithInternalErr(err) + return + } - if ref, ok := r.IsInlineFragmentField(); ok && !r.planInfo.isEntityLookup { + // check if we are inside of an inline fragment + if ref, ok := r.walker.ResolveInlineFragment(); ok { if r.planInfo.currentResponseMessage.FieldSelectionSet == nil { r.planInfo.currentResponseMessage.FieldSelectionSet = make(RPCFieldSelectionSet) } @@ -339,65 +287,14 @@ func (r *rpcPlanVisitor) EnterField(ref int) { r.planInfo.currentResponseMessage.Fields = append(r.planInfo.currentResponseMessage.Fields, field) } -// buildField builds a field from a field definition. -// It handles lists, enums, and other types. -func (r *rpcPlanVisitor) buildField(fd int, fieldName, fieldAlias string) RPCField { - fdt := r.definition.FieldDefinitionType(fd) - typeName := r.toDataType(&r.definition.Types[fdt]) - parentTypeName := r.walker.EnclosingTypeDefinition.NameString(r.definition) - - field := RPCField{ - Name: r.resolveFieldMapping(parentTypeName, fieldName), - Alias: fieldAlias, - Optional: !r.definition.TypeIsNonNull(fdt), - JSONPath: fieldName, - TypeName: typeName.String(), - } - - if r.definition.TypeIsList(fdt) { - switch { - // for nullable or nested lists we need to build a wrapper message - // Nullability is handled by the datasource during the execution. - case r.typeIsNullableOrNestedList(fdt): - field.ListMetadata = r.createListMetadata(fdt) - field.IsListType = true - default: - // For non-nullable single lists we can directly use the repeated syntax in protobuf. - field.Repeated = true - } - } - - if typeName == DataTypeEnum { - field.EnumName = r.definition.FieldDefinitionTypeNameString(fd) - } - - if fieldName == "__typename" { - field.StaticValue = parentTypeName - } - - return field -} - // LeaveField implements astvisitor.FieldVisitor. func (r *rpcPlanVisitor) LeaveField(ref int) { - if ref == r.planInfo.entityInfo.entityRootFieldRef { - r.planInfo.entityInfo.entityRootFieldRef = -1 - } - // If we are not in the operation field, we can increment the response field index. - if !r.IsRootField() { + if !r.walker.InRootField() { r.planInfo.currentResponseFieldIndex++ return } - // If we left the operation field, we need to finalize the current call and prepare the next one. - if r.currentCall.MethodName == "" { - methodName := r.rpcMethodName() - r.currentCall.MethodName = methodName - r.currentCall.Request.Name = methodName + "Request" - r.currentCall.Response.Name = methodName + "Response" - } - r.plan.Calls[r.currentCallID] = *r.currentCall r.currentCall = &RPCCall{} @@ -409,16 +306,6 @@ func (r *rpcPlanVisitor) LeaveField(ref int) { r.planInfo.currentResponseFieldIndex = 0 } -// newMessageFromSelectionSet creates a new message from a selection set. -func (r *rpcPlanVisitor) newMessageFromSelectionSet(ref int) *RPCMessage { - message := &RPCMessage{ - Name: r.walker.EnclosingTypeDefinition.NameString(r.definition), - Fields: make(RPCFields, 0, len(r.operation.SelectionSets[ref].SelectionRefs)), - } - - return message -} - // enrichRequestMessageFromInputArgument constructs a request message from an input argument based on its type. // It retrieves the underlying type and builds the request message from the underlying type. // If the underlying type is an input object type, it creates a new message and adds it to the current request message. @@ -466,7 +353,7 @@ func (r *rpcPlanVisitor) enrichRequestMessageFromInputArgument(argRef, typeRef i r.planInfo.requestMessageAncestors = r.planInfo.requestMessageAncestors[:len(r.planInfo.requestMessageAncestors)-1] case ast.NodeKindScalarTypeDefinition, ast.NodeKindEnumTypeDefinition: - dt := r.toDataType(&r.definition.Types[typeRef]) + dt := r.planCtx.toDataType(&r.definition.Types[typeRef]) r.planInfo.currentRequestMessage.Fields = append(r.planInfo.currentRequestMessage.Fields, r.buildInputMessageField(typeRef, mappedInputName, jsonPath, dt)) @@ -507,7 +394,7 @@ func (r *rpcPlanVisitor) buildMessageField(fieldName string, typeRef, parentType // If the type is not an object, directly add the field to the request message if underlyingTypeNode.Kind != ast.NodeKindInputObjectTypeDefinition { - dt := r.toDataType(&inputValueDefinitionType) + dt := r.planCtx.toDataType(&inputValueDefinitionType) r.planInfo.currentRequestMessage.Fields = append(r.planInfo.currentRequestMessage.Fields, r.buildInputMessageField(typeRef, mappedName, fieldName, dt)) @@ -545,8 +432,13 @@ func (r *rpcPlanVisitor) buildInputMessageField(typeRef int, fieldName, jsonPath switch { // for nullable or nested lists we need to build a wrapper message // Nullability is handled by the datasource during the execution. - case r.typeIsNullableOrNestedList(typeRef): - field.ListMetadata = r.createListMetadata(typeRef) + case r.planCtx.typeIsNullableOrNestedList(typeRef): + md, err := r.planCtx.createListMetadata(typeRef) + if err != nil { + r.walker.StopWithInternalErr(err) + return field + } + field.ListMetadata = md field.IsListType = true default: // For non-nullable single lists we can directly use the repeated syntax in protobuf. @@ -561,156 +453,6 @@ func (r *rpcPlanVisitor) buildInputMessageField(typeRef int, fieldName, jsonPath return field } -func (r *rpcPlanVisitor) createListMetadata(typeRef int) *ListMetadata { - nestingLevel := r.definition.TypeNumberOfListWraps(typeRef) - - md := &ListMetadata{ - NestingLevel: nestingLevel, - LevelInfo: make([]LevelInfo, nestingLevel), - } - - for i := 0; i < nestingLevel; i++ { - md.LevelInfo[i] = LevelInfo{ - Optional: !r.definition.TypeIsNonNull(typeRef), - } - - typeRef = r.definition.ResolveNestedListOrListType(typeRef) - if typeRef == ast.InvalidRef { - r.walker.StopWithInternalErr(fmt.Errorf("unable to resolve underlying list type for ref: %d", typeRef)) - return nil - } - } - - return md -} - -func (r *rpcPlanVisitor) resolveEntityInformation(inlineFragmentRef int) { - // TODO support multiple entities in a single query - if !r.planInfo.isEntityLookup || r.planInfo.entityInfo.name != "" { - return - } - - fragmentName := r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef) - node, found := r.definition.NodeByNameStr(fragmentName) - if !found { - return - } - - // Only process object type definitions - // TODO: handle interfaces - if node.Kind != ast.NodeKindObjectTypeDefinition { - return - } - - // An entity must at least have a key directive - def := r.definition.ObjectTypeDefinitions[node.Ref] - if !def.HasDirectives { - return - } - - // TODO: We currently only support one key directive per entity - // We need to get the used key from the graphql datasource. - for _, directiveRef := range def.Directives.Refs { - if r.definition.DirectiveNameString(directiveRef) != federationKeyDirectiveName { - continue - } - - r.planInfo.entityInfo.name = fragmentName - - directive := r.definition.Directives[directiveRef] - for _, argRef := range directive.Arguments.Refs { - if r.definition.ArgumentNameString(argRef) != "fields" { - continue - } - argument := r.definition.Arguments[argRef] - keyFieldName := r.definition.ValueContentString(argument.Value) - - fieldDef, ok := r.definition.NodeFieldDefinitionByName(node, ast.ByteSlice(keyFieldName)) - if !ok { - r.walker.Report.AddExternalError(operationreport.ExternalError{ - Message: fmt.Sprintf("Field %s not found in definition", keyFieldName), - }) - return - } - - fdt := r.definition.FieldDefinitionType(fieldDef) - ft := r.definition.Types[fdt] - - r.planInfo.entityInfo.keyFields = - append(r.planInfo.entityInfo.keyFields, keyField{ - fieldName: keyFieldName, - fieldType: r.toDataType(&ft).String(), - }) - } - - break - } - - keyFields := make([]string, 0, len(r.planInfo.entityInfo.keyFields)) - for _, key := range r.planInfo.entityInfo.keyFields { - keyFields = append(keyFields, key.fieldName) - } - - if ei, exists := r.mapping.EntityRPCs[r.planInfo.entityInfo.name]; exists { - r.currentCall.Request.Name = ei.RPCConfig.Request - r.currentCall.Response.Name = ei.RPCConfig.Response - r.planInfo.methodName = ei.RPCConfig.RPC - } - - r.planInfo.entityInfo.keyTypeName = r.planInfo.entityInfo.name + "By" + strings.Join(titleSlice(keyFields), "And") -} - -// scaffoldEntityLookup creates the entity lookup call structure -// by creating the key field message and adding it to the current request message. -// It also adds the results message to the current response message. -func (r *rpcPlanVisitor) scaffoldEntityLookup() { - if !r.planInfo.isEntityLookup { - return - } - - entityInfo := &r.planInfo.entityInfo - keyFieldMessage := &RPCMessage{ - Name: r.rpcMethodName() + "Key", - } - for _, key := range entityInfo.keyFields { - keyFieldMessage.Fields = append(keyFieldMessage.Fields, RPCField{ - Name: key.fieldName, - TypeName: key.fieldType, - JSONPath: key.fieldName, - }) - } - - r.planInfo.currentRequestMessage.Fields = []RPCField{ - { - Name: "keys", - TypeName: DataTypeMessage.String(), - Repeated: true, // The inputs are always a list of objects - JSONPath: "representations", - Message: keyFieldMessage, - }, - } - - // The proto response message has a field `result` which is a list of entities. - // As this is a special case we directly map it to _entities. - r.planInfo.currentResponseMessage.Fields = []RPCField{ - { - Name: "result", - TypeName: DataTypeMessage.String(), - JSONPath: "_entities", - Repeated: true, - }, - } -} - -func (r *rpcPlanVisitor) resolveServiceName() string { - if r.mapping == nil || r.mapping.Service == "" { - return r.subgraphName - } - - return r.mapping.Service -} - -// resolveFieldMapping resolves the field mapping for a field. // This applies both for complex types in the input and for all fields in the response. func (r *rpcPlanVisitor) resolveFieldMapping(typeName, fieldName string) string { grpcFieldName, ok := r.mapping.ResolveFieldMapping(typeName, fieldName) @@ -733,150 +475,3 @@ func (r *rpcPlanVisitor) resolveInputArgument(baseType string, fieldRef int, arg return grpcFieldName } - -func (r *rpcPlanVisitor) resolveRPCMethodMapping() error { - if r.mapping == nil { - return nil - } - - if r.planInfo.isEntityLookup && r.planInfo.entityInfo.name != "" { - // Resolving the entity lookup method name is done differently - return nil - } - - var rpcConfig RPCConfig - var ok bool - - switch r.planInfo.operationType { - case ast.OperationTypeQuery: - rpcConfig, ok = r.mapping.QueryRPCs[r.planInfo.operationFieldName] - case ast.OperationTypeMutation: - rpcConfig, ok = r.mapping.MutationRPCs[r.planInfo.operationFieldName] - case ast.OperationTypeSubscription: - rpcConfig, ok = r.mapping.SubscriptionRPCs[r.planInfo.operationFieldName] - } - - // if we don't have a mapping, we can skip the operation - if !ok { - return nil - } - - // We require all fields to be present when defining a mapping for the operation - if rpcConfig.RPC == "" { - return fmt.Errorf("no rpc method name mapping found for operation %s", r.planInfo.operationFieldName) - } - - if rpcConfig.Request == "" { - return fmt.Errorf("no request message name mapping found for operation %s", r.planInfo.operationFieldName) - } - - if rpcConfig.Response == "" { - return fmt.Errorf("no response message name mapping found for operation %s", r.planInfo.operationFieldName) - } - - r.currentCall.MethodName = rpcConfig.RPC - r.currentCall.Request.Name = rpcConfig.Request - r.currentCall.Response.Name = rpcConfig.Response - - return nil -} - -// rpcMethodName determines the appropriate method name based on operation type. -func (r *rpcPlanVisitor) rpcMethodName() string { - if r.planInfo.methodName != "" { - return r.planInfo.methodName - } - - switch r.planInfo.operationType { - case ast.OperationTypeQuery: - r.planInfo.methodName = r.buildQueryMethodName() - case ast.OperationTypeMutation: - r.planInfo.methodName = r.buildMutationMethodName() - case ast.OperationTypeSubscription: - r.planInfo.methodName = r.buildSubscriptionMethodName() - } - - return r.planInfo.methodName -} - -// buildQueryMethodName constructs a method name for query operations. -func (r *rpcPlanVisitor) buildQueryMethodName() string { - if r.planInfo.isEntityLookup && r.planInfo.entityInfo.name != "" { - return "Lookup" + r.planInfo.entityInfo.keyTypeName - } - - return "Query" + cases.Title(language.Und, cases.NoLower).String(r.planInfo.operationFieldName) -} - -// buildMutationMethodName constructs a method name for mutation operations. -func (r *rpcPlanVisitor) buildMutationMethodName() string { - // TODO implement mutation method name handling - return "Mutation" + cases.Title(language.Und, cases.NoLower).String(r.planInfo.operationFieldName) -} - -// buildSubscriptionMethodName constructs a method name for subscription operations. -func (r *rpcPlanVisitor) buildSubscriptionMethodName() string { - // TODO implement subscription method name handling - return "Subscription" + cases.Title(language.Und, cases.NoLower).String(r.planInfo.operationFieldName) -} - -// toDataType converts an ast.Type to a DataType. -// It handles the different type kinds and non-null types. -func (r *rpcPlanVisitor) toDataType(t *ast.Type) DataType { - switch t.TypeKind { - case ast.TypeKindNamed: - return r.parseGraphQLType(t) - case ast.TypeKindList: - return r.toDataType(&r.definition.Types[t.OfType]) - case ast.TypeKindNonNull: - return r.toDataType(&r.definition.Types[t.OfType]) - } - - return DataTypeUnknown -} - -// parseGraphQLType parses an ast.Type and returns the corresponding DataType. -// It handles the different type kinds and non-null types. -func (r *rpcPlanVisitor) parseGraphQLType(t *ast.Type) DataType { - dt := r.definition.Input.ByteSliceString(t.Name) - - // Retrieve the node to check the kind - node, found := r.definition.NodeByNameStr(dt) - if !found { - return DataTypeUnknown - } - - // For non-scalar types, return the corresponding DataType - switch node.Kind { - case ast.NodeKindInterfaceTypeDefinition: - fallthrough - case ast.NodeKindUnionTypeDefinition: - fallthrough - case ast.NodeKindObjectTypeDefinition, ast.NodeKindInputObjectTypeDefinition: - return DataTypeMessage - case ast.NodeKindEnumTypeDefinition: - return DataTypeEnum - default: - return fromGraphQLType(dt) - } -} - -func (r *rpcPlanVisitor) typeIsNullableOrNestedList(typeRef int) bool { - if !r.definition.TypeIsNonNull(typeRef) && r.definition.TypeIsList(typeRef) { - return true - } - - if r.definition.TypeNumberOfListWraps(typeRef) > 1 { - return true - } - - return false -} - -// titleSlice capitalizes the first letter of each string in a slice. -func titleSlice(s []string) []string { - for i, v := range s { - s[i] = cases.Title(language.Und, cases.NoLower).String(v) - } - return s -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go new file mode 100644 index 0000000000..7f2aab2615 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/execution_plan_visitor_federation.go @@ -0,0 +1,431 @@ +package grpcdatasource + +import ( + "errors" + "fmt" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" + "golang.org/x/text/cases" + "golang.org/x/text/language" +) + +// entityInfo contains the information about the entity that is being looked up. +type entityInfo struct { + typeName string + entityRootFieldRef int + entityInlineFragmentRef int +} + +type federationConfigData struct { + entityTypeName string + keyFields string + requiredFields string +} + +func newFederationConfigData(entityTypeName string) federationConfigData { + return federationConfigData{ + entityTypeName: entityTypeName, + keyFields: "", + requiredFields: "", + } +} + +type rpcPlanVisitorFederation struct { + walker *astvisitor.Walker + operation *ast.Document + definition *ast.Document + planCtx *rpcPlanningContext + mapping *GRPCMapping + + planInfo planningInfo + entityInfo entityInfo + federationConfigData []federationConfigData + + plan *RPCExecutionPlan + subgraphName string + currentCall *RPCCall + currentCallIndex int +} + +func newRPCPlanVisitorFederation(config rpcPlanVisitorConfig) *rpcPlanVisitorFederation { + walker := astvisitor.NewWalker(48) + visitor := &rpcPlanVisitorFederation{ + walker: &walker, + plan: &RPCExecutionPlan{}, + subgraphName: cases.Title(language.Und, cases.NoLower).String(config.subgraphName), + mapping: config.mapping, + entityInfo: entityInfo{ + entityRootFieldRef: ast.InvalidRef, + entityInlineFragmentRef: ast.InvalidRef, + }, + federationConfigData: parseFederationConfigData(config.federationConfigs), + } + + walker.RegisterEnterDocumentVisitor(visitor) + walker.RegisterEnterOperationVisitor(visitor) + walker.RegisterInlineFragmentVisitor(visitor) + walker.RegisterSelectionSetVisitor(visitor) + walker.RegisterFieldVisitor(visitor) + + return visitor +} + +func (r *rpcPlanVisitorFederation) PlanOperation(operation, definition *ast.Document) (*RPCExecutionPlan, error) { + report := &operationreport.Report{} + r.walker.Walk(operation, definition, report) + if report.HasErrors() { + return nil, fmt.Errorf("unable to plan operation: %w", report) + } + + return r.plan, nil +} + +// EnterDocument implements astvisitor.EnterDocumentVisitor. +func (r *rpcPlanVisitorFederation) EnterDocument(operation *ast.Document, definition *ast.Document) { + r.operation = operation + r.definition = definition + + r.planCtx = newRPCPlanningContext(operation, definition, r.mapping) +} + +// EnterOperationDefinition implements astvisitor.EnterOperationDefinitionVisitor. +func (r *rpcPlanVisitorFederation) EnterOperationDefinition(ref int) { + if r.operation.OperationDefinitions[ref].OperationType != ast.OperationTypeQuery { + r.walker.StopWithInternalErr(errors.New("only query operations are supported for the federation plan visitor")) + return + } + + r.planInfo.operationType = r.operation.OperationDefinitions[ref].OperationType +} + +// EnterInlineFragment implements astvisitor.InlineFragmentVisitor. +func (r *rpcPlanVisitorFederation) EnterInlineFragment(ref int) { + fragmentName := r.operation.InlineFragmentTypeConditionNameString(ref) + fc, ok := r.FederationConfigDataByEntityTypeName(fragmentName) + if !ok { + return + } + + r.currentCall = &RPCCall{ + ServiceName: r.planCtx.resolveServiceName(r.subgraphName), + } + + r.planInfo.currentRequestMessage = &r.currentCall.Request + r.planInfo.currentResponseMessage = &r.currentCall.Response + + r.entityInfo.entityInlineFragmentRef = ref + r.entityInfo.typeName = fragmentName + if err := r.resolveEntityInformation(ref, fc); err != nil { + r.walker.StopWithInternalErr(err) + return + } + + r.scaffoldEntityLookup(fc) +} + +// LeaveInlineFragment implements astvisitor.InlineFragmentVisitor. +func (r *rpcPlanVisitorFederation) LeaveInlineFragment(ref int) { + if r.entityInfo.entityInlineFragmentRef != ref { + // We only handle the entity inline fragment + return + } + + r.plan.Calls = append(r.plan.Calls, *r.currentCall) + r.currentCall = &RPCCall{} + r.currentCallIndex++ + + r.planInfo = planningInfo{ + operationType: r.planInfo.operationType, + operationFieldName: r.planInfo.operationFieldName, + currentRequestMessage: &RPCMessage{}, + currentResponseMessage: &RPCMessage{}, + currentResponseFieldIndex: 0, + responseMessageAncestors: []*RPCMessage{}, + responseFieldIndexAncestors: []int{}, + } + + r.entityInfo.entityInlineFragmentRef = ast.InvalidRef +} + +// EnterSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *rpcPlanVisitorFederation) EnterSelectionSet(ref int) { + if r.walker.Ancestor().Kind == ast.NodeKindOperationDefinition { + return + } + + if r.planInfo.currentRequestMessage == nil || len(r.planInfo.currentResponseMessage.Fields) == 0 || len(r.planInfo.currentResponseMessage.Fields) <= r.planInfo.currentResponseFieldIndex { + return + } + + // In nested selection sets, a new message needs to be created, which will be added to the current response message. + if r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message == nil { + r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) + } + + if r.IsEntityInlineFragment(r.walker.Ancestor()) { + r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message.AppendTypeNameField(r.entityInfo.typeName) + } + + // Add the current response message to the ancestors and set the current response message to the current field message + r.planInfo.responseMessageAncestors = append(r.planInfo.responseMessageAncestors, r.planInfo.currentResponseMessage) + r.planInfo.currentResponseMessage = r.planInfo.currentResponseMessage.Fields[r.planInfo.currentResponseFieldIndex].Message + + // Check if the ancestor type is a composite type (interface or union) + // and set the oneof type and member types. + if err := r.handleCompositeType(r.walker.Ancestor()); err != nil { + // If the ancestor is a composite type, but we were unable to resolve the member types, + // we stop the walker and return an internal error. + r.walker.StopWithInternalErr(err) + return + } + + // Keep track of the field indices for the current response message. + // This is used to set the correct field index for the current response message + // when leaving the selection set. + r.planInfo.responseFieldIndexAncestors = append(r.planInfo.responseFieldIndexAncestors, r.planInfo.currentResponseFieldIndex) + + r.planInfo.currentResponseFieldIndex = 0 // reset the field index for the current selection set +} + +func (r *rpcPlanVisitorFederation) handleCompositeType(node ast.Node) error { + if node.Ref == ast.InvalidRef { + return nil + } + + var ( + ok bool + oneOfType OneOfType + memberTypes []string + ) + + switch node.Kind { + case ast.NodeKindField: + return r.handleCompositeType(r.walker.EnclosingTypeDefinition) + case ast.NodeKindInterfaceTypeDefinition: + oneOfType = OneOfTypeInterface + memberTypes, ok = r.definition.InterfaceTypeDefinitionImplementedByObjectWithNames(node.Ref) + if !ok { + return fmt.Errorf("interface type %s is not implemented by any object", r.definition.InterfaceTypeDefinitionNameString(node.Ref)) + } + case ast.NodeKindUnionTypeDefinition: + oneOfType = OneOfTypeUnion + memberTypes, ok = r.definition.UnionTypeDefinitionMemberTypeNames(node.Ref) + if !ok { + return fmt.Errorf("union type %s is not defined", r.definition.UnionTypeDefinitionNameString(node.Ref)) + } + default: + return nil + } + + r.planInfo.currentResponseMessage.OneOfType = oneOfType + r.planInfo.currentResponseMessage.MemberTypes = memberTypes + + return nil +} + +// LeaveSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *rpcPlanVisitorFederation) LeaveSelectionSet(ref int) { + if r.walker.Ancestor().Kind == ast.NodeKindInlineFragment { + return + } + + if len(r.planInfo.responseFieldIndexAncestors) > 0 { + r.planInfo.currentResponseFieldIndex = r.planInfo.responseFieldIndexAncestors[len(r.planInfo.responseFieldIndexAncestors)-1] + r.planInfo.responseFieldIndexAncestors = r.planInfo.responseFieldIndexAncestors[:len(r.planInfo.responseFieldIndexAncestors)-1] + } + + if len(r.planInfo.responseMessageAncestors) > 0 { + r.planInfo.currentResponseMessage = r.planInfo.responseMessageAncestors[len(r.planInfo.responseMessageAncestors)-1] + r.planInfo.responseMessageAncestors = r.planInfo.responseMessageAncestors[:len(r.planInfo.responseMessageAncestors)-1] + } +} + +// EnterField implements astvisitor.FieldVisitor. +func (r *rpcPlanVisitorFederation) EnterField(ref int) { + fieldName := r.operation.FieldNameString(ref) + if r.walker.InRootField() { + r.planInfo.operationFieldName = r.operation.FieldNameString(ref) + } + + if fieldName == "_entities" { + // _entities is a special field that is used to look up entities + // Entity lookups are handled differently as we use special types for + // Providing variables (_Any) and the response type is a Union that needs to be + // determined from the first inline fragment. + r.entityInfo = entityInfo{ + entityRootFieldRef: ref, + entityInlineFragmentRef: ast.InvalidRef, + } + + r.entityInfo.entityRootFieldRef = ref + return + } + + // prevent duplicate fields + fieldAlias := r.operation.FieldAliasString(ref) + if r.planInfo.currentResponseMessage.Fields.Exists(fieldName, fieldAlias) { + return + } + + fd, ok := r.walker.FieldDefinition(ref) + if !ok { + r.walker.Report.AddExternalError(operationreport.ExternalError{ + Message: fmt.Sprintf("Field %s not found in definition %s", r.operation.FieldNameString(ref), r.walker.EnclosingTypeDefinition.NameString(r.definition)), + }) + return + } + + field, err := r.planCtx.buildField(r.walker.EnclosingTypeDefinition, fd, fieldName, fieldAlias) + if err != nil { + r.walker.StopWithInternalErr(err) + return + } + + // check if we are inside of an inline fragment and not the entity inline fragment + if ref, ok := r.walker.ResolveInlineFragment(); ok && r.entityInfo.entityInlineFragmentRef != ref { + if r.planInfo.currentResponseMessage.FieldSelectionSet == nil { + r.planInfo.currentResponseMessage.FieldSelectionSet = make(RPCFieldSelectionSet) + } + + inlineFragmentName := r.operation.InlineFragmentTypeConditionNameString(ref) + r.planInfo.currentResponseMessage.FieldSelectionSet.Add(inlineFragmentName, field) + return + } + + r.planInfo.currentResponseMessage.Fields = append(r.planInfo.currentResponseMessage.Fields, field) +} + +// LeaveField implements astvisitor.FieldVisitor. +func (r *rpcPlanVisitorFederation) LeaveField(ref int) { + // If we are not in the operation field, we can increment the response field index. + if !r.walker.InRootField() { + r.planInfo.currentResponseFieldIndex++ + return + } + + r.planInfo.currentResponseFieldIndex = 0 +} + +func (r *rpcPlanVisitorFederation) resolveEntityInformation(inlineFragmentRef int, fc federationConfigData) error { + fragmentName := r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef) + node, found := r.definition.NodeByNameStr(r.operation.InlineFragmentTypeConditionNameString(inlineFragmentRef)) + if !found { + return errors.New("definition node not found for inline fragment: " + fragmentName) + } + + // Only process object type definitions + // TODO: handle interfaces + if node.Kind != ast.NodeKindObjectTypeDefinition { + return nil + } + + rpcConfig, exists := r.mapping.ResolveEntityRPCConfig(fc.entityTypeName, fc.keyFields) + if !exists { + return fmt.Errorf("entity type %s not found in mapping", fc.entityTypeName) + } + + r.currentCall.Request.Name = rpcConfig.Request + r.currentCall.Response.Name = rpcConfig.Response + r.currentCall.MethodName = rpcConfig.RPC + + return nil +} + +// scaffoldEntityLookup creates the entity lookup call structure +// by creating the key field message and adding it to the current request message. +// It also adds the results message to the current response message. +func (r *rpcPlanVisitorFederation) scaffoldEntityLookup(fc federationConfigData) { + keyFieldMessage := &RPCMessage{ + Name: r.currentCall.MethodName + "Key", + } + + walker := astvisitor.WalkerFromPool() + defer walker.Release() + + requiredFieldsVisitor := newRequiredFieldsVisitor(walker, keyFieldMessage, r.planCtx) + err := requiredFieldsVisitor.visitRequiredFields(r.definition, fc.entityTypeName, fc.keyFields) + if err != nil { + r.walker.StopWithInternalErr(err) + return + } + + r.planInfo.currentRequestMessage.Fields = []RPCField{ + { + Name: "keys", + TypeName: DataTypeMessage.String(), + Repeated: true, // The inputs are always a list of objects + JSONPath: "representations", + Message: keyFieldMessage, + }, + } + + // The proto response message has a field `result` which is a list of entities. + // As this is a special case we directly map it to _entities. + r.planInfo.currentResponseMessage.Fields = []RPCField{ + { + Name: "result", + TypeName: DataTypeMessage.String(), + JSONPath: "_entities", + Repeated: true, + }, + } +} + +// FederationConfigDataByEntityTypeName returns the entity config data for the given entity type name. +func (r *rpcPlanVisitorFederation) FederationConfigDataByEntityTypeName(entityTypeName string) (federationConfigData, bool) { + for _, fc := range r.federationConfigData { + if fc.entityTypeName == entityTypeName { + return fc, true + } + } + return federationConfigData{}, false +} + +func (r *rpcPlanVisitorFederation) IsEntityInlineFragment(node ast.Node) bool { + if node.Kind != ast.NodeKindInlineFragment { + return false + } + + if r.entityInfo.entityInlineFragmentRef == ast.InvalidRef { + return false + } + + return r.entityInfo.entityInlineFragmentRef == node.Ref +} + +func parseFederationConfigData(federationConfigs plan.FederationFieldConfigurations) []federationConfigData { + var out []federationConfigData + + typeNameIndexSet := map[string]int{} + typeNameIndex := 0 + + for _, fc := range federationConfigs { + // Create a new entity type if it doesn't exist + if _, ok := typeNameIndexSet[fc.TypeName]; !ok { + out = append(out, newFederationConfigData(fc.TypeName)) + typeNameIndexSet[fc.TypeName] = typeNameIndex + typeNameIndex++ + } + + data := &out[typeNameIndexSet[fc.TypeName]] + + // Selection set determines whether we have key fields or additional required fields + if fc.SelectionSet == "" { + continue + } + + // This is a required field, so we add it to the required fields + if fc.FieldName != "" { + data.requiredFields = fc.SelectionSet + continue + } + + // This is a key field, so we add it to the key fields + data.keyFields = fc.SelectionSet + } + + return out +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index a718b71aac..935c764363 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -9,19 +9,17 @@ package grpcdatasource import ( "bytes" "context" - "errors" "fmt" - "strconv" + "sync" "github.com/tidwall/gjson" "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "golang.org/x/sync/errgroup" "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" - protoref "google.golang.org/protobuf/reflect/protoreflect" ) // Verify DataSource implements the resolve.DataSource interface @@ -32,11 +30,12 @@ var _ resolve.DataSource = (*DataSource)(nil) // transforms the responses back to GraphQL format. type DataSource struct { // Invocations is a list of gRPC invocations to be executed - plan *RPCExecutionPlan - cc grpc.ClientConnInterface - rc *RPCCompiler - mapping *GRPCMapping - disabled bool + plan *RPCExecutionPlan + cc grpc.ClientConnInterface + rc *RPCCompiler + mapping *GRPCMapping + federationConfigs plan.FederationFieldConfigurations + disabled bool } type ProtoConfig struct { @@ -44,28 +43,30 @@ type ProtoConfig struct { } type DataSourceConfig struct { - Operation *ast.Document - Definition *ast.Document - Compiler *RPCCompiler - SubgraphName string - Mapping *GRPCMapping - Disabled bool + Operation *ast.Document + Definition *ast.Document + Compiler *RPCCompiler + SubgraphName string + Mapping *GRPCMapping + FederationConfigs plan.FederationFieldConfigurations + Disabled bool } // NewDataSource creates a new gRPC datasource func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*DataSource, error) { - planner := NewPlanner(config.SubgraphName, config.Mapping) + planner := NewPlanner(config.SubgraphName, config.Mapping, config.FederationConfigs) plan, err := planner.PlanOperation(config.Operation, config.Definition) if err != nil { return nil, err } return &DataSource{ - plan: plan, - cc: client, - rc: config.Compiler, - mapping: config.Mapping, - disabled: config.Disabled, + plan: plan, + cc: client, + rc: config.Compiler, + mapping: config.Mapping, + federationConfigs: config.FederationConfigs, + disabled: config.Disabled, }, nil } @@ -76,47 +77,66 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D // The input is expected to contain the necessary information to make // a gRPC call, including service name, method name, and request data. func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) (err error) { + // get variables from input + variables := gjson.Parse(string(input)).Get("body.variables") + builder := newJSONBuilder(d.mapping, variables) + if d.disabled { - out.Write(writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used"))) + out.Write(builder.writeErrorBytes(fmt.Errorf("gRPC datasource needs to be enabled to be used"))) return nil } - // get variables from input - variables := gjson.Parse(string(input)).Get("body.variables") - // get invocations from plan invocations, err := d.rc.Compile(d.plan, variables) if err != nil { return err } - a := astjson.Arena{} - root := a.NewObject() + responses := make([]*astjson.Value, len(invocations)) + errGrp, errGrpCtx := errgroup.WithContext(ctx) + mu := sync.Mutex{} // make gRPC calls - for _, invocation := range invocations { - // Invoke the gRPC method - this will populate invocation.Output - methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) - err := d.cc.Invoke(ctx, methodName, invocation.Input, invocation.Output) - if err != nil { - out.Write(writeErrorBytes(err)) + for index, invocation := range invocations { + errGrp.Go(func() error { + a := astjson.Arena{} + // Invoke the gRPC method - this will populate invocation.Output + methodName := fmt.Sprintf("/%s/%s", invocation.ServiceName, invocation.MethodName) + + err := d.cc.Invoke(errGrpCtx, methodName, invocation.Input, invocation.Output) + if err != nil { + return err + } + + mu.Lock() + defer mu.Unlock() + + response, err := builder.marshalResponseJSON(&a, &invocation.Call.Response, invocation.Output) + if err != nil { + return err + } + + responses[index] = response return nil - } + }) + } - responseJSON, err := d.marshalResponseJSON(&a, &invocation.Call.Response, invocation.Output) - if err != nil { - return err - } + if err := errGrp.Wait(); err != nil { + out.Write(builder.writeErrorBytes(err)) + return nil + } - root, _, err = astjson.MergeValues(root, responseJSON) + a := astjson.Arena{} + root := a.NewObject() + for _, response := range responses { + root, err = builder.mergeValues(root, response) if err != nil { - out.Write(writeErrorBytes(err)) - return nil + out.Write(builder.writeErrorBytes(err)) + return err } } - data := a.NewObject() - data.Set("data", root) + data := builder.toDataObject(root) out.Write(data.MarshalTo(nil)) return nil @@ -132,336 +152,3 @@ func (d *DataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) func (d *DataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { panic("unimplemented") } - -func (d *DataSource) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { - if message == nil { - return arena.NewNull(), nil - } - - root := arena.NewObject() - - if message.IsOneOf() { - oneof := data.Descriptor().Oneofs().ByName(protoref.Name(message.OneOfType.FieldName())) - if oneof == nil { - return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) - } - - oneofDescriptor := data.WhichOneof(oneof) - if oneofDescriptor == nil { - return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) - } - - if oneofDescriptor.Kind() == protoref.MessageKind { - data = data.Get(oneofDescriptor).Message() - } - } - - validFields := message.Fields - if message.IsOneOf() { - validFields = append(validFields, message.FieldSelectionSet.SelectFieldsForTypes(message.SelectValidTypes(string(data.Type().Descriptor().Name())))...) - } - - for _, field := range validFields { - if field.StaticValue != "" { - if len(message.MemberTypes) == 0 { - root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) - continue - } - - for _, memberTypes := range message.MemberTypes { - if memberTypes == string(data.Type().Descriptor().Name()) { - root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) - break - } - } - - continue - } - - fd := data.Descriptor().Fields().ByName(protoref.Name(field.Name)) - if fd == nil { - continue - } - - if fd.IsList() { - list := data.Get(fd).List() - arr := arena.NewArray() - root.Set(field.AliasOrPath(), arr) - - if !list.IsValid() { - continue - } - - for i := 0; i < list.Len(); i++ { - - switch fd.Kind() { - case protoref.MessageKind: - message := list.Get(i).Message() - value, err := d.marshalResponseJSON(arena, field.Message, message) - if err != nil { - return nil, err - } - - arr.SetArrayItem(i, value) - default: - d.setArrayItem(i, arena, arr, list.Get(i), fd) - } - - } - - continue - } - - if fd.Kind() == protoref.MessageKind { - msg := data.Get(fd).Message() - if !msg.IsValid() { - root.Set(field.AliasOrPath(), arena.NewNull()) - continue - } - - if field.IsListType { - arr, err := d.flattenListStructure(arena, field.ListMetadata, msg, field.Message) - if err != nil { - return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) - } - - root.Set(field.AliasOrPath(), arr) - continue - } - - if field.IsOptionalScalar() { - err := d.resolveOptionalField(arena, root, field.JSONPath, msg) - if err != nil { - return nil, err - } - - continue - } - - value, err := d.marshalResponseJSON(arena, field.Message, msg) - if err != nil { - return nil, err - } - - if field.JSONPath == "" { - root, _, err = astjson.MergeValues(root, value) - if err != nil { - return nil, err - } - } else { - root.Set(field.AliasOrPath(), value) - } - - continue - } - - d.setJSONValue(arena, root, field.AliasOrPath(), data, fd) - } - - return root, nil -} - -func (d *DataSource) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { - if md == nil { - return arena.NewNull(), errors.New("list metadata not found") - } - - if len(md.LevelInfo) < md.NestingLevel { - return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") - } - - if !data.IsValid() { - if md.LevelInfo[0].Optional { - return arena.NewNull(), nil - } - - return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") - } - - root := arena.NewArray() - return d.traverseList(0, arena, root, md, data, message) -} - -func (d *DataSource) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { - if level > md.NestingLevel { - return current, nil - } - - // List wrappers always use field number 1 - fd := data.Descriptor().Fields().ByNumber(1) - if fd == nil { - return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) - } - - if fd.Kind() != protoref.MessageKind { - return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) - } - - msg := data.Get(fd).Message() - if !msg.IsValid() { - // If the message is not valid we can either return null if the list is nullable or an error if it is non nullable. - if md.LevelInfo[level].Optional { - return arena.NewNull(), nil - } - - return arena.NewArray(), fmt.Errorf("cannot add null item to response for non nullable list") - } - - fd = msg.Descriptor().Fields().ByNumber(1) - if !fd.IsList() { - return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) - } - - if level < md.NestingLevel-1 { - list := msg.Get(fd).List() - for i := 0; i < list.Len(); i++ { - next := arena.NewArray() - val, err := d.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) - if err != nil { - return nil, err - } - - current.SetArrayItem(i, val) - } - - return current, nil - } - - list := msg.Get(fd).List() - if !list.IsValid() { - // If the list is not valid, we return an empty array here as the - // nullabilty is checked on the outer List wrapper type. - return arena.NewArray(), nil - } - - for i := 0; i < list.Len(); i++ { - if message != nil { - val, err := d.marshalResponseJSON(arena, message, list.Get(i).Message()) - if err != nil { - return nil, err - } - - current.SetArrayItem(i, val) - } else { - d.setArrayItem(i, arena, current, list.Get(i), fd) - } - } - - return current, nil -} - -func (d *DataSource) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { - fd := data.Descriptor().Fields().ByName(protoref.Name("value")) - if fd == nil { - return fmt.Errorf("unable to resolve optional field: field %q not found in message %s", "value", data.Descriptor().Name()) - } - - d.setJSONValue(arena, root, name, data, fd) - return nil -} - -func (d *DataSource) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { - if !data.IsValid() { - root.Set(name, arena.NewNull()) - return - } - - switch fd.Kind() { - case protoref.BoolKind: - boolValue := data.Get(fd).Bool() - if boolValue { - root.Set(name, arena.NewTrue()) - } else { - root.Set(name, arena.NewFalse()) - } - case protoref.StringKind: - root.Set(name, arena.NewString(data.Get(fd).String())) - case protoref.Int32Kind, protoref.Int64Kind: - root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) - case protoref.Uint32Kind, protoref.Uint64Kind: - root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) - case protoref.FloatKind, protoref.DoubleKind: - root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) - case protoref.BytesKind: - root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) - case protoref.EnumKind: - enumDesc := fd.Enum() - enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) - if enumValueDesc == nil { - root.Set(name, arena.NewNull()) - return - } - - graphqlValue, ok := d.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) - if !ok { - root.Set(name, arena.NewNull()) - return - } - - root.Set(name, arena.NewString(graphqlValue)) - } -} - -func (d *DataSource) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { - if !data.IsValid() { - array.SetArrayItem(index, arena.NewNull()) - return - } - - switch fd.Kind() { - case protoref.BoolKind: - boolValue := data.Bool() - if boolValue { - array.SetArrayItem(index, arena.NewTrue()) - } else { - array.SetArrayItem(index, arena.NewFalse()) - } - case protoref.StringKind: - array.SetArrayItem(index, arena.NewString(data.String())) - case protoref.Int32Kind, protoref.Int64Kind: - array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) - case protoref.Uint32Kind, protoref.Uint64Kind: - array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) - case protoref.FloatKind, protoref.DoubleKind: - array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) - case protoref.BytesKind: - array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) - case protoref.EnumKind: - enumDesc := fd.Enum() - enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) - if enumValueDesc == nil { - array.SetArrayItem(index, arena.NewNull()) - return - } - - graphqlValue, ok := d.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) - if !ok { - array.SetArrayItem(index, arena.NewNull()) - return - } - - array.SetArrayItem(index, arena.NewString(graphqlValue)) - } -} - -func writeErrorBytes(err error) []byte { - a := astjson.Arena{} - errorRoot := a.NewObject() - errorArray := a.NewArray() - errorRoot.Set("errors", errorArray) - - errorItem := a.NewObject() - errorItem.Set("message", a.NewString(err.Error())) - - extensions := a.NewObject() - if st, ok := status.FromError(err); ok { - extensions.Set("code", a.NewString(st.Code().String())) - } else { - extensions.Set("code", a.NewString(codes.Internal.String())) - } - - errorItem.Set("extensions", extensions) - errorArray.SetArrayItem(0, errorItem) - - return errorRoot.MarshalTo(nil) -} diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index aae0711585..bd393ddf14 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -11,8 +11,10 @@ import ( "testing" "github.com/stretchr/testify/require" + "github.com/tidwall/gjson" "github.com/wundergraph/astjson" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest/productv1" "google.golang.org/grpc" @@ -125,6 +127,14 @@ func Test_DataSource_Load(t *testing.T) { SubgraphName: "Products", Compiler: compiler, Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "complexFilterType": { + RPC: "QueryComplexFilterType", + Request: "QueryComplexFilterTypeRequest", + Response: "QueryComplexFilterTypeResponse", + }, + }, Fields: map[string]FieldMap{ "Query": { "complexFilterType": { @@ -176,10 +186,21 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { SubgraphName: "Products", Compiler: compiler, Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "complexFilterType": { + RPC: "QueryComplexFilterType", + Request: "QueryComplexFilterTypeRequest", + Response: "QueryComplexFilterTypeResponse", + }, + }, Fields: map[string]FieldMap{ "Query": { "complexFilterType": { TargetName: "complex_filter_type", + ArgumentMappings: map[string]string{ + "filter": "filter", + }, }, }, "FilterType": { @@ -255,10 +276,21 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { SubgraphName: "Products", Compiler: compiler, Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "complexFilterType": { + RPC: "QueryComplexFilterType", + Request: "QueryComplexFilterTypeRequest", + Response: "QueryComplexFilterTypeResponse", + }, + }, Fields: map[string]FieldMap{ "Query": { "complexFilterType": { TargetName: "complex_filter_type", + ArgumentMappings: map[string]string{ + "filter": "filter", + }, }, }, "FilterType": { @@ -346,6 +378,31 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { Definition: &schemaDoc, SubgraphName: "Products", Compiler: compiler, + Mapping: &GRPCMapping{ + Service: "Products", + QueryRPCs: RPCConfigMap{ + "user": { + RPC: "QueryUser", + Request: "QueryUserRequest", + Response: "QueryUserResponse", + }, + }, + Fields: map[string]FieldMap{ + "Query": { + "user": { + TargetName: "user", + }, + }, + "User": { + "id": { + TargetName: "id", + }, + "name": { + TargetName: "name", + }, + }, + }, + }, }) require.NoError(t, err) @@ -440,10 +497,9 @@ func TestMarshalResponseJSON(t *testing.T) { responseMessage := dynamicpb.NewMessage(responseMessageDesc) responseMessage.Mutable(responseMessageDesc.Fields().ByName("result")).List().Append(protoref.ValueOfMessage(productMessage)) - ds := &DataSource{} - arena := astjson.Arena{} - responseJSON, err := ds.marshalResponseJSON(&arena, &response, responseMessage) + jsonBuilder := newJSONBuilder(nil, gjson.Result{}) + responseJSON, err := jsonBuilder.marshalResponseJSON(&arena, &response, responseMessage) require.NoError(t, err) require.Equal(t, `{"_entities":[{"__typename":"Product","id":"123","name_different":"test","price_different":123.45}]}`, responseJSON.String()) } @@ -1803,6 +1859,47 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } }, }, + { + name: "Query nullable fields type with all aliased fields", + query: `query { nullableFieldsType { id name optionalString1: optionalString optionalInt1: optionalInt optionalFloat1: optionalFloat optionalBoolean1: optionalBoolean requiredString1: requiredString requiredInt1: requiredInt } }`, + vars: "{}", + validate: func(t *testing.T, data map[string]interface{}) { + nullableFieldsType, ok := data["nullableFieldsType"].(map[string]interface{}) + require.True(t, ok, "nullableFieldsType should be an object") + require.NotEmpty(t, nullableFieldsType, "nullableFieldsType should not be empty") + + // Check required fields are present + require.Contains(t, nullableFieldsType, "id") + require.Contains(t, nullableFieldsType, "name") + require.Contains(t, nullableFieldsType, "requiredString1") + require.Contains(t, nullableFieldsType, "requiredInt1") + + require.NotEmpty(t, nullableFieldsType["id"], "id should not be empty") + require.NotEmpty(t, nullableFieldsType["name"], "name should not be empty") + require.NotEmpty(t, nullableFieldsType["requiredString1"], "requiredString1 should not be empty") + require.NotEmpty(t, nullableFieldsType["requiredInt1"], "requiredInt1 should not be empty") + + // Check optional fields are present (but may be null) + require.Contains(t, nullableFieldsType, "optionalString1") + require.Contains(t, nullableFieldsType, "optionalInt1") + require.Contains(t, nullableFieldsType, "optionalFloat1") + require.Contains(t, nullableFieldsType, "optionalBoolean1") + + // Verify types of non-null optional fields + if nullableFieldsType["optionalString1"] != nil { + require.IsType(t, "", nullableFieldsType["optionalString1"]) + } + if nullableFieldsType["optionalInt1"] != nil { + require.IsType(t, float64(0), nullableFieldsType["optionalInt1"]) // JSON numbers are float64 + } + if nullableFieldsType["optionalFloat1"] != nil { + require.IsType(t, float64(0), nullableFieldsType["optionalFloat1"]) + } + if nullableFieldsType["optionalBoolean1"] != nil { + require.IsType(t, false, nullableFieldsType["optionalBoolean1"]) + } + }, + }, { name: "Query nullable fields type by ID", query: `query($id: ID!) { nullableFieldsTypeById(id: $id) { id name optionalString requiredString } }`, @@ -3391,3 +3488,121 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { }) } } + +func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { + conn, cleanup := setupTestGRPCServer(t) + t.Cleanup(cleanup) + + testCases := []struct { + name string + query string + vars string + federationConfigs plan.FederationFieldConfigurations + validate func(t *testing.T, data map[string]interface{}) + }{ + { + name: "Query nullable fields type with all fields", + query: `query($representations: [_Any!]!) { _entities(representations: $representations) { ...on Product { id name } ...on Storage { id name } } }`, + vars: `{"variables":{"representations":[ + {"__typename":"Product","id":"1"}, + {"__typename":"Storage","id":"3"}, + {"__typename":"Product","id":"2"}, + {"__typename":"Storage","id":"4"} + ]}}`, + federationConfigs: plan.FederationFieldConfigurations{ + { + TypeName: "Product", + SelectionSet: "id", + }, + { + TypeName: "Storage", + SelectionSet: "id", + }, + }, + validate: func(t *testing.T, data map[string]interface{}) { + entities, ok := data["_entities"].([]interface{}) + require.True(t, ok, "_entities should be an array") + require.NotEmpty(t, entities, "_entities should not be empty") + + // Check required fields are present + require.Contains(t, entities[0], "id") + require.Contains(t, entities[0], "name") + require.Contains(t, entities[1], "id") + require.Contains(t, entities[1], "name") + + require.Len(t, entities, 4, "Should return 4 entities") + + product, ok := entities[0].(map[string]interface{}) + require.True(t, ok, "product should be an object") + require.Equal(t, "1", product["id"]) + require.Equal(t, "Product 1", product["name"]) + + storage, ok := entities[1].(map[string]interface{}) + require.True(t, ok, "storage should be an object") + require.Equal(t, "3", storage["id"]) + require.Equal(t, "Storage 3", storage["name"]) + + product2, ok := entities[2].(map[string]interface{}) + require.True(t, ok, "product2 should be an object") + require.Equal(t, "2", product2["id"]) + require.Equal(t, "Product 2", product2["name"]) + + storage2, ok := entities[3].(map[string]interface{}) + require.True(t, ok, "storage2 should be an object") + require.Equal(t, "4", storage2["id"]) + require.Equal(t, "Storage 4", storage2["name"]) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Parse the GraphQL schema + schemaDoc := grpctest.MustGraphQLSchema(t) + + // Parse the GraphQL query + queryDoc, report := astparser.ParseGraphqlDocumentString(tc.query) + if report.HasErrors() { + t.Fatalf("failed to parse query: %s", report.Error()) + } + + compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) + if err != nil { + t.Fatalf("failed to compile proto: %v", err) + } + + // Create the datasource + ds, err := NewDataSource(conn, DataSourceConfig{ + Operation: &queryDoc, + Definition: &schemaDoc, + SubgraphName: "Products", + Mapping: testMapping(), + Compiler: compiler, + FederationConfigs: tc.federationConfigs, + }) + require.NoError(t, err) + + // Execute the query through our datasource + output := new(bytes.Buffer) + input := fmt.Sprintf(`{"query":%q,"body":%s}`, tc.query, tc.vars) + err = ds.Load(context.Background(), []byte(input), output) + require.NoError(t, err) + + // Parse the response + var resp struct { + Data map[string]interface{} `json:"data"` + Errors []struct { + Message string `json:"message"` + } `json:"errors,omitempty"` + } + + err = json.Unmarshal(output.Bytes(), &resp) + require.NoError(t, err, "Failed to unmarshal response") + require.Empty(t, resp.Errors, "Response should not contain errors") + require.NotEmpty(t, resp.Data, "Response should contain data") + + // Run the validation function + tc.validate(t, resp.Data) + }) + } +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/json_builder.go b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go new file mode 100644 index 0000000000..1166fc9b92 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/json_builder.go @@ -0,0 +1,595 @@ +package grpcdatasource + +import ( + "errors" + "fmt" + "strconv" + + "github.com/tidwall/gjson" + "github.com/wundergraph/astjson" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + protoref "google.golang.org/protobuf/reflect/protoreflect" +) + +// Standard GraphQL response paths +var ( + entityPath = "_entities" // Path for federated entities in response + dataPath = "data" // Standard GraphQL data wrapper + errorsPath = "errors" // Standard GraphQL errors array +) + +// entityIndex represents the mapping between representation order and result order +// for GraphQL federation entities. This is crucial for maintaining correct entity +// order when multiple subgraphs return entities in different orders. +type entityIndex struct { + representationIndex int // Index in the original representation array + resultIndex int // Index where this entity should appear in the final result +} + +// indexMap maps GraphQL type names to their corresponding entity indices +// This allows proper ordering of federated entities by type +type indexMap map[string][]entityIndex + +// getResultIndex returns the correct result index for an entity based on its type +// and representation index. This ensures federated entities maintain proper ordering +// across multiple subgraph responses. +func (i indexMap) getResultIndex(val *astjson.Value, representationIndex int) int { + if i == nil { + return representationIndex + } + + if val == nil { + return representationIndex + } + + // Extract the __typename field to determine entity type + typeName := val.Get("__typename").GetStringBytes() + + // Find the correct result index for this type and representation index + for _, entityIndex := range i[string(typeName)] { + if entityIndex.representationIndex == representationIndex { + return entityIndex.resultIndex + } + } + + // Fallback to representation index if no mapping found + return representationIndex +} + +// createRepresentationIndexMap builds an index mapping for GraphQL federation entities +// from the variables containing entity representations. This map is used to ensure +// that entities are returned in the correct order when merging responses from multiple +// subgraphs, which is critical for GraphQL federation correctness. +func createRepresentationIndexMap(variables gjson.Result) indexMap { + var representations []gjson.Result + r := variables.Get("representations") + if !r.Exists() { + return nil + } + + representations = r.Array() + im := make(indexMap) + indexSet := make(map[string]int) // Track count per type name + + // Build mapping for each representation + for i, representation := range representations { + typeName := representation.Get("__typename").String() + + // Initialize counter for new type names + if _, ok := indexSet[typeName]; !ok { + indexSet[typeName] = -1 + } + + // Increment index for this type + indexSet[typeName]++ + + // Create mapping entry for this entity + im[typeName] = append(im[typeName], entityIndex{ + representationIndex: indexSet[typeName], // Position within entities of this type + resultIndex: i, // Position in the overall result array + }) + } + return im +} + +// jsonBuilder is the core component responsible for converting gRPC protobuf responses +// into GraphQL-compatible JSON format. It handles complex scenarios including: +// - GraphQL federation entity merging and ordering +// - Nested list structures with proper nullability handling +// - Protobuf to GraphQL type conversion +// - Error response formatting +type jsonBuilder struct { + mapping *GRPCMapping // Mapping configuration for GraphQL to gRPC translation + variables gjson.Result // GraphQL variables containing entity representations + indexMap indexMap // Entity index mapping for federation ordering +} + +// newJSONBuilder creates a new JSON builder instance with the provided mapping +// and variables. The builder automatically creates an index map for proper +// federation entity ordering if representations are present in the variables. +func newJSONBuilder(mapping *GRPCMapping, variables gjson.Result) *jsonBuilder { + return &jsonBuilder{ + mapping: mapping, + variables: variables, + indexMap: createRepresentationIndexMap(variables), + } +} + +// mergeValues combines two JSON values while preserving proper federation entity ordering. +// This is a critical function for GraphQL federation where multiple subgraphs may +// return entities that need to be merged in the correct order. +func (j *jsonBuilder) mergeValues(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { + if len(j.indexMap) == 0 { + // No federation index map available - use simple merge + // This path is taken for non-federated queries + root, _, err := astjson.MergeValues(left, right) + if err != nil { + return nil, err + } + return root, nil + } + + // Federation entities present - must preserve representation order + leftObject, err := left.Object() + if err != nil { + return nil, err + } + + // If left side is empty, just return right side + if leftObject.Len() == 0 { + return right, nil + } + + // Perform federation-aware entity merging + return j.mergeEntities(left, right) +} + +// mergeEntities performs federation-aware merging of entity arrays from multiple subgraph responses. +// This function ensures that entities are placed in the correct positions in the final response +// array based on their original representation order, which is critical for GraphQL federation. +func (j *jsonBuilder) mergeEntities(left *astjson.Value, right *astjson.Value) (*astjson.Value, error) { + root := astjson.Arena{} + + // Create the response structure with _entities array + entities := root.NewObject() + entities.Set(entityPath, root.NewArray()) + arr := entities.Get(entityPath) + + // Extract entity arrays from both responses + leftRepresentations, err := left.Get(entityPath).Array() + if err != nil { + return nil, err + } + + rightRepresentations, err := right.Get(entityPath).Array() + if err != nil { + return nil, err + } + + // Merge left entities using index mapping to preserve order + for index, lr := range leftRepresentations { + arr.SetArrayItem(j.indexMap.getResultIndex(lr, index), lr) + } + + // Merge right entities using index mapping to preserve order + for index, rr := range rightRepresentations { + arr.SetArrayItem(j.indexMap.getResultIndex(rr, index), rr) + } + + return entities, nil +} + +// marshalResponseJSON converts a protobuf message into a GraphQL-compatible JSON response. +// This is the core marshaling function that handles all the complex type conversions, +// including oneOf types, nested messages, lists, and scalar values. +func (j *jsonBuilder) marshalResponseJSON(arena *astjson.Arena, message *RPCMessage, data protoref.Message) (*astjson.Value, error) { + if message == nil { + return arena.NewNull(), nil + } + + root := arena.NewObject() + + // Handle protobuf oneOf types - these represent GraphQL union/interface types + if message.IsOneOf() { + oneof := data.Descriptor().Oneofs().ByName(protoref.Name(message.OneOfType.FieldName())) + if oneof == nil { + return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) + } + + // Determine which oneOf field is actually set + oneofDescriptor := data.WhichOneof(oneof) + if oneofDescriptor == nil { + return nil, fmt.Errorf("unable to build response JSON: oneof %s not found in message %s", message.OneOfType.FieldName(), message.Name) + } + + // Extract the actual message data from the oneOf wrapper + if oneofDescriptor.Kind() == protoref.MessageKind { + data = data.Get(oneofDescriptor).Message() + } + } + + // Determine which fields to include in the response + validFields := message.Fields + if message.IsOneOf() { + // For oneOf types, add type-specific fields based on the actual concrete type + validFields = append(validFields, message.FieldSelectionSet.SelectFieldsForTypes(message.SelectValidTypes(string(data.Type().Descriptor().Name())))...) + } + + // Process each field in the message + for _, field := range validFields { + // Handle static values (like __typename fields) + if field.StaticValue != "" { + if len(message.MemberTypes) == 0 { + // Simple static value - use as-is + root.Set(field.AliasOrPath(), arena.NewString(field.StaticValue)) + continue + } + + // Type-specific static value - match against member types + for _, memberTypes := range message.MemberTypes { + if memberTypes == string(data.Type().Descriptor().Name()) { + root.Set(field.AliasOrPath(), arena.NewString(memberTypes)) + break + } + } + + continue + } + + // Get the protobuf field descriptor for this GraphQL field + fd := data.Descriptor().Fields().ByName(protoref.Name(field.Name)) + if fd == nil { + // Field not found in protobuf message - skip it + continue + } + + // Handle list fields (repeated in protobuf) + if fd.IsList() { + list := data.Get(fd).List() + arr := arena.NewArray() + root.Set(field.AliasOrPath(), arr) + + if !list.IsValid() { + // Invalid list - leave as empty array + continue + } + + // Process each list item + for i := 0; i < list.Len(); i++ { + switch fd.Kind() { + case protoref.MessageKind: + // List of messages - recursively marshal each message + message := list.Get(i).Message() + value, err := j.marshalResponseJSON(arena, field.Message, message) + if err != nil { + return nil, err + } + + arr.SetArrayItem(i, value) + default: + // List of scalar values - convert directly + j.setArrayItem(i, arena, arr, list.Get(i), fd) + } + } + + continue + } + + // Handle message fields (nested objects) + if fd.Kind() == protoref.MessageKind { + msg := data.Get(fd).Message() + if !msg.IsValid() { + // Invalid message - set to null + root.Set(field.AliasOrPath(), arena.NewNull()) + continue + } + + // Handle special list wrapper types for complex nested lists + if field.IsListType { + arr, err := j.flattenListStructure(arena, field.ListMetadata, msg, field.Message) + if err != nil { + return nil, fmt.Errorf("unable to flatten list structure for field %q: %w", field.AliasOrPath(), err) + } + + root.Set(field.AliasOrPath(), arr) + continue + } + + // Handle optional scalar wrapper types (e.g., google.protobuf.StringValue) + if field.IsOptionalScalar() { + err := j.resolveOptionalField(arena, root, field.AliasOrPath(), msg) + if err != nil { + return nil, err + } + + continue + } + + // Regular nested message - recursively marshal + value, err := j.marshalResponseJSON(arena, field.Message, msg) + if err != nil { + return nil, err + } + + if field.JSONPath == "" { + // Field should be merged into parent object (flattened) + root, _, err = astjson.MergeValues(root, value) + if err != nil { + return nil, err + } + } else { + // Field should be nested under its own key + root.Set(field.AliasOrPath(), value) + } + + continue + } + + // Handle scalar fields (string, int, bool, etc.) + j.setJSONValue(arena, root, field.AliasOrPath(), data, fd) + } + + return root, nil +} + +// flattenListStructure handles complex nested list structures that are wrapped in protobuf +// messages to support nullable and multi-dimensional lists. This is necessary because +// protobuf doesn't directly support nullable list items or complex nesting scenarios +// that GraphQL allows. +func (j *jsonBuilder) flattenListStructure(arena *astjson.Arena, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { + if md == nil { + return arena.NewNull(), errors.New("list metadata not found") + } + + // Validate metadata consistency + if len(md.LevelInfo) < md.NestingLevel { + return arena.NewNull(), errors.New("nesting level data does not match the number of levels in the list metadata") + } + + // Handle null data with proper nullability checking + if !data.IsValid() { + if md.LevelInfo[0].Optional { + return arena.NewNull(), nil + } + + return arena.NewNull(), errors.New("cannot add null item to response for non nullable list") + } + + // Start recursive traversal of the nested list structure + root := arena.NewArray() + return j.traverseList(0, arena, root, md, data, message) +} + +// traverseList recursively traverses nested list wrapper structures to extract the actual +// list data. This handles multi-dimensional lists like [[String]] or [[[User]]] by +// unwrapping the protobuf message wrappers at each level. +func (j *jsonBuilder) traverseList(level int, arena *astjson.Arena, current *astjson.Value, md *ListMetadata, data protoref.Message, message *RPCMessage) (*astjson.Value, error) { + if level > md.NestingLevel { + return current, nil + } + + // List wrappers always use field number 1 in the generated protobuf + fd := data.Descriptor().Fields().ByNumber(1) + if fd == nil { + return arena.NewNull(), fmt.Errorf("field with number %d not found in message %q", 1, data.Descriptor().Name()) + } + + if fd.Kind() != protoref.MessageKind { + return arena.NewNull(), fmt.Errorf("field %q is not a message", fd.Name()) + } + + // Get the wrapper message containing the list + msg := data.Get(fd).Message() + if !msg.IsValid() { + // Handle null wrapper based on nullability rules + if md.LevelInfo[level].Optional { + return arena.NewNull(), nil + } + + return arena.NewArray(), fmt.Errorf("cannot add null item to response for non nullable list") + } + + // The actual list is always at field number 1 in the wrapper + fd = msg.Descriptor().Fields().ByNumber(1) + if !fd.IsList() { + return arena.NewNull(), fmt.Errorf("field %q is not a list", fd.Name()) + } + + // Handle intermediate nesting levels (not the final level) + if level < md.NestingLevel-1 { + list := msg.Get(fd).List() + for i := 0; i < list.Len(); i++ { + // Create nested array for next level + next := arena.NewArray() + val, err := j.traverseList(level+1, arena, next, md, list.Get(i).Message(), message) + if err != nil { + return nil, err + } + + current.SetArrayItem(i, val) + } + + return current, nil + } + + // Handle the final nesting level - extract actual data + list := msg.Get(fd).List() + if !list.IsValid() { + // Invalid list at final level - return empty array + // Nullability is checked at the wrapper level, not the list level + return arena.NewArray(), nil + } + + // Process each item in the final list + for i := 0; i < list.Len(); i++ { + if message != nil { + // List of complex objects - recursively marshal each item + val, err := j.marshalResponseJSON(arena, message, list.Get(i).Message()) + if err != nil { + return nil, err + } + + current.SetArrayItem(i, val) + } else { + // List of scalar values - convert directly + j.setArrayItem(i, arena, current, list.Get(i), fd) + } + } + + return current, nil +} + +// resolveOptionalField extracts the value from optional scalar wrapper types like +// google.protobuf.StringValue, google.protobuf.Int32Value, etc. These wrappers +// are used to represent nullable scalar values in protobuf. +func (j *jsonBuilder) resolveOptionalField(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message) error { + // Optional scalar wrappers always have a "value" field + fd := data.Descriptor().Fields().ByName(protoref.Name("value")) + if fd == nil { + return fmt.Errorf("unable to resolve optional field: field %q not found in message %s", "value", data.Descriptor().Name()) + } + + // Extract and set the wrapped value + j.setJSONValue(arena, root, name, data, fd) + return nil +} + +// setJSONValue converts a protobuf field value to the appropriate JSON representation +// and sets it on the provided JSON object. This handles all protobuf scalar types +// and enum values with proper GraphQL mapping. +func (j *jsonBuilder) setJSONValue(arena *astjson.Arena, root *astjson.Value, name string, data protoref.Message, fd protoref.FieldDescriptor) { + if !data.IsValid() { + root.Set(name, arena.NewNull()) + return + } + + switch fd.Kind() { + case protoref.BoolKind: + boolValue := data.Get(fd).Bool() + if boolValue { + root.Set(name, arena.NewTrue()) + } else { + root.Set(name, arena.NewFalse()) + } + case protoref.StringKind: + root.Set(name, arena.NewString(data.Get(fd).String())) + case protoref.Int32Kind: + root.Set(name, arena.NewNumberInt(int(data.Get(fd).Int()))) + case protoref.Int64Kind: + root.Set(name, arena.NewNumberString(strconv.FormatInt(data.Get(fd).Int(), 10))) + case protoref.Uint32Kind, protoref.Uint64Kind: + root.Set(name, arena.NewNumberString(strconv.FormatUint(data.Get(fd).Uint(), 10))) + case protoref.FloatKind, protoref.DoubleKind: + root.Set(name, arena.NewNumberFloat64(data.Get(fd).Float())) + case protoref.BytesKind: + root.Set(name, arena.NewStringBytes(data.Get(fd).Bytes())) + case protoref.EnumKind: + enumDesc := fd.Enum() + enumValueDesc := enumDesc.Values().ByNumber(data.Get(fd).Enum()) + if enumValueDesc == nil { + root.Set(name, arena.NewNull()) + return + } + + // Look up the GraphQL enum value mapping + graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) + if !ok { + // No mapping found - set to null + root.Set(name, arena.NewNull()) + return + } + + root.Set(name, arena.NewString(graphqlValue)) + } +} + +// setArrayItem converts a protobuf list item value to JSON and sets it at the specified +// array index. This is similar to setJSONValue but operates on array elements rather +// than object properties, and works with protobuf Value types rather than Message types. +func (j *jsonBuilder) setArrayItem(index int, arena *astjson.Arena, array *astjson.Value, data protoref.Value, fd protoref.FieldDescriptor) { + if !data.IsValid() { + array.SetArrayItem(index, arena.NewNull()) + return + } + + switch fd.Kind() { + case protoref.BoolKind: + boolValue := data.Bool() + if boolValue { + array.SetArrayItem(index, arena.NewTrue()) + } else { + array.SetArrayItem(index, arena.NewFalse()) + } + case protoref.StringKind: + array.SetArrayItem(index, arena.NewString(data.String())) + case protoref.Int32Kind: + array.SetArrayItem(index, arena.NewNumberInt(int(data.Int()))) + case protoref.Int64Kind: + array.SetArrayItem(index, arena.NewNumberString(strconv.FormatInt(data.Int(), 10))) + case protoref.Uint32Kind, protoref.Uint64Kind: + array.SetArrayItem(index, arena.NewNumberString(strconv.FormatUint(data.Uint(), 10))) + case protoref.FloatKind, protoref.DoubleKind: + array.SetArrayItem(index, arena.NewNumberFloat64(data.Float())) + case protoref.BytesKind: + array.SetArrayItem(index, arena.NewStringBytes(data.Bytes())) + case protoref.EnumKind: + enumDesc := fd.Enum() + enumValueDesc := enumDesc.Values().ByNumber(data.Enum()) + if enumValueDesc == nil { + array.SetArrayItem(index, arena.NewNull()) + return + } + + // Look up GraphQL enum mapping + graphqlValue, ok := j.mapping.ResolveEnumValue(string(enumDesc.Name()), string(enumValueDesc.Name())) + if !ok { + // No mapping found - use null + array.SetArrayItem(index, arena.NewNull()) + return + } + + array.SetArrayItem(index, arena.NewString(graphqlValue)) + } +} + +// toDataObject wraps a response value in the standard GraphQL data envelope. +// This creates the top-level structure { "data": ... } that GraphQL clients expect. +func (j *jsonBuilder) toDataObject(root *astjson.Value) *astjson.Value { + a := astjson.Arena{} + data := a.NewObject() + data.Set(dataPath, root) + return data +} + +// writeErrorBytes creates a properly formatted GraphQL error response in JSON format. +// This includes the error message and gRPC status code information in the extensions +// field, following GraphQL error specification standards. +func (j *jsonBuilder) writeErrorBytes(err error) []byte { + a := astjson.Arena{} + defer a.Reset() + + // Create standard GraphQL error structure + errorRoot := a.NewObject() + errorArray := a.NewArray() + errorRoot.Set(errorsPath, errorArray) + + // Create individual error object + errorItem := a.NewObject() + errorItem.Set("message", a.NewString(err.Error())) + + // Add gRPC status code information to extensions + extensions := a.NewObject() + if st, ok := status.FromError(err); ok { + // gRPC error - include the specific status code + extensions.Set("code", a.NewString(st.Code().String())) + } else { + // Generic error - default to INTERNAL status + extensions.Set("code", a.NewString(codes.Internal.String())) + } + + errorItem.Set("extensions", extensions) + errorArray.SetArrayItem(0, errorItem) + + return errorRoot.MarshalTo(nil) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go b/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go index 146f371778..bb52f42954 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go +++ b/v2/pkg/engine/datasource/grpc_datasource/mapping_test_helper.go @@ -218,21 +218,25 @@ func testMapping() *GRPCMapping { }, }, SubscriptionRPCs: RPCConfigMap{}, - EntityRPCs: map[string]EntityRPCConfig{ + EntityRPCs: map[string][]EntityRPCConfig{ "Product": { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupProductById", - Request: "LookupProductByIdRequest", - Response: "LookupProductByIdResponse", + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupProductById", + Request: "LookupProductByIdRequest", + Response: "LookupProductByIdResponse", + }, }, }, "Storage": { - Key: "id", - RPCConfig: RPCConfig{ - RPC: "LookupStorageById", - Request: "LookupStorageByIdRequest", - Response: "LookupStorageByIdResponse", + { + Key: "id", + RPCConfig: RPCConfig{ + RPC: "LookupStorageById", + Request: "LookupStorageByIdRequest", + Response: "LookupStorageByIdResponse", + }, }, }, }, @@ -624,6 +628,14 @@ func testMapping() *GRPCMapping { TargetName: "filter", }, }, + "FilterTypeInput": { + "filterField1": { + TargetName: "filter_field_1", + }, + "filterField2": { + TargetName: "filter_field_2", + }, + }, "Category": { "id": { TargetName: "id", diff --git a/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go new file mode 100644 index 0000000000..cd54e7ccf4 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/required_fields_visitor.go @@ -0,0 +1,170 @@ +package grpcdatasource + +import ( + "errors" + "fmt" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/astvisitor" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/operationreport" +) + +// requiredFieldsVisitor is a visitor that visits the required fields of a message. +type requiredFieldsVisitor struct { + operation *ast.Document + definition *ast.Document + + walker *astvisitor.Walker + message *RPCMessage + + planCtx *rpcPlanningContext + + messageAncestors []*RPCMessage +} + +// newRequiredFieldsVisitor creates a new requiredFieldsVisitor. +// It registers the visitor with the walker and returns it. +func newRequiredFieldsVisitor(walker *astvisitor.Walker, message *RPCMessage, planCtx *rpcPlanningContext) *requiredFieldsVisitor { + visitor := &requiredFieldsVisitor{ + walker: walker, + message: message, + planCtx: planCtx, + messageAncestors: []*RPCMessage{}, + } + + walker.RegisterEnterDocumentVisitor(visitor) + walker.RegisterSelectionSetVisitor(visitor) + walker.RegisterEnterFieldVisitor(visitor) + + return visitor +} + +// visitRequiredFields visits the required fields of a message. +// It creates a new document with the required fields and walks it. +// To achieve that we create a fragment with the required fields and walk it. +func (r *requiredFieldsVisitor) visitRequiredFields(definition *ast.Document, typeName, requiredFields string) error { + doc, report := plan.RequiredFieldsFragment(typeName, requiredFields, false) + if report.HasErrors() { + return report + } + + r.message.MemberTypes = []string{typeName} + r.walker.Walk(doc, definition, report) + if report.HasErrors() { + return report + } + + return nil +} + +// EnterDocument implements astvisitor.EnterDocumentVisitor. +func (r *requiredFieldsVisitor) EnterDocument(operation *ast.Document, definition *ast.Document) { + if r.message == nil { + r.walker.StopWithInternalErr(errors.New("unable to visit required fields. Message is required")) + return + } + + r.operation = operation + r.definition = definition +} + +// EnterSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *requiredFieldsVisitor) EnterSelectionSet(ref int) { + // Ignore the root selection set + if r.walker.Ancestor().Kind == ast.NodeKindFragmentDefinition { + return + } + + if len(r.message.Fields) == 0 { + r.walker.StopWithInternalErr(errors.New("cannot access last field: message has no fields")) + return + } + + lastField := &r.message.Fields[len(r.message.Fields)-1] + if lastField.Message == nil { + lastField.Message = r.planCtx.newMessageFromSelectionSet(r.walker.EnclosingTypeDefinition, ref) + } + + r.messageAncestors = append(r.messageAncestors, r.message) + r.message = lastField.Message + + if err := r.handleCompositeType(r.walker.EnclosingTypeDefinition); err != nil { + r.walker.StopWithInternalErr(err) + return + } +} + +// LeaveSelectionSet implements astvisitor.SelectionSetVisitor. +func (r *requiredFieldsVisitor) LeaveSelectionSet(ref int) { + if r.walker.Ancestor().Kind == ast.NodeKindFragmentDefinition { + return + } + + if len(r.messageAncestors) > 0 { + r.message = r.messageAncestors[len(r.messageAncestors)-1] + r.messageAncestors = r.messageAncestors[:len(r.messageAncestors)-1] + } +} + +// EnterField implements astvisitor.EnterFieldVisitor. +func (r *requiredFieldsVisitor) EnterField(ref int) { + fieldName := r.operation.FieldNameString(ref) + + // prevent duplicate fields + if r.message.Fields.Exists(fieldName, "") { + return + } + + fd, ok := r.walker.FieldDefinition(ref) + if !ok { + r.walker.Report.AddExternalError(operationreport.ExternalError{ + Message: fmt.Sprintf("Field %s not found in definition %s", fieldName, r.walker.EnclosingTypeDefinition.NameString(r.definition)), + }) + return + } + + field, err := r.planCtx.buildField(r.walker.EnclosingTypeDefinition, fd, fieldName, "") + if err != nil { + r.walker.StopWithInternalErr(err) + return + } + + r.message.Fields = append(r.message.Fields, field) +} + +func (r *requiredFieldsVisitor) handleCompositeType(node ast.Node) error { + if node.Ref < 0 { + return nil + } + + var ( + ok bool + oneOfType OneOfType + memberTypes []string + ) + + switch node.Kind { + case ast.NodeKindField: + return r.handleCompositeType(r.walker.EnclosingTypeDefinition) + case ast.NodeKindInterfaceTypeDefinition: + oneOfType = OneOfTypeInterface + memberTypes, ok = r.definition.InterfaceTypeDefinitionImplementedByObjectWithNames(node.Ref) + if !ok { + return fmt.Errorf("interface type %s is not implemented by any object", r.definition.InterfaceTypeDefinitionNameString(node.Ref)) + } + case ast.NodeKindUnionTypeDefinition: + oneOfType = OneOfTypeUnion + memberTypes, ok = r.definition.UnionTypeDefinitionMemberTypeNames(node.Ref) + if !ok { + return fmt.Errorf("union type %s is not defined", r.definition.UnionTypeDefinitionNameString(node.Ref)) + } + default: + return nil + } + + r.message.OneOfType = oneOfType + r.message.MemberTypes = memberTypes + + return nil +} diff --git a/v2/pkg/grpctest/mapping/mapping.go b/v2/pkg/grpctest/mapping/mapping.go index 842a029997..1476d084ab 100644 --- a/v2/pkg/grpctest/mapping/mapping.go +++ b/v2/pkg/grpctest/mapping/mapping.go @@ -225,21 +225,25 @@ func DefaultGRPCMapping() *grpcdatasource.GRPCMapping { }, }, SubscriptionRPCs: grpcdatasource.RPCConfigMap{}, - EntityRPCs: map[string]grpcdatasource.EntityRPCConfig{ + EntityRPCs: map[string][]grpcdatasource.EntityRPCConfig{ "Product": { - Key: "id", - RPCConfig: grpcdatasource.RPCConfig{ - RPC: "LookupProductById", - Request: "LookupProductByIdRequest", - Response: "LookupProductByIdResponse", + { + Key: "id", + RPCConfig: grpcdatasource.RPCConfig{ + RPC: "LookupProductById", + Request: "LookupProductByIdRequest", + Response: "LookupProductByIdResponse", + }, }, }, "Storage": { - Key: "id", - RPCConfig: grpcdatasource.RPCConfig{ - RPC: "LookupStorageById", - Request: "LookupStorageByIdRequest", - Response: "LookupStorageByIdResponse", + { + Key: "id", + RPCConfig: grpcdatasource.RPCConfig{ + RPC: "LookupStorageById", + Request: "LookupStorageByIdRequest", + Response: "LookupStorageByIdResponse", + }, }, }, },