-
Notifications
You must be signed in to change notification settings - Fork 5.3k
grpc: implement BufferedAsyncClient for bidirectional gRPC stream #18129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a400410
4885e81
82bf7c3
84073b0
469a5ca
8bff3c4
c43a128
680eff8
cad140b
952122e
f00e891
10ec8c5
4acfe2d
a473666
a7df1b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,128 @@ | ||
| #pragma once | ||
|
|
||
| #include <cstdint> | ||
|
|
||
| #include "source/common/grpc/typed_async_client.h" | ||
| #include "source/common/protobuf/utility.h" | ||
|
|
||
| #include "absl/container/btree_map.h" | ||
|
|
||
| namespace Envoy { | ||
| namespace Grpc { | ||
|
|
||
| enum class BufferState { Buffered, PendingFlush }; | ||
|
|
||
| // This class wraps bidirectional gRPC and provides message arrival guarantee. | ||
| // It stores messages to be sent or in the process of being sent in a buffer, | ||
| // and can track the status of the message based on the ID assigned to each message. | ||
| // If a message fails to be sent, it can be re-buffered to guarantee its arrival. | ||
| template <class RequestType, class ResponseType> class BufferedAsyncClient { | ||
| public: | ||
| BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method, | ||
| Grpc::AsyncStreamCallbacks<ResponseType>& callbacks, | ||
| const Grpc::AsyncClient<RequestType, ResponseType>& client) | ||
| : max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks), | ||
| client_(client) {} | ||
|
|
||
| ~BufferedAsyncClient() { | ||
| if (active_stream_ != nullptr) { | ||
| active_stream_ = nullptr; | ||
| } | ||
| } | ||
|
|
||
| // It push message into internal message buffer. | ||
| // If the buffer is full, it will return absl::nullopt. | ||
| absl::optional<uint64_t> bufferMessage(RequestType& message) { | ||
| const auto buffer_size = message.ByteSizeLong(); | ||
| if (current_buffer_bytes_ + buffer_size > max_buffer_bytes_) { | ||
| return absl::nullopt; | ||
| } | ||
|
|
||
| auto id = publishId(); | ||
| message_buffer_[id] = std::make_pair(BufferState::Buffered, message); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In #17486 this id is set on the RequestType (CriticalAccessLogsMessage). In conjunction with the previous comment, maybe we should return
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed it. I agree with to return |
||
| current_buffer_bytes_ += buffer_size; | ||
| return id; | ||
| } | ||
|
|
||
| absl::flat_hash_set<uint64_t> sendBufferedMessages() { | ||
| if (active_stream_ == nullptr) { | ||
| active_stream_ = | ||
| client_.start(service_method_, callbacks_, Http::AsyncClient::StreamOptions()); | ||
| } | ||
|
|
||
| if (active_stream_->isAboveWriteBufferHighWatermark()) { | ||
| return {}; | ||
| } | ||
|
|
||
| absl::flat_hash_set<uint64_t> inflight_message_ids; | ||
|
|
||
| for (auto&& it : message_buffer_) { | ||
| const auto id = it.first; | ||
| auto& state = it.second.first; | ||
| auto& message = it.second.second; | ||
|
|
||
| if (state == BufferState::PendingFlush) { | ||
| continue; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add a test that covers this case, since it prevents double sending. I noticed it's not hit in the coverage report
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
| } | ||
|
|
||
| state = BufferState::PendingFlush; | ||
| inflight_message_ids.emplace(id); | ||
| active_stream_->sendMessage(message, false); | ||
| } | ||
Shikugawa marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| return inflight_message_ids; | ||
| } | ||
|
|
||
| void onSuccess(uint64_t message_id) { erasePendingMessage(message_id); } | ||
|
|
||
| void onError(uint64_t message_id) { | ||
| if (message_buffer_.find(message_id) == message_buffer_.end()) { | ||
| return; | ||
| } | ||
|
|
||
| message_buffer_.at(message_id).first = BufferState::Buffered; | ||
| } | ||
|
|
||
| bool hasActiveStream() { return active_stream_ != nullptr; } | ||
|
|
||
| const absl::btree_map<uint64_t, std::pair<BufferState, RequestType>>& messageBuffer() { | ||
| return message_buffer_; | ||
| } | ||
|
|
||
| private: | ||
| void erasePendingMessage(uint64_t message_id) { | ||
| // This case will be considered if `onSuccess` had called with unknown message id that is not | ||
| // received by envoy as response. | ||
| if (message_buffer_.find(message_id) == message_buffer_.end()) { | ||
| return; | ||
| } | ||
| auto& buffer = message_buffer_.at(message_id); | ||
|
|
||
| // There may be cases where the buffer status is not PendingFlush when | ||
| // this function is called. For example, a message_buffer that was | ||
| // PendingFlush may become Buffered due to an external state change | ||
| // (e.g. re-buffering due to timeout). | ||
| if (buffer.first == BufferState::PendingFlush) { | ||
| const auto buffer_size = buffer.second.ByteSizeLong(); | ||
| current_buffer_bytes_ -= buffer_size; | ||
| message_buffer_.erase(message_id); | ||
| } | ||
| } | ||
|
|
||
| uint64_t publishId() { return next_message_id_++; } | ||
|
|
||
| const uint32_t max_buffer_bytes_ = 0; | ||
| const Protobuf::MethodDescriptor& service_method_; | ||
| Grpc::AsyncStreamCallbacks<ResponseType>& callbacks_; | ||
| Grpc::AsyncClient<RequestType, ResponseType> client_; | ||
| Grpc::AsyncStream<RequestType> active_stream_; | ||
| absl::btree_map<uint64_t, std::pair<BufferState, RequestType>> message_buffer_; | ||
| uint32_t current_buffer_bytes_ = 0; | ||
| uint64_t next_message_id_ = 0; | ||
| }; | ||
|
|
||
| template <class RequestType, class ResponseType> | ||
| using BufferedAsyncClientPtr = std::unique_ptr<BufferedAsyncClient<RequestType, ResponseType>>; | ||
|
|
||
| } // namespace Grpc | ||
| } // namespace Envoy | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,141 @@ | ||
| #include "envoy/config/core/v3/grpc_service.pb.h" | ||
|
|
||
| #include "source/common/grpc/async_client_impl.h" | ||
| #include "source/common/grpc/buffered_async_client.h" | ||
| #include "source/common/network/address_impl.h" | ||
| #include "source/common/network/socket_impl.h" | ||
|
|
||
| #include "test/mocks/http/mocks.h" | ||
| #include "test/mocks/tracing/mocks.h" | ||
| #include "test/mocks/upstream/cluster_manager.h" | ||
| #include "test/proto/helloworld.pb.h" | ||
| #include "test/test_common/test_time.h" | ||
|
|
||
| #include "gmock/gmock.h" | ||
| #include "gtest/gtest.h" | ||
|
|
||
| using testing::_; | ||
| using testing::NiceMock; | ||
| using testing::Return; | ||
| using testing::ReturnRef; | ||
|
|
||
| namespace Envoy { | ||
| namespace Grpc { | ||
| namespace { | ||
|
|
||
| class BufferedAsyncClientTest : public testing::Test { | ||
| public: | ||
| BufferedAsyncClientTest() | ||
| : method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) { | ||
| config_.mutable_envoy_grpc()->set_cluster_name("test_cluster"); | ||
|
|
||
| cm_.initializeThreadLocalClusters({"test_cluster"}); | ||
| ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_)); | ||
| } | ||
|
|
||
| const Protobuf::MethodDescriptor* method_descriptor_; | ||
| envoy::config::core::v3::GrpcService config_; | ||
| NiceMock<Upstream::MockClusterManager> cm_; | ||
| NiceMock<Http::MockAsyncClient> http_client_; | ||
| }; | ||
|
|
||
| TEST_F(BufferedAsyncClientTest, BasicSendFlow) { | ||
| Http::MockAsyncClientStream http_stream; | ||
| EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); | ||
| EXPECT_CALL(http_stream, sendHeaders(_, _)); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
| EXPECT_CALL(http_stream, sendData(_, _)); | ||
| EXPECT_CALL(http_stream, reset()); | ||
|
|
||
| DangerousDeprecatedTestTime test_time_; | ||
| auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem()); | ||
| AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client); | ||
|
|
||
| NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback; | ||
| BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client( | ||
| 100000, *method_descriptor_, callback, client); | ||
|
|
||
| helloworld::HelloRequest request; | ||
| request.set_name("Alice"); | ||
| EXPECT_EQ(0, buffered_client.bufferMessage(request).value()); | ||
| const auto inflight_message_ids = buffered_client.sendBufferedMessages(); | ||
| EXPECT_TRUE(buffered_client.hasActiveStream()); | ||
| EXPECT_EQ(1, inflight_message_ids.size()); | ||
|
|
||
| // Pending messages should not be re-sent. | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
| const auto inflight_message_ids2 = buffered_client.sendBufferedMessages(); | ||
| EXPECT_EQ(0, inflight_message_ids2.size()); | ||
|
|
||
| // Re-buffer, and transport. | ||
| buffered_client.onError(*inflight_message_ids.begin()); | ||
|
|
||
| EXPECT_CALL(http_stream, sendData(_, _)).Times(2); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
|
|
||
| helloworld::HelloRequest request2; | ||
| request2.set_name("Bob"); | ||
| EXPECT_EQ(1, buffered_client.bufferMessage(request2).value()); | ||
| auto ids2 = buffered_client.sendBufferedMessages(); | ||
| EXPECT_EQ(2, ids2.size()); | ||
|
|
||
| // Clear existing messages. | ||
| for (auto&& id : ids2) { | ||
| buffered_client.onSuccess(id); | ||
| } | ||
|
|
||
| // Successfully cleared pending messages. | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
| auto ids3 = buffered_client.sendBufferedMessages(); | ||
| EXPECT_EQ(0, ids3.size()); | ||
| } | ||
|
|
||
| TEST_F(BufferedAsyncClientTest, BufferLimitExceeded) { | ||
| Http::MockAsyncClientStream http_stream; | ||
| EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); | ||
| EXPECT_CALL(http_stream, sendHeaders(_, _)); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
| EXPECT_CALL(http_stream, reset()); | ||
|
|
||
| DangerousDeprecatedTestTime test_time_; | ||
| auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem()); | ||
| AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client); | ||
|
|
||
| NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback; | ||
| BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client( | ||
| 0, *method_descriptor_, callback, client); | ||
|
|
||
| helloworld::HelloRequest request; | ||
| request.set_name("Alice"); | ||
| EXPECT_EQ(absl::nullopt, buffered_client.bufferMessage(request)); | ||
|
|
||
| EXPECT_EQ(0, buffered_client.sendBufferedMessages().size()); | ||
| EXPECT_TRUE(buffered_client.hasActiveStream()); | ||
| } | ||
|
|
||
| TEST_F(BufferedAsyncClientTest, BufferHighWatermarkTest) { | ||
| Http::MockAsyncClientStream http_stream; | ||
| EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); | ||
| EXPECT_CALL(http_stream, sendHeaders(_, _)); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(true)); | ||
| EXPECT_CALL(http_stream, reset()); | ||
|
|
||
| DangerousDeprecatedTestTime test_time_; | ||
| auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem()); | ||
| AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client); | ||
|
|
||
| NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback; | ||
| BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client( | ||
| 100000, *method_descriptor_, callback, client); | ||
|
|
||
| helloworld::HelloRequest request; | ||
| request.set_name("Alice"); | ||
| EXPECT_EQ(0, buffered_client.bufferMessage(request).value()); | ||
|
|
||
| EXPECT_EQ(0, buffered_client.sendBufferedMessages().size()); | ||
| EXPECT_TRUE(buffered_client.hasActiveStream()); | ||
| } | ||
|
|
||
| } // namespace | ||
| } // namespace Grpc | ||
| } // namespace Envoy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Consider adding a comment that indicates that null_opt means buffer full.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added