diff --git a/go/mysql/auth_server_clientcert_test.go b/go/mysql/auth_server_clientcert_test.go index 9dbfdfe0d72..ed9062d29bf 100644 --- a/go/mysql/auth_server_clientcert_test.go +++ b/go/mysql/auth_server_clientcert_test.go @@ -61,7 +61,8 @@ func TestValidCert(t *testing.T) { serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem")) + path.Join(root, "ca-cert.pem"), + "") if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } @@ -143,7 +144,8 @@ func TestNoCert(t *testing.T) { serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem")) + path.Join(root, "ca-cert.pem"), + "") if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index 696836c44c6..9015bcd0951 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -125,7 +125,8 @@ func TestSSLConnection(t *testing.T) { serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem")) + path.Join(root, "ca-cert.pem"), + "") if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 7eab14765af..a1ca1271255 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -839,7 +839,8 @@ func TestTLSServer(t *testing.T) { serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem")) + path.Join(root, "ca-cert.pem"), + "") require.NoError(t, err) l.TLSConfig.Store(serverConfig) go l.Accept() @@ -927,7 +928,8 @@ func TestTLSRequired(t *testing.T) { serverConfig, err := vttls.ServerConfig( path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), - path.Join(root, "ca-cert.pem")) + path.Join(root, "ca-cert.pem"), + "") require.NoError(t, err) l.TLSConfig.Store(serverConfig) l.RequireSecureTransport = true diff --git a/go/vt/servenv/grpc_server.go b/go/vt/servenv/grpc_server.go index 448fbb1d73b..98268a50942 100644 --- a/go/vt/servenv/grpc_server.go +++ b/go/vt/servenv/grpc_server.go @@ -64,6 +64,9 @@ var ( // GRPCCA is the CA to use if TLS is enabled GRPCCA = flag.String("grpc_ca", "", "server CA to use for gRPC connections, requires TLS, and enforces client certificate check") + // GRPCServerCA if specified will combine server cert and server CA + GRPCServerCA = flag.String("grpc_server_ca", "", "path to server CA in PEM format, which will be combine with server cert, return full certificate chain to clients") + // GRPCAuth which auth plugin to use (at the moment now only static is supported) GRPCAuth = flag.String("grpc_auth_mode", "", "Which auth plugin implementation to use (eg: static)") @@ -125,7 +128,7 @@ func createGRPCServer() { var opts []grpc.ServerOption if GRPCPort != nil && *GRPCCert != "" && *GRPCKey != "" { - config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA) + config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCServerCA) if err != nil { log.Exitf("Failed to log gRPC cert/key/ca: %v", err) } diff --git a/go/vt/tlstest/tlstest_test.go b/go/vt/tlstest/tlstest_test.go index 7ad29cbf8f7..b7994fe731d 100644 --- a/go/vt/tlstest/tlstest_test.go +++ b/go/vt/tlstest/tlstest_test.go @@ -34,12 +34,20 @@ import ( "vitess.io/vitess/go/vt/vttls" ) -// TestClientServer generates: +func TestClientServerWithoutCombineCerts(t *testing.T) { + testClientServer(t, false) +} + +func TestClientServerWithCombineCerts(t *testing.T) { + testClientServer(t, true) +} + +// testClientServer generates: // - a root CA // - a server intermediate CA, with a server. // - a client intermediate CA, with a client. // And then performs a few tests on them. -func TestClientServer(t *testing.T) { +func testClientServer(t *testing.T, combineCerts bool) { // Our test root. root, err := ioutil.TempDir("", "tlstest") if err != nil { @@ -48,11 +56,17 @@ func TestClientServer(t *testing.T) { defer os.RemoveAll(root) clientServerKeyPairs := CreateClientServerCertPairs(root) + serverCA := "" + + if combineCerts { + serverCA = clientServerKeyPairs.ServerCA + } serverConfig, err := vttls.ServerConfig( clientServerKeyPairs.ServerCert, clientServerKeyPairs.ServerKey, - clientServerKeyPairs.ClientCA) + clientServerKeyPairs.ClientCA, + serverCA) if err != nil { t.Fatalf("TLSServerConfig failed: %v", err) } @@ -165,10 +179,19 @@ func TestClientServer(t *testing.T) { } } -func getServerConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) { +func getServerConfigWithoutCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) { return vttls.ServerConfig( - keypairs.ClientCert, - keypairs.ClientKey, + keypairs.ServerCert, + keypairs.ServerKey, + keypairs.ClientCA, + "") +} + +func getServerConfigWithCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) { + return vttls.ServerConfig( + keypairs.ServerCert, + keypairs.ServerKey, + keypairs.ClientCA, keypairs.ServerCA) } @@ -180,12 +203,20 @@ func getClientConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) { keypairs.ServerName) } -func TestServerTLSConfigCaching(t *testing.T) { +func testServerTLSConfigCaching(t *testing.T, getServerConfig func(ClientServerKeyPairs) (*tls.Config, error)) { testConfigGeneration(t, "servertlstest", getServerConfig, func(config *tls.Config) *x509.CertPool { return config.ClientCAs }) } +func TestServerTLSConfigCachingWithoutCombinedCerts(t *testing.T) { + testServerTLSConfigCaching(t, getServerConfigWithoutCombinedCerts) +} + +func TestServerTLSConfigCachingWithCombinedCerts(t *testing.T) { + testServerTLSConfigCaching(t, getServerConfigWithCombinedCerts) +} + func TestClientTLSConfigCaching(t *testing.T) { testConfigGeneration(t, "clienttlstest", getClientConfig, func(config *tls.Config) *x509.CertPool { return config.RootCAs @@ -238,3 +269,37 @@ func testConfigGeneration(t *testing.T, rootPrefix string, generateConfig func(C } } + +func testNumberOfCertsWithOrWithoutCombining(t *testing.T, numCertsExpected int, combine bool) { + // Our test root. + root, err := ioutil.TempDir("", "tlstest") + if err != nil { + t.Fatalf("TempDir failed: %v", err) + } + defer os.RemoveAll(root) + + clientServerKeyPairs := CreateClientServerCertPairs(root) + serverCA := "" + if combine { + serverCA = clientServerKeyPairs.ServerCA + } + + serverConfig, err := vttls.ServerConfig( + clientServerKeyPairs.ServerCert, + clientServerKeyPairs.ServerKey, + clientServerKeyPairs.ClientCA, + serverCA) + + if err != nil { + t.Fatalf("TLSServerConfig failed: %v", err) + } + assert.Equal(t, numCertsExpected, len(serverConfig.Certificates[0].Certificate)) +} + +func TestNumberOfCertsWithoutCombining(t *testing.T) { + testNumberOfCertsWithOrWithoutCombining(t, 1, false) +} + +func TestNumberOfCertsWithCombining(t *testing.T) { + testNumberOfCertsWithOrWithoutCombining(t, 2, true) +} diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index 212f3ba4261..12803f1251f 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -64,6 +64,8 @@ var ( mysqlSslKey = flag.String("mysql_server_ssl_key", "", "Path to ssl key for mysql server plugin SSL") mysqlSslCa = flag.String("mysql_server_ssl_ca", "", "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.") + mysqlSslServerCA = flag.String("mysql_server_ssl_server_ca", "", "path to server CA in PEM format, which will be combine with server cert, return full certificate chain to clients") + mysqlSlowConnectWarnThreshold = flag.Duration("mysql_slow_connect_warn_threshold", 0, "Warn if it takes more than the given threshold for a mysql connection to establish") mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout") @@ -360,8 +362,8 @@ var sigChan chan os.Signal var vtgateHandle *vtgateHandler // initTLSConfig inits tls config for the given mysql listener -func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa string, mysqlServerRequireSecureTransport bool) error { - serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa) +func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool) error { + serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA) if err != nil { log.Exitf("grpcutils.TLSServerConfig failed: %v", err) return err @@ -372,7 +374,7 @@ func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mys signal.Notify(sigChan, syscall.SIGHUP) go func() { for range sigChan { - serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa) + serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA) if err != nil { log.Errorf("grpcutils.TLSServerConfig failed: %v", err) } else { @@ -428,7 +430,7 @@ func initMySQLProtocol() { mysqlListener.ServerVersion = *servenv.MySQLServerVersion } if *mysqlSslCert != "" && *mysqlSslKey != "" { - initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlServerRequireSecureTransport) + initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlSslServerCA, *mysqlServerRequireSecureTransport) } mysqlListener.AllowClearTextWithoutTLS.Set(*mysqlAllowClearTextWithoutTLS) // Check for the connection threshold diff --git a/go/vt/vtgate/plugin_mysql_server_test.go b/go/vt/vtgate/plugin_mysql_server_test.go index 618f5aa269f..0fa7ff36810 100644 --- a/go/vt/vtgate/plugin_mysql_server_test.go +++ b/go/vt/vtgate/plugin_mysql_server_test.go @@ -248,7 +248,15 @@ func TestDefaultWorkloadOLAP(t *testing.T) { } } -func TestInitTLSConfig(t *testing.T) { +func TestInitTLSConfigWithoutServerCA(t *testing.T) { + testInitTLSConfig(t, false) +} + +func TestInitTLSConfigWithServerCA(t *testing.T) { + testInitTLSConfig(t, true) +} + +func testInitTLSConfig(t *testing.T, serverCA bool) { // Create the certs. root, err := ioutil.TempDir("", "TestInitTLSConfig") if err != nil { @@ -258,8 +266,13 @@ func TestInitTLSConfig(t *testing.T) { tlstest.CreateCA(root) tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") + serverCACert := "" + if serverCA { + serverCACert = path.Join(root, "ca-cert.pem") + } + listener := &mysql.Listener{} - if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), true); err != nil { + if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), serverCACert, true); err != nil { t.Fatalf("init tls config failure due to: +%v", err) } diff --git a/go/vt/vttls/vttls.go b/go/vt/vttls/vttls.go index 65c8724d95b..d7f212252e0 100644 --- a/go/vt/vttls/vttls.go +++ b/go/vt/vttls/vttls.go @@ -94,15 +94,21 @@ func ClientConfig(cert, key, ca, name string) (*tls.Config, error) { // ServerConfig returns the TLS config to use for a server to // accept client connections. -func ServerConfig(cert, key, ca string) (*tls.Config, error) { +func ServerConfig(cert, key, ca, serverCA string) (*tls.Config, error) { config := newTLSConfig() - certificates, err := loadTLSCertificate(cert, key) + var certificates *[]tls.Certificate + var err error + + if serverCA != "" { + certificates, err = combineAndLoadTLSCertificates(serverCA, cert, key) + } else { + certificates, err = loadTLSCertificate(cert, key) + } if err != nil { return nil, err } - config.Certificates = *certificates // if specified, load ca to validate client, @@ -161,8 +167,8 @@ func doLoadx509CertPool(ca string) error { var tlsCertificates = sync.Map{} -func tlsCertificatesIdentifier(cert, key string) string { - return strings.Join([]string{cert, key}, ";") +func tlsCertificatesIdentifier(tokens ...string) string { + return strings.Join(tokens, ";") } func loadTLSCertificate(cert, key string) (*[]tls.Certificate, error) { @@ -203,3 +209,62 @@ func doLoadTLSCertificate(cert, key string) error { return nil } + +var combinedTlsCertificates = sync.Map{} + +func combineAndLoadTLSCertificates(ca, cert, key string) (*[]tls.Certificate, error) { + combinedTlsIdentifier := tlsCertificatesIdentifier(ca, cert, key) + once, _ := onceByKeys.LoadOrStore(combinedTlsIdentifier, &sync.Once{}) + + var err error + once.(*sync.Once).Do(func() { + err = doLoadAndCombineTLSCertificates(ca, cert, key) + }) + + if err != nil { + return nil, err + } + + result, ok := combinedTlsCertificates.Load(combinedTlsIdentifier) + + if !ok { + return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded tls certificate chain with ca: %s, cert: %s, key: %s", ca, cert, key) + } + + return result.(*[]tls.Certificate), nil +} + +func doLoadAndCombineTLSCertificates(ca, cert, key string) error { + combinedTlsIdentifier := tlsCertificatesIdentifier(ca, cert, key) + + // Read CA certificates chain + ca_b, err := ioutil.ReadFile(ca) + if err != nil { + return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read ca file: %s", ca) + } + + // Read server certificate + cert_b, err := ioutil.ReadFile(cert) + if err != nil { + return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read server cert file: %s", cert) + } + + // Read server key file + key_b, err := ioutil.ReadFile(key) + if err != nil { + return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read key file: %s", key) + } + + // Load CA, server cert and key. + var certificate []tls.Certificate + crt, err := tls.X509KeyPair(append(cert_b, ca_b...), key_b) + if err != nil { + return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to load and merge tls certificate with CA, ca %s, cert %s, key: %s", ca, cert, key) + } + + certificate = []tls.Certificate{crt} + + combinedTlsCertificates.Store(combinedTlsIdentifier, &certificate) + + return nil +}