Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions source/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,17 @@ envoy_cc_library(
name = "buffered_async_client_lib",
hdrs = ["buffered_async_client.h"],
deps = [
":buffered_message_ttl_manager_lib",
":typed_async_client_lib",
"//source/common/protobuf:utility_lib",
"@com_google_absl//absl/container:btree",
],
)

envoy_cc_library(
name = "buffered_message_ttl_manager_lib",
hdrs = ["buffered_message_ttl_manager.h"],
deps = [
"//envoy/event:dispatcher_interface",
],
)
16 changes: 13 additions & 3 deletions source/common/grpc/buffered_async_client.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#pragma once

#include <chrono>
#include <cstdint>

#include "source/common/grpc/buffered_message_ttl_manager.h"
#include "source/common/grpc/typed_async_client.h"
#include "source/common/protobuf/utility.h"

Expand All @@ -20,9 +22,12 @@ template <class RequestType, class ResponseType> class BufferedAsyncClient {
public:
BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method,
Grpc::AsyncStreamCallbacks<ResponseType>& callbacks,
const Grpc::AsyncClient<RequestType, ResponseType>& client)
const Grpc::AsyncClient<RequestType, ResponseType>& client,
Event::Dispatcher& dispatcher, std::chrono::milliseconds message_timeout_msec)
: max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks),
client_(client) {}
client_(client),
ttl_manager_(
dispatcher, [this](uint64_t id) { onError(id); }, message_timeout_msec) {}

~BufferedAsyncClient() {
if (active_stream_ != nullptr) {
Expand Down Expand Up @@ -70,13 +75,17 @@ template <class RequestType, class ResponseType> class BufferedAsyncClient {
active_stream_->sendMessage(message, false);
}

ttl_manager_.addDeadlineEntry(inflight_message_ids);
return inflight_message_ids;
}

void onSuccess(uint64_t message_id) { erasePendingMessage(message_id); }

void onError(uint64_t message_id) {
if (message_buffer_.find(message_id) == message_buffer_.end()) {
const auto& message_it = message_buffer_.find(message_id);

if (message_it == message_buffer_.end() ||
message_it->second.first != Grpc::BufferState::PendingFlush) {
return;
}

Expand Down Expand Up @@ -119,6 +128,7 @@ template <class RequestType, class ResponseType> class BufferedAsyncClient {
absl::btree_map<uint64_t, std::pair<BufferState, RequestType>> message_buffer_;
uint32_t current_buffer_bytes_ = 0;
uint64_t next_message_id_ = 0;
BufferedMessageTtlManager ttl_manager_;
};

template <class RequestType, class ResponseType>
Expand Down
76 changes: 76 additions & 0 deletions source/common/grpc/buffered_message_ttl_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#pragma once

#include <queue>

#include "envoy/event/dispatcher.h"

namespace Envoy {
namespace Grpc {

using BufferedMessageExpirationCallback = std::function<void(uint64_t)>;

// This class is used to manage the lifetime of messages stored in BufferedAsyncClient. Messages
// whose survival period has expired will be deleted from the Buffer.
// We can set the ID you want to monitor TTL in the TTL Manager,
// which will set up a callback to check for expiration.
Comment thread
Shikugawa marked this conversation as resolved.
Outdated
// You can set the ID even after the callback has been invoked.
Comment thread
Shikugawa marked this conversation as resolved.
Outdated
// When the callback is invoked, the callback given to the constructor will be
// executed with the TTL-elapsed ID as an argument.
Comment thread
Shikugawa marked this conversation as resolved.
Outdated
// After that, if the ID to be monitored is not empty, the callback for expiration check will be set
Comment thread
Shikugawa marked this conversation as resolved.
Outdated
// again. The TTL Manager can be given a set of IDs that are expected to expire at the same time.
// When checking for ID expiration, an expiration callback will be called for each ID
// belonging to this set of IDs.
Comment thread
Shikugawa marked this conversation as resolved.
Outdated
class BufferedMessageTtlManager {
Comment thread
Shikugawa marked this conversation as resolved.
public:
BufferedMessageTtlManager(Event::Dispatcher& dispatcher,
BufferedMessageExpirationCallback&& expiry_callback,
std::chrono::milliseconds message_ack_timeout)
: dispatcher_(dispatcher), message_ack_timeout_(message_ack_timeout),
expiry_callback_(expiry_callback),
timer_(dispatcher_.createTimer([this] { checkExpiredMessages(); })) {}

~BufferedMessageTtlManager() { timer_->disableTimer(); }

void addDeadlineEntry(const absl::flat_hash_set<uint64_t>& ids) {
const auto expires_at = dispatcher_.timeSource().monotonicTime() + message_ack_timeout_;
deadline_.emplace(expires_at, std::move(ids));

if (!timer_->enabled()) {
timer_->enableTimer(message_ack_timeout_);
}
}

const std::queue<std::pair<MonotonicTime, absl::flat_hash_set<uint64_t>>>& deadlineForTest() {
return deadline_;
}

private:
void checkExpiredMessages() {
const auto now = dispatcher_.timeSource().monotonicTime();

while (!deadline_.empty()) {
auto& it = deadline_.front();
if (it.first > now) {
break;
}
for (auto&& id : it.second) {
expiry_callback_(id);
}
deadline_.pop();
}

if (!deadline_.empty()) {
const auto earliest_timepoint = deadline_.front().first;
timer_->enableTimer(
std::chrono::duration_cast<std::chrono::milliseconds>(earliest_timepoint - now));
}
}

Event::Dispatcher& dispatcher_;
std::chrono::milliseconds message_ack_timeout_;
BufferedMessageExpirationCallback expiry_callback_;
Event::TimerPtr timer_;
std::queue<std::pair<MonotonicTime, absl::flat_hash_set<uint64_t>>> deadline_;
};
} // namespace Grpc
} // namespace Envoy
10 changes: 10 additions & 0 deletions test/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,13 @@ envoy_cc_test(
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)

envoy_cc_test(
name = "buffered_message_ttl_manager_test",
srcs = ["buffered_message_ttl_manager_test.cc"],
deps = [
"//source/common/event:dispatcher_lib",
"//source/common/grpc:buffered_message_ttl_manager_lib",
"//test/test_common:utility_lib",
],
)
189 changes: 110 additions & 79 deletions test/common/grpc/buffered_async_client_test.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#include <chrono>
#include <cstdint>
#include <memory>

#include "envoy/config/core/v3/grpc_service.pb.h"

#include "source/common/grpc/async_client_impl.h"
Expand All @@ -23,117 +27,144 @@ namespace Envoy {
namespace Grpc {
namespace {

constexpr uint32_t kPendingBufferSizeLimit = 1 << 16;

class BufferedAsyncClientTest : public testing::Test {
public:
BufferedAsyncClientTest()
: method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) {
: api_(Api::createApiForTest()), dispatcher_(api_->allocateDispatcher("test_thread")),
method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) {
config_.mutable_envoy_grpc()->set_cluster_name("test_cluster");

cm_.initializeThreadLocalClusters({"test_cluster"});
ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_));
}

void SetUp() override {
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream_));
EXPECT_CALL(http_stream_, sendHeaders(_, _));
EXPECT_CALL(http_stream_, reset());

raw_client_ = std::make_shared<AsyncClientImpl>(cm_, config_, dispatcher_->timeSource());
client_ = std::make_unique<AsyncClient<helloworld::HelloRequest, helloworld::HelloReply>>(
raw_client_);
}

void prepareBufferedClient(uint32_t buffer_size, std::chrono::milliseconds ttl) {
buffered_client_ =
std::make_unique<BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply>>(
buffer_size, *method_descriptor_, callback_, *client_, *dispatcher_, ttl);
}

void bufferNewMessage(absl::optional<uint32_t> expected_message_id) {
helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(expected_message_id, buffered_client_->bufferMessage(request));
}

void validateBuffer(uint32_t expected_buffered_count, uint32_t expected_pending_count) {
const auto buffer = buffered_client_->messageBuffer();
uint32_t buffered_count = 0;
uint32_t pending_count = 0;

for (const auto& message : buffer) {
switch (message.second.first) {
case BufferState::Buffered:
++buffered_count;
break;
case BufferState::PendingFlush:
++pending_count;
break;
default:
break;
}
}

EXPECT_EQ(buffered_count, expected_buffered_count);
EXPECT_EQ(pending_count, expected_pending_count);
}

Api::ApiPtr api_;
Event::DispatcherPtr dispatcher_;
const Protobuf::MethodDescriptor* method_descriptor_;
envoy::config::core::v3::GrpcService config_;
NiceMock<Upstream::MockClusterManager> cm_;
NiceMock<Http::MockAsyncClient> http_client_;
Http::MockAsyncClientStream http_stream_;
std::shared_ptr<AsyncClientImpl> raw_client_;
std::unique_ptr<AsyncClient<helloworld::HelloRequest, helloworld::HelloReply>> client_;
NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback_;
std::unique_ptr<BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply>>
buffered_client_;
};

TEST_F(BufferedAsyncClientTest, BasicSendFlow) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, sendData(_, _));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(0, buffered_client.bufferMessage(request).value());
const auto inflight_message_ids = buffered_client.sendBufferedMessages();
EXPECT_TRUE(buffered_client.hasActiveStream());
EXPECT_EQ(1, inflight_message_ids.size());
EXPECT_CALL(http_stream_, sendData(_, _));
EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillRepeatedly(Return(false));

prepareBufferedClient(kPendingBufferSizeLimit, std::chrono::milliseconds(1000));
bufferNewMessage(0);
validateBuffer(1, 0);

EXPECT_EQ(1, buffered_client_->sendBufferedMessages().size());
EXPECT_TRUE(buffered_client_->hasActiveStream());
validateBuffer(0, 1);

// Pending messages should not be re-sent.
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
const auto inflight_message_ids2 = buffered_client.sendBufferedMessages();
EXPECT_EQ(0, inflight_message_ids2.size());
EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
validateBuffer(0, 1);

// Re-buffer, and transport.
buffered_client.onError(*inflight_message_ids.begin());
// It will call onError().
dispatcher_->run(Event::Dispatcher::RunType::Block);
validateBuffer(1, 0);

EXPECT_CALL(http_stream, sendData(_, _)).Times(2);
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
// If we call onSuccess(), after onError() called,
// onSuccess() do not affect to the buffer.
buffered_client_->onSuccess(0);
validateBuffer(1, 0);

helloworld::HelloRequest request2;
request2.set_name("Bob");
EXPECT_EQ(1, buffered_client.bufferMessage(request2).value());
auto ids2 = buffered_client.sendBufferedMessages();
EXPECT_EQ(2, ids2.size());
EXPECT_CALL(http_stream_, sendData(_, _)).Times(2);
bufferNewMessage(1);
validateBuffer(2, 0);

const auto inflight_message_ids = buffered_client_->sendBufferedMessages();
EXPECT_EQ(2, inflight_message_ids.size());
validateBuffer(0, 2);

// Clear existing messages.
for (auto&& id : ids2) {
buffered_client.onSuccess(id);
for (auto&& id : inflight_message_ids) {
buffered_client_->onSuccess(id);
}
validateBuffer(0, 0);

// It will call onError().
// But messages have been already cleared.
dispatcher_->run(Event::Dispatcher::RunType::Block);
validateBuffer(0, 0);

// Successfully cleared pending messages.
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
auto ids3 = buffered_client.sendBufferedMessages();
EXPECT_EQ(0, ids3.size());
EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
validateBuffer(0, 0);
}

TEST_F(BufferedAsyncClientTest, BufferLimitExceeded) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
0, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(absl::nullopt, buffered_client.bufferMessage(request));

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
EXPECT_TRUE(buffered_client.hasActiveStream());
EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));

prepareBufferedClient(0, std::chrono::milliseconds(1000));
bufferNewMessage(absl::nullopt);

EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
EXPECT_TRUE(buffered_client_->hasActiveStream());
}

TEST_F(BufferedAsyncClientTest, BufferHighWatermarkTest) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(true));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(0, buffered_client.bufferMessage(request).value());

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
EXPECT_TRUE(buffered_client.hasActiveStream());
EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillOnce(Return(true));

prepareBufferedClient(kPendingBufferSizeLimit, std::chrono::milliseconds(1000));
bufferNewMessage(0);

EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
EXPECT_TRUE(buffered_client_->hasActiveStream());
}

} // namespace
Expand Down
Loading