Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h"

#include "source/common/buffer/buffer_impl.h"
#include "source/common/common/logger.h"
#include "source/extensions/filters/network/mysql_proxy/mysql_codec.h"
#include "source/extensions/filters/network/mysql_proxy/mysql_utils.h"

Expand Down Expand Up @@ -55,6 +57,10 @@ bool ClientLogin::isClientSecureConnection() const {
return client_cap_ & CLIENT_SECURE_CONNECTION;
}

void ClientLogin::addConnectionAttribute(const std::pair<std::string, std::string>& attr) {
conn_attr_.emplace_back(attr);
}

DecodeStatus ClientLogin::parseMessage(Buffer::Instance& buffer, uint32_t len) {
/* 4.0 uses 2 bytes, 4.1+ uses 4 bytes, but the proto-flag is in the lower 2
* bytes */
Expand Down Expand Up @@ -96,6 +102,7 @@ DecodeStatus ClientLogin::parseResponseSsl(Buffer::Instance& buffer) {
}

DecodeStatus ClientLogin::parseResponse41(Buffer::Instance& buffer) {
int total = buffer.length();
uint16_t ext_cap;
if (BufferHelper::readUint16(buffer, ext_cap) != DecodeStatus::Success) {
ENVOY_LOG(debug, "error when parsing client cap flag of client login message");
Expand Down Expand Up @@ -155,6 +162,48 @@ DecodeStatus ClientLogin::parseResponse41(Buffer::Instance& buffer) {
ENVOY_LOG(debug, "error when parsing auth plugin name of client login message");
return DecodeStatus::Failure;
}
if (client_cap_ & CLIENT_CONNECT_ATTRS) {
// length of all key value pairs
uint64_t kvs_len;
if (BufferHelper::readLengthEncodedInteger(buffer, kvs_len) != DecodeStatus::Success) {
ENVOY_LOG(debug, "error when parsing length of all key-values in connection attributes of "
Comment thread
cpakulski marked this conversation as resolved.
"client login message");
return DecodeStatus::Failure;
}
while (kvs_len > 0) {
uint64_t str_len;
uint64_t prev_len = buffer.length();
if (BufferHelper::readLengthEncodedInteger(buffer, str_len) != DecodeStatus::Success) {
ENVOY_LOG(debug, "error when parsing total length of connection attribute key in "
"connection attributes of "
"client login message");
return DecodeStatus::Failure;
}
std::string key;
if (BufferHelper::readStringBySize(buffer, str_len, key) != DecodeStatus::Success) {
ENVOY_LOG(debug, "error when parsing connection attribute key in connection attributes of "
"client login message");
return DecodeStatus::Failure;
}
if (BufferHelper::readLengthEncodedInteger(buffer, str_len) != DecodeStatus::Success) {
ENVOY_LOG(
debug,
"error when parsing length of connection attribute value in connection attributes of "
"client login message");
return DecodeStatus::Failure;
}
std::string val;
if (BufferHelper::readStringBySize(buffer, str_len, val) != DecodeStatus::Success) {
ENVOY_LOG(debug, "error when parsing connection attribute val in connection attributes of "
"client login message");
return DecodeStatus::Failure;
}
conn_attr_.emplace_back(std::make_pair(std::move(key), std::move(val)));
kvs_len -= prev_len - buffer.length();
}
}
ENVOY_LOG(debug, "parsed client login protocol 41, consumed len {}, remain len {}",
total - buffer.length(), buffer.length());
return DecodeStatus::Success;
}

Expand Down Expand Up @@ -238,6 +287,17 @@ void ClientLogin::encodeResponse41(Buffer::Instance& out) const {
BufferHelper::addString(out, auth_plugin_name_);
BufferHelper::addUint8(out, enc_end_string);
}
if (client_cap_ & CLIENT_CONNECT_ATTRS) {
Buffer::OwnedImpl conn_attr;
for (const auto& kv : conn_attr_) {
BufferHelper::addLengthEncodedInteger(conn_attr, kv.first.length());
BufferHelper::addString(conn_attr, kv.first);
BufferHelper::addLengthEncodedInteger(conn_attr, kv.second.length());
BufferHelper::addString(conn_attr, kv.second);
}
BufferHelper::addLengthEncodedInteger(out, conn_attr.length());
out.move(conn_attr);
}
}

void ClientLogin::encodeResponse320(Buffer::Instance& out) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ class ClientLogin : public MySQLCodec {
const std::vector<uint8_t>& getAuthResp() const { return auth_resp_; }
const std::string& getDb() const { return db_; }
const std::string& getAuthPluginName() const { return auth_plugin_name_; }
const std::vector<std::pair<std::string, std::string>>& getConnectionAttribute() const {
return conn_attr_;
}
bool isResponse41() const;
bool isResponse320() const;
bool isSSLRequest() const;
Expand All @@ -40,6 +43,7 @@ class ClientLogin : public MySQLCodec {
void setAuthResp(const std::vector<uint8_t>& auth_resp);
void setDb(const std::string& db);
void setAuthPluginName(const std::string& auth_plugin_name);
void addConnectionAttribute(const std::pair<std::string, std::string>&);

private:
DecodeStatus parseResponseSsl(Buffer::Instance& buffer);
Expand All @@ -56,6 +60,7 @@ class ClientLogin : public MySQLCodec {
std::vector<uint8_t> auth_resp_;
std::string db_;
std::string auth_plugin_name_;
std::vector<std::pair<std::string, std::string>> conn_attr_;
};

} // namespace MySQLProxy
Expand Down
111 changes: 108 additions & 3 deletions test/extensions/filters/network/mysql_proxy/mysql_clogin_test.cc
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#include <functional>

#include "source/common/buffer/buffer_impl.h"
#include "source/extensions/filters/network/mysql_proxy/mysql_codec.h"
#include "source/extensions/filters/network/mysql_proxy/mysql_codec_clogin.h"
#include "source/extensions/filters/network/mysql_proxy/mysql_utils.h"

Expand All @@ -19,6 +22,7 @@ ClientLogin initClientLogin() {
mysql_clogin_encode.setAuthResp(MySQLTestUtils::getAuthResp8());
mysql_clogin_encode.setDb(MySQLTestUtils::getDb());
mysql_clogin_encode.setAuthPluginName(MySQLTestUtils::getAuthPluginName());
mysql_clogin_encode.addConnectionAttribute({"key", "val"});
return mysql_clogin_encode;
}
}; // namespace
Expand Down Expand Up @@ -188,8 +192,8 @@ TEST_F(MySQLCLoginTest, MySQLClientLogin41IncompleteAuthResp) {
* - message is decoded using the ClientLogin class
*/
TEST_F(MySQLCLoginTest, MySQLClientLogin41EncDec) {
ClientLogin& mysql_clogin_encode =
MySQLCLoginTest::getClientLogin(CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB);
ClientLogin& mysql_clogin_encode = MySQLCLoginTest::getClientLogin(
CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB | CLIENT_CONNECT_ATTRS);
Buffer::OwnedImpl decode_data;
mysql_clogin_encode.encode(decode_data);

Expand All @@ -204,7 +208,8 @@ TEST_F(MySQLCLoginTest, MySQLClientLogin41EncDec) {
EXPECT_EQ(mysql_clogin_decode.getCharset(), mysql_clogin_encode.getCharset());
EXPECT_EQ(mysql_clogin_decode.getUsername(), mysql_clogin_encode.getUsername());
EXPECT_EQ(mysql_clogin_decode.getAuthResp(), mysql_clogin_encode.getAuthResp());

EXPECT_EQ(mysql_clogin_decode.getConnectionAttribute(),
mysql_clogin_encode.getConnectionAttribute());
EXPECT_TRUE(mysql_clogin_decode.getAuthPluginName().empty());
}

Expand Down Expand Up @@ -553,6 +558,106 @@ TEST_F(MySQLCLoginTest, MySQLClientLogin41IncompleteAuthPluginName) {
EXPECT_EQ(mysql_clogin_decode.getAuthPluginName(), "");
}

class MySQL41LoginConnAttrTest : public MySQLCLoginTest {
public:
MySQL41LoginConnAttrTest() {
login_encode = MySQLCLoginTest::getClientLogin(CLIENT_PROTOCOL_41 | CLIENT_CONNECT_WITH_DB |
CLIENT_PLUGIN_AUTH | CLIENT_CONNECT_ATTRS);
incomplete_base_len = sizeof(login_encode.getClientCap()) +
sizeof(login_encode.getMaxPacket()) + sizeof(login_encode.getCharset()) +
UNSET_BYTES + login_encode.getUsername().size() + 1 +
login_encode.getAuthResp().size() + 1 + login_encode.getDb().size() + 1 +
login_encode.getAuthPluginName().length() + 1;
}
void prepareLoginDecode(int delta_len = 0) {
Buffer::OwnedImpl buffer;
login_encode.encode(buffer);
int incomplete_len = incomplete_base_len + delta_len;
Buffer::OwnedImpl decode_data(buffer.toString().data(), incomplete_len);

login_decode.decode(decode_data, CHALLENGE_SEQ_NUM, decode_data.length());
}

void checkLoginDecode(const std::function<void()>& additional_check = nullptr) {
EXPECT_TRUE(login_decode.isConnectWithDb());
EXPECT_EQ(login_decode.getClientCap(), login_encode.getClientCap());
EXPECT_EQ(login_decode.getExtendedClientCap(), login_decode.getExtendedClientCap());
EXPECT_EQ(login_decode.getMaxPacket(), login_encode.getMaxPacket());
EXPECT_EQ(login_decode.getCharset(), login_encode.getCharset());
EXPECT_EQ(login_decode.getUsername(), login_encode.getUsername());
EXPECT_EQ(login_decode.getAuthResp(), login_encode.getAuthResp());
EXPECT_EQ(login_decode.getDb(), login_encode.getDb());
EXPECT_EQ(login_decode.getAuthPluginName(), login_encode.getAuthPluginName());
if (additional_check) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you compare to nullptr?

additional_check();
}
}

ClientLogin login_encode;
ClientLogin login_decode;
int incomplete_base_len;
Comment thread
cpakulski marked this conversation as resolved.
Outdated
};

/*
* Negative Test the MYSQL Client Login 41 message parser:
* Incomplete total length of connection attributions
*/
TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrLength) {
prepareLoginDecode();
checkLoginDecode([&]() { EXPECT_EQ(login_decode.getConnectionAttribute().size(), 0); });
}

/*
* Negative Test the MYSQL Client Login 41 message parser:
* Incomplete length of connection attribution key
*/
TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrKeyLength) {
prepareLoginDecode(
MySQLTestUtils::bytesOfConnAtrributeLength(login_encode.getConnectionAttribute()));

checkLoginDecode([&]() { EXPECT_EQ(login_decode.getConnectionAttribute().size(), 0); });
}

/*
* Negative Test the MYSQL Client Login 41 message parser:
* Incomplete connection attribution key
*/
TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrKey) {
prepareLoginDecode(
MySQLTestUtils::bytesOfConnAtrributeLength(login_encode.getConnectionAttribute()) +
MySQLTestUtils::sizeOfLengthEncodeInteger(
login_encode.getConnectionAttribute()[0].first.length()));
checkLoginDecode([&]() { EXPECT_EQ(login_decode.getConnectionAttribute().size(), 0); });
}

/*
* Negative Test the MYSQL Client Login 41 message parser:
* Incomplete length of connection attribution val
*/
TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrValLength) {
prepareLoginDecode(
MySQLTestUtils::bytesOfConnAtrributeLength(login_encode.getConnectionAttribute()) +
MySQLTestUtils::sizeOfLengthEncodeInteger(
login_encode.getConnectionAttribute()[0].first.length()) +
login_encode.getConnectionAttribute()[0].first.length());
checkLoginDecode([&]() { EXPECT_EQ(login_decode.getConnectionAttribute().size(), 0); });
}

/*
* Negative Test the MYSQL Client Login 41 message parser:
* Incomplete connection attribution val
*/
TEST_F(MySQL41LoginConnAttrTest, MySQLClientLogin41IncompleteConnAttrVal) {
prepareLoginDecode(
MySQLTestUtils::bytesOfConnAtrributeLength(login_encode.getConnectionAttribute()) +
MySQLTestUtils::sizeOfLengthEncodeInteger(
login_encode.getConnectionAttribute()[0].first.length()) +
login_encode.getConnectionAttribute()[0].first.length() +
MySQLTestUtils::sizeOfLengthEncodeInteger(
login_encode.getConnectionAttribute()[0].second.length()));
checkLoginDecode([&]() { EXPECT_EQ(login_decode.getConnectionAttribute().size(), 0); });
}

/*
* Negative Test the MYSQL Client 320 login message parser:
* Incomplete header at cap
Expand Down
12 changes: 12 additions & 0 deletions test/extensions/filters/network/mysql_proxy/mysql_test_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,18 @@ std::string MySQLTestUtils::encodeMessage(uint32_t packet_len, uint8_t it, uint8
return buffer.toString();
}

int MySQLTestUtils::bytesOfConnAtrributeLength(
const std::vector<std::pair<std::string, std::string>> conn_attrs) {
int64_t allLen = 0;
Comment thread
cpakulski marked this conversation as resolved.
Outdated
for (const auto& kv : conn_attrs) {
allLen += sizeOfLengthEncodeInteger(kv.first.length());
allLen += kv.first.length();
allLen += sizeOfLengthEncodeInteger(kv.second.length());
allLen += kv.second.length();
}
return sizeOfLengthEncodeInteger(allLen);
}

int MySQLTestUtils::sizeOfLengthEncodeInteger(uint64_t val) {
if (val < 251) {
return sizeof(uint8_t);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class MySQLTestUtils {
static std::string getDb() { return "mysql.db"; }
static std::string getCommandResponse() { return "command response"; }
static std::string getInfo() { return "info"; }
static int
bytesOfConnAtrributeLength(const std::vector<std::pair<std::string, std::string>> conn);
static int sizeOfLengthEncodeInteger(uint64_t val);

std::string encodeServerGreeting(int protocol);
Expand Down