diff --git a/source/common/grpc/BUILD b/source/common/grpc/BUILD index 6729a434b2466..670068dd7f221 100644 --- a/source/common/grpc/BUILD +++ b/source/common/grpc/BUILD @@ -211,7 +211,17 @@ envoy_cc_library( name = "buffered_async_client_lib", hdrs = ["buffered_async_client.h"], deps = [ + ":buffered_message_ttl_manager_lib", ":typed_async_client_lib", "//source/common/protobuf:utility_lib", + "@com_google_absl//absl/container:btree", + ], +) + +envoy_cc_library( + name = "buffered_message_ttl_manager_lib", + hdrs = ["buffered_message_ttl_manager.h"], + deps = [ + "//envoy/event:dispatcher_interface", ], ) diff --git a/source/common/grpc/buffered_async_client.h b/source/common/grpc/buffered_async_client.h index 7c04673429429..bdd981d6bb6a7 100644 --- a/source/common/grpc/buffered_async_client.h +++ b/source/common/grpc/buffered_async_client.h @@ -1,7 +1,9 @@ #pragma once +#include #include +#include "source/common/grpc/buffered_message_ttl_manager.h" #include "source/common/grpc/typed_async_client.h" #include "source/common/protobuf/utility.h" @@ -20,9 +22,12 @@ template class BufferedAsyncClient { public: BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method, Grpc::AsyncStreamCallbacks& callbacks, - const Grpc::AsyncClient& client) + const Grpc::AsyncClient& client, + Event::Dispatcher& dispatcher, std::chrono::milliseconds message_timeout_msec) : max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks), - client_(client) {} + client_(client), + ttl_manager_( + dispatcher, [this](uint64_t id) { onError(id); }, message_timeout_msec) {} ~BufferedAsyncClient() { if (active_stream_ != nullptr) { @@ -70,13 +75,17 @@ template class BufferedAsyncClient { active_stream_->sendMessage(message, false); } + ttl_manager_.addDeadlineEntry(inflight_message_ids); return inflight_message_ids; } void onSuccess(uint64_t message_id) { erasePendingMessage(message_id); } void onError(uint64_t message_id) { - if (message_buffer_.find(message_id) == message_buffer_.end()) { + const auto& message_it = message_buffer_.find(message_id); + + if (message_it == message_buffer_.end() || + message_it->second.first != Grpc::BufferState::PendingFlush) { return; } @@ -119,6 +128,7 @@ template class BufferedAsyncClient { absl::btree_map> message_buffer_; uint32_t current_buffer_bytes_ = 0; uint64_t next_message_id_ = 0; + BufferedMessageTtlManager ttl_manager_; }; template diff --git a/source/common/grpc/buffered_message_ttl_manager.h b/source/common/grpc/buffered_message_ttl_manager.h new file mode 100644 index 0000000000000..8419c398b7527 --- /dev/null +++ b/source/common/grpc/buffered_message_ttl_manager.h @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include "envoy/event/dispatcher.h" + +namespace Envoy { +namespace Grpc { + +using BufferedMessageExpirationCallback = std::function; + +// This class is used to manage the TTL for pending uploads within a BufferedAsyncClient. Multiple +// IDs can be inserted into the TTL manager at once, all with the same TTL (specified in the +// constructor). Upon expiry, the TTL manager will invoke the provided expiry callback for each ID. +// Note that there is no way to disable the expiration, and so it's up to the recipient of the +// callback to handle this. BufferedAsyncClient will do the right thing here: if the expired ID is +// still in flight it will be returned to the buffer, otherwise it does nothing. The TTL manager is +// designed to handle multiple sets of IDs inserted at various times, backing this with a single +// Timer. This allows us to track a large amount of IDs inserted at different times without using a +// lot of different timers, which could put undue pressure on the event loop. +class BufferedMessageTtlManager { +public: + BufferedMessageTtlManager(Event::Dispatcher& dispatcher, + BufferedMessageExpirationCallback&& expiry_callback, + std::chrono::milliseconds message_ack_timeout) + : dispatcher_(dispatcher), message_ack_timeout_(message_ack_timeout), + expiry_callback_(expiry_callback), + timer_(dispatcher_.createTimer([this] { checkExpiredMessages(); })) {} + + ~BufferedMessageTtlManager() { timer_->disableTimer(); } + + void addDeadlineEntry(const absl::flat_hash_set& ids) { + const auto expires_at = dispatcher_.timeSource().monotonicTime() + message_ack_timeout_; + deadline_.emplace(expires_at, std::move(ids)); + + if (!timer_->enabled()) { + timer_->enableTimer(message_ack_timeout_); + } + } + + const std::queue>>& deadlineForTest() { + return deadline_; + } + +private: + void checkExpiredMessages() { + const auto now = dispatcher_.timeSource().monotonicTime(); + + while (!deadline_.empty()) { + auto& it = deadline_.front(); + if (it.first > now) { + break; + } + for (auto&& id : it.second) { + expiry_callback_(id); + } + deadline_.pop(); + } + + if (!deadline_.empty()) { + const auto earliest_timepoint = deadline_.front().first; + timer_->enableTimer( + std::chrono::duration_cast(earliest_timepoint - now)); + } + } + + Event::Dispatcher& dispatcher_; + std::chrono::milliseconds message_ack_timeout_; + BufferedMessageExpirationCallback expiry_callback_; + Event::TimerPtr timer_; + std::queue>> deadline_; +}; +} // namespace Grpc +} // namespace Envoy diff --git a/test/common/grpc/BUILD b/test/common/grpc/BUILD index 2a7c3681ae72d..d146d3f6759b1 100644 --- a/test/common/grpc/BUILD +++ b/test/common/grpc/BUILD @@ -202,3 +202,13 @@ envoy_cc_test( "@envoy_api//envoy/config/core/v3:pkg_cc_proto", ], ) + +envoy_cc_test( + name = "buffered_message_ttl_manager_test", + srcs = ["buffered_message_ttl_manager_test.cc"], + deps = [ + "//source/common/event:dispatcher_lib", + "//source/common/grpc:buffered_message_ttl_manager_lib", + "//test/test_common:utility_lib", + ], +) diff --git a/test/common/grpc/buffered_async_client_test.cc b/test/common/grpc/buffered_async_client_test.cc index fc025f6532a00..15ac3b8530fc6 100644 --- a/test/common/grpc/buffered_async_client_test.cc +++ b/test/common/grpc/buffered_async_client_test.cc @@ -1,3 +1,7 @@ +#include +#include +#include + #include "envoy/config/core/v3/grpc_service.pb.h" #include "source/common/grpc/async_client_impl.h" @@ -23,117 +27,144 @@ namespace Envoy { namespace Grpc { namespace { +constexpr uint32_t kPendingBufferSizeLimit = 1 << 16; + class BufferedAsyncClientTest : public testing::Test { public: BufferedAsyncClientTest() - : method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) { + : api_(Api::createApiForTest()), dispatcher_(api_->allocateDispatcher("test_thread")), + method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) { config_.mutable_envoy_grpc()->set_cluster_name("test_cluster"); cm_.initializeThreadLocalClusters({"test_cluster"}); ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_)); } + void SetUp() override { + EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream_)); + EXPECT_CALL(http_stream_, sendHeaders(_, _)); + EXPECT_CALL(http_stream_, reset()); + + raw_client_ = std::make_shared(cm_, config_, dispatcher_->timeSource()); + client_ = std::make_unique>( + raw_client_); + } + + void prepareBufferedClient(uint32_t buffer_size, std::chrono::milliseconds ttl) { + buffered_client_ = + std::make_unique>( + buffer_size, *method_descriptor_, callback_, *client_, *dispatcher_, ttl); + } + + void bufferNewMessage(absl::optional expected_message_id) { + helloworld::HelloRequest request; + request.set_name("Alice"); + EXPECT_EQ(expected_message_id, buffered_client_->bufferMessage(request)); + } + + void validateBuffer(uint32_t expected_buffered_count, uint32_t expected_pending_count) { + const auto buffer = buffered_client_->messageBuffer(); + uint32_t buffered_count = 0; + uint32_t pending_count = 0; + + for (const auto& message : buffer) { + switch (message.second.first) { + case BufferState::Buffered: + ++buffered_count; + break; + case BufferState::PendingFlush: + ++pending_count; + break; + default: + break; + } + } + + EXPECT_EQ(buffered_count, expected_buffered_count); + EXPECT_EQ(pending_count, expected_pending_count); + } + + Api::ApiPtr api_; + Event::DispatcherPtr dispatcher_; const Protobuf::MethodDescriptor* method_descriptor_; envoy::config::core::v3::GrpcService config_; NiceMock cm_; NiceMock http_client_; + Http::MockAsyncClientStream http_stream_; + std::shared_ptr raw_client_; + std::unique_ptr> client_; + NiceMock> callback_; + std::unique_ptr> + buffered_client_; }; TEST_F(BufferedAsyncClientTest, BasicSendFlow) { - Http::MockAsyncClientStream http_stream; - EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); - EXPECT_CALL(http_stream, sendHeaders(_, _)); - EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); - EXPECT_CALL(http_stream, sendData(_, _)); - EXPECT_CALL(http_stream, reset()); - - DangerousDeprecatedTestTime test_time_; - auto raw_client = std::make_shared(cm_, config_, test_time_.timeSystem()); - AsyncClient client(raw_client); - - NiceMock> callback; - BufferedAsyncClient buffered_client( - 100000, *method_descriptor_, callback, client); - - helloworld::HelloRequest request; - request.set_name("Alice"); - EXPECT_EQ(0, buffered_client.bufferMessage(request).value()); - const auto inflight_message_ids = buffered_client.sendBufferedMessages(); - EXPECT_TRUE(buffered_client.hasActiveStream()); - EXPECT_EQ(1, inflight_message_ids.size()); + EXPECT_CALL(http_stream_, sendData(_, _)); + EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillRepeatedly(Return(false)); + + prepareBufferedClient(kPendingBufferSizeLimit, std::chrono::milliseconds(1000)); + bufferNewMessage(0); + validateBuffer(1, 0); + + EXPECT_EQ(1, buffered_client_->sendBufferedMessages().size()); + EXPECT_TRUE(buffered_client_->hasActiveStream()); + validateBuffer(0, 1); // Pending messages should not be re-sent. - EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); - const auto inflight_message_ids2 = buffered_client.sendBufferedMessages(); - EXPECT_EQ(0, inflight_message_ids2.size()); + EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size()); + validateBuffer(0, 1); - // Re-buffer, and transport. - buffered_client.onError(*inflight_message_ids.begin()); + // It will call onError(). + dispatcher_->run(Event::Dispatcher::RunType::Block); + validateBuffer(1, 0); - EXPECT_CALL(http_stream, sendData(_, _)).Times(2); - EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); + // If we call onSuccess(), after onError() called, + // onSuccess() do not affect to the buffer. + buffered_client_->onSuccess(0); + validateBuffer(1, 0); - helloworld::HelloRequest request2; - request2.set_name("Bob"); - EXPECT_EQ(1, buffered_client.bufferMessage(request2).value()); - auto ids2 = buffered_client.sendBufferedMessages(); - EXPECT_EQ(2, ids2.size()); + EXPECT_CALL(http_stream_, sendData(_, _)).Times(2); + bufferNewMessage(1); + validateBuffer(2, 0); + + const auto inflight_message_ids = buffered_client_->sendBufferedMessages(); + EXPECT_EQ(2, inflight_message_ids.size()); + validateBuffer(0, 2); // Clear existing messages. - for (auto&& id : ids2) { - buffered_client.onSuccess(id); + for (auto&& id : inflight_message_ids) { + buffered_client_->onSuccess(id); } + validateBuffer(0, 0); + + // It will call onError(). + // But messages have been already cleared. + dispatcher_->run(Event::Dispatcher::RunType::Block); + validateBuffer(0, 0); // Successfully cleared pending messages. - EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); - auto ids3 = buffered_client.sendBufferedMessages(); - EXPECT_EQ(0, ids3.size()); + EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size()); + validateBuffer(0, 0); } TEST_F(BufferedAsyncClientTest, BufferLimitExceeded) { - Http::MockAsyncClientStream http_stream; - EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); - EXPECT_CALL(http_stream, sendHeaders(_, _)); - EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); - EXPECT_CALL(http_stream, reset()); - - DangerousDeprecatedTestTime test_time_; - auto raw_client = std::make_shared(cm_, config_, test_time_.timeSystem()); - AsyncClient client(raw_client); - - NiceMock> callback; - BufferedAsyncClient buffered_client( - 0, *method_descriptor_, callback, client); - - helloworld::HelloRequest request; - request.set_name("Alice"); - EXPECT_EQ(absl::nullopt, buffered_client.bufferMessage(request)); - - EXPECT_EQ(0, buffered_client.sendBufferedMessages().size()); - EXPECT_TRUE(buffered_client.hasActiveStream()); + EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); + + prepareBufferedClient(0, std::chrono::milliseconds(1000)); + bufferNewMessage(absl::nullopt); + + EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size()); + EXPECT_TRUE(buffered_client_->hasActiveStream()); } TEST_F(BufferedAsyncClientTest, BufferHighWatermarkTest) { - Http::MockAsyncClientStream http_stream; - EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); - EXPECT_CALL(http_stream, sendHeaders(_, _)); - EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(true)); - EXPECT_CALL(http_stream, reset()); - - DangerousDeprecatedTestTime test_time_; - auto raw_client = std::make_shared(cm_, config_, test_time_.timeSystem()); - AsyncClient client(raw_client); - - NiceMock> callback; - BufferedAsyncClient buffered_client( - 100000, *method_descriptor_, callback, client); - - helloworld::HelloRequest request; - request.set_name("Alice"); - EXPECT_EQ(0, buffered_client.bufferMessage(request).value()); - - EXPECT_EQ(0, buffered_client.sendBufferedMessages().size()); - EXPECT_TRUE(buffered_client.hasActiveStream()); + EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillOnce(Return(true)); + + prepareBufferedClient(kPendingBufferSizeLimit, std::chrono::milliseconds(1000)); + bufferNewMessage(0); + + EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size()); + EXPECT_TRUE(buffered_client_->hasActiveStream()); } } // namespace diff --git a/test/common/grpc/buffered_message_ttl_manager_test.cc b/test/common/grpc/buffered_message_ttl_manager_test.cc new file mode 100644 index 0000000000000..ca1e2804145d2 --- /dev/null +++ b/test/common/grpc/buffered_message_ttl_manager_test.cc @@ -0,0 +1,69 @@ +#include +#include +#include + +#include "source/common/event/dispatcher_impl.h" +#include "source/common/grpc/buffered_message_ttl_manager.h" + +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace Envoy { +namespace Grpc { +namespace { + +class BufferedMessageTtlManagerTest : public testing::Test { +public: + BufferedMessageTtlManagerTest() + : api_(Api::createApiForTest()), dispatcher_(api_->allocateDispatcher("test_thread")) {} + + Api::ApiPtr api_; + Event::DispatcherPtr dispatcher_; + std::shared_ptr ttl_manager_; + std::chrono::milliseconds msec_{1000}; + uint32_t callback_called_counter_ = 0; +}; + +// In this test, we will test the basic TTL Manager behavior, +// making sure that the buffers for identity management are in the proper state after deadline. +TEST_F(BufferedMessageTtlManagerTest, BasicFlow) { + absl::flat_hash_set ids{0}; + ttl_manager_ = std::make_shared( + *dispatcher_, + [this](uint64_t) { + switch (callback_called_counter_) { + case 0: { + EXPECT_EQ(ttl_manager_->deadlineForTest().size(), 1); + absl::flat_hash_set ids{1, 2}; + ttl_manager_->addDeadlineEntry(std::move(ids)); + } break; + case 1: + case 2: + case 3: + EXPECT_EQ(ttl_manager_->deadlineForTest().size(), 1); + break; + default: + break; + } + ++callback_called_counter_; + }, + msec_); + ttl_manager_->addDeadlineEntry(std::move(ids)); + + dispatcher_->run(Event::Dispatcher::RunType::Block); + + // Test if deadline queue is empty after queue cleared once. + EXPECT_EQ(ttl_manager_->deadlineForTest().size(), 0); + + absl::flat_hash_set ids2{3}; + ttl_manager_->addDeadlineEntry(std::move(ids2)); + dispatcher_->run(Event::Dispatcher::RunType::Block); + EXPECT_EQ(callback_called_counter_, 4); + EXPECT_EQ(ttl_manager_->deadlineForTest().size(), 0); +} + +} // namespace +} // namespace Grpc +} // namespace Envoy