diff --git a/go/mysql/client.go b/go/mysql/client.go index 108918be287..119640a8237 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -221,13 +221,14 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { } // The ServerName to verify depends on what the hostname is. + // We use the params's ServerName if specified. Otherwise: // - If using a socket, we use "localhost". // - If it is an IP address, we need to prefix it with 'IP:'. // - If not, we can just use it as is. - // We may need to add a ServerName field to ConnParams to - // make this more explicit. serverName := "localhost" - if params.Host != "" { + if params.ServerName != "" { + serverName = params.ServerName + } else if params.Host != "" { if net.ParseIP(params.Host) != nil { serverName = "IP:" + params.Host } else { diff --git a/go/mysql/conn_params.go b/go/mysql/conn_params.go index 1b6688ce022..053e662b165 100644 --- a/go/mysql/conn_params.go +++ b/go/mysql/conn_params.go @@ -29,10 +29,11 @@ type ConnParams struct { // The following SSL flags are only used when flags |= 2048 // is set (CapabilityClientSSL). - SslCa string `json:"ssl_ca"` - SslCaPath string `json:"ssl_ca_path"` - SslCert string `json:"ssl_cert"` - SslKey string `json:"ssl_key"` + SslCa string `json:"ssl_ca"` + SslCaPath string `json:"ssl_ca_path"` + SslCert string `json:"ssl_cert"` + SslKey string `json:"ssl_key"` + ServerName string `json:"server_name"` } // EnableSSL will set the right flag on the parameters. diff --git a/go/mysql/handshake_test.go b/go/mysql/handshake_test.go index e3626961705..a6f2e4287b6 100644 --- a/go/mysql/handshake_test.go +++ b/go/mysql/handshake_test.go @@ -116,7 +116,7 @@ func TestSSLConnection(t *testing.T) { } defer os.RemoveAll(root) tlstest.CreateCA(root) - tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "IP:"+host) + tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com") tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert") // Create the server with TLS config. @@ -139,10 +139,11 @@ func TestSSLConnection(t *testing.T) { Uname: "user1", Pass: "password1", // SSL flags. - Flags: CapabilityClientSSL, - SslCa: path.Join(root, "ca-cert.pem"), - SslCert: path.Join(root, "client-cert.pem"), - SslKey: path.Join(root, "client-key.pem"), + Flags: CapabilityClientSSL, + SslCa: path.Join(root, "ca-cert.pem"), + SslCert: path.Join(root, "client-cert.pem"), + SslKey: path.Join(root, "client-key.pem"), + ServerName: "server.example.com", } t.Run("Basics", func(t *testing.T) { diff --git a/go/vt/tlstest/tlstest_test.go b/go/vt/tlstest/tlstest_test.go index 1f985fc7c8b..686a2dcbd7d 100644 --- a/go/vt/tlstest/tlstest_test.go +++ b/go/vt/tlstest/tlstest_test.go @@ -47,7 +47,7 @@ func TestClientServer(t *testing.T) { CreateCA(root) CreateSignedCert(root, CA, "01", "servers", "Servers CA") - CreateSignedCert(root, "servers", "01", "server-instance", "Server Instance") + CreateSignedCert(root, "servers", "01", "server-instance", "server.example.com") CreateSignedCert(root, CA, "02", "clients", "Clients CA") CreateSignedCert(root, "clients", "01", "client-instance", "Client Instance") @@ -62,7 +62,7 @@ func TestClientServer(t *testing.T) { path.Join(root, "client-instance-cert.pem"), path.Join(root, "client-instance-key.pem"), path.Join(root, "servers-cert.pem"), - "Server Instance") + "server.example.com") if err != nil { t.Fatalf("TLSClientConfig failed: %v", err) } @@ -118,7 +118,7 @@ func TestClientServer(t *testing.T) { path.Join(root, "server-instance-cert.pem"), path.Join(root, "server-instance-key.pem"), path.Join(root, "servers-cert.pem"), - "Server Instance") + "server.example.com") if err != nil { t.Fatalf("TLSClientConfig failed: %v", err) }