diff --git a/router/core/executor.go b/router/core/executor.go index 71bb08432c..0774531b9d 100644 --- a/router/core/executor.go +++ b/router/core/executor.go @@ -13,6 +13,8 @@ import ( "github.com/wundergraph/cosmo/router/pkg/grpcconnector" pubsub_datasource "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + grpcdatasource "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/grpc_datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" "github.com/wundergraph/graphql-go-tools/v2/pkg/asttransform" @@ -37,6 +39,7 @@ type ExecutorConfigurationBuilder struct { instanceData InstanceData subscriptionHooks subscriptionHooks + connectTransports map[string]grpcdatasource.RPCTransport } type Executor struct { @@ -218,6 +221,7 @@ func (b *ExecutorConfigurationBuilder) buildPlannerConfiguration(ctx context.Con b.logger, routerEngineCfg.Execution.EnableNetPoll, b.instanceData, + b.connectTransports, ), b.logger, b.subscriptionHooks) // this generates the plan config using the data source factories from the config package diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index 6fb0af8384..52827746e3 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -77,6 +77,7 @@ type DefaultFactoryResolver struct { transportFactory ApiTransportFactory defaultSubgraphRequestTimeout time.Duration subscriptionClientOptions []graphql_datasource.Options + connectTransports map[string]grpcdatasource.RPCTransport } func NewDefaultFactoryResolver( @@ -89,6 +90,7 @@ func NewDefaultFactoryResolver( log *zap.Logger, enableNetPoll bool, instanceData InstanceData, + connectTransports map[string]grpcdatasource.RPCTransport, ) *DefaultFactoryResolver { transportFactory := NewTransport(transportOptions) @@ -164,10 +166,19 @@ func NewDefaultFactoryResolver( transportFactory: transportFactory, defaultSubgraphRequestTimeout: transportOptions.SubgraphTransportOptions.RequestTimeout, subscriptionClientOptions: options, + connectTransports: connectTransports, } } func (d *DefaultFactoryResolver) ResolveGraphqlFactory(subgraphName string) (plan.PlannerFactory[graphql_datasource.Configuration], error) { + // Check Connect transports first — they use HTTP via the Connect protocol + // instead of native gRPC, so they bypass the gRPC connector entirely. + if d.connectTransports != nil { + if ct, ok := d.connectTransports[subgraphName]; ok { + return graphql_datasource.NewFactoryConnect(d.engineCtx, ct) + } + } + if d.connector != nil { // If the connector is not nil, we try to get the provider for the subgraph. // In case of a provider, we use the gRPC client provider to create the factory. diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 40bd0ca716..f189159a8b 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -45,6 +45,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/cors" "github.com/wundergraph/cosmo/router/pkg/execution_config" "github.com/wundergraph/cosmo/router/pkg/grpcconnector" + "github.com/wundergraph/cosmo/router/pkg/grpcprotocol" "github.com/wundergraph/cosmo/router/pkg/grpcconnector/grpccommon" "github.com/wundergraph/cosmo/router/pkg/grpcconnector/grpcplugin" "github.com/wundergraph/cosmo/router/pkg/grpcconnector/grpcpluginoci" @@ -57,6 +58,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/statistics" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" + grpcdatasource "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/grpc_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" ) @@ -104,6 +106,7 @@ type ( connector *grpcconnector.Connector circuitBreakerManager *circuit.Manager headerPropagation *HeaderPropagation + grpcProtocolConfig *config.GRPCProtocolConfig } ) @@ -175,6 +178,10 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC } } + if err := grpcprotocol.Validate(r.grpcProtocol); err != nil { + return nil, fmt.Errorf("invalid grpc_protocol configuration: %w", err) + } + ctx, cancel := context.WithCancel(ctx) s := &graphServer{ context: ctx, @@ -192,8 +199,9 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC HostName: r.hostName, ListenAddress: r.listenAddr, }, - storageProviders: &r.storageProviders, - headerPropagation: r.headerPropagation, + storageProviders: &r.storageProviders, + headerPropagation: r.headerPropagation, + grpcProtocolConfig: r.grpcProtocol, } baseOtelAttributes := []attribute.KeyValue{ @@ -1242,7 +1250,42 @@ func (s *graphServer) buildGraphMux( subgraphTippers[subgraph] = subgraphTransport } - if err := s.setupConnector(ctx, opts.EngineConfig, opts.ConfigSubgraphs, telemetryAttExpressions, tracingAttExpressions); err != nil { + // Build HTTP clients for Connect subgraphs, matching the timeout and transport + // configuration applied to regular GraphQL subgraph clients. + connectDefaultHTTPClient := &http.Client{ + Transport: s.baseTransport, + Timeout: s.subgraphTransportOptions.RequestTimeout, + } + connectSubgraphHTTPClients := map[string]*http.Client{} + for subgraph, subgraphOpts := range s.subgraphTransportOptions.SubgraphMap { + transport, ok := s.subgraphTransports[subgraph] + if !ok { + transport = s.baseTransport + } + connectSubgraphHTTPClients[subgraph] = &http.Client{ + Transport: transport, + Timeout: subgraphOpts.RequestTimeout, + } + } + // Include subgraphs with per-subgraph TLS but no traffic shaping overrides. + for subgraph, transport := range s.subgraphTransports { + if _, exists := connectSubgraphHTTPClients[subgraph]; !exists { + connectSubgraphHTTPClients[subgraph] = &http.Client{ + Transport: transport, + Timeout: s.subgraphTransportOptions.RequestTimeout, + } + } + } + + // Build Connect transports for subgraphs configured to use ConnectRPC protocol. + connectTransports := grpcprotocol.BuildConnectTransports( + s.grpcProtocolConfig, + collectGRPCSubgraphURLs(opts.EngineConfig, opts.ConfigSubgraphs), + connectSubgraphHTTPClients, + connectDefaultHTTPClient, + ) + + if err := s.setupConnector(ctx, opts.EngineConfig, opts.ConfigSubgraphs, telemetryAttExpressions, tracingAttExpressions, connectTransports); err != nil { return nil, fmt.Errorf("failed to setup plugin host: %w", err) } @@ -1288,6 +1331,7 @@ func (s *graphServer) buildGraphMux( CircuitBreaker: s.circuitBreakerManager, }, subscriptionHooks: s.subscriptionHooks, + connectTransports: connectTransports, } executor, providers, err := ecb.Build( @@ -1623,6 +1667,7 @@ func (s *graphServer) setupConnector( configSubgraphs []*nodev1.Subgraph, telemetryAttributeExpressions *attributeExpressions, tracingAttributeExpressions *attributeExpressions, + connectTransports map[string]grpcdatasource.RPCTransport, ) error { s.connector = grpcconnector.NewConnector() @@ -1645,6 +1690,14 @@ func (s *graphServer) setupConnector( return fmt.Errorf("subgraph %s not found", dsConfig.Id) } + // Skip gRPC connector registration for Connect subgraphs — + // they use HTTP via the Connect protocol instead of native gRPC. + if connectTransports != nil { + if _, isConnect := connectTransports[sg.Name]; isConnect { + continue + } + } + pluginConfig := grpcConfig.GetPlugin() if pluginConfig == nil { remoteProvider, err := grpcremote.NewRemoteGRPCProvider(grpcremote.RemoteGRPCProviderConfig{ @@ -2049,3 +2102,24 @@ func configureSubgraphOverwrites( return subgraphs, nil } + +// collectGRPCSubgraphURLs returns a map of subgraphName → routingUrl +// for all subgraphs that have gRPC configuration in the engine config. +func collectGRPCSubgraphURLs( + engineConfig *nodev1.EngineConfiguration, + configSubgraphs []*nodev1.Subgraph, +) map[string]string { + urls := map[string]string{} + for _, dsConfig := range engineConfig.DatasourceConfigurations { + if dsConfig.GetCustomGraphql().GetGrpc() == nil { + continue + } + for _, sg := range configSubgraphs { + if sg.Id == dsConfig.Id { + urls[sg.Name] = sg.RoutingUrl + break + } + } + } + return urls +} diff --git a/router/core/router.go b/router/core/router.go index eef8b8734d..e5e67523b6 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -2264,6 +2264,12 @@ func WithConnectRPC(cfg config.ConnectRPCConfiguration) Option { } } +func WithGRPCProtocol(cfg *config.GRPCProtocolConfig) Option { + return func(r *Router) { + r.grpcProtocol = cfg + } +} + func WithDemoMode(demoMode bool) Option { return func(r *Router) { r.demoMode = demoMode diff --git a/router/core/router_config.go b/router/core/router_config.go index bdb126614f..abca0a15f3 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -144,6 +144,7 @@ type Config struct { mcp config.MCPConfiguration connectRPC config.ConnectRPCConfiguration plugins config.PluginsConfiguration + grpcProtocol *config.GRPCProtocolConfig tracingAttributes []config.CustomAttribute subscriptionHooks subscriptionHooks } diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 2f9f6fcbfb..8f6efc15e9 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -271,6 +271,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config, reloadPersi WithMCP(config.MCP), WithConnectRPC(config.ConnectRPC), WithPlugins(config.Plugins), + WithGRPCProtocol(config.GRPCProtocol), WithDemoMode(config.DemoMode), WithStreamsHandlerConfiguration(config.Events.Handlers), WithReloadPersistentState(reloadPersistentState), diff --git a/router/go.mod b/router/go.mod index c0ea066c4c..611db8aa22 100644 --- a/router/go.mod +++ b/router/go.mod @@ -197,4 +197,5 @@ replace ( // Remember you can use Go workspaces to avoid using replace directives in multiple go.mod files // Use what is best for your personal workflow. See CONTRIBUTING.md for more information -// replace github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 +// TODO: Update to released version once wundergraph/graphql-go-tools#1453 is merged and tagged. +replace github.com/wundergraph/graphql-go-tools/v2 => github.com/fengyuwusong/graphql-go-tools/v2 v2.0.0-20260319034538-12c891d918df diff --git a/router/go.sum b/router/go.sum index 93c4cfd031..f306661f4c 100644 --- a/router/go.sum +++ b/router/go.sum @@ -81,6 +81,8 @@ github.com/fatih/color v1.18.0 h1:S8gINlzdQ840/4pfAwic/ZE0djQEH3wM94VfqLTZcOM= github.com/fatih/color v1.18.0/go.mod h1:4FelSpRwEGDpQ12mAdzqdOukCy4u8WUtOY6lkT/6HfU= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fengyuwusong/graphql-go-tools/v2 v2.0.0-20260319034538-12c891d918df h1:4wW9/8mELQS2qGPaz4br5bjZwxtm+OLK1T4AEWZi9Xo= +github.com/fengyuwusong/graphql-go-tools/v2 v2.0.0-20260319034538-12c891d918df/go.mod h1:HjTAO/cuICpu31IfHY9qmSPygx6Gza7Wt9hTSReTI+A= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= @@ -329,8 +331,6 @@ github.com/wundergraph/astjson v1.1.0 h1:xORDosrZ87zQFJwNGe/HIHXqzpdHOFmqWgykCLV github.com/wundergraph/astjson v1.1.0/go.mod h1:h12D/dxxnedtLzsKyBLK7/Oe4TAoGpRVC9nDpDrZSWw= github.com/wundergraph/go-arena v1.1.0 h1:9+wSRkJAkA2vbYHp6s8tEGhPViRGQNGXqPHT0QzhdIc= github.com/wundergraph/go-arena v1.1.0/go.mod h1:ROOysEHWJjLQ8FSfNxZCziagb7Qw2nXY3/vgKRh7eWw= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.265 h1:KVmojt3oH13VX8Yr8NZ+fuOiruLyznderHITJs1MyWE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.265/go.mod h1:HjTAO/cuICpu31IfHY9qmSPygx6Gza7Wt9hTSReTI+A= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 5b31c1f932..3d9ac7e401 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -1114,6 +1114,24 @@ type MCPServer struct { BaseURL string `yaml:"base_url,omitempty" env:"MCP_SERVER_BASE_URL"` } +// GRPCProtocolConfig configures the transport protocol for gRPC subgraphs. +// By default all gRPC subgraphs use native gRPC (HTTP/2 + Protobuf). +// Setting the protocol to "connectrpc" switches a subgraph to the Connect protocol (HTTP/1.1). +type GRPCProtocolConfig struct { + // Default protocol for all gRPC subgraphs ("grpc" or "connectrpc"). + Default string `yaml:"default,omitempty" json:"default,omitempty"` + // DefaultEncoding for Connect subgraphs ("proto" or "json"). + DefaultEncoding string `yaml:"default_encoding,omitempty" json:"default_encoding,omitempty"` + // Per-subgraph protocol and encoding overrides. + Subgraphs map[string]SubgraphGRPCProtocolConfig `yaml:"subgraphs,omitempty" json:"subgraphs,omitempty"` +} + +// SubgraphGRPCProtocolConfig holds per-subgraph protocol/encoding settings. +type SubgraphGRPCProtocolConfig struct { + Protocol string `yaml:"protocol,omitempty" json:"protocol,omitempty"` + Encoding string `yaml:"encoding,omitempty" json:"encoding,omitempty"` +} + type ConnectRPCConfiguration struct { Enabled bool `yaml:"enabled" envDefault:"false" env:"CONNECT_RPC_ENABLED"` Server ConnectRPCServer `yaml:"server,omitempty" envPrefix:"CONNECT_RPC_SERVER_"` @@ -1222,6 +1240,8 @@ type Config struct { Plugins PluginsConfiguration `yaml:"plugins" envPrefix:"PLUGINS_"` WatchConfig WatchConfig `yaml:"watch_config" envPrefix:"WATCH_CONFIG_"` + + GRPCProtocol *GRPCProtocolConfig `yaml:"grpc_protocol,omitempty"` } type WatchConfig struct { diff --git a/router/pkg/grpcprotocol/config.go b/router/pkg/grpcprotocol/config.go new file mode 100644 index 0000000000..9167707389 --- /dev/null +++ b/router/pkg/grpcprotocol/config.go @@ -0,0 +1,65 @@ +package grpcprotocol + +import ( + "fmt" + + "github.com/wundergraph/cosmo/router/pkg/config" +) + +const ( + ProtocolGRPC = "grpc" + ProtocolConnectRPC = "connectrpc" + + EncodingProto = "proto" + EncodingJSON = "json" +) + +// Validate checks that all config values in a GRPCProtocolConfig are valid. +func Validate(cfg *config.GRPCProtocolConfig) error { + if cfg == nil { + return nil + } + if cfg.Default != "" && cfg.Default != ProtocolGRPC && cfg.Default != ProtocolConnectRPC { + return fmt.Errorf("grpc_protocol.default: invalid value %q, must be %q or %q", cfg.Default, ProtocolGRPC, ProtocolConnectRPC) + } + if cfg.DefaultEncoding != "" && cfg.DefaultEncoding != EncodingProto && cfg.DefaultEncoding != EncodingJSON { + return fmt.Errorf("grpc_protocol.default_encoding: invalid value %q, must be %q or %q", cfg.DefaultEncoding, EncodingProto, EncodingJSON) + } + for name, sg := range cfg.Subgraphs { + if sg.Protocol != "" && sg.Protocol != ProtocolGRPC && sg.Protocol != ProtocolConnectRPC { + return fmt.Errorf("grpc_protocol.subgraphs.%s.protocol: invalid value %q", name, sg.Protocol) + } + if sg.Encoding != "" && sg.Encoding != EncodingProto && sg.Encoding != EncodingJSON { + return fmt.Errorf("grpc_protocol.subgraphs.%s.encoding: invalid value %q", name, sg.Encoding) + } + } + return nil +} + +// ResolveProtocol returns the effective protocol for a subgraph. +func ResolveProtocol(cfg *config.GRPCProtocolConfig, subgraphName string) string { + if cfg == nil { + return ProtocolGRPC + } + if sg, ok := cfg.Subgraphs[subgraphName]; ok && sg.Protocol != "" { + return sg.Protocol + } + if cfg.Default != "" { + return cfg.Default + } + return ProtocolGRPC +} + +// ResolveEncoding returns the effective encoding for a subgraph. +func ResolveEncoding(cfg *config.GRPCProtocolConfig, subgraphName string) string { + if cfg == nil { + return EncodingProto + } + if sg, ok := cfg.Subgraphs[subgraphName]; ok && sg.Encoding != "" { + return sg.Encoding + } + if cfg.DefaultEncoding != "" { + return cfg.DefaultEncoding + } + return EncodingProto +} diff --git a/router/pkg/grpcprotocol/config_test.go b/router/pkg/grpcprotocol/config_test.go new file mode 100644 index 0000000000..e8a53ae199 --- /dev/null +++ b/router/pkg/grpcprotocol/config_test.go @@ -0,0 +1,94 @@ +package grpcprotocol + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestResolveProtocol_Default(t *testing.T) { + assert.Equal(t, ProtocolGRPC, ResolveProtocol(&config.GRPCProtocolConfig{}, "any")) +} + +func TestResolveProtocol_Nil(t *testing.T) { + assert.Equal(t, ProtocolGRPC, ResolveProtocol(nil, "any")) +} + +func TestResolveProtocol_GlobalDefault(t *testing.T) { + cfg := &config.GRPCProtocolConfig{Default: ProtocolConnectRPC} + assert.Equal(t, ProtocolConnectRPC, ResolveProtocol(cfg, "any")) +} + +func TestResolveProtocol_PerSubgraphOverride(t *testing.T) { + cfg := &config.GRPCProtocolConfig{ + Default: ProtocolGRPC, + Subgraphs: map[string]config.SubgraphGRPCProtocolConfig{ + "rpc-a": {Protocol: ProtocolConnectRPC}, + }, + } + assert.Equal(t, ProtocolConnectRPC, ResolveProtocol(cfg, "rpc-a")) + assert.Equal(t, ProtocolGRPC, ResolveProtocol(cfg, "rpc-b")) +} + +func TestResolveEncoding_Default(t *testing.T) { + assert.Equal(t, EncodingProto, ResolveEncoding(&config.GRPCProtocolConfig{}, "any")) +} + +func TestResolveEncoding_GlobalDefault(t *testing.T) { + cfg := &config.GRPCProtocolConfig{DefaultEncoding: EncodingJSON} + assert.Equal(t, EncodingJSON, ResolveEncoding(cfg, "any")) +} + +func TestResolveEncoding_PerSubgraphOverride(t *testing.T) { + cfg := &config.GRPCProtocolConfig{ + DefaultEncoding: EncodingProto, + Subgraphs: map[string]config.SubgraphGRPCProtocolConfig{ + "rpc-a": {Encoding: EncodingJSON}, + }, + } + assert.Equal(t, EncodingJSON, ResolveEncoding(cfg, "rpc-a")) + assert.Equal(t, EncodingProto, ResolveEncoding(cfg, "rpc-b")) +} + +func TestValidate_Valid(t *testing.T) { + cfg := &config.GRPCProtocolConfig{ + Default: ProtocolConnectRPC, + DefaultEncoding: EncodingJSON, + Subgraphs: map[string]config.SubgraphGRPCProtocolConfig{ + "rpc-a": {Protocol: ProtocolGRPC, Encoding: EncodingProto}, + }, + } + require.NoError(t, Validate(cfg)) +} + +func TestValidate_Nil(t *testing.T) { + require.NoError(t, Validate(nil)) +} + +func TestValidate_Empty(t *testing.T) { + require.NoError(t, Validate(&config.GRPCProtocolConfig{})) +} + +func TestValidate_InvalidDefault(t *testing.T) { + err := Validate(&config.GRPCProtocolConfig{Default: "invalid"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "grpc_protocol.default") +} + +func TestValidate_InvalidEncoding(t *testing.T) { + err := Validate(&config.GRPCProtocolConfig{DefaultEncoding: "xml"}) + require.Error(t, err) + assert.Contains(t, err.Error(), "grpc_protocol.default_encoding") +} + +func TestValidate_InvalidSubgraphProtocol(t *testing.T) { + err := Validate(&config.GRPCProtocolConfig{ + Subgraphs: map[string]config.SubgraphGRPCProtocolConfig{ + "rpc-a": {Protocol: "http2"}, + }, + }) + require.Error(t, err) + assert.Contains(t, err.Error(), "grpc_protocol.subgraphs.rpc-a.protocol") +} diff --git a/router/pkg/grpcprotocol/transport_builder.go b/router/pkg/grpcprotocol/transport_builder.go new file mode 100644 index 0000000000..603772d6b8 --- /dev/null +++ b/router/pkg/grpcprotocol/transport_builder.go @@ -0,0 +1,55 @@ +package grpcprotocol + +import ( + "net/http" + + "github.com/wundergraph/cosmo/router/pkg/config" + grpcdatasource "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/grpc_datasource" +) + +// BuildConnectTransports creates a map of subgraphName → RPCTransport +// for all subgraphs configured to use ConnectRPC. +// Returns nil if no subgraphs are configured for Connect. +func BuildConnectTransports( + cfg *config.GRPCProtocolConfig, + grpcSubgraphURLs map[string]string, + subgraphHTTPClients map[string]*http.Client, + defaultHTTPClient *http.Client, +) map[string]grpcdatasource.RPCTransport { + if cfg == nil { + return nil + } + + transports := make(map[string]grpcdatasource.RPCTransport) + + for subgraphName, routingURL := range grpcSubgraphURLs { + if ResolveProtocol(cfg, subgraphName) != ProtocolConnectRPC { + continue + } + + httpClient := defaultHTTPClient + if sgClient, ok := subgraphHTTPClients[subgraphName]; ok { + httpClient = sgClient + } + + var connectEncoding grpcdatasource.ConnectEncoding + if ResolveEncoding(cfg, subgraphName) == EncodingJSON { + connectEncoding = grpcdatasource.ConnectEncodingJSON + } else { + connectEncoding = grpcdatasource.ConnectEncodingProtobuf + } + + transports[subgraphName] = grpcdatasource.NewConnectTransport( + grpcdatasource.ConnectTransportConfig{ + BaseURL: routingURL, + HTTPClient: httpClient, + Encoding: connectEncoding, + }, + ) + } + + if len(transports) == 0 { + return nil + } + return transports +} diff --git a/router/pkg/grpcprotocol/transport_builder_test.go b/router/pkg/grpcprotocol/transport_builder_test.go new file mode 100644 index 0000000000..5c64f753bf --- /dev/null +++ b/router/pkg/grpcprotocol/transport_builder_test.go @@ -0,0 +1,63 @@ +package grpcprotocol + +import ( + "net/http" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestBuildConnectTransports_NilConfig(t *testing.T) { + result := BuildConnectTransports(nil, map[string]string{"rpc": "http://localhost"}, nil, http.DefaultClient) + assert.Nil(t, result) +} + +func TestBuildConnectTransports_AllGRPC(t *testing.T) { + cfg := &config.GRPCProtocolConfig{Default: ProtocolGRPC} + urls := map[string]string{"rpc-a": "http://localhost:3000"} + result := BuildConnectTransports(cfg, urls, nil, http.DefaultClient) + assert.Nil(t, result) +} + +func TestBuildConnectTransports_ConnectSubgraph(t *testing.T) { + cfg := &config.GRPCProtocolConfig{Default: ProtocolConnectRPC} + urls := map[string]string{"rpc-a": "http://localhost:3000"} + result := BuildConnectTransports(cfg, urls, nil, http.DefaultClient) + assert.NotNil(t, result) + assert.Contains(t, result, "rpc-a") +} + +func TestBuildConnectTransports_MixedProtocols(t *testing.T) { + cfg := &config.GRPCProtocolConfig{ + Default: ProtocolGRPC, + Subgraphs: map[string]config.SubgraphGRPCProtocolConfig{ + "rpc-connect": {Protocol: ProtocolConnectRPC}, + }, + } + urls := map[string]string{ + "rpc-grpc": "http://localhost:3001", + "rpc-connect": "http://localhost:3002", + } + result := BuildConnectTransports(cfg, urls, nil, http.DefaultClient) + assert.NotNil(t, result) + assert.Contains(t, result, "rpc-connect") + assert.NotContains(t, result, "rpc-grpc") +} + +func TestBuildConnectTransports_UsesPerSubgraphHTTPClient(t *testing.T) { + customClient := &http.Client{} + cfg := &config.GRPCProtocolConfig{Default: ProtocolConnectRPC} + urls := map[string]string{"rpc-a": "http://localhost:3000"} + sgClients := map[string]*http.Client{"rpc-a": customClient} + + result := BuildConnectTransports(cfg, urls, sgClients, http.DefaultClient) + assert.NotNil(t, result) + assert.Contains(t, result, "rpc-a") +} + +func TestBuildConnectTransports_EmptyURLs(t *testing.T) { + cfg := &config.GRPCProtocolConfig{Default: ProtocolConnectRPC} + result := BuildConnectTransports(cfg, map[string]string{}, nil, http.DefaultClient) + assert.Nil(t, result) +}