diff --git a/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.cc b/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.cc index be4739d909f93..a851b4ad581dc 100644 --- a/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.cc +++ b/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.cc @@ -1,5 +1,7 @@ #include "source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h" +#include "source/common/buffer/buffer_impl.h" +#include "source/common/common/logger.h" #include "source/extensions/filters/network/mysql_proxy/mysql_codec.h" #include "source/extensions/filters/network/mysql_proxy/mysql_utils.h" @@ -55,6 +57,10 @@ bool ClientLogin::isClientSecureConnection() const { return client_cap_ & CLIENT_SECURE_CONNECTION; } +void ClientLogin::addConnectionAttribute(const std::pair& attr) { + conn_attr_.emplace_back(attr); +} + DecodeStatus ClientLogin::parseMessage(Buffer::Instance& buffer, uint32_t len) { /* 4.0 uses 2 bytes, 4.1+ uses 4 bytes, but the proto-flag is in the lower 2 * bytes */ @@ -96,6 +102,7 @@ DecodeStatus ClientLogin::parseResponseSsl(Buffer::Instance& buffer) { } DecodeStatus ClientLogin::parseResponse41(Buffer::Instance& buffer) { + int total = buffer.length(); uint16_t ext_cap; if (BufferHelper::readUint16(buffer, ext_cap) != DecodeStatus::Success) { ENVOY_LOG(debug, "error when parsing client cap flag of client login message"); @@ -155,6 +162,48 @@ DecodeStatus ClientLogin::parseResponse41(Buffer::Instance& buffer) { ENVOY_LOG(debug, "error when parsing auth plugin name of client login message"); return DecodeStatus::Failure; } + if (client_cap_ & CLIENT_CONNECT_ATTRS) { + // length of all key value pairs + uint64_t kvs_len; + if (BufferHelper::readLengthEncodedInteger(buffer, kvs_len) != DecodeStatus::Success) { + ENVOY_LOG(debug, "error when parsing length of all key-values in connection attributes of " + "client login message"); + return DecodeStatus::Failure; + } + while (kvs_len > 0) { + uint64_t str_len; + uint64_t prev_len = buffer.length(); + if (BufferHelper::readLengthEncodedInteger(buffer, str_len) != DecodeStatus::Success) { + ENVOY_LOG(debug, "error when parsing total length of connection attribute key in " + "connection attributes of " + "client login message"); + return DecodeStatus::Failure; + } + std::string key; + if (BufferHelper::readStringBySize(buffer, str_len, key) != DecodeStatus::Success) { + ENVOY_LOG(debug, "error when parsing connection attribute key in connection attributes of " + "client login message"); + return DecodeStatus::Failure; + } + if (BufferHelper::readLengthEncodedInteger(buffer, str_len) != DecodeStatus::Success) { + ENVOY_LOG( + debug, + "error when parsing length of connection attribute value in connection attributes of " + "client login message"); + return DecodeStatus::Failure; + } + std::string val; + if (BufferHelper::readStringBySize(buffer, str_len, val) != DecodeStatus::Success) { + ENVOY_LOG(debug, "error when parsing connection attribute val in connection attributes of " + "client login message"); + return DecodeStatus::Failure; + } + conn_attr_.emplace_back(std::make_pair(std::move(key), std::move(val))); + kvs_len -= prev_len - buffer.length(); + } + } + ENVOY_LOG(debug, "parsed client login protocol 41, consumed len {}, remain len {}", + total - buffer.length(), buffer.length()); return DecodeStatus::Success; } @@ -238,6 +287,17 @@ void ClientLogin::encodeResponse41(Buffer::Instance& out) const { BufferHelper::addString(out, auth_plugin_name_); BufferHelper::addUint8(out, enc_end_string); } + if (client_cap_ & CLIENT_CONNECT_ATTRS) { + Buffer::OwnedImpl conn_attr; + for (const auto& kv : conn_attr_) { + BufferHelper::addLengthEncodedInteger(conn_attr, kv.first.length()); + BufferHelper::addString(conn_attr, kv.first); + BufferHelper::addLengthEncodedInteger(conn_attr, kv.second.length()); + BufferHelper::addString(conn_attr, kv.second); + } + BufferHelper::addLengthEncodedInteger(out, conn_attr.length()); + out.move(conn_attr); + } } void ClientLogin::encodeResponse320(Buffer::Instance& out) const { diff --git a/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h b/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h index fbf3f7b22a723..44e8ba16ee74e 100644 --- a/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h +++ b/source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h @@ -25,6 +25,9 @@ class ClientLogin : public MySQLCodec { const std::vector& getAuthResp() const { return auth_resp_; } const std::string& getDb() const { return db_; } const std::string& getAuthPluginName() const { return auth_plugin_name_; } + const std::vector>& getConnectionAttribute() const { + return conn_attr_; + } bool isResponse41() const; bool isResponse320() const; bool isSSLRequest() const; @@ -40,6 +43,7 @@ class ClientLogin : public MySQLCodec { void setAuthResp(const std::vector& auth_resp); void setDb(const std::string& db); void setAuthPluginName(const std::string& auth_plugin_name); + void addConnectionAttribute(const std::pair&); private: DecodeStatus parseResponseSsl(Buffer::Instance& buffer); @@ -56,6 +60,7 @@ class ClientLogin : public MySQLCodec { std::vector auth_resp_; std::string db_; std::string auth_plugin_name_; + std::vector> conn_attr_; }; } // namespace MySQLProxy diff --git a/test/extensions/filters/network/mysql_proxy/mysql_clogin_test.cc b/test/extensions/filters/network/mysql_proxy/mysql_clogin_test.cc index d06aad773f1d2..ee6611f4fac54 100644 --- a/test/extensions/filters/network/mysql_proxy/mysql_clogin_test.cc +++ b/test/extensions/filters/network/mysql_proxy/mysql_clogin_test.cc @@ -1,4 +1,7 @@ +#include + #include "source/common/buffer/buffer_impl.h" +#include "source/extensions/filters/network/mysql_proxy/mysql_codec.h" #include "source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h" #include "source/extensions/filters/network/mysql_proxy/mysql_utils.h" @@ -19,6 +22,7 @@ ClientLogin initClientLogin() { mysql_clogin_encode.setAuthResp(MySQLTestUtils::getAuthResp8()); mysql_clogin_encode.setDb(MySQLTestUtils::getDb()); mysql_clogin_encode.setAuthPluginName(MySQLTestUtils::getAuthPluginName()); + mysql_clogin_encode.addConnectionAttribute({"key", "val"}); return mysql_clogin_encode; } }; // namespace @@ -188,8 +192,8 @@ TEST_F(MySQLCLoginTest, MySQLClientLogin41IncompleteAuthResp) { * - message is decoded using the ClientLogin class */ TEST_F(MySQLCLoginTest, MySQLClientLogin41EncDec) { - ClientLogin& mysql_clogin_encode = - MySQLCLoginTest::getClientLogin(CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB); + ClientLogin& mysql_clogin_encode = MySQLCLoginTest::getClientLogin( + CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB | CLIENT_CONNECT_ATTRS); Buffer::OwnedImpl decode_data; mysql_clogin_encode.encode(decode_data); @@ -204,7 +208,8 @@ TEST_F(MySQLCLoginTest, MySQLClientLogin41EncDec) { EXPECT_EQ(mysql_clogin_decode.getCharset(), mysql_clogin_encode.getCharset()); EXPECT_EQ(mysql_clogin_decode.getUsername(), mysql_clogin_encode.getUsername()); EXPECT_EQ(mysql_clogin_decode.getAuthResp(), mysql_clogin_encode.getAuthResp()); - + EXPECT_EQ(mysql_clogin_decode.getConnectionAttribute(), + mysql_clogin_encode.getConnectionAttribute()); EXPECT_TRUE(mysql_clogin_decode.getAuthPluginName().empty()); } @@ -553,6 +558,110 @@ TEST_F(MySQLCLoginTest, MySQLClientLogin41IncompleteAuthPluginName) { EXPECT_EQ(mysql_clogin_decode.getAuthPluginName(), ""); } +class MySQL41LoginConnAttrTest : public MySQLCLoginTest { +public: + MySQL41LoginConnAttrTest() { + login_encode_ = MySQLCLoginTest::getClientLogin(CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB | + CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS); + incomplete_base_len_ = + sizeof(login_encode_.getClientCap()) + sizeof(login_encode_.getMaxPacket()) + + sizeof(login_encode_.getCharset()) + UNSET_BYTES + login_encode_.getUsername().size() + 1 + + login_encode_.getAuthResp().size() + 1 + login_encode_.getDb().size() + 1 + + login_encode_.getAuthPluginName().length() + 1; + } + + void prepareLoginDecode(int delta_len = 0) { + Buffer::OwnedImpl buffer; + login_encode_.encode(buffer); + int incomplete_len = incomplete_base_len_ + delta_len; + Buffer::OwnedImpl decode_data(buffer.toString().data(), incomplete_len); + + login_decode_.decode(decode_data, CHALLENGE_SEQ_NUM, decode_data.length()); + } + + void checkLoginDecode(const std::function& additional_check = nullptr) { + EXPECT_TRUE(login_decode_.isConnectWithDb()); + EXPECT_EQ(login_decode_.getClientCap(), login_encode_.getClientCap()); + EXPECT_EQ(login_decode_.getExtendedClientCap(), login_decode_.getExtendedClientCap()); + EXPECT_EQ(login_decode_.getMaxPacket(), login_encode_.getMaxPacket()); + EXPECT_EQ(login_decode_.getCharset(), login_encode_.getCharset()); + EXPECT_EQ(login_decode_.getUsername(), login_encode_.getUsername()); + EXPECT_EQ(login_decode_.getAuthResp(), login_encode_.getAuthResp()); + EXPECT_EQ(login_decode_.getDb(), login_encode_.getDb()); + EXPECT_EQ(login_decode_.getAuthPluginName(), login_encode_.getAuthPluginName()); + if (additional_check != nullptr) { + additional_check(); + } + } + const ClientLogin& loginEncode() const { return login_encode_; } + const ClientLogin& loginDecode() const { return login_decode_; } + +private: + ClientLogin login_encode_; + ClientLogin login_decode_; + int incomplete_base_len_; +}; + +/* + * Negative Test the MYSQL Client Login 41 message parser: + * Incomplete total length of connection attributions + */ +TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrLength) { + prepareLoginDecode(); + checkLoginDecode([&]() { EXPECT_EQ(loginDecode().getConnectionAttribute().size(), 0); }); +} + +/* + * Negative Test the MYSQL Client Login 41 message parser: + * Incomplete length of connection attribution key + */ +TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrKeyLength) { + prepareLoginDecode( + MySQLTestUtils::bytesOfConnAtrributeLength(loginEncode().getConnectionAttribute())); + + checkLoginDecode([&]() { EXPECT_EQ(loginDecode().getConnectionAttribute().size(), 0); }); +} + +/* + * Negative Test the MYSQL Client Login 41 message parser: + * Incomplete connection attribution key + */ +TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrKey) { + prepareLoginDecode( + MySQLTestUtils::bytesOfConnAtrributeLength(loginEncode().getConnectionAttribute()) + + MySQLTestUtils::sizeOfLengthEncodeInteger( + loginEncode().getConnectionAttribute()[0].first.length())); + checkLoginDecode([&]() { EXPECT_EQ(loginDecode().getConnectionAttribute().size(), 0); }); +} + +/* + * Negative Test the MYSQL Client Login 41 message parser: + * Incomplete length of connection attribution val + */ +TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrValLength) { + prepareLoginDecode( + MySQLTestUtils::bytesOfConnAtrributeLength(loginEncode().getConnectionAttribute()) + + MySQLTestUtils::sizeOfLengthEncodeInteger( + loginEncode().getConnectionAttribute()[0].first.length()) + + loginEncode().getConnectionAttribute()[0].first.length()); + checkLoginDecode([&]() { EXPECT_EQ(loginDecode().getConnectionAttribute().size(), 0); }); +} + +/* + * Negative Test the MYSQL Client Login 41 message parser: + * Incomplete connection attribution val + */ +TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrVal) { + prepareLoginDecode( + MySQLTestUtils::bytesOfConnAtrributeLength(loginEncode().getConnectionAttribute()) + + MySQLTestUtils::sizeOfLengthEncodeInteger( + loginEncode().getConnectionAttribute()[0].first.length()) + + loginEncode().getConnectionAttribute()[0].first.length() + + MySQLTestUtils::sizeOfLengthEncodeInteger( + loginEncode().getConnectionAttribute()[0].second.length())); + checkLoginDecode([&]() { EXPECT_EQ(loginDecode().getConnectionAttribute().size(), 0); }); +} + /* * Negative Test the MYSQL Client 320 login message parser: * Incomplete header at cap diff --git a/test/extensions/filters/network/mysql_proxy/mysql_test_utils.cc b/test/extensions/filters/network/mysql_proxy/mysql_test_utils.cc index 989e0ebef708a..6bc91d26ca661 100644 --- a/test/extensions/filters/network/mysql_proxy/mysql_test_utils.cc +++ b/test/extensions/filters/network/mysql_proxy/mysql_test_utils.cc @@ -122,6 +122,18 @@ std::string MySQLTestUtils::encodeMessage(uint32_t packet_len, uint8_t it, uint8 return buffer.toString(); } +int MySQLTestUtils::bytesOfConnAtrributeLength( + const std::vector>& conn_attrs) { + int64_t total_len = 0; + for (const auto& kv : conn_attrs) { + total_len += sizeOfLengthEncodeInteger(kv.first.length()); + total_len += kv.first.length(); + total_len += sizeOfLengthEncodeInteger(kv.second.length()); + total_len += kv.second.length(); + } + return sizeOfLengthEncodeInteger(total_len); +} + int MySQLTestUtils::sizeOfLengthEncodeInteger(uint64_t val) { if (val < 251) { return sizeof(uint8_t); diff --git a/test/extensions/filters/network/mysql_proxy/mysql_test_utils.h b/test/extensions/filters/network/mysql_proxy/mysql_test_utils.h index 996cc92ff918c..d2724c9a369ff 100644 --- a/test/extensions/filters/network/mysql_proxy/mysql_test_utils.h +++ b/test/extensions/filters/network/mysql_proxy/mysql_test_utils.h @@ -37,6 +37,8 @@ class MySQLTestUtils { static std::string getDb() { return "mysql.db"; } static std::string getCommandResponse() { return "command response"; } static std::string getInfo() { return "info"; } + static int + bytesOfConnAtrributeLength(const std::vector>& conn); static int sizeOfLengthEncodeInteger(uint64_t val); std::string encodeServerGreeting(int protocol);