Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
860 changes: 860 additions & 0 deletions router-tests/subgraph_mtls_test.go

Large diffs are not rendered by default.

48 changes: 29 additions & 19 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,10 @@ type SubgraphConfig struct {
GRPCInterceptor grpc.UnaryServerInterceptor
Delay time.Duration
CloseOnStart bool

// TLSConfig enables TLS on this subgraph server. When set, the subgraph uses StartTLS()
// instead of Start(). This is useful for testing mTLS between the router and subgraphs.
TLSConfig *tls.Config
}

type LogObservationConfig struct {
Expand Down Expand Up @@ -584,15 +588,15 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) {
localDelay: cfg.Subgraphs.Countries.Delay,
}

employeesServer := makeSafeHttpTestServer(t, employees)
familyServer := makeSafeHttpTestServer(t, family)
hobbiesServer := makeSafeHttpTestServer(t, hobbies)
productsServer := makeSafeHttpTestServer(t, products)
test1Server := makeSafeHttpTestServer(t, test1)
availabilityServer := makeSafeHttpTestServer(t, availability)
moodServer := makeSafeHttpTestServer(t, mood)
countriesServer := makeSafeHttpTestServer(t, countries)
productFgServer := makeSafeHttpTestServer(t, productsFg)
employeesServer := makeSubgraphTestServer(t, employees, cfg.Subgraphs.Employees.TLSConfig)
familyServer := makeSubgraphTestServer(t, family, cfg.Subgraphs.Family.TLSConfig)
hobbiesServer := makeSubgraphTestServer(t, hobbies, cfg.Subgraphs.Hobbies.TLSConfig)
productsServer := makeSubgraphTestServer(t, products, cfg.Subgraphs.Products.TLSConfig)
test1Server := makeSubgraphTestServer(t, test1, cfg.Subgraphs.Test1.TLSConfig)
availabilityServer := makeSubgraphTestServer(t, availability, cfg.Subgraphs.Availability.TLSConfig)
moodServer := makeSubgraphTestServer(t, mood, cfg.Subgraphs.Mood.TLSConfig)
countriesServer := makeSubgraphTestServer(t, countries, cfg.Subgraphs.Countries.TLSConfig)
productFgServer := makeSubgraphTestServer(t, productsFg, cfg.Subgraphs.ProductsFg.TLSConfig)

var (
projectServer *grpc.Server
Expand Down Expand Up @@ -1014,15 +1018,15 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) {
localDelay: cfg.Subgraphs.Countries.Delay,
}

employeesServer := makeSafeHttpTestServer(t, employees)
familyServer := makeSafeHttpTestServer(t, family)
hobbiesServer := makeSafeHttpTestServer(t, hobbies)
productsServer := makeSafeHttpTestServer(t, products)
test1Server := makeSafeHttpTestServer(t, test1)
availabilityServer := makeSafeHttpTestServer(t, availability)
moodServer := makeSafeHttpTestServer(t, mood)
countriesServer := makeSafeHttpTestServer(t, countries)
productFgServer := makeSafeHttpTestServer(t, productsFg)
employeesServer := makeSubgraphTestServer(t, employees, cfg.Subgraphs.Employees.TLSConfig)
familyServer := makeSubgraphTestServer(t, family, cfg.Subgraphs.Family.TLSConfig)
hobbiesServer := makeSubgraphTestServer(t, hobbies, cfg.Subgraphs.Hobbies.TLSConfig)
productsServer := makeSubgraphTestServer(t, products, cfg.Subgraphs.Products.TLSConfig)
test1Server := makeSubgraphTestServer(t, test1, cfg.Subgraphs.Test1.TLSConfig)
availabilityServer := makeSubgraphTestServer(t, availability, cfg.Subgraphs.Availability.TLSConfig)
moodServer := makeSubgraphTestServer(t, mood, cfg.Subgraphs.Mood.TLSConfig)
countriesServer := makeSubgraphTestServer(t, countries, cfg.Subgraphs.Countries.TLSConfig)
productFgServer := makeSubgraphTestServer(t, productsFg, cfg.Subgraphs.ProductsFg.TLSConfig)

var (
projectServer *grpc.Server
Expand Down Expand Up @@ -1676,12 +1680,18 @@ func testVersionedTokenClaims() jwt.MapClaims {
}
}

func makeSafeHttpTestServer(t testing.TB, handler http.Handler) *httptest.Server {
func makeSubgraphTestServer(_ testing.TB, handler http.Handler, tlsConfig *tls.Config) *httptest.Server {
// NewUnstartedServer binds an ephemeral port.
// We want to avoid using freeport because it creates too much strain on the network stack:
// freeport checks if port is available by listening on it and then closing the listener.
// On Linux trying to listen on the just-closed port could lead to the "unable to bind" error.
s := httptest.NewUnstartedServer(handler)

if tlsConfig != nil {
s.TLS = tlsConfig
s.StartTLS()
return s
}
s.Start()
return s
}
Expand Down
12 changes: 12 additions & 0 deletions router/core/factoryresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ func NewDefaultFactoryResolver(
subgraphHTTPClients[subgraph] = subgraphClient
}

// Create HTTP clients for subgraphs that have per-subgraph TLS
// but no per-subgraph transport options. These use the default request timeout.
for subgraph, subgraphTransport := range subgraphTransports {
if _, exists := subgraphHTTPClients[subgraph]; !exists {
subgraphClient := &http.Client{
Transport: transportFactory.RoundTripper(subgraphTransport),
Timeout: transportOptions.SubgraphTransportOptions.RequestTimeout,
}
subgraphHTTPClients[subgraph] = subgraphClient
}
}

var factoryLogger abstractlogger.Logger
if log != nil {
factoryLogger = abstractlogger.NewZapLogger(log, abstractlogger.DebugLevel)
Expand Down
29 changes: 24 additions & 5 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,35 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC
traceDialer = NewTraceDialer()
}

// Build subgraph client TLS configs (mTLS for outbound subgraph connections)
defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(r.logger, &r.subgraphTLSConfiguration)
if err != nil {
return nil, fmt.Errorf("could not build subgraph client TLS config: %w", err)
}

// Base transport
baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "")
baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "", defaultClientTLS)

// Subgraph transports
subgraphTransports := map[string]*http.Transport{}
for subgraph, subgraphOpts := range r.subgraphTransportOptions.SubgraphMap {
subgraphBaseTransport := newHTTPTransport(subgraphOpts, proxy, traceDialer, subgraph)
clientTLS := defaultClientTLS
if sgTLS, ok := perSubgraphTLS[subgraph]; ok {
clientTLS = sgTLS
}
subgraphBaseTransport := newHTTPTransport(subgraphOpts, proxy, traceDialer, subgraph, clientTLS)
subgraphTransports[subgraph] = subgraphBaseTransport
}

// Create transports for subgraphs with per-subgraph TLS configs that don't have
// per-subgraph transport options (they inherit the base transport options).
for subgraph, sgTLS := range perSubgraphTLS {
if _, exists := subgraphTransports[subgraph]; !exists {
subgraphBaseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, subgraph, sgTLS)
subgraphTransports[subgraph] = subgraphBaseTransport
}
}

ctx, cancel := context.WithCancel(ctx)
s := &graphServer{
context: ctx,
Expand Down Expand Up @@ -1305,9 +1324,9 @@ func (s *graphServer) buildGraphMux(
MaxDepth: s.securityConfiguration.ParserLimits.ApproximateDepthLimit,
MaxFields: s.securityConfiguration.ParserLimits.TotalFieldsLimit,
},
OperationNameLengthLimit: s.securityConfiguration.OperationNameLengthLimit,
ApolloCompatibilityFlags: s.apolloCompatibilityFlags,
ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags,
OperationNameLengthLimit: s.securityConfiguration.OperationNameLengthLimit,
ApolloCompatibilityFlags: s.apolloCompatibilityFlags,
ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags,
DisableExposingVariablesContentOnValidationError: s.engineExecutionConfiguration.DisableExposingVariablesContentOnValidationError,
RelaxSubgraphOperationFieldSelectionMergingNullability: s.engineExecutionConfiguration.RelaxSubgraphOperationFieldSelectionMergingNullability,
ComplexityLimits: s.securityConfiguration.ComplexityLimits,
Expand Down
13 changes: 12 additions & 1 deletion router/core/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -2140,6 +2140,12 @@ func WithTLSConfig(cfg *TlsConfig) Option {
}
}

func WithSubgraphTLSConfiguration(cfg config.ClientTLSConfiguration) Option {
return func(r *Router) {
r.subgraphTLSConfiguration = cfg
}
}

func WithTelemetryAttributes(attributes []config.CustomAttribute) Option {
return func(r *Router) {
r.telemetryAttributes = attributes
Expand Down Expand Up @@ -2254,7 +2260,7 @@ func WithStreamsHandlerConfiguration(cfg config.StreamsHandlerConfiguration) Opt

type ProxyFunc func(req *http.Request) (*url.URL, error)

func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDialer *TraceDialer, subgraph string) *http.Transport {
func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDialer *TraceDialer, subgraph string, clientTLS *tls.Config) *http.Transport {
dialer := &net.Dialer{
Timeout: opts.DialTimeout,
KeepAlive: opts.KeepAliveProbeInterval,
Expand Down Expand Up @@ -2283,6 +2289,11 @@ func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDiale
// Will return nil when HTTP(S)_PROXY does not exist or is empty.
// This will prevent the transport from handling the proxy when it is not needed.
Proxy: proxy,
// TLSClientConfig configures client TLS for outbound subgraph connections (mTLS).
}

if clientTLS != nil {
transport.TLSClientConfig = clientTLS
}

if traceDialer != nil {
Expand Down
1 change: 1 addition & 0 deletions router/core/router_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ type Config struct {
localhostFallbackInsideDocker bool
tlsServerConfig *tls.Config
tlsConfig *TlsConfig
subgraphTLSConfiguration config.ClientTLSConfiguration
telemetryAttributes []config.CustomAttribute
tracePropagators []propagation.TextMapPropagator
compositePropagator propagation.TextMapPropagator
Expand Down
1 change: 1 addition & 0 deletions router/core/supervisor_instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config, reloadPersi
WithDemoMode(config.DemoMode),
WithStreamsHandlerConfiguration(config.Events.Handlers),
WithReloadPersistentState(reloadPersistentState),
WithSubgraphTLSConfiguration(config.TLS.Client),
}

return options
Expand Down
84 changes: 84 additions & 0 deletions router/core/tls.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package core

import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"go.uber.org/zap"
"os"

"github.com/wundergraph/cosmo/router/pkg/config"
)

// buildTLSClientConfig creates a *tls.Config from a TLSClientCertConfiguration.
func buildTLSClientConfig(clientCfg *config.TLSClientCertConfiguration) (*tls.Config, error) {
tlsConfig := &tls.Config{
InsecureSkipVerify: clientCfg.InsecureSkipCaVerification,
}

// Load client certificate and key if provided
if clientCfg.CertFile != "" && clientCfg.KeyFile != "" {
cert, err := tls.LoadX509KeyPair(clientCfg.CertFile, clientCfg.KeyFile)
if err != nil {
return nil, fmt.Errorf("failed to load client TLS cert and key: %w", err)
}
tlsConfig.Certificates = []tls.Certificate{cert}
}

// Load custom CA for verifying subgraph server certificates
if clientCfg.CaFile != "" {
caCert, err := os.ReadFile(clientCfg.CaFile)
if err != nil {
return nil, fmt.Errorf("failed to read client TLS CA file: %w", err)
}
caPool := x509.NewCertPool()
if ok := caPool.AppendCertsFromPEM(caCert); !ok {
return nil, errors.New("failed to append client TLS CA cert to pool")
}
tlsConfig.RootCAs = caPool
}

return tlsConfig, nil
}

// buildSubgraphTLSConfigs builds the default and per-subgraph TLS configs from raw configuration.
// Returns (defaultClientTLS, perSubgraphTLS, error).
func buildSubgraphTLSConfigs(logger *zap.Logger, cfg *config.ClientTLSConfiguration) (*tls.Config, map[string]*tls.Config, error) {
hasAll := (cfg.All.CertFile != "" && cfg.All.KeyFile != "") || cfg.All.CaFile != "" || cfg.All.InsecureSkipCaVerification

// If no global TLS config is provided and there are no subgraph specific TLS configs
if !hasAll && len(cfg.Subgraphs) == 0 {
return nil, nil, nil
}

var defaultClientTLS *tls.Config
perSubgraphTLS := make(map[string]*tls.Config)

if hasAll {
if cfg.All.InsecureSkipCaVerification {
logger.Warn("Global TLS config has InsecureSkipCaVerification enabled. This is not recommended for production environments.")
}

defaultTLS, err := buildTLSClientConfig(&cfg.All)
if err != nil {
return nil, nil, fmt.Errorf("failed to build global subgraph TLS config: %w", err)
}
defaultClientTLS = defaultTLS
}

for name, sgCfg := range cfg.Subgraphs {
if sgCfg.InsecureSkipCaVerification {
logger.Warn("Subgraph TLS config inherits InsecureSkipCaVerification from global config. This is not recommended for production environments.",
zap.String("subgraph", name))
}

subgraphTLS, err := buildTLSClientConfig(&sgCfg)
if err != nil {
return nil, nil, fmt.Errorf("failed to build TLS config for subgraph %q: %w", name, err)
}
perSubgraphTLS[name] = subgraphTLS
}

return defaultClientTLS, perSubgraphTLS, nil
}
Loading
Loading