Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions go/mysql/auth_server_clientcert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ func TestValidCert(t *testing.T) {
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"))
path.Join(root, "ca-cert.pem"),
"")
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
Expand Down Expand Up @@ -143,7 +144,8 @@ func TestNoCert(t *testing.T) {
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"))
path.Join(root, "ca-cert.pem"),
"")
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
Expand Down
3 changes: 2 additions & 1 deletion go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,8 @@ func TestSSLConnection(t *testing.T) {
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"))
path.Join(root, "ca-cert.pem"),
"")
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
Expand Down
6 changes: 4 additions & 2 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -839,7 +839,8 @@ func TestTLSServer(t *testing.T) {
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"))
path.Join(root, "ca-cert.pem"),
"")
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
go l.Accept()
Expand Down Expand Up @@ -927,7 +928,8 @@ func TestTLSRequired(t *testing.T) {
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"))
path.Join(root, "ca-cert.pem"),
"")
require.NoError(t, err)
l.TLSConfig.Store(serverConfig)
l.RequireSecureTransport = true
Expand Down
5 changes: 4 additions & 1 deletion go/vt/servenv/grpc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ var (
// GRPCCA is the CA to use if TLS is enabled
GRPCCA = flag.String("grpc_ca", "", "server CA to use for gRPC connections, requires TLS, and enforces client certificate check")

// GRPCServerCA if specified will combine server cert and server CA
GRPCServerCA = flag.String("grpc_server_ca", "", "path to server CA in PEM format, which will be combine with server cert, return full certificate chain to clients")

// GRPCAuth which auth plugin to use (at the moment now only static is supported)
GRPCAuth = flag.String("grpc_auth_mode", "", "Which auth plugin implementation to use (eg: static)")

Expand Down Expand Up @@ -125,7 +128,7 @@ func createGRPCServer() {

var opts []grpc.ServerOption
if GRPCPort != nil && *GRPCCert != "" && *GRPCKey != "" {
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA)
config, err := vttls.ServerConfig(*GRPCCert, *GRPCKey, *GRPCCA, *GRPCServerCA)
if err != nil {
log.Exitf("Failed to log gRPC cert/key/ca: %v", err)
}
Expand Down
79 changes: 72 additions & 7 deletions go/vt/tlstest/tlstest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,20 @@ import (
"vitess.io/vitess/go/vt/vttls"
)

// TestClientServer generates:
func TestClientServerWithoutCombineCerts(t *testing.T) {
testClientServer(t, false)
}

func TestClientServerWithCombineCerts(t *testing.T) {
testClientServer(t, true)
}

// testClientServer generates:
// - a root CA
// - a server intermediate CA, with a server.
// - a client intermediate CA, with a client.
// And then performs a few tests on them.
func TestClientServer(t *testing.T) {
func testClientServer(t *testing.T, combineCerts bool) {
// Our test root.
root, err := ioutil.TempDir("", "tlstest")
if err != nil {
Expand All @@ -48,11 +56,17 @@ func TestClientServer(t *testing.T) {
defer os.RemoveAll(root)

clientServerKeyPairs := CreateClientServerCertPairs(root)
serverCA := ""

if combineCerts {
serverCA = clientServerKeyPairs.ServerCA
}

serverConfig, err := vttls.ServerConfig(
clientServerKeyPairs.ServerCert,
clientServerKeyPairs.ServerKey,
clientServerKeyPairs.ClientCA)
clientServerKeyPairs.ClientCA,
serverCA)
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
Expand Down Expand Up @@ -165,10 +179,19 @@ func TestClientServer(t *testing.T) {
}
}

func getServerConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) {
func getServerConfigWithoutCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) {
return vttls.ServerConfig(
keypairs.ClientCert,
keypairs.ClientKey,
keypairs.ServerCert,
keypairs.ServerKey,
keypairs.ClientCA,
"")
}

func getServerConfigWithCombinedCerts(keypairs ClientServerKeyPairs) (*tls.Config, error) {
return vttls.ServerConfig(
keypairs.ServerCert,
keypairs.ServerKey,
keypairs.ClientCA,
keypairs.ServerCA)
}

Expand All @@ -180,12 +203,20 @@ func getClientConfig(keypairs ClientServerKeyPairs) (*tls.Config, error) {
keypairs.ServerName)
}

func TestServerTLSConfigCaching(t *testing.T) {
func testServerTLSConfigCaching(t *testing.T, getServerConfig func(ClientServerKeyPairs) (*tls.Config, error)) {
testConfigGeneration(t, "servertlstest", getServerConfig, func(config *tls.Config) *x509.CertPool {
return config.ClientCAs
})
}

func TestServerTLSConfigCachingWithoutCombinedCerts(t *testing.T) {
testServerTLSConfigCaching(t, getServerConfigWithoutCombinedCerts)
}

func TestServerTLSConfigCachingWithCombinedCerts(t *testing.T) {
testServerTLSConfigCaching(t, getServerConfigWithCombinedCerts)
}

func TestClientTLSConfigCaching(t *testing.T) {
testConfigGeneration(t, "clienttlstest", getClientConfig, func(config *tls.Config) *x509.CertPool {
return config.RootCAs
Expand Down Expand Up @@ -238,3 +269,37 @@ func testConfigGeneration(t *testing.T, rootPrefix string, generateConfig func(C
}

}

func testNumberOfCertsWithOrWithoutCombining(t *testing.T, numCertsExpected int, combine bool) {
// Our test root.
root, err := ioutil.TempDir("", "tlstest")
if err != nil {
t.Fatalf("TempDir failed: %v", err)
}
defer os.RemoveAll(root)

clientServerKeyPairs := CreateClientServerCertPairs(root)
serverCA := ""
if combine {
serverCA = clientServerKeyPairs.ServerCA
}

serverConfig, err := vttls.ServerConfig(
clientServerKeyPairs.ServerCert,
clientServerKeyPairs.ServerKey,
clientServerKeyPairs.ClientCA,
serverCA)

if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
assert.Equal(t, numCertsExpected, len(serverConfig.Certificates[0].Certificate))
}

func TestNumberOfCertsWithoutCombining(t *testing.T) {
testNumberOfCertsWithOrWithoutCombining(t, 1, false)
}

func TestNumberOfCertsWithCombining(t *testing.T) {
testNumberOfCertsWithOrWithoutCombining(t, 2, true)
}
10 changes: 6 additions & 4 deletions go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ var (
mysqlSslKey = flag.String("mysql_server_ssl_key", "", "Path to ssl key for mysql server plugin SSL")
mysqlSslCa = flag.String("mysql_server_ssl_ca", "", "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.")

mysqlSslServerCA = flag.String("mysql_server_ssl_server_ca", "", "path to server CA in PEM format, which will be combine with server cert, return full certificate chain to clients")

mysqlSlowConnectWarnThreshold = flag.Duration("mysql_slow_connect_warn_threshold", 0, "Warn if it takes more than the given threshold for a mysql connection to establish")

mysqlConnReadTimeout = flag.Duration("mysql_server_read_timeout", 0, "connection read timeout")
Expand Down Expand Up @@ -360,8 +362,8 @@ var sigChan chan os.Signal
var vtgateHandle *vtgateHandler

// initTLSConfig inits tls config for the given mysql listener
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa string, mysqlServerRequireSecureTransport bool) error {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa)
func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA string, mysqlServerRequireSecureTransport bool) error {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA)
if err != nil {
log.Exitf("grpcutils.TLSServerConfig failed: %v", err)
return err
Expand All @@ -372,7 +374,7 @@ func initTLSConfig(mysqlListener *mysql.Listener, mysqlSslCert, mysqlSslKey, mys
signal.Notify(sigChan, syscall.SIGHUP)
go func() {
for range sigChan {
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa)
serverConfig, err := vttls.ServerConfig(mysqlSslCert, mysqlSslKey, mysqlSslCa, mysqlSslServerCA)
if err != nil {
log.Errorf("grpcutils.TLSServerConfig failed: %v", err)
} else {
Expand Down Expand Up @@ -428,7 +430,7 @@ func initMySQLProtocol() {
mysqlListener.ServerVersion = *servenv.MySQLServerVersion
}
if *mysqlSslCert != "" && *mysqlSslKey != "" {
initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlServerRequireSecureTransport)
initTLSConfig(mysqlListener, *mysqlSslCert, *mysqlSslKey, *mysqlSslCa, *mysqlSslServerCA, *mysqlServerRequireSecureTransport)
}
mysqlListener.AllowClearTextWithoutTLS.Set(*mysqlAllowClearTextWithoutTLS)
// Check for the connection threshold
Expand Down
17 changes: 15 additions & 2 deletions go/vt/vtgate/plugin_mysql_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,15 @@ func TestDefaultWorkloadOLAP(t *testing.T) {
}
}

func TestInitTLSConfig(t *testing.T) {
func TestInitTLSConfigWithoutServerCA(t *testing.T) {
testInitTLSConfig(t, false)
}

func TestInitTLSConfigWithServerCA(t *testing.T) {
testInitTLSConfig(t, true)
}

func testInitTLSConfig(t *testing.T, serverCA bool) {
// Create the certs.
root, err := ioutil.TempDir("", "TestInitTLSConfig")
if err != nil {
Expand All @@ -258,8 +266,13 @@ func TestInitTLSConfig(t *testing.T) {
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")

serverCACert := ""
if serverCA {
serverCACert = path.Join(root, "ca-cert.pem")
}

listener := &mysql.Listener{}
if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), true); err != nil {
if err := initTLSConfig(listener, path.Join(root, "server-cert.pem"), path.Join(root, "server-key.pem"), path.Join(root, "ca-cert.pem"), serverCACert, true); err != nil {
t.Fatalf("init tls config failure due to: +%v", err)
}

Expand Down
75 changes: 70 additions & 5 deletions go/vt/vttls/vttls.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,21 @@ func ClientConfig(cert, key, ca, name string) (*tls.Config, error) {

// ServerConfig returns the TLS config to use for a server to
// accept client connections.
func ServerConfig(cert, key, ca string) (*tls.Config, error) {
func ServerConfig(cert, key, ca, serverCA string) (*tls.Config, error) {
config := newTLSConfig()

certificates, err := loadTLSCertificate(cert, key)
var certificates *[]tls.Certificate
var err error

if serverCA != "" {
certificates, err = combineAndLoadTLSCertificates(serverCA, cert, key)
} else {
certificates, err = loadTLSCertificate(cert, key)
}

if err != nil {
return nil, err
}

config.Certificates = *certificates

// if specified, load ca to validate client,
Expand Down Expand Up @@ -161,8 +167,8 @@ func doLoadx509CertPool(ca string) error {

var tlsCertificates = sync.Map{}

func tlsCertificatesIdentifier(cert, key string) string {
return strings.Join([]string{cert, key}, ";")
func tlsCertificatesIdentifier(tokens ...string) string {
return strings.Join(tokens, ";")
}

func loadTLSCertificate(cert, key string) (*[]tls.Certificate, error) {
Expand Down Expand Up @@ -203,3 +209,62 @@ func doLoadTLSCertificate(cert, key string) error {

return nil
}

var combinedTlsCertificates = sync.Map{}

func combineAndLoadTLSCertificates(ca, cert, key string) (*[]tls.Certificate, error) {
combinedTlsIdentifier := tlsCertificatesIdentifier(ca, cert, key)
once, _ := onceByKeys.LoadOrStore(combinedTlsIdentifier, &sync.Once{})

var err error
once.(*sync.Once).Do(func() {
err = doLoadAndCombineTLSCertificates(ca, cert, key)
})

if err != nil {
return nil, err
}

result, ok := combinedTlsCertificates.Load(combinedTlsIdentifier)

if !ok {
return nil, vterrors.Errorf(vtrpc.Code_NOT_FOUND, "Cannot find loaded tls certificate chain with ca: %s, cert: %s, key: %s", ca, cert, key)
}

return result.(*[]tls.Certificate), nil
}

func doLoadAndCombineTLSCertificates(ca, cert, key string) error {
combinedTlsIdentifier := tlsCertificatesIdentifier(ca, cert, key)

// Read CA certificates chain
ca_b, err := ioutil.ReadFile(ca)
if err != nil {
return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read ca file: %s", ca)
}

// Read server certificate
cert_b, err := ioutil.ReadFile(cert)
if err != nil {
return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read server cert file: %s", cert)
}

// Read server key file
key_b, err := ioutil.ReadFile(key)
if err != nil {
return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to read key file: %s", key)
}

// Load CA, server cert and key.
var certificate []tls.Certificate
crt, err := tls.X509KeyPair(append(cert_b, ca_b...), key_b)
if err != nil {
return vterrors.Errorf(vtrpc.Code_NOT_FOUND, "failed to load and merge tls certificate with CA, ca %s, cert %s, key: %s", ca, cert, key)
}

certificate = []tls.Certificate{crt}

combinedTlsCertificates.Store(combinedTlsIdentifier, &certificate)

return nil
}