diff --git a/go/mysql/client.go b/go/mysql/client.go index 1c30044197d..3494a2a637c 100644 --- a/go/mysql/client.go +++ b/go/mysql/client.go @@ -303,17 +303,23 @@ func (c *Conn) clientHandshake(characterSet uint8, params *ConnParams) error { case AuthSwitchRequestPacket: // Server is asking to use a different auth method. We // only support cleartext plugin. - pluginName, _, err := parseAuthSwitchRequest(response) + pluginName, salt, err := parseAuthSwitchRequest(response) if err != nil { return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "cannot parse auth switch request: %v", err) } - if pluginName != MysqlClearPassword { - return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", pluginName) - } - // Write the password packet. - if err := c.writeClearTextPassword(params); err != nil { - return err + if pluginName == MysqlClearPassword { + // Write the cleartext password packet. + if err := c.writeClearTextPassword(params); err != nil { + return err + } + } else if pluginName == MysqlNativePassword { + // Write the mysql_native_password packet. + if err := c.writeMysqlNativePassword(params, salt); err != nil { + return err + } + } else { + return NewSQLError(CRServerHandshakeErr, SSUnknownSQLState, "server asked for unsupported auth method: %v", pluginName) } // Wait for OK packet. @@ -649,7 +655,12 @@ func parseAuthSwitchRequest(data []byte) (string, []byte, error) { return "", nil, vterrors.Errorf(vtrpc.Code_INTERNAL, "cannot get plugin name from AuthSwitchRequest: %v", data) } - return pluginName, data[pos:], nil + // If this was a request with a salt in it, max 20 bytes + salt := data[pos:] + if len(salt) > 20 { + salt = salt[:20] + } + return pluginName, salt, nil } // writeClearTextPassword writes the clear text password. @@ -665,3 +676,17 @@ func (c *Conn) writeClearTextPassword(params *ConnParams) error { } return c.writeEphemeralPacket() } + +// writeMysqlNativePassword writes the encrypted mysql_native_password format +// Returns a SQLError. +func (c *Conn) writeMysqlNativePassword(params *ConnParams, salt []byte) error { + scrambledPassword := ScramblePassword(salt, []byte(params.Pass)) + data := c.startEphemeralPacket(len(scrambledPassword)) + pos := 0 + pos += copy(data[pos:], scrambledPassword) + // Sanity check. + if pos != len(data) { + return vterrors.Errorf(vtrpc.Code_INTERNAL, "error building MysqlNativePassword packet: got %v bytes expected %v", pos, len(data)) + } + return c.writeEphemeralPacket() +}