Skip to content
Merged
9 changes: 9 additions & 0 deletions source/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,12 @@ envoy_cc_library(
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)

envoy_cc_library(
name = "buffered_async_client_lib",
hdrs = ["buffered_async_client.h"],
deps = [
":typed_async_client_lib",
"//source/common/protobuf:utility_lib",
],
)
114 changes: 114 additions & 0 deletions source/common/grpc/buffered_async_client.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#pragma once

#include "source/common/grpc/typed_async_client.h"
#include "source/common/protobuf/utility.h"

namespace Envoy {
namespace Grpc {

enum class BufferState { Buffered, PendingFlush };

template <class RequestType, class ResponseType> class BufferedAsyncClient {
public:
BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method,
Grpc::AsyncStreamCallbacks<ResponseType>& callbacks,
const Grpc::AsyncClient<RequestType, ResponseType>& client)
: max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks),
client_(client) {}

virtual ~BufferedAsyncClient() { cleanup(); }
Comment thread
Shikugawa marked this conversation as resolved.
Outdated

uint32_t publishId(RequestType& message) { return MessageUtil::hash(message); }
Comment thread
Shikugawa marked this conversation as resolved.
Outdated

void bufferMessage(uint32_t id, RequestType& message) {
const auto buffer_size = message.ByteSizeLong();
if (current_buffer_bytes_ + buffer_size > max_buffer_bytes_) {
return;
Comment thread
Shikugawa marked this conversation as resolved.
Outdated
}

message_buffer_[id] = std::make_pair(BufferState::Buffered, message);

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #17486 this id is set on the RequestType (CriticalAccessLogsMessage). In conjunction with the previous comment, maybe we should return absl::optional<uint64_t> with the id if buffered or null_opt if the buffer's full? The caller can then add the id to the message (although that kind of change via the message reference feels a little weird). Another possibility is to require that the id is passed obtained via publishId, used to call set_id on the message, and then passed to bufferMessage as a separate parameter.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it. I agree with to return absl::optional<uint64_t> them the caller set it to the message.

current_buffer_bytes_ += buffer_size;
}

absl::flat_hash_set<uint32_t> sendBufferedMessages() {
if (active_stream_ == nullptr) {
active_stream_ =
client_.start(service_method_, callbacks_, Http::AsyncClient::StreamOptions());
}

if (active_stream_->isAboveWriteBufferHighWatermark()) {
return {};
}

absl::flat_hash_set<uint32_t> inflight_message_ids;

for (auto&& it : message_buffer_) {
const auto id = it.first;
auto& state = it.second.first;
auto& message = it.second.second;

if (state == BufferState::PendingFlush) {
continue;

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a test that covers this case, since it prevents double sending. I noticed it's not hit in the coverage report

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

}

state = BufferState::PendingFlush;
inflight_message_ids.emplace(id);
active_stream_->sendMessage(message, false);
}
Comment thread
Shikugawa marked this conversation as resolved.

return inflight_message_ids;
}

void onSuccess(uint32_t message_id) { erasePendingMessage(message_id); }

void onError(uint32_t message_id) {
if (message_buffer_.find(message_id) == message_buffer_.end()) {
return;
}
message_buffer_.at(message_id).first = BufferState::Buffered;
}

void cleanup() {
if (active_stream_ != nullptr) {
active_stream_ = nullptr;
}
}

bool hasActiveStream() { return active_stream_ != nullptr; }

const absl::flat_hash_map<uint32_t, std::pair<BufferState, RequestType>>& messageBuffer() {
return message_buffer_;
}

private:
void erasePendingMessage(uint32_t message_id) {
if (message_buffer_.find(message_id) == message_buffer_.end()) {
return;
}
auto& buffer = message_buffer_.at(message_id);

// There may be cases where the buffer status is not PendingFlush when
// this function is called. For example, a message_buffer that was
// PendingFlush may become Buffered due to an external state change
// (e.g. re-buffering due to timeout).
if (buffer.first == BufferState::PendingFlush) {
const auto buffer_size = buffer.second.ByteSizeLong();
current_buffer_bytes_ -= buffer_size;
message_buffer_.erase(message_id);
}
}

uint32_t max_buffer_bytes_ = 0;
Comment thread
Shikugawa marked this conversation as resolved.
Outdated
const Protobuf::MethodDescriptor& service_method_;
Grpc::AsyncStreamCallbacks<ResponseType>& callbacks_;
Grpc::AsyncClient<RequestType, ResponseType> client_;
Grpc::AsyncStream<RequestType> active_stream_;
absl::flat_hash_map<uint32_t, std::pair<BufferState, RequestType>> message_buffer_;
uint32_t current_buffer_bytes_ = 0;
};

template <class RequestType, class ResponseType>
using BufferedAsyncClientPtr = std::unique_ptr<BufferedAsyncClient<RequestType, ResponseType>>;

} // namespace Grpc
} // namespace Envoy
15 changes: 15 additions & 0 deletions test/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,18 @@ envoy_cc_test_library(
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)

envoy_cc_test(
name = "buffered_async_client_test",
srcs = ["buffered_async_client_test.cc"],
deps = [
"//source/common/grpc:async_client_lib",
"//source/common/grpc:buffered_async_client_lib",
"//test/mocks/http:http_mocks",
"//test/mocks/tracing:tracing_mocks",
"//test/mocks/upstream:cluster_manager_mocks",
"//test/proto:helloworld_proto_cc_proto",
"//test/test_common:test_time_lib",
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)
136 changes: 136 additions & 0 deletions test/common/grpc/buffered_async_client_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include "envoy/config/core/v3/grpc_service.pb.h"

#include "source/common/grpc/async_client_impl.h"
#include "source/common/grpc/buffered_async_client.h"
#include "source/common/network/address_impl.h"
#include "source/common/network/socket_impl.h"

#include "test/mocks/http/mocks.h"
#include "test/mocks/tracing/mocks.h"
#include "test/mocks/upstream/cluster_manager.h"
#include "test/proto/helloworld.pb.h"
#include "test/test_common/test_time.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"

using testing::_;
using testing::NiceMock;
using testing::Return;
using testing::ReturnRef;

namespace Envoy {
namespace Grpc {
namespace {

class BufferedAsyncClientTest : public testing::Test {
public:
BufferedAsyncClientTest()
: method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) {
config_.mutable_envoy_grpc()->set_cluster_name("test_cluster");

cm_.initializeThreadLocalClusters({"test_cluster"});
ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_));
}

const Protobuf::MethodDescriptor* method_descriptor_;
envoy::config::core::v3::GrpcService config_;
NiceMock<Upstream::MockClusterManager> cm_;
NiceMock<Http::MockAsyncClient> http_client_;
};

TEST_F(BufferedAsyncClientTest, BasicSendFlow) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, sendData(_, _));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
auto id = buffered_client.publishId(request);
buffered_client.bufferMessage(id, request);
EXPECT_EQ(1, buffered_client.sendBufferedMessages().size());

// Re-buffer, and transport.
buffered_client.onError(id);

EXPECT_CALL(http_stream, sendData(_, _)).Times(2);
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));

helloworld::HelloRequest request2;
request2.set_name("Bob");
auto id2 = buffered_client.publishId(request2);
buffered_client.bufferMessage(id2, request2);
auto ids2 = buffered_client.sendBufferedMessages();
EXPECT_EQ(2, ids2.size());

// Clear existing messages.
for (auto&& id : ids2) {
buffered_client.onSuccess(id);
}

// Successfully cleared pending messages.
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
auto ids3 = buffered_client.sendBufferedMessages();
EXPECT_EQ(0, ids3.size());
}

TEST_F(BufferedAsyncClientTest, BufferLimitExceeded) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
0, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
auto id = buffered_client.publishId(request);
buffered_client.bufferMessage(id, request);

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
}

TEST_F(BufferedAsyncClientTest, BufferHighWatermarkTest) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(true));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
auto id = buffered_client.publishId(request);
buffered_client.bufferMessage(id, request);

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
}

} // namespace
} // namespace Grpc
} // namespace Envoy