diff --git a/v2/pkg/engine/datasource/graphql_datasource/configuration.go b/v2/pkg/engine/datasource/graphql_datasource/configuration.go index ec398c0e05..508807b419 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/configuration.go +++ b/v2/pkg/engine/datasource/graphql_datasource/configuration.go @@ -20,7 +20,18 @@ type ConfigurationInput struct { SchemaConfiguration *SchemaConfiguration CustomScalarTypeFields []SingleTypeField - GRPC *grpcdatasource.GRPCConfiguration + GRPC *grpcdatasource.GRPCConfiguration + Connect *ConnectConfiguration +} + +// ConnectConfiguration holds Connect protocol-specific configuration. +// The GRPC field in ConfigurationInput is still used for Mapping and Compiler, +// as Connect shares the same protobuf schema definitions. +type ConnectConfiguration struct { + // BaseURL is the base URL of the Connect service (e.g., "http://localhost:8080"). + BaseURL string + // Encoding specifies the serialization format (Protobuf or JSON). + Encoding grpcdatasource.ConnectEncoding } type Configuration struct { @@ -29,7 +40,8 @@ type Configuration struct { schemaConfiguration SchemaConfiguration customScalarTypeFields []SingleTypeField - grpc *grpcdatasource.GRPCConfiguration + grpc *grpcdatasource.GRPCConfiguration + connect *ConnectConfiguration } func NewConfiguration(input ConfigurationInput) (Configuration, error) { @@ -46,8 +58,8 @@ func NewConfiguration(input ConfigurationInput) (Configuration, error) { cfg.schemaConfiguration = *input.SchemaConfiguration - if input.Fetch == nil && input.Subscription == nil && input.GRPC == nil { - return Configuration{}, errors.New("fetch or subscription or grpc configuration is required") + if input.Fetch == nil && input.Subscription == nil && input.GRPC == nil && input.Connect == nil { + return Configuration{}, errors.New("fetch or subscription or grpc or connect configuration is required") } if input.Fetch != nil { @@ -76,6 +88,15 @@ func NewConfiguration(input ConfigurationInput) (Configuration, error) { cfg.grpc = input.GRPC } + if input.Connect != nil { + cfg.connect = input.Connect + // Connect uses the same GRPC mapping/compiler for proto schema definitions. + // Ensure GRPC config is also provided when using Connect. + if input.GRPC == nil { + return Configuration{}, errors.New("GRPC configuration (mapping/compiler) is required when using Connect") + } + } + return cfg, nil } @@ -99,6 +120,10 @@ func (c *Configuration) IsGRPC() bool { return c.grpc != nil } +func (c *Configuration) IsConnect() bool { + return c.connect != nil +} + type SingleTypeField struct { TypeName string FieldName string diff --git a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go index f4268d1f6a..292412d0c4 100644 --- a/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go +++ b/v2/pkg/engine/datasource/graphql_datasource/graphql_datasource.go @@ -90,8 +90,9 @@ type Planner[T Configuration] struct { minifier *astminify.Minifier - // gRPC - grpcClient grpc.ClientConnInterface + // gRPC / Connect + grpcClient grpc.ClientConnInterface + connectTransport grpcdatasource.RPCTransport printKitPool *sync.Pool } @@ -360,7 +361,15 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { return resolve.FetchConfiguration{} } - dataSource, err = grpcdatasource.NewDataSource(p.grpcClient, grpcdatasource.DataSourceConfig{ + // Determine which transport to use: Connect or gRPC. + var transport grpcdatasource.RPCTransport + if p.connectTransport != nil { + transport = p.connectTransport + } else { + transport = grpcdatasource.NewGRPCTransport(p.grpcClient) + } + + dataSource, err = grpcdatasource.NewDataSource(transport, grpcdatasource.DataSourceConfig{ Operation: &opDocument, Definition: p.config.schemaConfiguration.upstreamSchemaAst, Mapping: p.config.grpc.Mapping, @@ -372,7 +381,7 @@ func (p *Planner[T]) ConfigureFetch() resolve.FetchConfiguration { SubgraphName: p.dataSourceConfig.Name(), }) if err != nil { - p.stopWithError(errors.WithStack(fmt.Errorf("failed to create gRPC datasource: %w", err))) + p.stopWithError(errors.WithStack(fmt.Errorf("failed to create datasource: %w", err))) return resolve.FetchConfiguration{} } } @@ -1728,10 +1737,11 @@ func getRelaxedPrintKitPool() *sync.Pool { } type Factory[T Configuration] struct { - executionContext context.Context + executionContext context.Context httpClient *http.Client grpcClient grpc.ClientConnInterface grpcClientProvider func() grpc.ClientConnInterface + connectTransport grpcdatasource.RPCTransport subscriptionClient GraphQLSubscriptionClient printKitPool *sync.Pool } @@ -1795,6 +1805,23 @@ func NewFactoryGRPCClientProvider(executionContext context.Context, clientProvid }, nil } +// NewFactoryConnect creates a Connect protocol factory for the GraphQL datasource planner. +// It uses the Connect protocol (HTTP-native RPC) instead of gRPC for transport. +func NewFactoryConnect(executionContext context.Context, connectTransport grpcdatasource.RPCTransport) (*Factory[Configuration], error) { + if executionContext == nil { + return nil, fmt.Errorf("execution context is required") + } + + if connectTransport == nil { + return nil, fmt.Errorf("connect transport is required") + } + + return &Factory[Configuration]{ + executionContext: executionContext, + connectTransport: connectTransport, + }, nil +} + func (p *Planner[T]) getKit() *printKit { pool := p.printKitPool if pool == nil { @@ -1837,6 +1864,7 @@ func (f *Factory[T]) Planner(logger abstractlogger.Logger) plan.DataSourcePlanne return &Planner[T]{ fetchClient: f.httpClient, grpcClient: grpcClient, + connectTransport: f.connectTransport, subscriptionClient: f.subscriptionClient, printKitPool: f.getPrintKitPool(), } @@ -1862,7 +1890,7 @@ func (f *Factory[T]) PlanningBehavior() plan.DataSourcePlanningBehavior { MergeAliasedRootNodes: true, OverrideFieldPathFromAlias: true, AllowPlanningTypeName: true, - AlwaysFlattenFragments: f.grpcClient != nil || f.grpcClientProvider != nil, + AlwaysFlattenFragments: f.grpcClient != nil || f.grpcClientProvider != nil || f.connectTransport != nil, } return b } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go index a4f4309926..c1c9bb4cc3 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource.go @@ -43,7 +43,7 @@ var _ resolve.DataSource = (*DataSource)(nil) // transforms the responses back to GraphQL format. type DataSource struct { plan *RPCExecutionPlan - cc grpc.ClientConnInterface + transport RPCTransport rc *RPCCompiler mapping *GRPCMapping federationConfigs plan.FederationFieldConfigurations @@ -66,8 +66,8 @@ type DataSourceConfig struct { Disabled bool } -// NewDataSource creates a new gRPC datasource -func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*DataSource, error) { +// NewDataSource creates a new datasource with the given RPCTransport. +func NewDataSource(transport RPCTransport, config DataSourceConfig) (*DataSource, error) { planner, err := NewPlanner(config.SubgraphName, config.Mapping, config.FederationConfigs) if err != nil { return nil, err @@ -79,7 +79,7 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D return &DataSource{ plan: plan, - cc: client, + transport: transport, rc: config.Compiler, mapping: config.Mapping, federationConfigs: config.FederationConfigs, @@ -88,6 +88,12 @@ func NewDataSource(client grpc.ClientConnInterface, config DataSourceConfig) (*D }, nil } +// NewDataSourceGRPC creates a new gRPC datasource using a gRPC ClientConnInterface. +// This is a convenience function that wraps the connection in a grpcTransport. +func NewDataSourceGRPC(client grpc.ClientConnInterface, config DataSourceConfig) (*DataSource, error) { + return NewDataSource(NewGRPCTransport(client), config) +} + // Load implements resolve.DataSource interface. // It processes the input JSON data to make gRPC calls and returns // the response data. @@ -149,7 +155,7 @@ func (d *DataSource) Load(ctx context.Context, headers http.Header, input []byte errGrp.Go(func() error { // Invoke the gRPC method - this will populate serviceCall.Output - err := d.cc.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) + err := d.transport.Invoke(errGrpCtx, serviceCall.MethodFullName(), serviceCall.Input, serviceCall.Output) if err != nil { return err } diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go index bdffde44c3..4e67891107 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_federation_test.go @@ -178,7 +178,7 @@ func Test_DataSource_Load_WithEntity_Calls(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -478,7 +478,7 @@ func Test_DataSource_Load_WithEntity_Calls_WithCompositeTypes(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1267,7 +1267,7 @@ func Test_DataSource_Load_WithEntity_Calls_And_Requires(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1682,7 +1682,7 @@ func Test_DataSource_Load_WithEntity_Calls_And_Requires_And_FieldResolvers(t *te } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go index 2ad9172916..1c3b358bd4 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_spy_test.go @@ -150,7 +150,7 @@ func Test_DataSource_Load_NullMetrics_NestedResolversNotInvoked(t *testing.T) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -191,7 +191,7 @@ func Test_DataSource_Load_NullCategory_FieldResolversNotInvoked(t *testing.T) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -226,7 +226,7 @@ func Test_DataSource_Load_ArgumentLessFieldResolversCalled(t *testing.T) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -267,7 +267,7 @@ func Test_DataSource_Load_NullCategory_ArgumentLessFieldResolversNotInvoked(t *t compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) require.NoError(t, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSource(NewGRPCTransport(conn), DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", diff --git a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go index cb5e0025cc..d6166bd8af 100644 --- a/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go +++ b/v2/pkg/engine/datasource/grpc_datasource/grpc_datasource_test.go @@ -44,7 +44,7 @@ func Benchmark_DataSource_Load(b *testing.B) { compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(b), testMapping()) require.NoError(b, err) - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -81,7 +81,7 @@ func Benchmark_DataSource_Load_WithFieldArguments(b *testing.B) { const subgraphName = "Products" - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: subgraphName, @@ -192,7 +192,7 @@ func Test_DataSource_Load(t *testing.T) { } mi := mockInterface{} - ds, err := NewDataSource(mi, DataSourceConfig{ + ds, err := NewDataSourceGRPC(mi, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -247,7 +247,7 @@ func Test_DataSource_Load_WithMockService(t *testing.T) { } // 2. Create a datasource with the real gRPC client connection - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -336,7 +336,7 @@ func Test_DataSource_Load_WithMockService_WithResponseMapping(t *testing.T) { } // 2. Create a datasource with the real gRPC client connection - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -437,7 +437,7 @@ func Test_DataSource_Load_WithGrpcError(t *testing.T) { } // 3. Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -786,7 +786,7 @@ func Test_DataSource_Load_WithAnimalInterface(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1056,7 +1056,7 @@ func Test_Datasource_Load_WithUnionTypes(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1192,7 +1192,7 @@ func Test_DataSource_Load_WithCategoryQueries(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1272,7 +1272,7 @@ func Test_DataSource_Load_WithTotalCalculation(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1362,7 +1362,7 @@ func Test_DataSource_Load_WithTypename(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -1831,7 +1831,7 @@ func Test_DataSource_Load_WithAliases(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -2209,7 +2209,7 @@ func Test_DataSource_Load_WithNullableFieldsType(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -3510,7 +3510,7 @@ func Test_DataSource_Load_WithNestedLists(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -4744,7 +4744,7 @@ func Test_Datasource_Load_WithFieldResolvers(t *testing.T) { } // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -4953,7 +4953,7 @@ func Test_Datasource_Load_WithHeaders(t *testing.T) { require.NoError(t, err) // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", @@ -5004,7 +5004,7 @@ func Test_Datasource_Load_PreservesExistingContextMetadata(t *testing.T) { require.NoError(t, err) // Create the datasource - ds, err := NewDataSource(conn, DataSourceConfig{ + ds, err := NewDataSourceGRPC(conn, DataSourceConfig{ Operation: &queryDoc, Definition: &schemaDoc, SubgraphName: "Products", diff --git a/v2/pkg/engine/datasource/grpc_datasource/transport.go b/v2/pkg/engine/datasource/grpc_datasource/transport.go new file mode 100644 index 0000000000..6bdd0208d1 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/transport.go @@ -0,0 +1,36 @@ +package grpcdatasource + +import ( + "context" + "errors" + + "google.golang.org/grpc" + protoref "google.golang.org/protobuf/reflect/protoreflect" +) + +// RPCTransport abstracts the transport protocol for RPC calls. +// Both gRPC and Connect protocol implement this interface. +type RPCTransport interface { + Invoke(ctx context.Context, methodFullName string, input, output protoref.Message) error +} + +// grpcTransport wraps grpc.ClientConnInterface to implement RPCTransport. +type grpcTransport struct { + cc grpc.ClientConnInterface +} + +// NewGRPCTransport creates an RPCTransport that delegates to a gRPC ClientConnInterface. +func NewGRPCTransport(cc grpc.ClientConnInterface) RPCTransport { + return &grpcTransport{cc: cc} +} + +func (t *grpcTransport) Invoke(ctx context.Context, method string, input, output protoref.Message) error { + if t.cc == nil { + return errors.New("grpc transport: nil client connection") + } + // grpc.ClientConnInterface.Invoke accepts (ctx, method, args any, reply any, opts ...grpc.CallOption). + // protoref.Message satisfies the any constraint; variadic opts can be omitted. + // This wrapper intentionally does not forward grpc.CallOption, as RPCTransport + // is protocol-agnostic. The existing grpc_datasource code does not use any CallOption at the Invoke site. + return t.cc.Invoke(ctx, method, input, output) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/transport_connect.go b/v2/pkg/engine/datasource/grpc_datasource/transport_connect.go new file mode 100644 index 0000000000..ff7c38ec90 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/transport_connect.go @@ -0,0 +1,174 @@ +package grpcdatasource + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + protoref "google.golang.org/protobuf/reflect/protoreflect" +) + +// ConnectEncoding represents the encoding format for Connect protocol requests. +type ConnectEncoding int + +const ( + // ConnectEncodingProtobuf uses binary protobuf encoding. + ConnectEncodingProtobuf ConnectEncoding = iota + // ConnectEncodingJSON uses JSON encoding via protojson. + ConnectEncodingJSON +) + +// maxConnectResponseSize limits the response body read from a Connect service to 10 MB +// to prevent memory exhaustion from unexpectedly large or malicious responses. +const maxConnectResponseSize = 10 * 1024 * 1024 + +// ConnectTransportConfig holds the configuration for creating a Connect transport. +type ConnectTransportConfig struct { + // BaseURL is the base URL of the Connect service (e.g., "http://localhost:8080"). + BaseURL string + // HTTPClient is the HTTP client to use. If nil, http.DefaultClient is used. + HTTPClient *http.Client + // Encoding specifies the serialization format (Protobuf or JSON). + Encoding ConnectEncoding +} + +// connectTransport implements RPCTransport using the Connect protocol over HTTP. +type connectTransport struct { + baseURL string + httpClient *http.Client + encoding ConnectEncoding +} + +// NewConnectTransport creates an RPCTransport that uses the Connect protocol. +func NewConnectTransport(config ConnectTransportConfig) RPCTransport { + httpClient := config.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + + return &connectTransport{ + baseURL: strings.TrimRight(config.BaseURL, "/"), + httpClient: httpClient, + encoding: config.Encoding, + } +} + +func (t *connectTransport) Invoke(ctx context.Context, methodFullName string, input, output protoref.Message) error { + url := t.baseURL + methodFullName + + // protoref.Message (protoreflect.Message) is the reflection interface. + // The underlying runtime type (*dynamicpb.Message) implements proto.Message, + // which is required by proto.Marshal / protojson.Marshal. + // We need to get the proto.Message interface via the ProtoReflect().Interface() path, + // or directly type-assert since dynamicpb.Message implements proto.Message. + inputMsg, ok := input.Interface().(proto.Message) + if !ok { + return fmt.Errorf("connect: input does not implement proto.Message") + } + + var reqBody []byte + var contentType string + var err error + + switch t.encoding { + case ConnectEncodingProtobuf: + contentType = "application/proto" + reqBody, err = proto.Marshal(inputMsg) + case ConnectEncodingJSON: + contentType = "application/json" + reqBody, err = protojson.Marshal(inputMsg) + default: + return fmt.Errorf("connect: unsupported encoding: %d", t.encoding) + } + if err != nil { + return fmt.Errorf("connect: marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(reqBody)) + if err != nil { + return fmt.Errorf("connect: create request: %w", err) + } + req.Header.Set("Content-Type", contentType) + req.Header.Set("Connect-Protocol-Version", "1") + + // Forward gRPC metadata as HTTP headers. + // Keys ending in "-bin" carry binary values; the Connect protocol requires + // these to be base64-encoded before placing them in an HTTP header. + if md, ok := metadata.FromOutgoingContext(ctx); ok { + for k, vals := range md { + isBin := strings.HasSuffix(k, "-bin") + for _, v := range vals { + if isBin { + req.Header.Add(k, base64.StdEncoding.EncodeToString([]byte(v))) + } else { + req.Header.Add(k, v) + } + } + } + } + + resp, err := t.httpClient.Do(req) + if err != nil { + return fmt.Errorf("connect: send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(io.LimitReader(resp.Body, maxConnectResponseSize+1)) + if err != nil { + return fmt.Errorf("connect: read response: %w", err) + } + if len(respBody) > maxConnectResponseSize { + return fmt.Errorf("connect: response body exceeds %d bytes", maxConnectResponseSize) + } + + if resp.StatusCode != http.StatusOK { + return parseConnectError(resp.StatusCode, respBody) + } + + // Unmarshal response into the output message. + outputMsg, ok := output.Interface().(proto.Message) + if !ok { + return fmt.Errorf("connect: output does not implement proto.Message") + } + + switch t.encoding { + case ConnectEncodingProtobuf: + err = proto.Unmarshal(respBody, outputMsg) + case ConnectEncodingJSON: + err = protojson.Unmarshal(respBody, outputMsg) + } + if err != nil { + return fmt.Errorf("connect: unmarshal response: %w", err) + } + + return nil +} + +// connectError represents an error response from a Connect protocol service. +type connectError struct { + Code string `json:"code"` + Message string `json:"message"` +} + +// parseConnectError parses an error response body from a Connect service. +// Connect errors are always JSON-encoded regardless of the request encoding. +// If the body is not valid JSON (e.g., a 502 from a reverse proxy), falls back to raw status code. +func parseConnectError(statusCode int, body []byte) error { + var ce connectError + if err := json.Unmarshal(body, &ce); err != nil { + bodyStr := string(body) + if len(bodyStr) > 256 { + bodyStr = bodyStr[:256] + "... (truncated)" + } + return fmt.Errorf("connect: HTTP %d: %s", statusCode, bodyStr) + } + return fmt.Errorf("connect: %s: %s", ce.Code, ce.Message) +} diff --git a/v2/pkg/engine/datasource/grpc_datasource/transport_connect_test.go b/v2/pkg/engine/datasource/grpc_datasource/transport_connect_test.go new file mode 100644 index 0000000000..4665568bc3 --- /dev/null +++ b/v2/pkg/engine/datasource/grpc_datasource/transport_connect_test.go @@ -0,0 +1,264 @@ +package grpcdatasource + +import ( + "context" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + protoref "google.golang.org/protobuf/reflect/protoreflect" + "google.golang.org/protobuf/types/dynamicpb" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/grpctest" +) + +func newTestCompiler(t *testing.T) *RPCCompiler { + t.Helper() + compiler, err := NewProtoCompiler(grpctest.MustProtoSchema(t), testMapping()) + require.NoError(t, err) + return compiler +} + +func findMessageDesc(t *testing.T, compiler *RPCCompiler, fullName string) protoref.MessageDescriptor { + t.Helper() + for _, m := range compiler.doc.Messages { + if string(m.Desc.FullName()) == fullName { + return m.Desc + } + } + t.Fatalf("message %q not found in proto document", fullName) + return nil +} + +func TestConnectTransport_Invoke_Protobuf(t *testing.T) { + compiler := newTestCompiler(t) + + reqDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryResponse") + categoryDesc := findMessageDesc(t, compiler, "productv1.Category") + + // Build a response message. + respMsg := dynamicpb.NewMessage(respDesc) + categoryMsg := dynamicpb.NewMessage(categoryDesc) + categoryMsg.Set(categoryDesc.Fields().ByName("id"), protoref.ValueOfString("cat-123")) + categoryMsg.Set(categoryDesc.Fields().ByName("name"), protoref.ValueOfString("Electronics")) + respMsg.Set(respDesc.Fields().ByName("category"), protoref.ValueOfMessage(categoryMsg)) + + respBytes, err := proto.Marshal(respMsg) + require.NoError(t, err) + + var receivedContentType string + var receivedProtocolVersion string + var receivedBody []byte + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedContentType = r.Header.Get("Content-Type") + receivedProtocolVersion = r.Header.Get("Connect-Protocol-Version") + receivedBody, _ = io.ReadAll(r.Body) + w.Header().Set("Content-Type", "application/proto") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(respBytes) + })) + defer server.Close() + + transport := NewConnectTransport(ConnectTransportConfig{ + BaseURL: server.URL, + Encoding: ConnectEncodingProtobuf, + }) + + inputMsg := dynamicpb.NewMessage(reqDesc) + inputMsg.Set(reqDesc.Fields().ByName("id"), protoref.ValueOfString("cat-123")) + + outputMsg := dynamicpb.NewMessage(respDesc) + + err = transport.Invoke(context.Background(), "/productv1.ProductService/QueryCategory", inputMsg, outputMsg) + require.NoError(t, err) + + require.Equal(t, "application/proto", receivedContentType) + require.Equal(t, "1", receivedProtocolVersion) + require.NotEmpty(t, receivedBody) + + outputJSON, err := protojson.Marshal(outputMsg) + require.NoError(t, err) + require.Contains(t, string(outputJSON), "cat-123") + require.Contains(t, string(outputJSON), "Electronics") +} + +func TestConnectTransport_Invoke_JSON(t *testing.T) { + compiler := newTestCompiler(t) + + reqDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryResponse") + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "application/json", r.Header.Get("Content-Type")) + + body, _ := io.ReadAll(r.Body) + require.True(t, json.Valid(body), "request body should be valid JSON: %s", string(body)) + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte(`{"category":{"id":"cat-456","name":"Books"}}`)) + })) + defer server.Close() + + transport := NewConnectTransport(ConnectTransportConfig{ + BaseURL: server.URL, + Encoding: ConnectEncodingJSON, + }) + + inputMsg := dynamicpb.NewMessage(reqDesc) + inputMsg.Set(reqDesc.Fields().ByName("id"), protoref.ValueOfString("cat-456")) + + outputMsg := dynamicpb.NewMessage(respDesc) + + err := transport.Invoke(context.Background(), "/productv1.ProductService/QueryCategory", inputMsg, outputMsg) + require.NoError(t, err) + + outputJSON, err := protojson.Marshal(outputMsg) + require.NoError(t, err) + require.Contains(t, string(outputJSON), "cat-456") + require.Contains(t, string(outputJSON), "Books") +} + +func TestConnectTransport_Invoke_ConnectError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"code":"not_found","message":"category not found"}`)) + })) + defer server.Close() + + transport := NewConnectTransport(ConnectTransportConfig{ + BaseURL: server.URL, + Encoding: ConnectEncodingProtobuf, + }) + + compiler := newTestCompiler(t) + reqDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryResponse") + + inputMsg := dynamicpb.NewMessage(reqDesc) + outputMsg := dynamicpb.NewMessage(respDesc) + + err := transport.Invoke(context.Background(), "/productv1.ProductService/QueryCategory", inputMsg, outputMsg) + require.Error(t, err) + require.Contains(t, err.Error(), "not_found") + require.Contains(t, err.Error(), "category not found") +} + +func TestConnectTransport_Invoke_NonJSONError(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.WriteHeader(http.StatusBadGateway) + _, _ = w.Write([]byte("Bad Gateway")) + })) + defer server.Close() + + transport := NewConnectTransport(ConnectTransportConfig{ + BaseURL: server.URL, + Encoding: ConnectEncodingProtobuf, + }) + + compiler := newTestCompiler(t) + reqDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryResponse") + + inputMsg := dynamicpb.NewMessage(reqDesc) + outputMsg := dynamicpb.NewMessage(respDesc) + + err := transport.Invoke(context.Background(), "/productv1.ProductService/QueryCategory", inputMsg, outputMsg) + require.Error(t, err) + require.Contains(t, err.Error(), "HTTP 502") +} + +func TestConnectTransport_Invoke_HeaderForwarding(t *testing.T) { + var receivedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte{}) + })) + defer server.Close() + + transport := NewConnectTransport(ConnectTransportConfig{ + BaseURL: server.URL, + Encoding: ConnectEncodingProtobuf, + }) + + compiler := newTestCompiler(t) + reqDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryResponse") + + inputMsg := dynamicpb.NewMessage(reqDesc) + outputMsg := dynamicpb.NewMessage(respDesc) + + ctx := metadata.AppendToOutgoingContext(context.Background(), + "authorization", "Bearer test-token", + "x-custom-header", "custom-value", + ) + + err := transport.Invoke(ctx, "/productv1.ProductService/QueryCategory", inputMsg, outputMsg) + require.NoError(t, err) + + require.Equal(t, "Bearer test-token", receivedHeaders.Get("Authorization")) + require.Equal(t, "custom-value", receivedHeaders.Get("X-Custom-Header")) +} + +func TestConnectTransport_Invoke_BinaryHeaderForwarding(t *testing.T) { + var receivedHeaders http.Header + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + receivedHeaders = r.Header.Clone() + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte{}) + })) + defer server.Close() + + transport := NewConnectTransport(ConnectTransportConfig{ + BaseURL: server.URL, + Encoding: ConnectEncodingProtobuf, + }) + + compiler := newTestCompiler(t) + reqDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryCategoryResponse") + + inputMsg := dynamicpb.NewMessage(reqDesc) + outputMsg := dynamicpb.NewMessage(respDesc) + + binaryValue := "\x00\x01\x02\x03" + ctx := metadata.AppendToOutgoingContext(context.Background(), + "x-trace-id-bin", binaryValue, + "authorization", "Bearer token", + ) + + err := transport.Invoke(ctx, "/productv1.ProductService/QueryCategory", inputMsg, outputMsg) + require.NoError(t, err) + + // Binary header must be base64-encoded per the Connect protocol spec. + require.Equal(t, base64.StdEncoding.EncodeToString([]byte(binaryValue)), receivedHeaders.Get("X-Trace-Id-Bin")) + // String header must be forwarded as-is. + require.Equal(t, "Bearer token", receivedHeaders.Get("Authorization")) +} + +func TestGRPCTransport_Invoke(t *testing.T) { + // Use the existing mockInterface from grpc_datasource_test.go. + mi := mockInterface{} + transport := NewGRPCTransport(mi) + + compiler := newTestCompiler(t) + reqDesc := findMessageDesc(t, compiler, "productv1.QueryComplexFilterTypeRequest") + respDesc := findMessageDesc(t, compiler, "productv1.QueryComplexFilterTypeResponse") + + inputMsg := dynamicpb.NewMessage(reqDesc) + outputMsg := dynamicpb.NewMessage(respDesc) + + err := transport.Invoke(context.Background(), "/productv1.ProductService/QueryComplexFilterType", inputMsg, outputMsg) + require.NoError(t, err) +}