Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions go/mysql/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,9 @@ type Listener struct {

// shutdown indicates that Shutdown method was called.
shutdown sync2.AtomicBool

// RequireSecureTransport configures the server to reject connections from insecure clients
RequireSecureTransport bool
}

// NewFromListener creares a new mysql listener from an existing net.Listener
Expand Down Expand Up @@ -317,6 +320,9 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti
}
}
} else {
if l.RequireSecureTransport {
c.writeErrorPacketFromError(fmt.Errorf("Server does not allow insecure connections, client must use SSL/TLS"))
}
connCountByTLSVer.Add(versionNoTLS, 1)
defer connCountByTLSVer.Add(versionNoTLS, -1)
}
Expand Down
80 changes: 80 additions & 0 deletions go/mysql/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,86 @@ func TestTLSServer(t *testing.T) {

}

// TestTLSRequired creates a Server with TLS required, then tests that an insecure mysql
// client is rejected
func TestTLSRequired(t *testing.T) {
th := &testHandler{}

authServer := NewAuthServerStatic()
authServer.Entries["user1"] = []*AuthServerStaticEntry{{
Password: "password1",
}}

// Create the listener, so we can get its host.
// Below, we are enabling --ssl-verify-server-cert, which adds
// a check that the common name of the certificate matches the
// server host name we connect to.
l, err := NewListener("tcp", ":0", authServer, th, 0, 0)
if err != nil {
t.Fatalf("NewListener failed: %v", err)
}
defer l.Close()

// Make sure hostname is added as an entry to /etc/hosts, otherwise ssl handshake will fail
host, err := os.Hostname()
if err != nil {
t.Fatalf("Failed to get os Hostname: %v", err)
}

port := l.Addr().(*net.TCPAddr).Port

// Create the certs.
root, err := ioutil.TempDir("", "TestTLSRequired")
if err != nil {
t.Fatalf("TempDir failed: %v", err)
}
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", host)

// Create the server with TLS config.
serverConfig, err := vttls.ServerConfig(
path.Join(root, "server-cert.pem"),
path.Join(root, "server-key.pem"),
path.Join(root, "ca-cert.pem"))
if err != nil {
t.Fatalf("TLSServerConfig failed: %v", err)
}
l.TLSConfig = serverConfig
l.RequireSecureTransport = true
go l.Accept()

// Setup conn params without SSL.
params := &ConnParams{
Host: host,
Port: port,
Uname: "user1",
Pass: "password1",
}
conn, err := Connect(context.Background(), params)
if err == nil {
t.Fatal("mysql should have failed")
}
if conn != nil {
conn.Close()
}

// setup conn params with TLS
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")
params.Flags = CapabilityClientSSL
params.SslCa = path.Join(root, "ca-cert.pem")
params.SslCert = path.Join(root, "client-cert.pem")
params.SslKey = path.Join(root, "client-key.pem")

conn, err = Connect(context.Background(), params)
if err != nil {
t.Fatalf("mysql failed: %v", err)
}
if conn != nil {
conn.Close()
}
}

func checkCountForTLSVer(t *testing.T, version string, expected int64) {
connCounts := connCountByTLSVer.Counts()
count, ok := connCounts[version]
Expand Down
4 changes: 3 additions & 1 deletion go/vt/vtgate/plugin_mysql_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ var (
mysqlAllowClearTextWithoutTLS = flag.Bool("mysql_allow_clear_text_without_tls", false, "If set, the server will allow the use of a clear text password over non-SSL connections.")
mysqlServerVersion = flag.String("mysql_server_version", mysql.DefaultServerVersion, "MySQL server version to advertise.")

mysqlServerRequireSecureTransport = flag.Bool("mysql_server_require_secure_transport", false, "Reject insecure connections but only if mysql_server_ssl_cert and mysql_server_ssl_key are provided")

mysqlSslCert = flag.String("mysql_server_ssl_cert", "", "Path to the ssl cert for mysql server plugin SSL")
mysqlSslKey = flag.String("mysql_server_ssl_key", "", "Path to ssl key for mysql server plugin SSL")
mysqlSslCa = flag.String("mysql_server_ssl_ca", "", "Path to ssl CA for mysql server plugin SSL. If specified, server will require and validate client certs.")
Expand Down Expand Up @@ -212,9 +214,9 @@ func initMySQLProtocol() {
log.Exitf("grpcutils.TLSServerConfig failed: %v", err)
return
}
mysqlListener.RequireSecureTransport = *mysqlServerRequireSecureTransport
}
mysqlListener.AllowClearTextWithoutTLS = *mysqlAllowClearTextWithoutTLS

// Check for the connection threshold
if *mysqlSlowConnectWarnThreshold != 0 {
log.Infof("setting mysql slow connection threshold to %v", mysqlSlowConnectWarnThreshold)
Expand Down