Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 220 additions & 0 deletions router-tests/subgraph_mtls_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
package integration

import (
"crypto/tls"
"crypto/x509"
"os"
"testing"

"github.com/stretchr/testify/require"
"github.com/wundergraph/cosmo/router-tests/testenv"
"github.com/wundergraph/cosmo/router/core"
"github.com/wundergraph/cosmo/router/pkg/config"
)

// loadSubgraphMTLSCert loads a TLS certificate from the testdata/tls directory.
func loadSubgraphMTLSCert(t *testing.T, certFile, keyFile string) tls.Certificate {
t.Helper()
cert, err := tls.LoadX509KeyPair(certFile, keyFile)
require.NoError(t, err)
return cert
}

// loadSubgraphMTLSCACertPool loads a CA certificate pool from a PEM file.
func loadSubgraphMTLSCACertPool(t *testing.T, caFile string) *x509.CertPool {
t.Helper()
caCert, err := os.ReadFile(caFile)
require.NoError(t, err)

pool := x509.NewCertPool()
ok := pool.AppendCertsFromPEM(caCert)
require.True(t, ok, "failed to append CA cert to pool")
return pool
}

// subgraphMTLSServerConfig creates a tls.Config for a subgraph test server.
// It does NOT set Certificates — httptest.StartTLS() will generate a cert valid for 127.0.0.1.
// If requireClientCert is true, the subgraph will require the router to present a valid client certificate
// signed by the CA in testdata/tls/cert.pem.
func subgraphMTLSServerConfig(t *testing.T, requireClientCert bool) *tls.Config {
t.Helper()
cfg := &tls.Config{}
if requireClientCert {
caPool := loadSubgraphMTLSCACertPool(t, "testdata/tls/cert.pem")
cfg.ClientCAs = caPool
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}
return cfg
}

func TestSubgraphMTLS(t *testing.T) {
t.Parallel()

t.Run("Router connects to TLS subgraph with InsecureSkipVerify", func(t *testing.T) {
t.Parallel()

// Subgraph is a TLS server (httptest generates a self-signed cert for 127.0.0.1).
// Router uses InsecureSkipVerify to trust it.
testenv.Run(t, &testenv.Config{
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
TLSConfig: subgraphMTLSServerConfig(t, false),
},
},
RouterOptions: []core.Option{
core.WithSubgraphTLSConfig(&core.SubgraphTLSConfig{
DefaultClientTLS: &tls.Config{
InsecureSkipVerify: true,
},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.JSONEq(t, employeesIDData, res.Body)
})
})

t.Run("Router presents client certificate to mTLS subgraph", func(t *testing.T) {
t.Parallel()

// Subgraph requires client cert signed by testdata/tls/cert.pem CA
clientCert := loadSubgraphMTLSCert(t, "testdata/tls/cert.pem", "testdata/tls/key.pem")

testenv.Run(t, &testenv.Config{
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
TLSConfig: subgraphMTLSServerConfig(t, true),
},
},
RouterOptions: []core.Option{
core.WithSubgraphTLSConfig(&core.SubgraphTLSConfig{
DefaultClientTLS: &tls.Config{
// InsecureSkipVerify for httptest's self-signed server cert
InsecureSkipVerify: true,
// Present client cert for mTLS
Certificates: []tls.Certificate{clientCert},
},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.JSONEq(t, employeesIDData, res.Body)
})
})

t.Run("Router fails to connect to mTLS subgraph without client certificate", func(t *testing.T) {
t.Parallel()

// Subgraph requires client cert, but router does not provide one
testenv.Run(t, &testenv.Config{
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
TLSConfig: subgraphMTLSServerConfig(t, true),
},
},
RouterOptions: []core.Option{
core.WithSubgraphTLSConfig(&core.SubgraphTLSConfig{
DefaultClientTLS: &tls.Config{
InsecureSkipVerify: true,
// NO client certificate — should cause mTLS failure
},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
// The query should fail because the router has no client cert to present
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.NoError(t, err)
// The router returns 200 with a GraphQL error about the subgraph fetch failure
require.Contains(t, res.Body, `Failed to fetch from Subgraph`)
})
})

t.Run("Router fails to connect to mTLS subgraph with wrong client certificate", func(t *testing.T) {
t.Parallel()

// Subgraph requires client cert signed by cert.pem CA, but router presents cert-2
wrongCert := loadSubgraphMTLSCert(t, "testdata/tls/cert-2.pem", "testdata/tls/key-2.pem")

testenv.Run(t, &testenv.Config{
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
TLSConfig: subgraphMTLSServerConfig(t, true),
},
},
RouterOptions: []core.Option{
core.WithSubgraphTLSConfig(&core.SubgraphTLSConfig{
DefaultClientTLS: &tls.Config{
InsecureSkipVerify: true,
Certificates: []tls.Certificate{wrongCert},
},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.NoError(t, err)
require.Contains(t, res.Body, `Failed to fetch from Subgraph`)
})
})

t.Run("Per-subgraph TLS config overrides global", func(t *testing.T) {
t.Parallel()

// Subgraph requires client cert
clientCert := loadSubgraphMTLSCert(t, "testdata/tls/cert.pem", "testdata/tls/key.pem")

testenv.Run(t, &testenv.Config{
Subgraphs: testenv.SubgraphsConfig{
Employees: testenv.SubgraphConfig{
TLSConfig: subgraphMTLSServerConfig(t, true),
},
},
RouterOptions: []core.Option{
core.WithSubgraphTLSConfig(&core.SubgraphTLSConfig{
// No global client TLS — would fail without per-subgraph override
PerSubgraphTLS: map[string]*tls.Config{
"employees": {
InsecureSkipVerify: true,
Certificates: []tls.Certificate{clientCert},
},
},
}),
},
}, func(t *testing.T, xEnv *testenv.Environment) {
res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{
Query: `query { employees { id } }`,
})
require.JSONEq(t, employeesIDData, res.Body)
})
})

t.Run("Router config builds SubgraphTLSConfig from config struct", func(t *testing.T) {
t.Parallel()

cfg := &config.Config{
TLS: config.TLSConfiguration{
Subgraph: config.SubgraphTLSConfiguration{
All: config.TLSClientCertConfiguration{
CertificateChain: "testdata/tls/cert.pem",
Key: "testdata/tls/key.pem",
CaFile: "testdata/tls/cert.pem",
},
},
},
}

subgraphTLS, err := core.NewSubgraphTLSConfig(cfg)
require.NoError(t, err)
require.NotNil(t, subgraphTLS)
require.NotNil(t, subgraphTLS.DefaultClientTLS)
require.Len(t, subgraphTLS.DefaultClientTLS.Certificates, 1)
require.NotNil(t, subgraphTLS.DefaultClientTLS.RootCAs)
})
}
53 changes: 35 additions & 18 deletions router-tests/testenv/testenv.go
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,9 @@ type SubgraphConfig struct {
Middleware func(http.Handler) http.Handler
Delay time.Duration
CloseOnStart bool
// TLSConfig enables TLS on this subgraph server. When set, the subgraph uses StartTLS()
// instead of Start(). This is useful for testing mTLS between the router and subgraphs.
TLSConfig *tls.Config
}

type LogObservationConfig struct {
Expand Down Expand Up @@ -581,15 +584,15 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) {
localDelay: cfg.Subgraphs.Countries.Delay,
}

employeesServer := makeSafeHttpTestServer(t, employees)
familyServer := makeSafeHttpTestServer(t, family)
hobbiesServer := makeSafeHttpTestServer(t, hobbies)
productsServer := makeSafeHttpTestServer(t, products)
test1Server := makeSafeHttpTestServer(t, test1)
availabilityServer := makeSafeHttpTestServer(t, availability)
moodServer := makeSafeHttpTestServer(t, mood)
countriesServer := makeSafeHttpTestServer(t, countries)
productFgServer := makeSafeHttpTestServer(t, productsFg)
employeesServer := makeSubgraphTestServer(t, employees, cfg.Subgraphs.Employees.TLSConfig)
familyServer := makeSubgraphTestServer(t, family, cfg.Subgraphs.Family.TLSConfig)
hobbiesServer := makeSubgraphTestServer(t, hobbies, cfg.Subgraphs.Hobbies.TLSConfig)
productsServer := makeSubgraphTestServer(t, products, cfg.Subgraphs.Products.TLSConfig)
test1Server := makeSubgraphTestServer(t, test1, cfg.Subgraphs.Test1.TLSConfig)
availabilityServer := makeSubgraphTestServer(t, availability, cfg.Subgraphs.Availability.TLSConfig)
moodServer := makeSubgraphTestServer(t, mood, cfg.Subgraphs.Mood.TLSConfig)
countriesServer := makeSubgraphTestServer(t, countries, cfg.Subgraphs.Countries.TLSConfig)
productFgServer := makeSubgraphTestServer(t, productsFg, cfg.Subgraphs.ProductsFg.TLSConfig)

var (
projectServer *grpc.Server
Expand Down Expand Up @@ -1011,15 +1014,15 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) {
localDelay: cfg.Subgraphs.Countries.Delay,
}

employeesServer := makeSafeHttpTestServer(t, employees)
familyServer := makeSafeHttpTestServer(t, family)
hobbiesServer := makeSafeHttpTestServer(t, hobbies)
productsServer := makeSafeHttpTestServer(t, products)
test1Server := makeSafeHttpTestServer(t, test1)
availabilityServer := makeSafeHttpTestServer(t, availability)
moodServer := makeSafeHttpTestServer(t, mood)
countriesServer := makeSafeHttpTestServer(t, countries)
productFgServer := makeSafeHttpTestServer(t, productsFg)
employeesServer := makeSubgraphTestServer(t, employees, cfg.Subgraphs.Employees.TLSConfig)
familyServer := makeSubgraphTestServer(t, family, cfg.Subgraphs.Family.TLSConfig)
hobbiesServer := makeSubgraphTestServer(t, hobbies, cfg.Subgraphs.Hobbies.TLSConfig)
productsServer := makeSubgraphTestServer(t, products, cfg.Subgraphs.Products.TLSConfig)
test1Server := makeSubgraphTestServer(t, test1, cfg.Subgraphs.Test1.TLSConfig)
availabilityServer := makeSubgraphTestServer(t, availability, cfg.Subgraphs.Availability.TLSConfig)
moodServer := makeSubgraphTestServer(t, mood, cfg.Subgraphs.Mood.TLSConfig)
countriesServer := makeSubgraphTestServer(t, countries, cfg.Subgraphs.Countries.TLSConfig)
productFgServer := makeSubgraphTestServer(t, productsFg, cfg.Subgraphs.ProductsFg.TLSConfig)

var (
projectServer *grpc.Server
Expand Down Expand Up @@ -1673,6 +1676,20 @@ func makeSafeHttpTestServer(t testing.TB, handler http.Handler) *httptest.Server
return s
}

func makeSafeTLSHttpTestServer(t testing.TB, handler http.Handler, tlsConfig *tls.Config) *httptest.Server {
s := httptest.NewUnstartedServer(handler)
s.TLS = tlsConfig
s.StartTLS()
return s
}

func makeSubgraphTestServer(t testing.TB, handler http.Handler, tlsConfig *tls.Config) *httptest.Server {
if tlsConfig != nil {
return makeSafeTLSHttpTestServer(t, handler, tlsConfig)
}
return makeSafeHttpTestServer(t, handler)
}

func makeSafeGRPCServer(t testing.TB, sd *grpc.ServiceDesc, service any) (*grpc.Server, string) {
t.Helper()

Expand Down
12 changes: 12 additions & 0 deletions router/core/factoryresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,18 @@ func NewDefaultFactoryResolver(
subgraphHTTPClients[subgraph] = subgraphClient
}

// Create HTTP clients for subgraphs that have per-subgraph transports (e.g. per-subgraph TLS)
// but no per-subgraph transport options. These use the default request timeout.
for subgraph, subgraphTransport := range subgraphTransports {
if _, exists := subgraphHTTPClients[subgraph]; !exists {
subgraphClient := &http.Client{
Transport: transportFactory.RoundTripper(subgraphTransport),
Timeout: transportOptions.SubgraphTransportOptions.RequestTimeout,
}
subgraphHTTPClients[subgraph] = subgraphClient
}
}

var factoryLogger abstractlogger.Logger
if log != nil {
factoryLogger = abstractlogger.NewZapLogger(log, abstractlogger.DebugLevel)
Expand Down
25 changes: 23 additions & 2 deletions router/core/graph_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/tls"
"errors"
"fmt"
"net/http"
Expand Down Expand Up @@ -145,15 +146,35 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC
traceDialer = NewTraceDialer()
}

// Resolve default and per-subgraph client TLS configs
var defaultClientTLS *tls.Config
perSubgraphTLS := map[string]*tls.Config{}
if r.subgraphTLSConfig != nil {
defaultClientTLS = r.subgraphTLSConfig.DefaultClientTLS
perSubgraphTLS = r.subgraphTLSConfig.PerSubgraphTLS
}

// Base transport
baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "")
baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "", defaultClientTLS)

// Subgraph transports
subgraphTransports := map[string]*http.Transport{}
for subgraph, subgraphOpts := range r.subgraphTransportOptions.SubgraphMap {
subgraphBaseTransport := newHTTPTransport(subgraphOpts, proxy, traceDialer, subgraph)
clientTLS := defaultClientTLS
if sgTLS, ok := perSubgraphTLS[subgraph]; ok {
clientTLS = sgTLS
}
subgraphBaseTransport := newHTTPTransport(subgraphOpts, proxy, traceDialer, subgraph, clientTLS)
subgraphTransports[subgraph] = subgraphBaseTransport
}
// Create transports for subgraphs with per-subgraph TLS configs that don't have
// per-subgraph transport options (they inherit the base transport options).
for subgraph, sgTLS := range perSubgraphTLS {
if _, exists := subgraphTransports[subgraph]; !exists {
subgraphBaseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, subgraph, sgTLS)
subgraphTransports[subgraph] = subgraphBaseTransport
}
}

ctx, cancel := context.WithCancel(ctx)
s := &graphServer{
Expand Down
Loading
Loading