diff --git a/router-tests/subgraph_mtls_test.go b/router-tests/subgraph_mtls_test.go new file mode 100644 index 0000000000..f920deee9d --- /dev/null +++ b/router-tests/subgraph_mtls_test.go @@ -0,0 +1,860 @@ +package integration + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestSubgraphMTLS(t *testing.T) { + t.Parallel() + + t.Run("InsecureSkipVerify", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("connects when enabled", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, false), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("fails when disabled", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, false), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: false, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.Equal(t, `{"errors":[{"message":"Failed to fetch from Subgraph 'employees'."}],"data":{"employees":null}}`, res.Body) + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("overrides global CaFile", func(t *testing.T) { + t.Parallel() + + // Global config uses CaFile that does NOT match the httptest server cert, + // so it would fail. Per-subgraph overrides with InsecureSkipVerify to skip + // verification entirely, proving per-subgraph can be less secure than global. + certPath, _ := generateServerCert(t) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, false), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CaFile: certPath, + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + InsecureSkipCaVerification: true, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + }) + }) + + t.Run("Client certificate", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("presents correct cert to mTLS subgraph", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("fails without client cert", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, `Failed to fetch from Subgraph`) + }) + }) + + t.Run("fails with wrong client cert", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert-2.pem", + KeyFile: "testdata/tls/key-2.pem", + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, `Failed to fetch from Subgraph`) + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("correct config without global", func(t *testing.T) { + t.Parallel() + + // No global config — only per-subgraph with correct client cert + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("correct config overrides incorrect global", func(t *testing.T) { + t.Parallel() + + // Global has wrong certs (cert-2), per-subgraph has correct ones (cert). + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert-2.pem", + KeyFile: "testdata/tls/key-2.pem", + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("incorrect config overrides correct global", func(t *testing.T) { + t.Parallel() + + // Global has correct certs (cert), per-subgraph overrides with wrong ones (cert-2). + // Per-subgraph always wins, even when it causes failure. + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert-2.pem", + KeyFile: "testdata/tls/key-2.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, `Failed to fetch from Subgraph`) + }) + }) + + t.Run("override without cert fails even when global has cert", func(t *testing.T) { + t.Parallel() + + // Global has full working mTLS config (InsecureSkipVerify + client cert). + // Per-subgraph overrides with ONLY InsecureSkipVerify — no client cert. + // Because per-subgraph COMPLETELY REPLACES global (no field merging), + // the router will not present a client cert, causing mTLS failure. + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + InsecureSkipCaVerification: true, + // NO CertFile/KeyFile — proves fields are NOT inherited from All + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res, err := xEnv.MakeGraphQLRequest(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.NoError(t, err) + require.Contains(t, res.Body, `Failed to fetch from Subgraph`) + }) + }) + }) + }) + + t.Run("CaFile", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("trusts subgraph server", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CaFile: certPath, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("trusts subgraph server without global config", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + CaFile: certPath, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("overrides global InsecureSkipVerify with proper verification", func(t *testing.T) { + t.Parallel() + + // Global uses InsecureSkipVerify (insecure), per-subgraph uses CaFile + // (proper verification). Proves per-subgraph can be more secure than global. + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + CaFile: certPath, + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + }) + }) + + t.Run("Full mTLS", func(t *testing.T) { + t.Parallel() + + t.Run("All", func(t *testing.T) { + t.Parallel() + + t.Run("with CaFile and client certificate", func(t *testing.T) { + t.Parallel() + + // Production-like: router verifies server cert via CaFile + // AND presents client cert for mTLS — no InsecureSkipCaVerification. + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + caPool := loadSubgraphMTLSCACertPool(t, "testdata/tls/cert.pem") + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientCAs: caPool, + ClientAuth: tls.RequireAndVerifyClientCert, + }, + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CaFile: certPath, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + }) + + t.Run("Per-subgraph", func(t *testing.T) { + t.Parallel() + + t.Run("with CaFile and client certificate without global config", func(t *testing.T) { + t.Parallel() + + // Production-like per-subgraph: CaFile for server verification + // + client cert for mTLS, no global config at all. + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + caPool := loadSubgraphMTLSCACertPool(t, "testdata/tls/cert.pem") + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientCAs: caPool, + ClientAuth: tls.RequireAndVerifyClientCert, + }, + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + CaFile: certPath, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + }) + }) + + t.Run("Traffic shaping integration", func(t *testing.T) { + t.Parallel() + + t.Run("TLS with per-subgraph transport", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, false), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTransportOptions(core.NewSubgraphTransportOptions(config.TrafficShapingRules{ + All: config.GlobalSubgraphRequestRule{ + RequestTimeout: ToPtr(30 * time.Second), + }, + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "employees": { + RequestTimeout: ToPtr(5 * time.Second), + }, + }, + })), + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("mTLS with per-subgraph transport", func(t *testing.T) { + t.Parallel() + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTransportOptions(core.NewSubgraphTransportOptions(config.TrafficShapingRules{ + Subgraphs: map[string]config.GlobalSubgraphRequestRule{ + "employees": { + RequestTimeout: ToPtr(5 * time.Second), + }, + }, + })), + core.WithSubgraphTLSConfiguration(config.ClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "employees": { + InsecureSkipCaVerification: true, + CertFile: "testdata/tls/cert.pem", + KeyFile: "testdata/tls/key.pem", + }, + }, + }), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + }) +} + +func TestSubgraphMTLSEnvVarConfig(t *testing.T) { + t.Run("Verify envars being set", func(t *testing.T) { + t.Run("InsecureSkipCaVerification defaults to false via env var", func(t *testing.T) { + // Do not set TLS_CLIENT_ALL_INSECURE_SKIP_CA_VERIFICATION — should default to false + cfg := loadConfigFromEnv(t) + require.False(t, cfg.TLS.Client.All.InsecureSkipCaVerification) + }) + + t.Run("InsecureSkipCaVerification set via env var", func(t *testing.T) { + t.Setenv("TLS_CLIENT_ALL_INSECURE_SKIP_CA_VERIFICATION", "true") + + cfg := loadConfigFromEnv(t) + require.True(t, cfg.TLS.Client.All.InsecureSkipCaVerification) + }) + + t.Run("CertFile and KeyFile set via env vars", func(t *testing.T) { + t.Setenv("TLS_CLIENT_ALL_CERT_FILE", "testdata/tls/cert.pem") + t.Setenv("TLS_CLIENT_ALL_KEY_FILE", "testdata/tls/key.pem") + + cfg := loadConfigFromEnv(t) + + require.Equal(t, "testdata/tls/cert.pem", cfg.TLS.Client.All.CertFile) + require.Equal(t, "testdata/tls/key.pem", cfg.TLS.Client.All.KeyFile) + }) + + t.Run("CaFile set via env var", func(t *testing.T) { + t.Setenv("TLS_CLIENT_ALL_CA_FILE", "testdata/tls/cert.pem") + + cfg := loadConfigFromEnv(t) + + require.Equal(t, "testdata/tls/cert.pem", cfg.TLS.Client.All.CaFile) + }) + }) + + t.Run("InsecureSkipCaVerification set via env var", func(t *testing.T) { + t.Setenv("TLS_CLIENT_ALL_INSECURE_SKIP_CA_VERIFICATION", "true") + + cfg := loadConfigFromEnv(t) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, false), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(cfg.TLS.Client), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("Router presents client certificate to mTLS subgraph via env vars", func(t *testing.T) { + t.Setenv("TLS_CLIENT_ALL_INSECURE_SKIP_CA_VERIFICATION", "true") + t.Setenv("TLS_CLIENT_ALL_CERT_FILE", "testdata/tls/cert.pem") + t.Setenv("TLS_CLIENT_ALL_KEY_FILE", "testdata/tls/key.pem") + + cfg := loadConfigFromEnv(t) + + require.True(t, cfg.TLS.Client.All.InsecureSkipCaVerification) + require.Equal(t, "testdata/tls/cert.pem", cfg.TLS.Client.All.CertFile) + require.Equal(t, "testdata/tls/key.pem", cfg.TLS.Client.All.KeyFile) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: subgraphMTLSServerConfig(t, true), + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(cfg.TLS.Client), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("Router trusts subgraph server via CaFile env var", func(t *testing.T) { + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + t.Setenv("TLS_CLIENT_ALL_CA_FILE", certPath) + + cfg := loadConfigFromEnv(t) + require.Equal(t, certPath, cfg.TLS.Client.All.CaFile) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + }, + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(cfg.TLS.Client), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) + + t.Run("Full mTLS via env vars with CaFile and client certificate", func(t *testing.T) { + certPath, keyPath := generateServerCert(t) + serverCert, err := tls.LoadX509KeyPair(certPath, keyPath) + require.NoError(t, err) + + caPool := loadSubgraphMTLSCACertPool(t, "testdata/tls/cert.pem") + + t.Setenv("TLS_CLIENT_ALL_CA_FILE", certPath) + t.Setenv("TLS_CLIENT_ALL_CERT_FILE", "testdata/tls/cert.pem") + t.Setenv("TLS_CLIENT_ALL_KEY_FILE", "testdata/tls/key.pem") + + cfg := loadConfigFromEnv(t) + + require.Equal(t, certPath, cfg.TLS.Client.All.CaFile) + require.Equal(t, "testdata/tls/cert.pem", cfg.TLS.Client.All.CertFile) + require.Equal(t, "testdata/tls/key.pem", cfg.TLS.Client.All.KeyFile) + require.False(t, cfg.TLS.Client.All.InsecureSkipCaVerification) + + testenv.Run(t, &testenv.Config{ + Subgraphs: testenv.SubgraphsConfig{ + Employees: testenv.SubgraphConfig{ + TLSConfig: &tls.Config{ + Certificates: []tls.Certificate{serverCert}, + ClientCAs: caPool, + ClientAuth: tls.RequireAndVerifyClientCert, + }, + }, + }, + RouterOptions: []core.Option{ + core.WithSubgraphTLSConfiguration(cfg.TLS.Client), + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + res := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `query { employees { id } }`, + }) + require.JSONEq(t, employeesIDData, res.Body) + }) + }) +} + +// loadConfigFromEnv creates a minimal config file and loads config, allowing +// environment variables to populate the TLS client configuration fields. +func loadConfigFromEnv(t *testing.T) config.Config { + t.Helper() + + f, err := os.CreateTemp(t.TempDir(), "config_test_*.yaml") + require.NoError(t, err) + _, err = f.WriteString(`version: "1"`) + require.NoError(t, err) + require.NoError(t, f.Close()) + + result, err := config.LoadConfig([]string{f.Name()}) + require.NoError(t, err) + return result.Config +} + +// loadSubgraphMTLSCACertPool loads a CA certificate pool from a PEM file. +func loadSubgraphMTLSCACertPool(t *testing.T, caFile string) *x509.CertPool { + t.Helper() + caCert, err := os.ReadFile(caFile) + require.NoError(t, err) + + pool := x509.NewCertPool() + ok := pool.AppendCertsFromPEM(caCert) + require.True(t, ok, "failed to append CA cert to pool") + return pool +} + +// subgraphMTLSServerConfig creates a tls.Config for a subgraph test server. +// It does NOT set Certificates — httptest.StartTLS() will generate a cert valid for 127.0.0.1. +// If requireClientCert is true, the subgraph will require the router to present a valid client certificate +// signed by the CA in testdata/tls/cert.pem. +func subgraphMTLSServerConfig(t *testing.T, requireClientCert bool) *tls.Config { + t.Helper() + cfg := &tls.Config{} + if requireClientCert { + caPool := loadSubgraphMTLSCACertPool(t, "testdata/tls/cert.pem") + cfg.ClientCAs = caPool + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + return cfg +} + +// generateServerCert creates a self-signed certificate valid for 127.0.0.1. +// Returns paths to the cert and key PEM files in a temp directory. +// The cert can be used as both the server certificate and the router's CaFile +// (since it's self-signed, it is its own CA). +func generateServerCert(t *testing.T) (certPath, keyPath string) { + t.Helper() + + dir := t.TempDir() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test-server"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)}, + IsCA: true, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certPath = filepath.Join(dir, "server.crt") + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + keyPath = filepath.Join(dir, "server.key") + keyFile, err := os.Create(keyPath) + require.NoError(t, err) + keyDER, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + require.NoError(t, keyFile.Close()) + + return certPath, keyPath +} diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 6d1d884ed3..407f1f47b4 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -394,6 +394,10 @@ type SubgraphConfig struct { GRPCInterceptor grpc.UnaryServerInterceptor Delay time.Duration CloseOnStart bool + + // TLSConfig enables TLS on this subgraph server. When set, the subgraph uses StartTLS() + // instead of Start(). This is useful for testing mTLS between the router and subgraphs. + TLSConfig *tls.Config } type LogObservationConfig struct { @@ -584,15 +588,15 @@ func CreateTestSupervisorEnv(t testing.TB, cfg *Config) (*Environment, error) { localDelay: cfg.Subgraphs.Countries.Delay, } - employeesServer := makeSafeHttpTestServer(t, employees) - familyServer := makeSafeHttpTestServer(t, family) - hobbiesServer := makeSafeHttpTestServer(t, hobbies) - productsServer := makeSafeHttpTestServer(t, products) - test1Server := makeSafeHttpTestServer(t, test1) - availabilityServer := makeSafeHttpTestServer(t, availability) - moodServer := makeSafeHttpTestServer(t, mood) - countriesServer := makeSafeHttpTestServer(t, countries) - productFgServer := makeSafeHttpTestServer(t, productsFg) + employeesServer := makeSubgraphTestServer(t, employees, cfg.Subgraphs.Employees.TLSConfig) + familyServer := makeSubgraphTestServer(t, family, cfg.Subgraphs.Family.TLSConfig) + hobbiesServer := makeSubgraphTestServer(t, hobbies, cfg.Subgraphs.Hobbies.TLSConfig) + productsServer := makeSubgraphTestServer(t, products, cfg.Subgraphs.Products.TLSConfig) + test1Server := makeSubgraphTestServer(t, test1, cfg.Subgraphs.Test1.TLSConfig) + availabilityServer := makeSubgraphTestServer(t, availability, cfg.Subgraphs.Availability.TLSConfig) + moodServer := makeSubgraphTestServer(t, mood, cfg.Subgraphs.Mood.TLSConfig) + countriesServer := makeSubgraphTestServer(t, countries, cfg.Subgraphs.Countries.TLSConfig) + productFgServer := makeSubgraphTestServer(t, productsFg, cfg.Subgraphs.ProductsFg.TLSConfig) var ( projectServer *grpc.Server @@ -1014,15 +1018,15 @@ func CreateTestEnv(t testing.TB, cfg *Config) (*Environment, error) { localDelay: cfg.Subgraphs.Countries.Delay, } - employeesServer := makeSafeHttpTestServer(t, employees) - familyServer := makeSafeHttpTestServer(t, family) - hobbiesServer := makeSafeHttpTestServer(t, hobbies) - productsServer := makeSafeHttpTestServer(t, products) - test1Server := makeSafeHttpTestServer(t, test1) - availabilityServer := makeSafeHttpTestServer(t, availability) - moodServer := makeSafeHttpTestServer(t, mood) - countriesServer := makeSafeHttpTestServer(t, countries) - productFgServer := makeSafeHttpTestServer(t, productsFg) + employeesServer := makeSubgraphTestServer(t, employees, cfg.Subgraphs.Employees.TLSConfig) + familyServer := makeSubgraphTestServer(t, family, cfg.Subgraphs.Family.TLSConfig) + hobbiesServer := makeSubgraphTestServer(t, hobbies, cfg.Subgraphs.Hobbies.TLSConfig) + productsServer := makeSubgraphTestServer(t, products, cfg.Subgraphs.Products.TLSConfig) + test1Server := makeSubgraphTestServer(t, test1, cfg.Subgraphs.Test1.TLSConfig) + availabilityServer := makeSubgraphTestServer(t, availability, cfg.Subgraphs.Availability.TLSConfig) + moodServer := makeSubgraphTestServer(t, mood, cfg.Subgraphs.Mood.TLSConfig) + countriesServer := makeSubgraphTestServer(t, countries, cfg.Subgraphs.Countries.TLSConfig) + productFgServer := makeSubgraphTestServer(t, productsFg, cfg.Subgraphs.ProductsFg.TLSConfig) var ( projectServer *grpc.Server @@ -1676,12 +1680,18 @@ func testVersionedTokenClaims() jwt.MapClaims { } } -func makeSafeHttpTestServer(t testing.TB, handler http.Handler) *httptest.Server { +func makeSubgraphTestServer(_ testing.TB, handler http.Handler, tlsConfig *tls.Config) *httptest.Server { // NewUnstartedServer binds an ephemeral port. // We want to avoid using freeport because it creates too much strain on the network stack: // freeport checks if port is available by listening on it and then closing the listener. // On Linux trying to listen on the just-closed port could lead to the "unable to bind" error. s := httptest.NewUnstartedServer(handler) + + if tlsConfig != nil { + s.TLS = tlsConfig + s.StartTLS() + return s + } s.Start() return s } diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index 4ab0fd3bb7..33f2dc386f 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -104,6 +104,18 @@ func NewDefaultFactoryResolver( subgraphHTTPClients[subgraph] = subgraphClient } + // Create HTTP clients for subgraphs that have per-subgraph TLS + // but no per-subgraph transport options. These use the default request timeout. + for subgraph, subgraphTransport := range subgraphTransports { + if _, exists := subgraphHTTPClients[subgraph]; !exists { + subgraphClient := &http.Client{ + Transport: transportFactory.RoundTripper(subgraphTransport), + Timeout: transportOptions.SubgraphTransportOptions.RequestTimeout, + } + subgraphHTTPClients[subgraph] = subgraphClient + } + } + var factoryLogger abstractlogger.Logger if log != nil { factoryLogger = abstractlogger.NewZapLogger(log, abstractlogger.DebugLevel) diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 8413bcfca6..325ec74c25 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -145,16 +145,35 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC traceDialer = NewTraceDialer() } + // Build subgraph client TLS configs (mTLS for outbound subgraph connections) + defaultClientTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(r.logger, &r.subgraphTLSConfiguration) + if err != nil { + return nil, fmt.Errorf("could not build subgraph client TLS config: %w", err) + } + // Base transport - baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "") + baseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, "", defaultClientTLS) // Subgraph transports subgraphTransports := map[string]*http.Transport{} for subgraph, subgraphOpts := range r.subgraphTransportOptions.SubgraphMap { - subgraphBaseTransport := newHTTPTransport(subgraphOpts, proxy, traceDialer, subgraph) + clientTLS := defaultClientTLS + if sgTLS, ok := perSubgraphTLS[subgraph]; ok { + clientTLS = sgTLS + } + subgraphBaseTransport := newHTTPTransport(subgraphOpts, proxy, traceDialer, subgraph, clientTLS) subgraphTransports[subgraph] = subgraphBaseTransport } + // Create transports for subgraphs with per-subgraph TLS configs that don't have + // per-subgraph transport options (they inherit the base transport options). + for subgraph, sgTLS := range perSubgraphTLS { + if _, exists := subgraphTransports[subgraph]; !exists { + subgraphBaseTransport := newHTTPTransport(r.subgraphTransportOptions.TransportRequestOptions, proxy, traceDialer, subgraph, sgTLS) + subgraphTransports[subgraph] = subgraphBaseTransport + } + } + ctx, cancel := context.WithCancel(ctx) s := &graphServer{ context: ctx, @@ -1305,9 +1324,9 @@ func (s *graphServer) buildGraphMux( MaxDepth: s.securityConfiguration.ParserLimits.ApproximateDepthLimit, MaxFields: s.securityConfiguration.ParserLimits.TotalFieldsLimit, }, - OperationNameLengthLimit: s.securityConfiguration.OperationNameLengthLimit, - ApolloCompatibilityFlags: s.apolloCompatibilityFlags, - ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, + OperationNameLengthLimit: s.securityConfiguration.OperationNameLengthLimit, + ApolloCompatibilityFlags: s.apolloCompatibilityFlags, + ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, DisableExposingVariablesContentOnValidationError: s.engineExecutionConfiguration.DisableExposingVariablesContentOnValidationError, RelaxSubgraphOperationFieldSelectionMergingNullability: s.engineExecutionConfiguration.RelaxSubgraphOperationFieldSelectionMergingNullability, ComplexityLimits: s.securityConfiguration.ComplexityLimits, diff --git a/router/core/router.go b/router/core/router.go index 63ae344944..aea15aefde 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -2140,6 +2140,12 @@ func WithTLSConfig(cfg *TlsConfig) Option { } } +func WithSubgraphTLSConfiguration(cfg config.ClientTLSConfiguration) Option { + return func(r *Router) { + r.subgraphTLSConfiguration = cfg + } +} + func WithTelemetryAttributes(attributes []config.CustomAttribute) Option { return func(r *Router) { r.telemetryAttributes = attributes @@ -2254,7 +2260,7 @@ func WithStreamsHandlerConfiguration(cfg config.StreamsHandlerConfiguration) Opt type ProxyFunc func(req *http.Request) (*url.URL, error) -func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDialer *TraceDialer, subgraph string) *http.Transport { +func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDialer *TraceDialer, subgraph string, clientTLS *tls.Config) *http.Transport { dialer := &net.Dialer{ Timeout: opts.DialTimeout, KeepAlive: opts.KeepAliveProbeInterval, @@ -2283,6 +2289,11 @@ func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDiale // Will return nil when HTTP(S)_PROXY does not exist or is empty. // This will prevent the transport from handling the proxy when it is not needed. Proxy: proxy, + // TLSClientConfig configures client TLS for outbound subgraph connections (mTLS). + } + + if clientTLS != nil { + transport.TLSClientConfig = clientTLS } if traceDialer != nil { diff --git a/router/core/router_config.go b/router/core/router_config.go index c3af6c201d..bdb126614f 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -118,6 +118,7 @@ type Config struct { localhostFallbackInsideDocker bool tlsServerConfig *tls.Config tlsConfig *TlsConfig + subgraphTLSConfiguration config.ClientTLSConfiguration telemetryAttributes []config.CustomAttribute tracePropagators []propagation.TextMapPropagator compositePropagator propagation.TextMapPropagator diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index dad0771988..2f9f6fcbfb 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -274,6 +274,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config, reloadPersi WithDemoMode(config.DemoMode), WithStreamsHandlerConfiguration(config.Events.Handlers), WithReloadPersistentState(reloadPersistentState), + WithSubgraphTLSConfiguration(config.TLS.Client), } return options diff --git a/router/core/tls.go b/router/core/tls.go new file mode 100644 index 0000000000..752e0b214c --- /dev/null +++ b/router/core/tls.go @@ -0,0 +1,84 @@ +package core + +import ( + "crypto/tls" + "crypto/x509" + "errors" + "fmt" + "go.uber.org/zap" + "os" + + "github.com/wundergraph/cosmo/router/pkg/config" +) + +// buildTLSClientConfig creates a *tls.Config from a TLSClientCertConfiguration. +func buildTLSClientConfig(clientCfg *config.TLSClientCertConfiguration) (*tls.Config, error) { + tlsConfig := &tls.Config{ + InsecureSkipVerify: clientCfg.InsecureSkipCaVerification, + } + + // Load client certificate and key if provided + if clientCfg.CertFile != "" && clientCfg.KeyFile != "" { + cert, err := tls.LoadX509KeyPair(clientCfg.CertFile, clientCfg.KeyFile) + if err != nil { + return nil, fmt.Errorf("failed to load client TLS cert and key: %w", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + + // Load custom CA for verifying subgraph server certificates + if clientCfg.CaFile != "" { + caCert, err := os.ReadFile(clientCfg.CaFile) + if err != nil { + return nil, fmt.Errorf("failed to read client TLS CA file: %w", err) + } + caPool := x509.NewCertPool() + if ok := caPool.AppendCertsFromPEM(caCert); !ok { + return nil, errors.New("failed to append client TLS CA cert to pool") + } + tlsConfig.RootCAs = caPool + } + + return tlsConfig, nil +} + +// buildSubgraphTLSConfigs builds the default and per-subgraph TLS configs from raw configuration. +// Returns (defaultClientTLS, perSubgraphTLS, error). +func buildSubgraphTLSConfigs(logger *zap.Logger, cfg *config.ClientTLSConfiguration) (*tls.Config, map[string]*tls.Config, error) { + hasAll := (cfg.All.CertFile != "" && cfg.All.KeyFile != "") || cfg.All.CaFile != "" || cfg.All.InsecureSkipCaVerification + + // If no global TLS config is provided and there are no subgraph specific TLS configs + if !hasAll && len(cfg.Subgraphs) == 0 { + return nil, nil, nil + } + + var defaultClientTLS *tls.Config + perSubgraphTLS := make(map[string]*tls.Config) + + if hasAll { + if cfg.All.InsecureSkipCaVerification { + logger.Warn("Global TLS config has InsecureSkipCaVerification enabled. This is not recommended for production environments.") + } + + defaultTLS, err := buildTLSClientConfig(&cfg.All) + if err != nil { + return nil, nil, fmt.Errorf("failed to build global subgraph TLS config: %w", err) + } + defaultClientTLS = defaultTLS + } + + for name, sgCfg := range cfg.Subgraphs { + if sgCfg.InsecureSkipCaVerification { + logger.Warn("Subgraph TLS config inherits InsecureSkipCaVerification from global config. This is not recommended for production environments.", + zap.String("subgraph", name)) + } + + subgraphTLS, err := buildTLSClientConfig(&sgCfg) + if err != nil { + return nil, nil, fmt.Errorf("failed to build TLS config for subgraph %q: %w", name, err) + } + perSubgraphTLS[name] = subgraphTLS + } + + return defaultClientTLS, perSubgraphTLS, nil +} diff --git a/router/core/tls_test.go b/router/core/tls_test.go new file mode 100644 index 0000000000..d478e58b3f --- /dev/null +++ b/router/core/tls_test.go @@ -0,0 +1,283 @@ +package core + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "go.uber.org/zap/zaptest/observer" + "math/big" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/config" +) + +func TestBuildTLSClientConfig(t *testing.T) { + t.Parallel() + + t.Run("returns config with insecure_skip_ca_verification only", func(t *testing.T) { + t.Parallel() + + tlsCfg, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }) + + require.NoError(t, err) + require.NotNil(t, tlsCfg) + require.True(t, tlsCfg.InsecureSkipVerify) + require.Empty(t, tlsCfg.Certificates) + require.Nil(t, tlsCfg.RootCAs) + }) + + t.Run("loads client cert and key", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, "client") + + tlsCfg, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + CertFile: certPath, + KeyFile: keyPath, + }) + + require.NoError(t, err) + require.NotNil(t, tlsCfg) + require.Len(t, tlsCfg.Certificates, 1) + }) + + t.Run("loads CA file", func(t *testing.T) { + t.Parallel() + + certPath, _ := generateTestCert(t, "ca") + + tlsCfg, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + CaFile: certPath, + }) + require.NoError(t, err) + require.NotNil(t, tlsCfg) + require.NotNil(t, tlsCfg.RootCAs) + }) + + t.Run("errors on invalid cert path", func(t *testing.T) { + t.Parallel() + + _, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + CertFile: "/nonexistent/cert.pem", + KeyFile: "/nonexistent/key.pem", + }) + require.Error(t, err) + require.EqualError(t, err, "failed to load client TLS cert and key: open /nonexistent/cert.pem: no such file or directory") + }) + + t.Run("errors on invalid CA path", func(t *testing.T) { + t.Parallel() + + _, err := buildTLSClientConfig(&config.TLSClientCertConfiguration{ + CaFile: "/nonexistent/ca.pem", + }) + require.Error(t, err) + require.EqualError(t, err, "failed to read client TLS CA file: open /nonexistent/ca.pem: no such file or directory") + }) + + t.Run("returns nil when no TLS configured", func(t *testing.T) { + t.Parallel() + + cfg := &config.ClientTLSConfiguration{} + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + require.NoError(t, err) + require.Nil(t, defaultTLS) + require.Nil(t, perSubgraphTLS) + }) + + t.Run("builds global client TLS config", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, "client") + caPath, _ := generateTestCert(t, "ca") + + cfg := &config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CertFile: certPath, + KeyFile: keyPath, + CaFile: caPath, + }, + } + + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + require.NoError(t, err) + require.NotNil(t, defaultTLS) + require.Len(t, defaultTLS.Certificates, 1) + require.NotNil(t, defaultTLS.RootCAs) + require.Empty(t, perSubgraphTLS) + }) + + t.Run("builds per-subgraph TLS config", func(t *testing.T) { + t.Parallel() + + certPath, keyPath := generateTestCert(t, "products") + + cfg := &config.ClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "products": { + CertFile: certPath, + KeyFile: keyPath, + }, + }, + } + + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + require.NoError(t, err) + require.Nil(t, defaultTLS) + require.Contains(t, perSubgraphTLS, "products") + require.Len(t, perSubgraphTLS["products"].Certificates, 1) + }) + + t.Run("builds both global and per-subgraph TLS config", func(t *testing.T) { + t.Parallel() + + globalCert, globalKey := generateTestCert(t, "global") + productsCert, productsKey := generateTestCert(t, "products") + + cfg := &config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CertFile: globalCert, + KeyFile: globalKey, + }, + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "products": { + CertFile: productsCert, + KeyFile: productsKey, + }, + }, + } + + defaultTLS, perSubgraphTLS, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + require.NoError(t, err) + require.NotNil(t, defaultTLS) + require.Contains(t, perSubgraphTLS, "products") + }) + + t.Run("errors on invalid global cert", func(t *testing.T) { + t.Parallel() + + cfg := &config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + CertFile: "/nonexistent/cert.pem", + KeyFile: "/nonexistent/key.pem", + }, + } + + _, _, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + require.Error(t, err) + require.EqualError(t, err, "failed to build global subgraph TLS config: failed to load client TLS cert and key: open /nonexistent/cert.pem: no such file or directory") + }) + + t.Run("errors on invalid per-subgraph cert", func(t *testing.T) { + t.Parallel() + + cfg := &config.ClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "products": { + CertFile: "/nonexistent/cert.pem", + KeyFile: "/nonexistent/key.pem", + }, + }, + } + + _, _, err := buildSubgraphTLSConfigs(zap.NewNop(), cfg) + require.Error(t, err) + require.EqualError(t, err, `failed to build TLS config for subgraph "products": failed to load client TLS cert and key: open /nonexistent/cert.pem: no such file or directory`) + }) + + t.Run("logs warning when global InsecureSkipCaVerification is enabled", func(t *testing.T) { + t.Parallel() + + core, logs := observer.New(zapcore.WarnLevel) + logger := zap.New(core) + + cfg := &config.ClientTLSConfiguration{ + All: config.TLSClientCertConfiguration{ + InsecureSkipCaVerification: true, + }, + } + + defaultTLS, _, err := buildSubgraphTLSConfigs(logger, cfg) + require.NoError(t, err) + require.NotNil(t, defaultTLS) + require.True(t, defaultTLS.InsecureSkipVerify) + + require.Equal(t, 1, logs.Len()) + require.Equal(t, "Global TLS config has InsecureSkipCaVerification enabled. This is not recommended for production environments.", logs.All()[0].Message) + }) + + t.Run("logs warning when subgraph InsecureSkipCaVerification is enabled", func(t *testing.T) { + t.Parallel() + + core, logs := observer.New(zapcore.WarnLevel) + logger := zap.New(core) + + cfg := &config.ClientTLSConfiguration{ + Subgraphs: map[string]config.TLSClientCertConfiguration{ + "products": { + InsecureSkipCaVerification: true, + }, + }, + } + + _, perSubgraphTLS, err := buildSubgraphTLSConfigs(logger, cfg) + require.NoError(t, err) + require.Contains(t, perSubgraphTLS, "products") + require.True(t, perSubgraphTLS["products"].InsecureSkipVerify) + + require.Equal(t, 1, logs.Len()) + require.Equal(t, "Subgraph TLS config inherits InsecureSkipCaVerification from global config. This is not recommended for production environments.", logs.All()[0].Message) + require.Equal(t, "products", logs.All()[0].ContextMap()["subgraph"]) + }) +} + +// generateTestCert creates a self-signed certificate and key in the given directory. +// Returns the paths to the cert and key files. +func generateTestCert(t *testing.T, prefix string) (certPath, keyPath string) { + t.Helper() + + dir := t.TempDir() + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: prefix + "-test"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature, + IsCA: true, + BasicConstraintsValid: true, + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + require.NoError(t, err) + + certPath = filepath.Join(dir, prefix+".crt") + certFile, err := os.Create(certPath) + require.NoError(t, err) + require.NoError(t, pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})) + require.NoError(t, certFile.Close()) + + keyPath = filepath.Join(dir, prefix+".key") + keyFile, err := os.Create(keyPath) + require.NoError(t, err) + keyDER, err := x509.MarshalECPrivateKey(key) + require.NoError(t, err) + require.NoError(t, pem.Encode(keyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER})) + require.NoError(t, keyFile.Close()) + + return certPath, keyPath +} diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 5baa086249..79adc0972d 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -802,8 +802,23 @@ type TLSServerConfiguration struct { ClientAuth TLSClientAuthConfiguration `yaml:"client_auth,omitempty"` } +type TLSClientCertConfiguration struct { + CertFile string `yaml:"cert_file,omitempty" env:"CERT_FILE"` + KeyFile string `yaml:"key_file,omitempty" env:"KEY_FILE"` + CaFile string `yaml:"ca_file,omitempty" env:"CA_FILE"` + InsecureSkipCaVerification bool `yaml:"insecure_skip_ca_verification" envDefault:"false" env:"INSECURE_SKIP_CA_VERIFICATION"` +} + +type ClientTLSConfiguration struct { + // All applies to all subgraph connections. + All TLSClientCertConfiguration `yaml:"all" envPrefix:"TLS_CLIENT_ALL_"` + // Subgraphs overrides per-subgraph TLS config. Key is the subgraph name. + Subgraphs map[string]TLSClientCertConfiguration `yaml:"subgraphs,omitempty"` +} + type TLSConfiguration struct { Server TLSServerConfiguration `yaml:"server"` + Client ClientTLSConfiguration `yaml:"client"` } type SubgraphErrorPropagationMode string diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index e9335caf1c..167cb875bd 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -441,6 +441,78 @@ "key_file" ] } + }, + "client": { + "type": "object", + "description": "The TLS configuration for outbound connections from the router to subgraphs. Enables mTLS by presenting a client certificate when connecting to subgraphs.", + "additionalProperties": false, + "properties": { + "all": { + "type": "object", + "description": "TLS configuration applied to all subgraph connections.", + "additionalProperties": false, + "properties": { + "cert_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client certificate chain file. Used to authenticate the router to subgraphs. May include intermediate certificates." + }, + "key_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client private key file." + }, + "ca_file": { + "type": "string", + "format": "file-path", + "description": "The path to the CA certificate file. Used to verify subgraph server certificates. If not set, the system's root CAs are used." + }, + "insecure_skip_ca_verification": { + "type": "boolean", + "default": false, + "description": "Skip verification of the subgraph server certificate. Only use for development or testing." + } + }, + "dependencies": { + "cert_file": ["key_file"], + "key_file": ["cert_file"] + } + }, + "subgraphs": { + "type": "object", + "description": "Per-subgraph TLS configuration overrides. Each key is a subgraph name. Fully overrides the 'all' config for that subgraph.", + "additionalProperties": { + "type": "object", + "additionalProperties": false, + "properties": { + "cert_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client certificate chain file for this subgraph. May include intermediate certificates." + }, + "key_file": { + "type": "string", + "format": "file-path", + "description": "The path to the client private key file for this subgraph." + }, + "ca_file": { + "type": "string", + "format": "file-path", + "description": "The path to the CA certificate file for verifying this subgraph's server certificate." + }, + "insecure_skip_ca_verification": { + "type": "boolean", + "default": false, + "description": "Skip verification of this subgraph's server certificate. Only use for development or testing." + } + }, + "dependencies": { + "cert_file": ["key_file"], + "key_file": ["cert_file"] + } + } + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index b089b39eac..3767f468cc 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -33,6 +33,21 @@ liveness_check_path: '/health/live' router_registration: true graphql_path: /graphql dev_mode: false +tls: + server: + enabled: false + client: + all: + cert_file: '/path/to/client.crt' + key_file: '/path/to/client.key' + ca_file: '/path/to/ca.crt' + insecure_skip_ca_verification: false + subgraphs: + products: + cert_file: '/path/to/products-client.crt' + key_file: '/path/to/products-client.key' + ca_file: '/path/to/products-ca.crt' + insecure_skip_ca_verification: false instance_id: '' graphql_metrics: enabled: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index f53b401f76..710630ae04 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -118,6 +118,15 @@ "CertFile": "", "Required": false } + }, + "Client": { + "All": { + "CertFile": "", + "KeyFile": "", + "CaFile": "", + "InsecureSkipCaVerification": false + }, + "Subgraphs": null } }, "CacheControl": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 70091fcae3..fa60cb18e8 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -148,6 +148,22 @@ "CertFile": "", "Required": false } + }, + "Client": { + "All": { + "CertFile": "/path/to/client.crt", + "KeyFile": "/path/to/client.key", + "CaFile": "/path/to/ca.crt", + "InsecureSkipCaVerification": false + }, + "Subgraphs": { + "products": { + "CertFile": "/path/to/products-client.crt", + "KeyFile": "/path/to/products-client.key", + "CaFile": "/path/to/products-ca.crt", + "InsecureSkipCaVerification": false + } + } } }, "CacheControl": {