Skip to content
Merged
9 changes: 9 additions & 0 deletions source/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,12 @@ envoy_cc_library(
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)

envoy_cc_library(
name = "buffered_async_client_lib",
hdrs = ["buffered_async_client.h"],
deps = [
":typed_async_client_lib",
"//source/common/protobuf:utility_lib",
],
)
128 changes: 128 additions & 0 deletions source/common/grpc/buffered_async_client.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#pragma once

#include <cstdint>

#include "source/common/grpc/typed_async_client.h"
#include "source/common/protobuf/utility.h"

#include "absl/container/btree_map.h"

namespace Envoy {
namespace Grpc {

enum class BufferState { Buffered, PendingFlush };

// This class wraps bidirectional gRPC and provides message arrival guarantee.
// It stores messages to be sent or in the process of being sent in a buffer,
// and can track the status of the message based on the ID assigned to each message.
// If a message fails to be sent, it can be re-buffered to guarantee its arrival.
template <class RequestType, class ResponseType> class BufferedAsyncClient {
public:
BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method,
Grpc::AsyncStreamCallbacks<ResponseType>& callbacks,
const Grpc::AsyncClient<RequestType, ResponseType>& client)
: max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks),
client_(client) {}

~BufferedAsyncClient() {
if (active_stream_ != nullptr) {
active_stream_ = nullptr;
}
}

// It push message into internal message buffer.
// If the buffer is full, it will return absl::nullopt.
absl::optional<uint64_t> bufferMessage(RequestType& message) {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Consider adding a comment that indicates that null_opt means buffer full.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

const auto buffer_size = message.ByteSizeLong();
if (current_buffer_bytes_ + buffer_size > max_buffer_bytes_) {
return absl::nullopt;
}

auto id = publishId();
message_buffer_[id] = std::make_pair(BufferState::Buffered, message);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In #17486 this id is set on the RequestType (CriticalAccessLogsMessage). In conjunction with the previous comment, maybe we should return absl::optional<uint64_t> with the id if buffered or null_opt if the buffer's full? The caller can then add the id to the message (although that kind of change via the message reference feels a little weird). Another possibility is to require that the id is passed obtained via publishId, used to call set_id on the message, and then passed to bufferMessage as a separate parameter.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed it. I agree with to return absl::optional<uint64_t> them the caller set it to the message.

current_buffer_bytes_ += buffer_size;
return id;
}

absl::flat_hash_set<uint64_t> sendBufferedMessages() {
if (active_stream_ == nullptr) {
active_stream_ =
client_.start(service_method_, callbacks_, Http::AsyncClient::StreamOptions());
}

if (active_stream_->isAboveWriteBufferHighWatermark()) {
return {};
}

absl::flat_hash_set<uint64_t> inflight_message_ids;

for (auto&& it : message_buffer_) {
const auto id = it.first;
auto& state = it.second.first;
auto& message = it.second.second;

if (state == BufferState::PendingFlush) {
continue;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a test that covers this case, since it prevents double sending. I noticed it's not hit in the coverage report

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

}

state = BufferState::PendingFlush;
inflight_message_ids.emplace(id);
active_stream_->sendMessage(message, false);
}

return inflight_message_ids;
}

void onSuccess(uint64_t message_id) { erasePendingMessage(message_id); }

void onError(uint64_t message_id) {
if (message_buffer_.find(message_id) == message_buffer_.end()) {
return;
}

message_buffer_.at(message_id).first = BufferState::Buffered;
}

bool hasActiveStream() { return active_stream_ != nullptr; }

const absl::btree_map<uint64_t, std::pair<BufferState, RequestType>>& messageBuffer() {
return message_buffer_;
}

private:
void erasePendingMessage(uint64_t message_id) {
// This case will be considered if `onSuccess` had called with unknown message id that is not
// received by envoy as response.
if (message_buffer_.find(message_id) == message_buffer_.end()) {
return;
}
auto& buffer = message_buffer_.at(message_id);

// There may be cases where the buffer status is not PendingFlush when
// this function is called. For example, a message_buffer that was
// PendingFlush may become Buffered due to an external state change
// (e.g. re-buffering due to timeout).
if (buffer.first == BufferState::PendingFlush) {
const auto buffer_size = buffer.second.ByteSizeLong();
current_buffer_bytes_ -= buffer_size;
message_buffer_.erase(message_id);
}
}

uint64_t publishId() { return next_message_id_++; }

const uint32_t max_buffer_bytes_ = 0;
const Protobuf::MethodDescriptor& service_method_;
Grpc::AsyncStreamCallbacks<ResponseType>& callbacks_;
Grpc::AsyncClient<RequestType, ResponseType> client_;
Grpc::AsyncStream<RequestType> active_stream_;
absl::btree_map<uint64_t, std::pair<BufferState, RequestType>> message_buffer_;
uint32_t current_buffer_bytes_ = 0;
uint64_t next_message_id_ = 0;
};

template <class RequestType, class ResponseType>
using BufferedAsyncClientPtr = std::unique_ptr<BufferedAsyncClient<RequestType, ResponseType>>;

} // namespace Grpc
} // namespace Envoy
15 changes: 15 additions & 0 deletions test/common/grpc/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,18 @@ envoy_cc_test_library(
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)

envoy_cc_test(
name = "buffered_async_client_test",
srcs = ["buffered_async_client_test.cc"],
deps = [
"//source/common/grpc:async_client_lib",
"//source/common/grpc:buffered_async_client_lib",
"//test/mocks/http:http_mocks",
"//test/mocks/tracing:tracing_mocks",
"//test/mocks/upstream:cluster_manager_mocks",
"//test/proto:helloworld_proto_cc_proto",
"//test/test_common:test_time_lib",
"@envoy_api//envoy/config/core/v3:pkg_cc_proto",
],
)
141 changes: 141 additions & 0 deletions test/common/grpc/buffered_async_client_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
#include "envoy/config/core/v3/grpc_service.pb.h"

#include "source/common/grpc/async_client_impl.h"
#include "source/common/grpc/buffered_async_client.h"
#include "source/common/network/address_impl.h"
#include "source/common/network/socket_impl.h"

#include "test/mocks/http/mocks.h"
#include "test/mocks/tracing/mocks.h"
#include "test/mocks/upstream/cluster_manager.h"
#include "test/proto/helloworld.pb.h"
#include "test/test_common/test_time.h"

#include "gmock/gmock.h"
#include "gtest/gtest.h"

using testing::_;
using testing::NiceMock;
using testing::Return;
using testing::ReturnRef;

namespace Envoy {
namespace Grpc {
namespace {

class BufferedAsyncClientTest : public testing::Test {
public:
BufferedAsyncClientTest()
: method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) {
config_.mutable_envoy_grpc()->set_cluster_name("test_cluster");

cm_.initializeThreadLocalClusters({"test_cluster"});
ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_));
}

const Protobuf::MethodDescriptor* method_descriptor_;
envoy::config::core::v3::GrpcService config_;
NiceMock<Upstream::MockClusterManager> cm_;
NiceMock<Http::MockAsyncClient> http_client_;
};

TEST_F(BufferedAsyncClientTest, BasicSendFlow) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, sendData(_, _));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(0, buffered_client.bufferMessage(request).value());
const auto inflight_message_ids = buffered_client.sendBufferedMessages();
EXPECT_TRUE(buffered_client.hasActiveStream());
EXPECT_EQ(1, inflight_message_ids.size());

// Pending messages should not be re-sent.
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
const auto inflight_message_ids2 = buffered_client.sendBufferedMessages();
EXPECT_EQ(0, inflight_message_ids2.size());

// Re-buffer, and transport.
buffered_client.onError(*inflight_message_ids.begin());

EXPECT_CALL(http_stream, sendData(_, _)).Times(2);
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));

helloworld::HelloRequest request2;
request2.set_name("Bob");
EXPECT_EQ(1, buffered_client.bufferMessage(request2).value());
auto ids2 = buffered_client.sendBufferedMessages();
EXPECT_EQ(2, ids2.size());

// Clear existing messages.
for (auto&& id : ids2) {
buffered_client.onSuccess(id);
}

// Successfully cleared pending messages.
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
auto ids3 = buffered_client.sendBufferedMessages();
EXPECT_EQ(0, ids3.size());
}

TEST_F(BufferedAsyncClientTest, BufferLimitExceeded) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
0, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(absl::nullopt, buffered_client.bufferMessage(request));

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
EXPECT_TRUE(buffered_client.hasActiveStream());
}

TEST_F(BufferedAsyncClientTest, BufferHighWatermarkTest) {
Http::MockAsyncClientStream http_stream;
EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream));
EXPECT_CALL(http_stream, sendHeaders(_, _));
EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(true));
EXPECT_CALL(http_stream, reset());

DangerousDeprecatedTestTime test_time_;
auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem());
AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client);

NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback;
BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client(
100000, *method_descriptor_, callback, client);

helloworld::HelloRequest request;
request.set_name("Alice");
EXPECT_EQ(0, buffered_client.bufferMessage(request).value());

EXPECT_EQ(0, buffered_client.sendBufferedMessages().size());
EXPECT_TRUE(buffered_client.hasActiveStream());
}

} // namespace
} // namespace Grpc
} // namespace Envoy