diff --git a/go/mysql/server.go b/go/mysql/server.go index 37d5e2b3e5b..29ebeec6a14 100644 --- a/go/mysql/server.go +++ b/go/mysql/server.go @@ -17,7 +17,7 @@ limitations under the License. package mysql import ( - "crypto/tls" + tls "crypto/tls" "fmt" "io" "net" @@ -37,8 +37,14 @@ const ( DefaultServerVersion = "5.5.10-Vitess" // timing metric keys - connectTimingKey = "Connect" - queryTimingKey = "Query" + connectTimingKey = "Connect" + queryTimingKey = "Query" + versionSSL30 = "SSL30" + versionTLS10 = "TLS10" + versionTLS11 = "TLS11" + versionTLS12 = "TLS12" + versionTLSUnknown = "UnknownTLSVersion" + versionNoTLS = "None" ) var ( @@ -48,8 +54,9 @@ var ( connAccept = stats.NewCounter("MysqlServerConnAccepted", "Connections accepted by MySQL server") connSlow = stats.NewCounter("MysqlServerConnSlow", "Connections that took more than the configured mysql_slow_connect_warn_threshold to establish") - connCountPerUser = stats.NewGaugesWithSingleLabel("MysqlServerConnCountPerUser", "Active MySQL server connections per user", "count") - _ = stats.NewGaugeFunc("MysqlServerConnCountUnauthenticated", "Active MySQL server connections that haven't authenticated yet", func() int64 { + connCountByTLSVer = stats.NewGaugesWithSingleLabel("MysqlServerConnCountByTLSVer", "Active MySQL server connections by TLS version", "tls") + connCountPerUser = stats.NewGaugesWithSingleLabel("MysqlServerConnCountPerUser", "Active MySQL server connections per user", "count") + _ = stats.NewGaugeFunc("MysqlServerConnCountUnauthenticated", "Active MySQL server connections that haven't authenticated yet", func() int64 { totalUsers := int64(0) for _, v := range connCountPerUser.Counts() { totalUsers += v @@ -300,6 +307,18 @@ func (l *Listener) handle(conn net.Conn, connectionID uint32, acceptTime time.Ti return } c.recycleReadPacket() + + if con, ok := c.conn.(*tls.Conn); ok { + connState := con.ConnectionState() + tlsVerStr := tlsVersionToString(connState.Version) + if tlsVerStr != "" { + connCountByTLSVer.Add(tlsVerStr, 1) + defer connCountByTLSVer.Add(tlsVerStr, -1) + } + } + } else { + connCountByTLSVer.Add(versionNoTLS, 1) + defer connCountByTLSVer.Add(versionNoTLS, -1) } // See what auth method the AuthServer wants to use for that user. @@ -649,3 +668,19 @@ func (c *Conn) writeAuthSwitchRequest(pluginName string, pluginData []byte) erro } return c.writeEphemeralPacket() } + +// Whenever we move to a new version of go, we will need add any new supported TLS versions here +func tlsVersionToString(version uint16) string { + switch version { + case tls.VersionSSL30: + return versionSSL30 + case tls.VersionTLS10: + return versionTLS10 + case tls.VersionTLS11: + return versionTLS11 + case tls.VersionTLS12: + return versionTLS12 + default: + return versionTLSUnknown + } +} diff --git a/go/mysql/server_test.go b/go/mysql/server_test.go index 7023df61bb0..9a27f6f1a5f 100644 --- a/go/mysql/server_test.go +++ b/go/mysql/server_test.go @@ -930,6 +930,23 @@ func TestTLSServer(t *testing.T) { if results.Rows[0][0].ToString() != "ON" { t.Errorf("Unexpected output for 'ssl echo': %v", results) } + + checkCountForTLSVer(t, versionTLS12, 1) + checkCountForTLSVer(t, versionNoTLS, 0) + conn.Close() + +} + +func checkCountForTLSVer(t *testing.T, version string, expected int64) { + connCounts := connCountByTLSVer.Counts() + count, ok := connCounts[version] + if ok { + if count != expected { + t.Errorf("Expected connection count for version %s to be %d, got %d", version, expected, count) + } + } else { + t.Errorf("No count found for version %s", version) + } } func TestErrorCodes(t *testing.T) {