Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions go/mysql/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,14 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error {
}

// The ServerName to verify depends on what the hostname is.
// We use the params's ServerName if specified. Otherwise:
// - If using a socket, we use "localhost".
// - If it is an IP address, we need to prefix it with 'IP:'.
// - If not, we can just use it as is.
// We may need to add a ServerName field to ConnParams to
// make this more explicit.
serverName := "localhost"
if params.Host != "" {
if params.ServerName != "" {
serverName = params.ServerName
} else if params.Host != "" {
if net.ParseIP(params.Host) != nil {
serverName = "IP:" + params.Host
} else {
Expand Down
9 changes: 5 additions & 4 deletions go/mysql/conn_params.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ type ConnParams struct {

// The following SSL flags are only used when flags |= 2048
// is set (CapabilityClientSSL).
SslCa string `json:"ssl_ca"`
SslCaPath string `json:"ssl_ca_path"`
SslCert string `json:"ssl_cert"`
SslKey string `json:"ssl_key"`
SslCa string `json:"ssl_ca"`
SslCaPath string `json:"ssl_ca_path"`
SslCert string `json:"ssl_cert"`
SslKey string `json:"ssl_key"`
ServerName string `json:"server_name"`
}

// EnableSSL will set the right flag on the parameters.
Expand Down
11 changes: 6 additions & 5 deletions go/mysql/handshake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func TestSSLConnection(t *testing.T) {
}
defer os.RemoveAll(root)
tlstest.CreateCA(root)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "IP:"+host)
tlstest.CreateSignedCert(root, tlstest.CA, "01", "server", "server.example.com")
tlstest.CreateSignedCert(root, tlstest.CA, "02", "client", "Client Cert")

// Create the server with TLS config.
Expand All @@ -139,10 +139,11 @@ func TestSSLConnection(t *testing.T) {
Uname: "user1",
Pass: "password1",
// SSL flags.
Flags: CapabilityClientSSL,
SslCa: path.Join(root, "ca-cert.pem"),
SslCert: path.Join(root, "client-cert.pem"),
SslKey: path.Join(root, "client-key.pem"),
Flags: CapabilityClientSSL,
SslCa: path.Join(root, "ca-cert.pem"),
SslCert: path.Join(root, "client-cert.pem"),
SslKey: path.Join(root, "client-key.pem"),
ServerName: "server.example.com",
}

t.Run("Basics", func(t *testing.T) {
Expand Down
6 changes: 3 additions & 3 deletions go/vt/tlstest/tlstest_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestClientServer(t *testing.T) {
CreateCA(root)

CreateSignedCert(root, CA, "01", "servers", "Servers CA")
CreateSignedCert(root, "servers", "01", "server-instance", "Server Instance")
CreateSignedCert(root, "servers", "01", "server-instance", "server.example.com")

CreateSignedCert(root, CA, "02", "clients", "Clients CA")
CreateSignedCert(root, "clients", "01", "client-instance", "Client Instance")
Expand All @@ -62,7 +62,7 @@ func TestClientServer(t *testing.T) {
path.Join(root, "client-instance-cert.pem"),
path.Join(root, "client-instance-key.pem"),
path.Join(root, "servers-cert.pem"),
"Server Instance")
"server.example.com")
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
Expand Down Expand Up @@ -118,7 +118,7 @@ func TestClientServer(t *testing.T) {
path.Join(root, "server-instance-cert.pem"),
path.Join(root, "server-instance-key.pem"),
path.Join(root, "servers-cert.pem"),
"Server Instance")
"server.example.com")
if err != nil {
t.Fatalf("TLSClientConfig failed: %v", err)
}
Expand Down