diff --git a/go/mysql/server.go b/go/mysql/server.go index 4a9726fed23..c21535be12f 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -149,6 +149,9 @@ type Listener struct { // shutdown indicates that Shutdown method was called. shutdown sync2.AtomicBool + + // RequireSecureTransport configures the server to reject connections from insecure clients + RequireSecureTransport bool } // NewFromListener creares a new mysql listener from an existing net.Listener @@ -317,6 +320,9 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti } } } else { + if l.RequireSecureTransport { + c.writeErrorPacketFromError(fmt.Errorf("Server does not allow insecure connections, client must use SSL/TLS")) + } connCountByTLSVer.Add(versionNoTLS, 1) defer connCountByTLSVer.Add(versionNoTLS, -1) } diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 9a27f6f1a5f..5fa1efe4f67 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -937,6 +937,86 @@ func TestTLSServer(t *testing.T) { } +// TestTLSRequired creates a Server with TLS required, then tests that an insecure mysql +// client is rejected +func TestTLSRequired(t *testing.T) { + th := &testHandler{} + + authServer := NewAuthServerStatic() + authServer.Entries["user1"] = []*AuthServerStaticEntry{{ + Password: "password1", + }} + + // Create the listener, so we can get its host. + // Below, we are enabling --ssl-verify-server-cert, which adds + // a check that the common name of the certificate matches the + // server host name we connect to. + l, err := NewListener("tcp", ":0", authServer, th, 0, 0) + if err != nil { + t.Fatalf("NewListener failed: %v", err) + } + defer l.Close() + + // Make sure hostname is added as an entry to /etc/hosts, otherwise ssl handshake will fail + host, err := os.Hostname() + if err != nil { + t.Fatalf("Failed to get os Hostname: %v", err) + } + + port := l.Addr().(*net.TCPAddr).Port + + // Create the certs. + root, err := ioutil.TempDir("", "TestTLSRequired") + if err != nil { + t.Fatalf("TempDir failed: %v", err) + } + defer os.RemoveAll(root) + tlstest.CreateCA(root) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", host) + + // Create the server with TLS config. + serverConfig, err := vttls.ServerConfig( + path.Join(root, "server-cert.pem"), + path.Join(root, "server-key.pem"), + path.Join(root, "ca-cert.pem")) + if err != nil { + t.Fatalf("TLSServerConfig failed: %v", err) + } + l.TLSConfig = serverConfig + l.RequireSecureTransport = true + go l.Accept() + + // Setup conn params without SSL. + params := &ConnParams{ + Host: host, + Port: port, + Uname: "user1", + Pass: "password1", + } + conn, err := Connect(context.Background(), params) + if err == nil { + t.Fatal("mysql should have failed") + } + if conn != nil { + conn.Close() + } + + // setup conn params with TLS + tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") + params.Flags = CapabilityClientSSL + params.SslCa = path.Join(root, "ca-cert.pem") + params.SslCert = path.Join(root, "client-cert.pem") + params.SslKey = path.Join(root, "client-key.pem") + + conn, err = Connect(context.Background(), params) + if err != nil { + t.Fatalf("mysql failed: %v", err) + } + if conn != nil { + conn.Close() + } +} + func checkCountForTLSVer(t *testing.T, version string, expected int64) { connCounts := connCountByTLSVer.Counts() count, ok := connCounts[version] diff --git a/go/vt/vtgate/plugin_mysql_server.go b/go/vt/vtgate/plugin_mysql_server.go index e67e1026e3b..e1c0e3408f0 100644 --- a/go/vt/vtgate/plugin_mysql_server.go +++ b/go/vt/vtgate/plugin_mysql_server.go @@ -48,6 +48,8 @@ var ( mysqlAllowClearTextWithoutTLS = flag.Bool("mysql_allow_clear_text_without_tls", false, "If set, the server will allow the use of a clear text password over non-SSL connections.") mysqlServerVersion = flag.String("mysql_server_version", mysql.DefaultServerVersion, "MySQL server version to advertise.") + mysqlServerRequireSecureTransport = flag.Bool("mysql_server_require_secure_transport", false, "Reject insecure connections but only if mysql_server_ssl_cert and mysql_server_ssl_key are provided") + mysqlSslCert = flag.String("mysql_server_ssl_cert", "", "Path to the ssl cert for mysql server plugin SSL") mysqlSslKey = flag.String("mysql_server_ssl_key", "", "Path to ssl key for mysql server plugin SSL") mysqlSslCa = flag.String("mysql_server_ssl_ca", "", "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.") @@ -212,9 +214,9 @@ func initMySQLProtocol() { log.Exitf("grpcutils.TLSServerConfig failed: %v", err) return } + mysqlListener.RequireSecureTransport = *mysqlServerRequireSecureTransport } mysqlListener.AllowClearTextWithoutTLS = *mysqlAllowClearTextWithoutTLS - // Check for the connection threshold if *mysqlSlowConnectWarnThreshold != 0 { log.Infof("setting mysql slow connection threshold to %v", mysqlSlowConnectWarnThreshold)