Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions source/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,17 @@ envoy_cc_library(
name = "buffered_async_client_lib",
hdrs = ["buffered_async_client.h"],
deps = [
":buffered_message_ttl_manager_lib",
":typed_async_client_lib",
"//source/common/protobuf:utility_lib",
"@com_google_absl//absl/container:btree",
],
)

envoy_cc_library(
name = "buffered_message_ttl_manager_lib",
hdrs = ["buffered_message_ttl_manager.h"],
deps = [
"//envoy/event:dispatcher_interface",
],
)
16 changes: 13 additions & 3 deletions source/common/grpc/buffered_async_client.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#pragma once

#include <chrono>
#include <cstdint>

#include "source/common/grpc/buffered_message_ttl_manager.h"
#include "source/common/grpc/typed_async_client.h"
#include "source/common/protobuf/utility.h"

Expand All @@ -20,9 +22,12 @@ template <class RequestType, class ResponseType> class BufferedAsyncClient {
public:
BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method,
Grpc::AsyncStreamCallbacks<ResponseType>& callbacks,
const Grpc::AsyncClient<RequestType, ResponseType>& client)
const Grpc::AsyncClient<RequestType, ResponseType>& client,
Event::Dispatcher& dispatcher, std::chrono::milliseconds message_timeout_msec)
: max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks),
client_(client) {}
client_(client),
ttl_manager_(
dispatcher, [this](uint64_t id) { onError(id); }, message_timeout_msec) {}

~BufferedAsyncClient() {
if (active_stream_ != nullptr) {
Expand Down Expand Up @@ -70,13 +75,17 @@ template <class RequestType, class ResponseType> class BufferedAsyncClient {
active_stream_->sendMessage(message, false);
}

ttl_manager_.addDeadlineEntry(inflight_message_ids);
return inflight_message_ids;
}

void onSuccess(uint64_t message_id) { erasePendingMessage(message_id); }

void onError(uint64_t message_id) {
if (message_buffer_.find(message_id) == message_buffer_.end()) {
const auto& message_it = message_buffer_.find(message_id);

if (message_it == message_buffer_.end() ||
message_it->second.first != Grpc::BufferState::PendingFlush) {
return;
}

Expand Down Expand Up @@ -119,6 +128,7 @@ template <class RequestType, class ResponseType> class BufferedAsyncClient {
absl::btree_map<uint64_t, std::pair<BufferState, RequestType>> message_buffer_;
uint32_t current_buffer_bytes_ = 0;
uint64_t next_message_id_ = 0;
BufferedMessageTtlManager ttl_manager_;
};

template <class RequestType, class ResponseType>
Expand Down
74 changes: 74 additions & 0 deletions source/common/grpc/buffered_message_ttl_manager.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#pragma once

#include <queue>

#include "envoy/event/dispatcher.h"

namespace Envoy {
namespace Grpc {

using BufferedMessageExpirationCallback = std::function<void(uint64_t)>;

// This class is used to manage the TTL for pending uploads within a BufferedAsyncClient. Multiple
// IDs can be inserted into the TTL manager at once, all with the same TTL (specified in the
// constructor). Upon expiry, the TTL manager will invoke the provided expiry callback for each ID.
// Note that there is no way to disable the expiration, and so it's up to the recipient of the
// callback to handle this. BufferedAsyncClient will do the right thing here: if the expired ID is
// still in flight it will be returned to the buffer, otherwise it does nothing. The TTL manager is
// designed to handle multiple sets of IDs inserted at various times, backing this with a single
// Timer. This allows us to track a large amount of IDs inserted at different times without using a
// lot of different timers, which could put undue pressure on the event loop.
class BufferedMessageTtlManager {
public:
BufferedMessageTtlManager(Event::Dispatcher& dispatcher,
BufferedMessageExpirationCallback&& expiry_callback,
std::chrono::milliseconds message_ack_timeout)
: dispatcher_(dispatcher), message_ack_timeout_(message_ack_timeout),
expiry_callback_(expiry_callback),
timer_(dispatcher_.createTimer([this] { checkExpiredMessages(); })) {}

~BufferedMessageTtlManager() { timer_->disableTimer(); }

void addDeadlineEntry(const absl::flat_hash_set<uint64_t>& ids) {
const auto expires_at = dispatcher_.timeSource().monotonicTime() + message_ack_timeout_;
deadline_.emplace(expires_at, std::move(ids));

if (!timer_->enabled()) {
timer_->enableTimer(message_ack_timeout_);
}
}

const std::queue<std::pair<MonotonicTime, absl::flat_hash_set<uint64_t>>>& deadlineForTest() {
return deadline_;
}

private:
void checkExpiredMessages() {
const auto now = dispatcher_.timeSource().monotonicTime();

while (!deadline_.empty()) {
auto& it = deadline_.front();
if (it.first > now) {
break;
}
for (auto&& id : it.second) {
expiry_callback_(id);
}
deadline_.pop();
}

if (!deadline_.empty()) {
const auto earliest_timepoint = deadline_.front().first;
timer_->enableTimer(
std::chrono::duration_cast<std::chrono::milliseconds>(earliest_timepoint - now));
}
}

Event::Dispatcher& dispatcher_;
std::chrono::milliseconds message_ack_timeout_;
BufferedMessageExpirationCallback expiry_callback_;
Event::TimerPtr timer_;
std::queue<std::pair<MonotonicTime, absl::flat_hash_set<uint64_t>>> deadline_;
};
} // namespace Grpc
} // namespace Envoy
10 changes: 10 additions & 0 deletions test/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -202,3 +202,13 @@ envoy_cc_test(
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)

envoy_cc_test(
name = "buffered_message_ttl_manager_test",
srcs = ["buffered_message_ttl_manager_test.cc"],
deps = [
"//source/common/event:dispatcher_lib",
"//source/common/grpc:buffered_message_ttl_manager_lib",
"//test/test_common:utility_lib",
],
)
189 changes: 110 additions & 79 deletions test/common/grpc/buffered_async_client_test.cc
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
#include <chrono>
#include <cstdint>
#include <memory>

#include "envoy/config/core/v3/grpc_service.pb.h"

#include "source/common/grpc/async_client_impl.h"
Expand All @@ -23,117 +27,144 @@ namespace Envoy {
namespace Grpc {
namespace {

constexpr uint32_t kPendingBufferSizeLimit = 1 << 16;

class BufferedAsyncClientTest : public testing::Test {
public:
BufferedAsyncClientTest()
: method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) {
: api_(Api::createApiForTest()), dispatcher_(api_->allocateDispatcher("test_thread")),
method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) {
config_.mutable_envoy_grpc()->set_cluster_name("test_cluster");

cm_.initializeThreadLocalClusters({"test_cluster"});
ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_));
}

void SetUp() override {
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream_));
EXPECT_CALL(http_stream_, sendHeaders(_, _));
EXPECT_CALL(http_stream_, reset());

raw_client_ = std::make_shared<AsyncClientImpl>(cm_, config_, dispatcher_->timeSource());
client_ = std::make_unique<AsyncClient<helloworld::HelloRequest, helloworld::HelloReply>>(
raw_client_);
}

void prepareBufferedClient(uint32_t buffer_size, std::chrono::milliseconds ttl) {
buffered_client_ =
std::make_unique<BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply>>(
buffer_size, *method_descriptor_, callback_, *client_, *dispatcher_, ttl);
}

void bufferNewMessage(absl::optional<uint32_t> expected_message_id) {
helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(expected_message_id, buffered_client_->bufferMessage(request));
}

void validateBuffer(uint32_t expected_buffered_count, uint32_t expected_pending_count) {
const auto buffer = buffered_client_->messageBuffer();
uint32_t buffered_count = 0;
uint32_t pending_count = 0;

for (const auto& message : buffer) {
switch (message.second.first) {
case BufferState::Buffered:
++buffered_count;
break;
case BufferState::PendingFlush:
++pending_count;
break;
default:
break;
}
}

EXPECT_EQ(buffered_count, expected_buffered_count);
EXPECT_EQ(pending_count, expected_pending_count);
}

Api::ApiPtr api_;
Event::DispatcherPtr dispatcher_;
const Protobuf::MethodDescriptor* method_descriptor_;
envoy::config::core::v3::GrpcService config_;
NiceMock<Upstream::MockClusterManager> cm_;
NiceMock<Http::MockAsyncClient> http_client_;
Http::MockAsyncClientStream http_stream_;
std::shared_ptr<AsyncClientImpl> raw_client_;
std::unique_ptr<AsyncClient<helloworld::HelloRequest, helloworld::HelloReply>> client_;
NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback_;
std::unique_ptr<BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply>>
buffered_client_;
};

TEST_F(BufferedAsyncClientTest, BasicSendFlow) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, sendData(_, _));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(0, buffered_client.bufferMessage(request).value());
const auto inflight_message_ids = buffered_client.sendBufferedMessages();
EXPECT_TRUE(buffered_client.hasActiveStream());
EXPECT_EQ(1, inflight_message_ids.size());
EXPECT_CALL(http_stream_, sendData(_, _));
EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillRepeatedly(Return(false));

prepareBufferedClient(kPendingBufferSizeLimit, std::chrono::milliseconds(1000));
bufferNewMessage(0);
validateBuffer(1, 0);

EXPECT_EQ(1, buffered_client_->sendBufferedMessages().size());
EXPECT_TRUE(buffered_client_->hasActiveStream());
validateBuffer(0, 1);

// Pending messages should not be re-sent.
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
const auto inflight_message_ids2 = buffered_client.sendBufferedMessages();
EXPECT_EQ(0, inflight_message_ids2.size());
EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
validateBuffer(0, 1);

// Re-buffer, and transport.
buffered_client.onError(*inflight_message_ids.begin());
// It will call onError().
dispatcher_->run(Event::Dispatcher::RunType::Block);
validateBuffer(1, 0);

EXPECT_CALL(http_stream, sendData(_, _)).Times(2);
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
// If we call onSuccess(), after onError() called,
// onSuccess() do not affect to the buffer.
buffered_client_->onSuccess(0);
validateBuffer(1, 0);

helloworld::HelloRequest request2;
request2.set_name("Bob");
EXPECT_EQ(1, buffered_client.bufferMessage(request2).value());
auto ids2 = buffered_client.sendBufferedMessages();
EXPECT_EQ(2, ids2.size());
EXPECT_CALL(http_stream_, sendData(_, _)).Times(2);
bufferNewMessage(1);
validateBuffer(2, 0);

const auto inflight_message_ids = buffered_client_->sendBufferedMessages();
EXPECT_EQ(2, inflight_message_ids.size());
validateBuffer(0, 2);

// Clear existing messages.
for (auto&& id : ids2) {
buffered_client.onSuccess(id);
for (auto&& id : inflight_message_ids) {
buffered_client_->onSuccess(id);
}
validateBuffer(0, 0);

// It will call onError().
// But messages have been already cleared.
dispatcher_->run(Event::Dispatcher::RunType::Block);
validateBuffer(0, 0);

// Successfully cleared pending messages.
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
auto ids3 = buffered_client.sendBufferedMessages();
EXPECT_EQ(0, ids3.size());
EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
validateBuffer(0, 0);
}

TEST_F(BufferedAsyncClientTest, BufferLimitExceeded) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
0, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(absl::nullopt, buffered_client.bufferMessage(request));

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
EXPECT_TRUE(buffered_client.hasActiveStream());
EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));

prepareBufferedClient(0, std::chrono::milliseconds(1000));
bufferNewMessage(absl::nullopt);

EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
EXPECT_TRUE(buffered_client_->hasActiveStream());
}

TEST_F(BufferedAsyncClientTest, BufferHighWatermarkTest) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(true));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(0, buffered_client.bufferMessage(request).value());

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
EXPECT_TRUE(buffered_client.hasActiveStream());
EXPECT_CALL(http_stream_, isAboveWriteBufferHighWatermark()).WillOnce(Return(true));

prepareBufferedClient(kPendingBufferSizeLimit, std::chrono::milliseconds(1000));
bufferNewMessage(0);

EXPECT_EQ(0, buffered_client_->sendBufferedMessages().size());
EXPECT_TRUE(buffered_client_->hasActiveStream());
}

} // namespace
Expand Down
Loading