diff --git a/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc b/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc index 0993f9f4c272e..18a55b088e40d 100644 --- a/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc +++ b/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc @@ -176,7 +176,7 @@ void DecoderImpl::initialize() { }; } -bool DecoderImpl::parseMessage(Buffer::Instance& data) { +bool DecoderImpl::parseHeader(Buffer::Instance& data) { ENVOY_LOG(trace, "postgres_proxy: parsing message, len {}", data.length()); // The minimum size of the message sufficient for parsing is 5 bytes. @@ -220,10 +220,6 @@ bool DecoderImpl::parseMessage(Buffer::Instance& data) { data.drain(startup_ ? 4 : 5); // Length plus optional 1st byte. - uint32_t bytes_to_read = message_len_ - 4; - message.assign(std::string(static_cast(data.linearize(bytes_to_read)), bytes_to_read)); - setMessage(message); - ENVOY_LOG(trace, "postgres_proxy: msg parsed"); return true; } @@ -238,7 +234,7 @@ bool DecoderImpl::onData(Buffer::Instance& data, bool frontend) { ENVOY_LOG(trace, "postgres_proxy: decoding {} bytes", data.length()); - if (!parseMessage(data)) { + if (!parseHeader(data)) { return false; } @@ -259,16 +255,25 @@ bool DecoderImpl::onData(Buffer::Instance& data, bool frontend) { } } - std::vector& actions = std::get<2>(msg.get()); - for (const auto& action : actions) { - action(this); - } - // message_len_ specifies total message length including 4 bytes long // "length" field. The length of message body is total length minus size // of "length" field (4 bytes). uint32_t bytes_to_read = message_len_ - 4; + std::vector& actions = std::get<2>(msg.get()); + if (!actions.empty()) { + // Linearize the message for processing. + message_.assign(std::string(static_cast(data.linearize(bytes_to_read)), bytes_to_read)); + + // Invoke actions associated with the type of received message. + for (const auto& action : actions) { + action(this); + } + + // Drop the linearized message. + message_.erase(); + } + ENVOY_LOG(debug, "({}) command = {} ({})", msg_processor.direction_, command_, std::get<0>(msg.get())); ENVOY_LOG(debug, "({}) length = {}", msg_processor.direction_, message_len_); diff --git a/source/extensions/filters/network/postgres_proxy/postgres_decoder.h b/source/extensions/filters/network/postgres_proxy/postgres_decoder.h index dc4638b1c436a..409cdbba659c9 100644 --- a/source/extensions/filters/network/postgres_proxy/postgres_decoder.h +++ b/source/extensions/filters/network/postgres_proxy/postgres_decoder.h @@ -72,7 +72,6 @@ class DecoderImpl : public Decoder, Logger::Loggable { bool onData(Buffer::Instance& data, bool frontend) override; PostgresSession& getSession() override { return session_; } - void setMessage(std::string message) { message_ = message; } std::string getMessage() { return message_; } void setStartup(bool startup) { startup_ = startup; } @@ -122,7 +121,7 @@ class DecoderImpl : public Decoder, Logger::Loggable { MsgAction unknown_; }; - bool parseMessage(Buffer::Instance& data); + bool parseHeader(Buffer::Instance& data); void decode(Buffer::Instance& data); void decodeAuthentication(); void decodeBackendStatements(); diff --git a/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc b/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc index aa2d9ff2c7b7b..e787a18f2d5bf 100644 --- a/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc +++ b/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc @@ -242,6 +242,9 @@ TEST_P(PostgresProxyFrontendDecoderTest, FrontendInc) { EXPECT_CALL(callbacks_, incMessagesFrontend()).Times(1); createPostgresMsg(data_, GetParam(), "SELECT 1;"); decoder_->onData(data_, true); + + // Make sure that decoder releases memory used during message processing. + ASSERT_TRUE(decoder_->getMessage().empty()); } // Run the above test for each frontend message. @@ -507,6 +510,89 @@ TEST_P(PostgresProxyFrontendEncrDecoderTest, EncyptedTraffic) { INSTANTIATE_TEST_SUITE_P(FrontendEncryptedMessagesTests, PostgresProxyFrontendEncrDecoderTest, ::testing::Values(80877103, 80877104)); +class FakeBuffer : public Buffer::Instance { +public: + MOCK_METHOD(void, addDrainTracker, (std::function), (override)); + MOCK_METHOD(void, add, (const void*, uint64_t), (override)); + MOCK_METHOD(void, addBufferFragment, (Buffer::BufferFragment&), (override)); + MOCK_METHOD(void, add, (absl::string_view), (override)); + MOCK_METHOD(void, add, (const Instance&), (override)); + MOCK_METHOD(void, prepend, (absl::string_view), (override)); + MOCK_METHOD(void, prepend, (Instance&), (override)); + MOCK_METHOD(void, commit, (Buffer::RawSlice*, uint64_t), (override)); + MOCK_METHOD(void, copyOut, (size_t, uint64_t, void*), (const, override)); + MOCK_METHOD(void, drain, (uint64_t), (override)); + MOCK_METHOD(Buffer::RawSliceVector, getRawSlices, (absl::optional), (const, override)); + MOCK_METHOD(Buffer::SliceDataPtr, extractMutableFrontSlice, (), (override)); + MOCK_METHOD(uint64_t, length, (), (const, override)); + MOCK_METHOD(void*, linearize, (uint32_t), (override)); + MOCK_METHOD(void, move, (Instance&), (override)); + MOCK_METHOD(void, move, (Instance&, uint64_t), (override)); + MOCK_METHOD(uint64_t, reserve, (uint64_t, Buffer::RawSlice*, uint64_t), (override)); + MOCK_METHOD(ssize_t, search, (const void*, uint64_t, size_t, size_t), (const, override)); + MOCK_METHOD(bool, startsWith, (absl::string_view), (const, override)); + MOCK_METHOD(std::string, toString, (), (const, override)); +}; + +// Test verifies that decoder calls Buffer::linearize method +// for messages which have associated 'action'. +TEST_F(PostgresProxyDecoderTest, Linearize) { + testing::NiceMock fake_buf; + uint8_t body[] = "test\0"; + + decoder_->setStartup(false); + + // Simulate that decoder reads message which needs processing. + // Query 'Q' message's body is just string. + // Message header is 5 bytes and body will contain string "test\0". + EXPECT_CALL(fake_buf, length).WillRepeatedly(testing::Return(10)); + // The decoder will first ask for 1-byte message type + // Then for length and finally for message body. + EXPECT_CALL(fake_buf, copyOut) + .WillOnce([](size_t start, uint64_t size, void* data) { + ASSERT_THAT(start, 0); + ASSERT_THAT(size, 1); + *(static_cast(data)) = 'Q'; + }) + .WillOnce([](size_t start, uint64_t size, void* data) { + ASSERT_THAT(start, 1); + ASSERT_THAT(size, 4); + *(static_cast(data)) = htonl(9); + }) + .WillRepeatedly([=](size_t start, uint64_t size, void* data) { + ASSERT_THAT(start, 0); + ASSERT_THAT(size, 5); + memcpy(data, body, 5); + }); + + // It should call "Buffer::linearize". + EXPECT_CALL(fake_buf, linearize).WillOnce([&](uint32_t) -> void* { return body; }); + + decoder_->onData(fake_buf, false); + + // Simulate that decoder reads message which does not need processing. + // BindComplete message has type '2' and empty body. + // Total message length is equal to length of header (5 bytes). + EXPECT_CALL(fake_buf, length).WillRepeatedly(testing::Return(5)); + // The decoder will first ask for 1-byte message type and next for length. + EXPECT_CALL(fake_buf, copyOut) + .WillOnce([](size_t start, uint64_t size, void* data) { + ASSERT_THAT(start, 0); + ASSERT_THAT(size, 1); + *(static_cast(data)) = '2'; + }) + .WillOnce([](size_t start, uint64_t size, void* data) { + ASSERT_THAT(start, 1); + ASSERT_THAT(size, 4); + *(static_cast(data)) = htonl(4); + }); + + // Make sure that decoder does not call linearize. + EXPECT_CALL(fake_buf, linearize).Times(0); + + decoder_->onData(fake_buf, false); +} + } // namespace PostgresProxy } // namespace NetworkFilters } // namespace Extensions