-
Notifications
You must be signed in to change notification settings - Fork 5.5k
grpc: implement BufferedAsyncClient for bidirectional gRPC stream #18129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
a400410
4885e81
82bf7c3
84073b0
469a5ca
8bff3c4
c43a128
680eff8
cad140b
952122e
f00e891
10ec8c5
4acfe2d
a473666
a7df1b1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,114 @@ | ||
| #pragma once | ||
|
|
||
| #include "source/common/grpc/typed_async_client.h" | ||
| #include "source/common/protobuf/utility.h" | ||
|
|
||
| namespace Envoy { | ||
| namespace Grpc { | ||
|
|
||
| enum class BufferState { Buffered, PendingFlush }; | ||
|
|
||
| template <class RequestType, class ResponseType> class BufferedAsyncClient { | ||
| public: | ||
| BufferedAsyncClient(uint32_t max_buffer_bytes, const Protobuf::MethodDescriptor& service_method, | ||
| Grpc::AsyncStreamCallbacks<ResponseType>& callbacks, | ||
| const Grpc::AsyncClient<RequestType, ResponseType>& client) | ||
| : max_buffer_bytes_(max_buffer_bytes), service_method_(service_method), callbacks_(callbacks), | ||
| client_(client) {} | ||
|
|
||
| virtual ~BufferedAsyncClient() { cleanup(); } | ||
|
|
||
| uint32_t publishId(RequestType& message) { return MessageUtil::hash(message); } | ||
|
Shikugawa marked this conversation as resolved.
Outdated
|
||
|
|
||
| void bufferMessage(uint32_t id, RequestType& message) { | ||
| const auto buffer_size = message.ByteSizeLong(); | ||
| if (current_buffer_bytes_ + buffer_size > max_buffer_bytes_) { | ||
| return; | ||
|
Shikugawa marked this conversation as resolved.
Outdated
|
||
| } | ||
|
|
||
| message_buffer_[id] = std::make_pair(BufferState::Buffered, message); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In #17486 this id is set on the RequestType (CriticalAccessLogsMessage). In conjunction with the previous comment, maybe we should return
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I missed it. I agree with to return |
||
| current_buffer_bytes_ += buffer_size; | ||
| } | ||
|
|
||
| absl::flat_hash_set<uint32_t> sendBufferedMessages() { | ||
| if (active_stream_ == nullptr) { | ||
| active_stream_ = | ||
| client_.start(service_method_, callbacks_, Http::AsyncClient::StreamOptions()); | ||
| } | ||
|
|
||
| if (active_stream_->isAboveWriteBufferHighWatermark()) { | ||
| return {}; | ||
| } | ||
|
|
||
| absl::flat_hash_set<uint32_t> inflight_message_ids; | ||
|
|
||
| for (auto&& it : message_buffer_) { | ||
| const auto id = it.first; | ||
| auto& state = it.second.first; | ||
| auto& message = it.second.second; | ||
|
|
||
| if (state == BufferState::PendingFlush) { | ||
| continue; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should add a test that covers this case, since it prevents double sending. I noticed it's not hit in the coverage report
Member
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added |
||
| } | ||
|
|
||
| state = BufferState::PendingFlush; | ||
| inflight_message_ids.emplace(id); | ||
| active_stream_->sendMessage(message, false); | ||
| } | ||
|
Shikugawa marked this conversation as resolved.
|
||
|
|
||
| return inflight_message_ids; | ||
| } | ||
|
|
||
| void onSuccess(uint32_t message_id) { erasePendingMessage(message_id); } | ||
|
|
||
| void onError(uint32_t message_id) { | ||
| if (message_buffer_.find(message_id) == message_buffer_.end()) { | ||
| return; | ||
| } | ||
| message_buffer_.at(message_id).first = BufferState::Buffered; | ||
| } | ||
|
|
||
| void cleanup() { | ||
| if (active_stream_ != nullptr) { | ||
| active_stream_ = nullptr; | ||
| } | ||
| } | ||
|
|
||
| bool hasActiveStream() { return active_stream_ != nullptr; } | ||
|
|
||
| const absl::flat_hash_map<uint32_t, std::pair<BufferState, RequestType>>& messageBuffer() { | ||
| return message_buffer_; | ||
| } | ||
|
|
||
| private: | ||
| void erasePendingMessage(uint32_t message_id) { | ||
| if (message_buffer_.find(message_id) == message_buffer_.end()) { | ||
| return; | ||
| } | ||
| auto& buffer = message_buffer_.at(message_id); | ||
|
|
||
| // There may be cases where the buffer status is not PendingFlush when | ||
| // this function is called. For example, a message_buffer that was | ||
| // PendingFlush may become Buffered due to an external state change | ||
| // (e.g. re-buffering due to timeout). | ||
| if (buffer.first == BufferState::PendingFlush) { | ||
| const auto buffer_size = buffer.second.ByteSizeLong(); | ||
| current_buffer_bytes_ -= buffer_size; | ||
| message_buffer_.erase(message_id); | ||
| } | ||
| } | ||
|
|
||
| uint32_t max_buffer_bytes_ = 0; | ||
|
Shikugawa marked this conversation as resolved.
Outdated
|
||
| const Protobuf::MethodDescriptor& service_method_; | ||
| Grpc::AsyncStreamCallbacks<ResponseType>& callbacks_; | ||
| Grpc::AsyncClient<RequestType, ResponseType> client_; | ||
| Grpc::AsyncStream<RequestType> active_stream_; | ||
| absl::flat_hash_map<uint32_t, std::pair<BufferState, RequestType>> message_buffer_; | ||
| uint32_t current_buffer_bytes_ = 0; | ||
| }; | ||
|
|
||
| template <class RequestType, class ResponseType> | ||
| using BufferedAsyncClientPtr = std::unique_ptr<BufferedAsyncClient<RequestType, ResponseType>>; | ||
|
|
||
| } // namespace Grpc | ||
| } // namespace Envoy | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| #include "envoy/config/core/v3/grpc_service.pb.h" | ||
|
|
||
| #include "source/common/grpc/async_client_impl.h" | ||
| #include "source/common/grpc/buffered_async_client.h" | ||
| #include "source/common/network/address_impl.h" | ||
| #include "source/common/network/socket_impl.h" | ||
|
|
||
| #include "test/mocks/http/mocks.h" | ||
| #include "test/mocks/tracing/mocks.h" | ||
| #include "test/mocks/upstream/cluster_manager.h" | ||
| #include "test/proto/helloworld.pb.h" | ||
| #include "test/test_common/test_time.h" | ||
|
|
||
| #include "gmock/gmock.h" | ||
| #include "gtest/gtest.h" | ||
|
|
||
| using testing::_; | ||
| using testing::NiceMock; | ||
| using testing::Return; | ||
| using testing::ReturnRef; | ||
|
|
||
| namespace Envoy { | ||
| namespace Grpc { | ||
| namespace { | ||
|
|
||
| class BufferedAsyncClientTest : public testing::Test { | ||
| public: | ||
| BufferedAsyncClientTest() | ||
| : method_descriptor_(helloworld::Greeter::descriptor()->FindMethodByName("SayHello")) { | ||
| config_.mutable_envoy_grpc()->set_cluster_name("test_cluster"); | ||
|
|
||
| cm_.initializeThreadLocalClusters({"test_cluster"}); | ||
| ON_CALL(cm_.thread_local_cluster_, httpAsyncClient()).WillByDefault(ReturnRef(http_client_)); | ||
| } | ||
|
|
||
| const Protobuf::MethodDescriptor* method_descriptor_; | ||
| envoy::config::core::v3::GrpcService config_; | ||
| NiceMock<Upstream::MockClusterManager> cm_; | ||
| NiceMock<Http::MockAsyncClient> http_client_; | ||
| }; | ||
|
|
||
| TEST_F(BufferedAsyncClientTest, BasicSendFlow) { | ||
| Http::MockAsyncClientStream http_stream; | ||
| EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); | ||
| EXPECT_CALL(http_stream, sendHeaders(_, _)); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
| EXPECT_CALL(http_stream, sendData(_, _)); | ||
| EXPECT_CALL(http_stream, reset()); | ||
|
|
||
| DangerousDeprecatedTestTime test_time_; | ||
| auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem()); | ||
| AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client); | ||
|
|
||
| NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback; | ||
| BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client( | ||
| 100000, *method_descriptor_, callback, client); | ||
|
|
||
| helloworld::HelloRequest request; | ||
| request.set_name("Alice"); | ||
| auto id = buffered_client.publishId(request); | ||
| buffered_client.bufferMessage(id, request); | ||
| EXPECT_EQ(1, buffered_client.sendBufferedMessages().size()); | ||
|
|
||
| // Re-buffer, and transport. | ||
| buffered_client.onError(id); | ||
|
|
||
| EXPECT_CALL(http_stream, sendData(_, _)).Times(2); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
|
|
||
| helloworld::HelloRequest request2; | ||
| request2.set_name("Bob"); | ||
| auto id2 = buffered_client.publishId(request2); | ||
| buffered_client.bufferMessage(id2, request2); | ||
| auto ids2 = buffered_client.sendBufferedMessages(); | ||
| EXPECT_EQ(2, ids2.size()); | ||
|
|
||
| // Clear existing messages. | ||
| for (auto&& id : ids2) { | ||
| buffered_client.onSuccess(id); | ||
| } | ||
|
|
||
| // Successfully cleared pending messages. | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
| auto ids3 = buffered_client.sendBufferedMessages(); | ||
| EXPECT_EQ(0, ids3.size()); | ||
| } | ||
|
|
||
| TEST_F(BufferedAsyncClientTest, BufferLimitExceeded) { | ||
| Http::MockAsyncClientStream http_stream; | ||
| EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); | ||
| EXPECT_CALL(http_stream, sendHeaders(_, _)); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(false)); | ||
| EXPECT_CALL(http_stream, reset()); | ||
|
|
||
| DangerousDeprecatedTestTime test_time_; | ||
| auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem()); | ||
| AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client); | ||
|
|
||
| NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback; | ||
| BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client( | ||
| 0, *method_descriptor_, callback, client); | ||
|
|
||
| helloworld::HelloRequest request; | ||
| request.set_name("Alice"); | ||
| auto id = buffered_client.publishId(request); | ||
| buffered_client.bufferMessage(id, request); | ||
|
|
||
| EXPECT_EQ(0, buffered_client.sendBufferedMessages().size()); | ||
| } | ||
|
|
||
| TEST_F(BufferedAsyncClientTest, BufferHighWatermarkTest) { | ||
| Http::MockAsyncClientStream http_stream; | ||
| EXPECT_CALL(http_client_, start(_, _)).WillOnce(Return(&http_stream)); | ||
| EXPECT_CALL(http_stream, sendHeaders(_, _)); | ||
| EXPECT_CALL(http_stream, isAboveWriteBufferHighWatermark()).WillOnce(Return(true)); | ||
| EXPECT_CALL(http_stream, reset()); | ||
|
|
||
| DangerousDeprecatedTestTime test_time_; | ||
| auto raw_client = std::make_shared<AsyncClientImpl>(cm_, config_, test_time_.timeSystem()); | ||
| AsyncClient<helloworld::HelloRequest, helloworld::HelloReply> client(raw_client); | ||
|
|
||
| NiceMock<MockAsyncStreamCallbacks<helloworld::HelloReply>> callback; | ||
| BufferedAsyncClient<helloworld::HelloRequest, helloworld::HelloReply> buffered_client( | ||
| 100000, *method_descriptor_, callback, client); | ||
|
|
||
| helloworld::HelloRequest request; | ||
| request.set_name("Alice"); | ||
| auto id = buffered_client.publishId(request); | ||
| buffered_client.bufferMessage(id, request); | ||
|
|
||
| EXPECT_EQ(0, buffered_client.sendBufferedMessages().size()); | ||
| } | ||
|
|
||
| } // namespace | ||
| } // namespace Grpc | ||
| } // namespace Envoy |
Uh oh!
There was an error while loading. Please reload this page.