diff --git a/include/envoy/grpc/async_client.h b/include/envoy/grpc/async_client.h index 56ad3f1d5fbc5..c156c4a66dd6e 100644 --- a/include/envoy/grpc/async_client.h +++ b/include/envoy/grpc/async_client.h @@ -2,6 +2,7 @@ #include +#include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" #include "envoy/grpc/status.h" #include "envoy/http/header_map.h" @@ -41,7 +42,16 @@ class AsyncStream { * object, but callbacks may still be received until the stream is closed * remotely. */ - virtual void sendMessage(const Protobuf::Message& request, bool end_stream) PURE; + virtual void sendMessage(const Protobuf::Message& request, bool end_stream); + + /** + * Send request message to the stream. + * @param request serializalized message. + * @param end_stream close the stream locally. No further methods may be invoked on the stream + * object, but callbacks may still be received until the stream is closed + * remotely. + */ + virtual void sendRawMessage(Buffer::InstancePtr request, bool end_stream) PURE; /** * Close the stream locally and send an empty DATA frame to the remote. No further methods may be @@ -57,9 +67,9 @@ class AsyncStream { virtual void resetStream() PURE; }; -class AsyncRequestCallbacks { +class RawAsyncRequestCallbacks { public: - virtual ~AsyncRequestCallbacks() {} + virtual ~RawAsyncRequestCallbacks() {} /** * Called when populating the headers to send with initial metadata. @@ -71,15 +81,18 @@ class AsyncRequestCallbacks { * Factory for empty response messages. * @return ProtobufTypes::MessagePtr a Protobuf::Message with the response * type for the request. + * NB: createEmptyResponse will not be called if onSuccessRaw() is overriden. */ - virtual ProtobufTypes::MessagePtr createEmptyResponse() PURE; + virtual ProtobufTypes::MessagePtr createEmptyResponse() { + throw EnvoyException("AsyncRequestCallbacks::createEmptyResponse must be overriden"); + } /** * Called when the async gRPC request succeeds. No further callbacks will be invoked. - * @param response the gRPC response. + * @param response the gRPC response bytes. * @param span a tracing span to fill with extra tags. */ - virtual void onSuccessUntyped(ProtobufTypes::MessagePtr&& response, Tracing::Span& span) PURE; + virtual void onSuccessRaw(Buffer::InstancePtr response, Tracing::Span& span) PURE; /** * Called when the async gRPC request fails. No further callbacks will be invoked. @@ -91,6 +104,20 @@ class AsyncRequestCallbacks { Tracing::Span& span) PURE; }; +class AsyncRequestCallbacks : public RawAsyncRequestCallbacks { +public: + virtual ~AsyncRequestCallbacks() {} + + void onSuccessRaw(Buffer::InstancePtr response, Tracing::Span& span) override; + /** + * Called when the async gRPC request succeeds. No further callbacks will be invoked. + * @param response the gRPC response. + * @param span a tracing span to fill with extra tags. + * NB: requires overriding createEmptyResponse(). + */ + virtual void onSuccessUntyped(ProtobufTypes::MessagePtr&& response, Tracing::Span& span) PURE; +}; + // Templatized variant of AsyncRequestCallbacks. template class TypedAsyncRequestCallbacks : public AsyncRequestCallbacks { public: @@ -108,20 +135,23 @@ template class TypedAsyncRequestCallbacks : public AsyncReq /** * Notifies caller of async gRPC stream status. * Note the gRPC stream is full-duplex, even if the local to remote stream has been ended by - * AsyncStream.close(), AsyncStreamCallbacks can continue to receive events until the remote + * AsyncStream.close(), RawAsyncStreamCallbacks can continue to receive events until the remote * to local stream is closed (onRemoteClose), and vice versa. Once the stream is closed remotely, no * further callbacks will be invoked. */ -class AsyncStreamCallbacks { +class RawAsyncStreamCallbacks { public: - virtual ~AsyncStreamCallbacks() {} + virtual ~RawAsyncStreamCallbacks() {} /** * Factory for empty response messages. * @return ProtobufTypes::MessagePtr a Protobuf::Message with the response * type for the stream. + * NB: createEmptyResponse will not be called if onRecieveRawMessage() is overriden. */ - virtual ProtobufTypes::MessagePtr createEmptyResponse() PURE; + virtual ProtobufTypes::MessagePtr createEmptyResponse() { + throw EnvoyException("AsyncStreamCallbacks::createEmptyResponse must be overriden"); + } /** * Called when populating the headers to send with initial metadata. @@ -139,8 +169,10 @@ class AsyncStreamCallbacks { /** * Called when an async gRPC message is received. * @param response the gRPC message. + * @return bool which is true if the message well formed and false otherwise which will cause + the stream to shutdown with an INTERNAL error. */ - virtual void onReceiveMessageUntyped(ProtobufTypes::MessagePtr&& message) PURE; + virtual bool onReceiveRawMessage(Buffer::InstancePtr response) PURE; /** * Called when trailing metadata is received. This will also be called on non-Ok grpc-status @@ -152,13 +184,33 @@ class AsyncStreamCallbacks { /** * Called when the remote closes or an error occurs on the gRPC stream. The stream is * considered remotely closed after this invocation and no further callbacks will be - * invoked. In addition, no further stream operations are permitted. + { * invoked. In addition, no further stream operations are permitted. * @param status the gRPC status. * @param message the gRPC status message or empty string if not present. */ virtual void onRemoteClose(Status::GrpcStatus status, const std::string& message) PURE; }; +/** + * Notifies caller of async gRPC stream status. + * Note the gRPC stream is full-duplex, even if the local to remote stream has been ended by + * AsyncStream.close(), AsyncStreamCallbacks can continue to receive events until the remote + * to local stream is closed (onRemoteClose), and vice versa. Once the stream is closed remotely, no + * further callbacks will be invoked. + */ +class AsyncStreamCallbacks : public RawAsyncStreamCallbacks { +public: + virtual ~AsyncStreamCallbacks() {} + + bool onReceiveRawMessage(Buffer::InstancePtr response) override; + /** + * Called when an async gRPC message is received. + * @param response the gRPC message. + * NB: requires overriding createEmptyResponse(). + */ + virtual void onReceiveMessageUntyped(ProtobufTypes::MessagePtr&& message) PURE; +}; + // Templatized variant of AsyncStreamCallbacks. template class TypedAsyncStreamCallbacks : public AsyncStreamCallbacks { public: @@ -195,7 +247,24 @@ class AsyncClient { virtual AsyncRequest* send(const Protobuf::MethodDescriptor& service_method, const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, - const absl::optional& timeout) PURE; + const absl::optional& timeout); + + /** + * Start a gRPC unary RPC asynchronously. + * @param service_full_name full name of the service (i.e. service_method.service()->full_name()). + * @param method_name name of the method (i.e. service_method.name()). + * @param request serialized message. + * @param callbacks the callbacks to be notified of RPC status. + * @param parent_span the current parent tracing context. + * @param timeout supplies the request timeout. + * @return a request handle or nullptr if no request could be started. NOTE: In this case + * onFailure() has already been called inline. The client owns the request and the + * handle should just be used to cancel. + */ + virtual AsyncRequest* sendRaw(absl::string_view service_full_name, absl::string_view method_name, + Buffer::InstancePtr request, RawAsyncRequestCallbacks& callbacks, + Tracing::Span& parent_span, + const absl::optional& timeout) PURE; /** * Start a gRPC stream asynchronously. @@ -209,7 +278,22 @@ class AsyncClient { * may be reclaimed. */ virtual AsyncStream* start(const Protobuf::MethodDescriptor& service_method, - AsyncStreamCallbacks& callbacks) PURE; + AsyncStreamCallbacks& callbacks); + + /** + * Start a gRPC stream asynchronously. + * TODO(mattklein123): Determine if tracing should be added to streaming requests. + * @param service_full_name full name of the service (i.e. service_method.service()->full_name()). + * @param method_name name of the method (i.e. service_method.name()). + * @param callbacks the callbacks to be notified of stream status. + * @return a stream handle or nullptr if no stream could be started. NOTE: In this case + * onRemoteClose() has already been called inline. The client owns the stream and + * the handle can be used to send more messages or finish the stream. It is expected that + * closeStream() is invoked by the caller to notify the client that the stream resources + * may be reclaimed. + */ + virtual AsyncStream* startRaw(absl::string_view service_full_name, absl::string_view method_name, + RawAsyncStreamCallbacks& callbacks) PURE; }; typedef std::unique_ptr AsyncClientPtr; diff --git a/source/common/config/BUILD b/source/common/config/BUILD index 4737db10e65eb..428c014b194db 100644 --- a/source/common/config/BUILD +++ b/source/common/config/BUILD @@ -109,6 +109,7 @@ envoy_cc_library( "//source/common/common:backoff_lib", "//source/common/common:minimal_logger_lib", "//source/common/common:token_bucket_impl_lib", + "//source/common/grpc:async_client_lib", "//source/common/protobuf", ], ) diff --git a/source/common/grpc/BUILD b/source/common/grpc/BUILD index f0471c3568ee3..bfe2e369f6234 100644 --- a/source/common/grpc/BUILD +++ b/source/common/grpc/BUILD @@ -11,7 +11,10 @@ envoy_package() envoy_cc_library( name = "async_client_lib", - srcs = ["async_client_impl.cc"], + srcs = [ + "async_client.cc", + "async_client_impl.cc", + ], hdrs = ["async_client_impl.h"], deps = [ ":codec_lib", @@ -28,6 +31,7 @@ envoy_cc_library( hdrs = ["async_client_manager_impl.h"], deps = [ ":async_client_lib", + ":common_lib", "//include/envoy/grpc:async_client_manager_interface", "//include/envoy/singleton:manager_interface", "//include/envoy/thread_local:thread_local_interface", @@ -61,7 +65,10 @@ envoy_cc_library( name = "common_lib", srcs = ["common.cc"], hdrs = ["common.h"], - external_deps = ["abseil_optional"], + external_deps = [ + "abseil_optional", + "grpc", + ], deps = [ "//include/envoy/http:header_map_interface", "//include/envoy/http:message_interface", @@ -91,6 +98,8 @@ envoy_cc_library( "grpc", ], deps = [ + ":async_client_lib", + ":common_lib", ":google_grpc_creds_lib", "//include/envoy/api:api_interface", "//include/envoy/grpc:google_grpc_creds_interface", diff --git a/source/common/grpc/async_client.cc b/source/common/grpc/async_client.cc new file mode 100644 index 0000000000000..739ab5e6747f4 --- /dev/null +++ b/source/common/grpc/async_client.cc @@ -0,0 +1,55 @@ +#include "envoy/grpc/async_client.h" + +#include "common/buffer/zero_copy_input_stream_impl.h" +#include "common/common/utility.h" +#include "common/grpc/common.h" +#include "common/http/utility.h" + +namespace Envoy { +namespace Grpc { + +void AsyncStream::sendMessage(const Protobuf::Message& request, bool end_stream) { + sendRawMessage(Common::serializeBody(request), end_stream); +} + +AsyncRequest* AsyncClient::send(const Protobuf::MethodDescriptor& service_method, + const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, + Tracing::Span& parent_span, + const absl::optional& timeout) { + return sendRaw(service_method.service()->full_name(), service_method.name(), + Common::serializeBody(request), callbacks, parent_span, timeout); +} + +AsyncStream* AsyncClient::start(const Protobuf::MethodDescriptor& service_method, + AsyncStreamCallbacks& callbacks) { + return startRaw(service_method.service()->full_name(), service_method.name(), callbacks); +} + +void AsyncRequestCallbacks::onSuccessRaw(Buffer::InstancePtr response, Tracing::Span& span) { + ProtobufTypes::MessagePtr response_message = createEmptyResponse(); + // TODO(htuch): Need to add support for compressed responses as well here. + if (response->length() > 0) { + Buffer::ZeroCopyInputStreamImpl stream(std::move(response)); + if (!response_message->ParseFromZeroCopyStream(&stream)) { + onFailure(Status::GrpcStatus::Internal, "", span); + return; + } + } + onSuccessUntyped(std::move(response_message), span); +} + +bool AsyncStreamCallbacks::onReceiveRawMessage(Buffer::InstancePtr response) { + ProtobufTypes::MessagePtr response_message = createEmptyResponse(); + // TODO(htuch): Need to add support for compressed responses as well here. + if (response->length() > 0) { + Buffer::ZeroCopyInputStreamImpl stream(std::move(response)); + if (!response_message->ParseFromZeroCopyStream(&stream)) { + return false; + } + } + onReceiveMessageUntyped(std::move(response_message)); + return true; +} + +} // namespace Grpc +} // namespace Envoy diff --git a/source/common/grpc/async_client_impl.cc b/source/common/grpc/async_client_impl.cc index fc0f62e544d29..80d4230868b1a 100644 --- a/source/common/grpc/async_client_impl.cc +++ b/source/common/grpc/async_client_impl.cc @@ -26,8 +26,17 @@ AsyncRequest* AsyncClientImpl::send(const Protobuf::MethodDescriptor& service_me const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, const absl::optional& timeout) { - auto* const async_request = - new AsyncRequestImpl(*this, service_method, request, callbacks, parent_span, timeout); + return sendRaw(service_method.service()->full_name(), service_method.name(), + Common::serializeBody(request), callbacks, parent_span, timeout); +} + +AsyncRequest* AsyncClientImpl::sendRaw(absl::string_view service_full_name, + absl::string_view method_name, Buffer::InstancePtr request, + RawAsyncRequestCallbacks& callbacks, + Tracing::Span& parent_span, + const absl::optional& timeout) { + auto* const async_request = new AsyncRequestImpl( + *this, service_full_name, method_name, std::move(request), callbacks, parent_span, timeout); std::unique_ptr grpc_stream{async_request}; grpc_stream->initialize(true); @@ -39,11 +48,12 @@ AsyncRequest* AsyncClientImpl::send(const Protobuf::MethodDescriptor& service_me return async_request; } -AsyncStream* AsyncClientImpl::start(const Protobuf::MethodDescriptor& service_method, - AsyncStreamCallbacks& callbacks) { +AsyncStream* AsyncClientImpl::startRaw(absl::string_view service_full_name, + absl::string_view method_name, + RawAsyncStreamCallbacks& callbacks) { const absl::optional no_timeout; - auto grpc_stream = - std::make_unique(*this, service_method, callbacks, no_timeout); + auto grpc_stream = std::make_unique(*this, service_full_name, method_name, + callbacks, no_timeout); grpc_stream->initialize(false); if (grpc_stream->hasResetStream()) { @@ -58,7 +68,13 @@ AsyncStreamImpl::AsyncStreamImpl(AsyncClientImpl& parent, const Protobuf::MethodDescriptor& service_method, AsyncStreamCallbacks& callbacks, const absl::optional& timeout) - : parent_(parent), service_method_(service_method), callbacks_(callbacks), timeout_(timeout) {} + : AsyncStreamImpl(parent, service_method.service()->full_name(), service_method.name(), + callbacks, timeout) {} +AsyncStreamImpl::AsyncStreamImpl(AsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, RawAsyncStreamCallbacks& callbacks, + const absl::optional& timeout) + : parent_(parent), service_full_name_(service_full_name), method_name_(method_name), + callbacks_(callbacks), timeout_(timeout) {} void AsyncStreamImpl::initialize(bool buffer_body_for_retry) { if (parent_.cm_.get(parent_.remote_cluster_name_) == nullptr) { @@ -81,9 +97,9 @@ void AsyncStreamImpl::initialize(bool buffer_body_for_retry) { // TODO(htuch): match Google gRPC base64 encoding behavior for *-bin headers, see // https://github.com/envoyproxy/envoy/pull/2444#discussion_r163914459. - headers_message_ = Common::prepareHeaders( - parent_.remote_cluster_name_, service_method_.service()->full_name(), service_method_.name(), - absl::optional(timeout_)); + headers_message_ = + Common::prepareHeaders(parent_.remote_cluster_name_, service_full_name_, method_name_, + absl::optional(timeout_)); // Fill service-wide initial metadata. for (const auto& header_value : parent_.initial_metadata_) { headers_message_->headers().addCopy(Http::LowerCaseString(header_value.key()), @@ -129,21 +145,18 @@ void AsyncStreamImpl::onData(Buffer::Instance& data, bool end_stream) { } for (auto& frame : decoded_frames_) { - ProtobufTypes::MessagePtr response = callbacks_.createEmptyResponse(); - // TODO(htuch): Need to add support for compressed responses as well here. - if (frame.length_ > 0) { - Buffer::ZeroCopyInputStreamImpl stream(std::move(frame.data_)); - - if (frame.flags_ != GRPC_FH_DEFAULT || !response->ParseFromZeroCopyStream(&stream)) { - streamError(Status::GrpcStatus::Internal); - return; - } + if (frame.length_ > 0 && frame.flags_ != GRPC_FH_DEFAULT) { + streamError(Status::GrpcStatus::Internal); + return; + } + if (!callbacks_.onReceiveRawMessage(frame.data_ ? std::move(frame.data_) + : std::make_unique())) { + streamError(Status::GrpcStatus::Internal); + return; } - callbacks_.onReceiveMessageUntyped(std::move(response)); } - if (end_stream) { - Http::HeaderMapPtr empty_trailers = std::make_unique(); + if (!http_reset_ && end_stream) { streamError(Status::GrpcStatus::Unknown); } } @@ -177,7 +190,11 @@ void AsyncStreamImpl::onReset() { } void AsyncStreamImpl::sendMessage(const Protobuf::Message& request, bool end_stream) { - stream_->sendData(*Common::serializeBody(request), end_stream); + sendRawMessage(Common::serializeBody(request), end_stream); +} + +void AsyncStreamImpl::sendRawMessage(Buffer::InstancePtr request, bool end_stream) { + stream_->sendData(*request, end_stream); } void AsyncStreamImpl::closeStream() { @@ -201,13 +218,12 @@ void AsyncStreamImpl::cleanup() { } } -AsyncRequestImpl::AsyncRequestImpl(AsyncClientImpl& parent, - const Protobuf::MethodDescriptor& service_method, - const Protobuf::Message& request, - AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, +AsyncRequestImpl::AsyncRequestImpl(AsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, Buffer::InstancePtr request, + RawAsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, const absl::optional& timeout) - : AsyncStreamImpl(parent, service_method, *this, timeout), request_(request), - callbacks_(callbacks) { + : AsyncStreamImpl(parent, service_full_name, method_name, *this, timeout), + request_(std::move(request)), callbacks_(callbacks) { current_span_ = parent_span.spawnChild(Tracing::EgressConfig::get(), "async " + parent.remote_cluster_name_ + " egress", @@ -216,12 +232,20 @@ AsyncRequestImpl::AsyncRequestImpl(AsyncClientImpl& parent, current_span_->setTag(Tracing::Tags::get().COMPONENT, Tracing::Tags::get().PROXY); } +AsyncRequestImpl::AsyncRequestImpl(AsyncClientImpl& parent, + const Protobuf::MethodDescriptor& service_method, + const Protobuf::Message& request, + AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, + const absl::optional& timeout) + : AsyncRequestImpl(parent, service_method.service()->full_name(), service_method.name(), + Common::serializeBody(request), callbacks, parent_span, timeout) {} + void AsyncRequestImpl::initialize(bool buffer_body_for_retry) { AsyncStreamImpl::initialize(buffer_body_for_retry); if (this->hasResetStream()) { return; } - this->sendMessage(request_, true); + this->sendRawMessage(std::move(request_), true); } void AsyncRequestImpl::cancel() { @@ -241,8 +265,9 @@ void AsyncRequestImpl::onCreateInitialMetadata(Http::HeaderMap& metadata) { void AsyncRequestImpl::onReceiveInitialMetadata(Http::HeaderMapPtr&&) {} -void AsyncRequestImpl::onReceiveMessageUntyped(ProtobufTypes::MessagePtr&& message) { - response_ = std::move(message); +bool AsyncRequestImpl::onReceiveRawMessage(Buffer::InstancePtr response) { + response_ = std::move(response); + return true; } void AsyncRequestImpl::onReceiveTrailingMetadata(Http::HeaderMapPtr&&) {} @@ -257,7 +282,7 @@ void AsyncRequestImpl::onRemoteClose(Grpc::Status::GrpcStatus status, const std: current_span_->setTag(Tracing::Tags::get().ERROR, Tracing::Tags::get().TRUE); callbacks_.onFailure(Status::Internal, EMPTY_STRING, *current_span_); } else { - callbacks_.onSuccessUntyped(std::move(response_), *current_span_); + callbacks_.onSuccessRaw(std::move(response_), *current_span_); } current_span_->finishSpan(); diff --git a/source/common/grpc/async_client_impl.h b/source/common/grpc/async_client_impl.h index 6fa54b8dfdf72..a403c7357159a 100644 --- a/source/common/grpc/async_client_impl.h +++ b/source/common/grpc/async_client_impl.h @@ -23,8 +23,12 @@ class AsyncClientImpl final : public AsyncClient { const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, const absl::optional& timeout) override; - AsyncStream* start(const Protobuf::MethodDescriptor& service_method, - AsyncStreamCallbacks& callbacks) override; + AsyncRequest* sendRaw(absl::string_view service_full_name, absl::string_view method_name, + Buffer::InstancePtr request, RawAsyncRequestCallbacks& callbacks, + Tracing::Span& parent_span, + const absl::optional& timeout) override; + AsyncStream* startRaw(absl::string_view service_full_name, absl::string_view method_name, + RawAsyncStreamCallbacks& callbacks) override; private: Upstream::ClusterManager& cm_; @@ -45,6 +49,9 @@ class AsyncStreamImpl : public AsyncStream, AsyncStreamImpl(AsyncClientImpl& parent, const Protobuf::MethodDescriptor& service_method, AsyncStreamCallbacks& callbacks, const absl::optional& timeout); + AsyncStreamImpl(AsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, RawAsyncStreamCallbacks& callbacks, + const absl::optional& timeout); virtual void initialize(bool buffer_body_for_retry); @@ -56,6 +63,7 @@ class AsyncStreamImpl : public AsyncStream, // Grpc::AsyncStream void sendMessage(const Protobuf::Message& request, bool end_stream) override; + void sendRawMessage(Buffer::InstancePtr request, bool end_stream) override; void closeStream() override; void resetStream() override; @@ -72,8 +80,9 @@ class AsyncStreamImpl : public AsyncStream, Event::Dispatcher* dispatcher_{}; Http::MessagePtr headers_message_; AsyncClientImpl& parent_; - const Protobuf::MethodDescriptor& service_method_; - AsyncStreamCallbacks& callbacks_; + std::string service_full_name_; + std::string method_name_; + RawAsyncStreamCallbacks& callbacks_; const absl::optional& timeout_; bool http_reset_{}; Http::AsyncClient::Stream* stream_{}; @@ -84,12 +93,16 @@ class AsyncStreamImpl : public AsyncStream, friend class AsyncClientImpl; }; -class AsyncRequestImpl : public AsyncRequest, public AsyncStreamImpl, AsyncStreamCallbacks { +class AsyncRequestImpl : public AsyncRequest, public AsyncStreamImpl, RawAsyncStreamCallbacks { public: AsyncRequestImpl(AsyncClientImpl& parent, const Protobuf::MethodDescriptor& service_method, const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, const absl::optional& timeout); + AsyncRequestImpl(AsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, Buffer::InstancePtr request, + RawAsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, + const absl::optional& timeout); void initialize(bool buffer_body_for_retry) override; @@ -101,14 +114,14 @@ class AsyncRequestImpl : public AsyncRequest, public AsyncStreamImpl, AsyncStrea ProtobufTypes::MessagePtr createEmptyResponse() override; void onCreateInitialMetadata(Http::HeaderMap& metadata) override; void onReceiveInitialMetadata(Http::HeaderMapPtr&&) override; - void onReceiveMessageUntyped(ProtobufTypes::MessagePtr&& message) override; + bool onReceiveRawMessage(Buffer::InstancePtr response) override; void onReceiveTrailingMetadata(Http::HeaderMapPtr&&) override; void onRemoteClose(Grpc::Status::GrpcStatus status, const std::string& message) override; - const Protobuf::Message& request_; - AsyncRequestCallbacks& callbacks_; + Buffer::InstancePtr request_; + RawAsyncRequestCallbacks& callbacks_; Tracing::SpanPtr current_span_; - ProtobufTypes::MessagePtr response_; + Buffer::InstancePtr response_; }; } // namespace Grpc diff --git a/source/common/grpc/common.cc b/source/common/grpc/common.cc index d2ea436b4227f..9d1f20ee3362c 100644 --- a/source/common/grpc/common.cc +++ b/source/common/grpc/common.cc @@ -2,6 +2,7 @@ #include +#include #include #include #include @@ -12,6 +13,7 @@ #include "common/common/enum_to_int.h" #include "common/common/fmt.h" #include "common/common/macros.h" +#include "common/common/stack_array.h" #include "common/common/utility.h" #include "common/http/headers.h" #include "common/http/message_impl.h" @@ -133,6 +135,21 @@ Buffer::InstancePtr Common::serializeBody(const Protobuf::Message& message) { return body; } +Buffer::InstancePtr Common::serializeMessage(const Protobuf::Message& message) { + Buffer::InstancePtr body(new Buffer::OwnedImpl()); + const uint32_t size = message.ByteSize(); + Buffer::RawSlice iovec; + body->reserve(size, &iovec, 1); + ASSERT(iovec.len_ >= size); + iovec.len_ = size; + uint8_t* current = reinterpret_cast(iovec.mem_); + Protobuf::io::ArrayOutputStream stream(current, size, -1); + Protobuf::io::CodedOutputStream codec_stream(&stream); + message.SerializeWithCachedSizes(&codec_stream); + body->commit(&iovec, 1); + return body; +} + std::chrono::milliseconds Common::getGrpcTimeout(Http::HeaderMap& request_headers) { std::chrono::milliseconds timeout(0); Http::HeaderEntry* header_grpc_timeout_entry = request_headers.GrpcTimeout(); @@ -269,5 +286,90 @@ std::string Common::typeUrl(const std::string& qualified_name) { return typeUrlPrefix() + "/" + qualified_name; } +struct BufferInstanceContainer { + BufferInstanceContainer(int ref_count, Buffer::InstancePtr buffer) + : ref_count_(ref_count), buffer_(std::move(buffer)) {} + std::atomic ref_count_; + Buffer::InstancePtr buffer_; +}; + +static void derefBufferInstanceContainer(void* container_ptr) { + auto container = reinterpret_cast(container_ptr); + container->ref_count_--; + if (container->ref_count_ <= 0) { + delete container; + } +} + +grpc::ByteBuffer Common::makeByteBuffer(Buffer::InstancePtr bufferInstance) { + if (!bufferInstance) { + return {}; + } + Buffer::RawSlice oneRawSlice; + // NB: we need to pass in >= 1 in order to get the real "n" (see Buffer::Instance for details). + int nSlices = bufferInstance->getRawSlices(&oneRawSlice, 1); + if (nSlices <= 0) { + return {}; + } + auto container = new BufferInstanceContainer{nSlices, std::move(bufferInstance)}; + if (nSlices == 1) { + grpc::Slice oneSlice(oneRawSlice.mem_, oneRawSlice.len_, &derefBufferInstanceContainer, + container); + return {&oneSlice, 1}; + } + STACK_ARRAY(manyRawSlices, Buffer::RawSlice, nSlices); + bufferInstance->getRawSlices(manyRawSlices.begin(), nSlices); + std::vector slices; + slices.reserve(nSlices); + for (int i = 0; i < nSlices; i++) { + slices.emplace_back(manyRawSlices[i].mem_, manyRawSlices[i].len_, &derefBufferInstanceContainer, + container); + } + return {&slices[0], slices.size()}; +} + +struct ByteBufferContainer { + ByteBufferContainer(int ref_count) : ref_count_(ref_count) {} + ~ByteBufferContainer() { ::free(fragments); } + std::atomic ref_count_; + Buffer::BufferFragmentImpl* fragments = nullptr; + std::vector slices_; +}; + +Buffer::InstancePtr Common::makeBufferInstance(const grpc::ByteBuffer& byteBuffer) { + auto buffer = std::make_unique(); + if (byteBuffer.Length() == 0) { + return buffer; + } + // NB: ByteBuffer::Dump moves the data out of the ByteBuffer so we need to ensure that the + // lifetime of the Slice(s) exceeds our Buffer::Instance. + std::vector slices; + byteBuffer.Dump(&slices); + if (slices.size() == 0) { + return buffer; + } + auto container = new ByteBufferContainer(static_cast(slices.size())); + std::function releaser = + [container](const void*, size_t, const Buffer::BufferFragmentImpl*) { + container->ref_count_--; + if (container->ref_count_ <= 0) { + delete container; + } + }; + // NB: addBufferFragment takes a pointer alias to the BufferFragmentImpl which is passed in so we + // need to ensure that the lifetime of those objects exceeds that of the Buffer::Instance. + container->fragments = static_cast( + ::malloc(sizeof(Buffer::BufferFragmentImpl) * slices.size())); + for (size_t i = 0; i < slices.size(); i++) { + new (&container->fragments[i]) + Buffer::BufferFragmentImpl(slices[i].begin(), slices[i].size(), releaser); + } + for (size_t i = 0; i < slices.size(); i++) { + buffer->addBufferFragment(container->fragments[i]); + } + container->slices_ = std::move(slices); + return buffer; +} + } // namespace Grpc } // namespace Envoy diff --git a/source/common/grpc/common.h b/source/common/grpc/common.h index 7a7399c1679c0..7e05d63399fa3 100644 --- a/source/common/grpc/common.h +++ b/source/common/grpc/common.h @@ -13,6 +13,7 @@ #include "common/protobuf/protobuf.h" #include "absl/types/optional.h" +#include "grpcpp/grpcpp.h" namespace Envoy { namespace Grpc { @@ -119,10 +120,15 @@ class Common { std::string* method); /** - * Serialize protobuf message. + * Serialize protobuf message. With grpc header. */ static Buffer::InstancePtr serializeBody(const Protobuf::Message& message); + /** + * Serialize protobuf message. Without grpc header. + */ + static Buffer::InstancePtr serializeMessage(const Protobuf::Message& message); + /** * Prepare headers for protobuf service. */ @@ -148,6 +154,21 @@ class Common { */ static std::string typeUrl(const std::string& qualified_name); + /** + * BUild grpc::ByteBuffer which aliases the data in a Buffer::InstancePtr. + * @param bufferInstance source data container. + * @return byteBuffer target container aliased to the data in Buffer::Instance and owning the + * Buffer::Instance. + */ + static grpc::ByteBuffer makeByteBuffer(Buffer::InstancePtr bufferInstance); + + /** + * BUild Buffer::Instance which aliases the data in a grpc::ByteBuffer. + * @param byteBuffer source data container. + * @param Buffer::InstancePtr target container aliased to the data in grpc::ByteBuffer. + */ + static Buffer::InstancePtr makeBufferInstance(const grpc::ByteBuffer& byteBuffer); + private: static void checkForHeaderOnlyError(Http::Message& http_response); }; diff --git a/source/common/grpc/google_async_client_impl.cc b/source/common/grpc/google_async_client_impl.cc index a2adeed68936d..8c6526fcddba9 100644 --- a/source/common/grpc/google_async_client_impl.cc +++ b/source/common/grpc/google_async_client_impl.cc @@ -5,6 +5,7 @@ #include "common/common/empty_string.h" #include "common/common/lock_guard.h" #include "common/config/datasource.h" +#include "common/grpc/common.h" #include "common/grpc/google_grpc_creds_impl.h" #include "common/tracing/http_tracer_impl.h" @@ -98,8 +99,17 @@ GoogleAsyncClientImpl::send(const Protobuf::MethodDescriptor& service_method, const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, const absl::optional& timeout) { - auto* const async_request = - new GoogleAsyncRequestImpl(*this, service_method, request, callbacks, parent_span, timeout); + return sendRaw(service_method.service()->full_name(), service_method.name(), + Common::serializeMessage(request), callbacks, parent_span, timeout); +} + +AsyncRequest* +GoogleAsyncClientImpl::sendRaw(absl::string_view service_full_name, absl::string_view method_name, + Buffer::InstancePtr request, RawAsyncRequestCallbacks& callbacks, + Tracing::Span& parent_span, + const absl::optional& timeout) { + auto* const async_request = new GoogleAsyncRequestImpl( + *this, service_full_name, method_name, std::move(request), callbacks, parent_span, timeout); std::unique_ptr grpc_stream{async_request}; grpc_stream->initialize(true); @@ -111,11 +121,12 @@ GoogleAsyncClientImpl::send(const Protobuf::MethodDescriptor& service_method, return async_request; } -AsyncStream* GoogleAsyncClientImpl::start(const Protobuf::MethodDescriptor& service_method, - AsyncStreamCallbacks& callbacks) { +AsyncStream* GoogleAsyncClientImpl::startRaw(absl::string_view service_full_name, + absl::string_view method_name, + RawAsyncStreamCallbacks& callbacks) { const absl::optional no_timeout; - auto grpc_stream = - std::make_unique(*this, service_method, callbacks, no_timeout); + auto grpc_stream = std::make_unique(*this, service_full_name, method_name, + callbacks, no_timeout); grpc_stream->initialize(false); if (grpc_stream->call_failed()) { @@ -126,16 +137,27 @@ AsyncStream* GoogleAsyncClientImpl::start(const Protobuf::MethodDescriptor& serv return active_streams_.front().get(); } +GoogleAsyncStreamImpl::GoogleAsyncStreamImpl( + GoogleAsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, RawAsyncStreamCallbacks& callbacks, + const absl::optional& timeout) + : parent_(parent), tls_(parent_.tls_), dispatcher_(parent_.dispatcher_), stub_(parent_.stub_), + service_full_name_(service_full_name), method_name_(method_name), callbacks_(callbacks), + timeout_(timeout) {} + GoogleAsyncStreamImpl::GoogleAsyncStreamImpl( GoogleAsyncClientImpl& parent, const Protobuf::MethodDescriptor& service_method, AsyncStreamCallbacks& callbacks, const absl::optional& timeout) - : parent_(parent), tls_(parent_.tls_), dispatcher_(parent_.dispatcher_), stub_(parent_.stub_), - service_method_(service_method), callbacks_(callbacks), timeout_(timeout) {} + : GoogleAsyncStreamImpl(parent, service_method.service()->full_name(), service_method.name(), + callbacks, timeout) {} GoogleAsyncStreamImpl::~GoogleAsyncStreamImpl() { ENVOY_LOG(debug, "GoogleAsyncStreamImpl destruct"); } +GoogleAsyncStreamImpl::PendingMessage::PendingMessage(Buffer::InstancePtr request, bool end_stream) + : buf_(Common::makeByteBuffer(std::move(request))), end_stream_(end_stream) {} + // TODO(htuch): figure out how to propagate "this request should be buffered for // retry" bit to Google gRPC library. void GoogleAsyncStreamImpl::initialize(bool /*buffer_body_for_retry*/) { @@ -161,9 +183,8 @@ void GoogleAsyncStreamImpl::initialize(bool /*buffer_body_for_retry*/) { }, &ctxt_); // Invoke stub call. - rw_ = parent_.stub_->PrepareCall( - &ctxt_, "/" + service_method_.service()->full_name() + "/" + service_method_.name(), - &parent_.tls_.completionQueue()); + rw_ = parent_.stub_->PrepareCall(&ctxt_, "/" + service_full_name_ + "/" + method_name_, + &parent_.tls_.completionQueue()); if (rw_ == nullptr) { notifyRemoteClose(Status::GrpcStatus::Unavailable, nullptr, EMPTY_STRING); call_failed_ = true; @@ -193,7 +214,11 @@ void GoogleAsyncStreamImpl::notifyRemoteClose(Status::GrpcStatus grpc_status, } void GoogleAsyncStreamImpl::sendMessage(const Protobuf::Message& request, bool end_stream) { - write_pending_queue_.emplace(request, end_stream); + sendRawMessage(Common::serializeMessage(request), end_stream); +} + +void GoogleAsyncStreamImpl::sendRawMessage(Buffer::InstancePtr request, bool end_stream) { + write_pending_queue_.emplace(std::move(request), end_stream); ENVOY_LOG(trace, "Queued message to write ({} bytes)", write_pending_queue_.back().buf_.value().Length()); writeQueued(); @@ -309,20 +334,12 @@ void GoogleAsyncStreamImpl::handleOpCompletion(GoogleAsyncTag::Operation op, boo } case GoogleAsyncTag::Operation::Read: { ASSERT(ok); - ProtobufTypes::MessagePtr response = callbacks_.createEmptyResponse(); - { - // reader must be destructed before we queue up read_buf_ for the next - // Read op, otherwise we can race between this thread in the reader - // destructor and the gRPC op thread. - grpc::ProtoBufferReader reader(&read_buf_); - if (!response->ParseFromZeroCopyStream(&reader)) { - // This is basically streamError in Grpc::AsyncClientImpl. - notifyRemoteClose(Status::GrpcStatus::Internal, nullptr, EMPTY_STRING); - resetStream(); - break; - }; + if (!callbacks_.onReceiveRawMessage(Common::makeBufferInstance(read_buf_))) { + // This is basically streamError in Grpc::AsyncClientImpl. + notifyRemoteClose(Status::GrpcStatus::Internal, nullptr, EMPTY_STRING); + resetStream(); + break; } - callbacks_.onReceiveMessageUntyped(std::move(response)); rw_->Read(&read_buf_, &read_tag_); ++inflight_tags_; break; @@ -384,11 +401,11 @@ void GoogleAsyncStreamImpl::cleanup() { } GoogleAsyncRequestImpl::GoogleAsyncRequestImpl( - GoogleAsyncClientImpl& parent, const Protobuf::MethodDescriptor& service_method, - const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, - const absl::optional& timeout) - : GoogleAsyncStreamImpl(parent, service_method, *this, timeout), request_(request), - callbacks_(callbacks) { + GoogleAsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, Buffer::InstancePtr request, RawAsyncRequestCallbacks& callbacks, + Tracing::Span& parent_span, const absl::optional& timeout) + : GoogleAsyncStreamImpl(parent, service_full_name, method_name, *this, timeout), + request_(std::move(request)), callbacks_(callbacks) { current_span_ = parent_span.spawnChild(Tracing::EgressConfig::get(), "async " + parent.stat_prefix_ + " egress", parent.timeSource().systemTime()); @@ -396,12 +413,20 @@ GoogleAsyncRequestImpl::GoogleAsyncRequestImpl( current_span_->setTag(Tracing::Tags::get().COMPONENT, Tracing::Tags::get().PROXY); } +GoogleAsyncRequestImpl::GoogleAsyncRequestImpl( + GoogleAsyncClientImpl& parent, const Protobuf::MethodDescriptor& service_method, + const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, + const absl::optional& timeout) + : GoogleAsyncRequestImpl(parent, service_method.service()->full_name(), service_method.name(), + Common::serializeMessage(std::move(request)), callbacks, parent_span, + timeout) {} + void GoogleAsyncRequestImpl::initialize(bool buffer_body_for_retry) { GoogleAsyncStreamImpl::initialize(buffer_body_for_retry); if (this->call_failed()) { return; } - this->sendMessage(request_, true); + this->sendRawMessage(std::move(request_), true); } void GoogleAsyncRequestImpl::cancel() { @@ -417,8 +442,9 @@ void GoogleAsyncRequestImpl::onCreateInitialMetadata(Http::HeaderMap& metadata) void GoogleAsyncRequestImpl::onReceiveInitialMetadata(Http::HeaderMapPtr&&) {} -void GoogleAsyncRequestImpl::onReceiveMessageUntyped(ProtobufTypes::MessagePtr&& message) { - response_ = std::move(message); +bool GoogleAsyncRequestImpl::onReceiveRawMessage(Buffer::InstancePtr response) { + response_ = std::move(response); + return true; } void GoogleAsyncRequestImpl::onReceiveTrailingMetadata(Http::HeaderMapPtr&&) {} @@ -438,7 +464,7 @@ void GoogleAsyncRequestImpl::onRemoteClose(Grpc::Status::GrpcStatus status, current_span_->setTag(Tracing::Tags::get().ERROR, Tracing::Tags::get().TRUE); callbacks_.onFailure(Status::Internal, EMPTY_STRING, *current_span_); } else { - callbacks_.onSuccessUntyped(std::move(response_), *current_span_); + callbacks_.onSuccessRaw(std::move(response_), *current_span_); } current_span_->finishSpan(); diff --git a/source/common/grpc/google_async_client_impl.h b/source/common/grpc/google_async_client_impl.h index 315b0cfa37bc4..e6621f2cfad29 100644 --- a/source/common/grpc/google_async_client_impl.h +++ b/source/common/grpc/google_async_client_impl.h @@ -164,8 +164,12 @@ class GoogleAsyncClientImpl final : public AsyncClient, Logger::Loggable& timeout) override; - AsyncStream* start(const Protobuf::MethodDescriptor& service_method, - AsyncStreamCallbacks& callbacks) override; + AsyncRequest* sendRaw(absl::string_view service_full_name, absl::string_view method_name, + Buffer::InstancePtr request, RawAsyncRequestCallbacks& callbacks, + Tracing::Span& parent_span, + const absl::optional& timeout) override; + AsyncStream* startRaw(absl::string_view service_full_name, absl::string_view method_name, + RawAsyncStreamCallbacks& callbacks) override; TimeSource& timeSource() { return dispatcher_.timeSystem(); } @@ -199,12 +203,16 @@ class GoogleAsyncStreamImpl : public AsyncStream, const Protobuf::MethodDescriptor& service_method, AsyncStreamCallbacks& callbacks, const absl::optional& timeout); + GoogleAsyncStreamImpl(GoogleAsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, RawAsyncStreamCallbacks& callbacks, + const absl::optional& timeout); ~GoogleAsyncStreamImpl(); virtual void initialize(bool buffer_body_for_retry); // Grpc::AsyncStream void sendMessage(const Protobuf::Message& request, bool end_stream) override; + void sendRawMessage(Buffer::InstancePtr request, bool end_stream) override; void closeStream() override; void resetStream() override; @@ -235,17 +243,7 @@ class GoogleAsyncStreamImpl : public AsyncStream, // Pending serialized message on write queue. Only one Operation::Write is in-flight at any // point-in-time, so we queue pending writes here. struct PendingMessage { - // We serialize the message to a grpc::ByteBuffer prior to queueing. - PendingMessage(const Protobuf::Message& request, bool end_stream) - : buf_([](const Protobuf::Message& request) -> absl::optional { - grpc::ByteBuffer buffer; - grpc::ProtoBufferWriter writer(&buffer, grpc::kProtoBufferWriterMaxBufferLength, - request.ByteSize()); - return request.SerializeToZeroCopyStream(&writer) - ? absl::make_optional(buffer) - : absl::nullopt; - }(request)), - end_stream_(end_stream) {} + PendingMessage(Buffer::InstancePtr request, bool end_stream); // End-of-stream with no additional message. PendingMessage() : end_stream_(true) {} @@ -270,8 +268,9 @@ class GoogleAsyncStreamImpl : public AsyncStream, // We hold a ref count on the stub_ to allow the stream to wait for its tags // to drain from the CQ on cleanup. std::shared_ptr stub_; - const Protobuf::MethodDescriptor& service_method_; - AsyncStreamCallbacks& callbacks_; + std::string service_full_name_; + std::string method_name_; + RawAsyncStreamCallbacks& callbacks_; const absl::optional& timeout_; grpc::ClientContext ctxt_; std::unique_ptr rw_; @@ -304,13 +303,17 @@ class GoogleAsyncStreamImpl : public AsyncStream, class GoogleAsyncRequestImpl : public AsyncRequest, public GoogleAsyncStreamImpl, - AsyncStreamCallbacks { + RawAsyncStreamCallbacks { public: GoogleAsyncRequestImpl(GoogleAsyncClientImpl& parent, const Protobuf::MethodDescriptor& service_method, const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, const absl::optional& timeout); + GoogleAsyncRequestImpl(GoogleAsyncClientImpl& parent, absl::string_view service_full_name, + absl::string_view method_name, Buffer::InstancePtr request, + RawAsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, + const absl::optional& timeout); void initialize(bool buffer_body_for_retry) override; @@ -321,15 +324,15 @@ class GoogleAsyncRequestImpl : public AsyncRequest, // Grpc::AsyncStreamCallbacks void onCreateInitialMetadata(Http::HeaderMap& metadata) override; void onReceiveInitialMetadata(Http::HeaderMapPtr&&) override; - void onReceiveMessageUntyped(ProtobufTypes::MessagePtr&& message) override; + bool onReceiveRawMessage(Buffer::InstancePtr response) override; void onReceiveTrailingMetadata(Http::HeaderMapPtr&&) override; ProtobufTypes::MessagePtr createEmptyResponse() override; void onRemoteClose(Grpc::Status::GrpcStatus status, const std::string& message) override; - const Protobuf::Message& request_; - AsyncRequestCallbacks& callbacks_; + Buffer::InstancePtr request_; + RawAsyncRequestCallbacks& callbacks_; Tracing::SpanPtr current_span_; - ProtobufTypes::MessagePtr response_; + Buffer::InstancePtr response_; }; } // namespace Grpc diff --git a/test/common/grpc/common_test.cc b/test/common/grpc/common_test.cc index 34a582d35c937..722e30519acd4 100644 --- a/test/common/grpc/common_test.cc +++ b/test/common/grpc/common_test.cc @@ -348,5 +348,15 @@ TEST(GrpcCommonTest, ValidateResponse) { } } +TEST(GrpcCommonTest, MakeBufferInstance) { + grpc::ByteBuffer byteBuffer; + Common::makeBufferInstance(byteBuffer); +} + +TEST(GrpcCommonTest, MakeByteBuffer) { + auto buffer = std::make_unique(); + Common::makeByteBuffer(std::move(buffer)); +} + } // namespace Grpc } // namespace Envoy diff --git a/test/mocks/grpc/mocks.h b/test/mocks/grpc/mocks.h index e0db3dfe2772d..694dda1a882d9 100644 --- a/test/mocks/grpc/mocks.h +++ b/test/mocks/grpc/mocks.h @@ -26,8 +26,8 @@ class MockAsyncStream : public AsyncStream { ~MockAsyncStream(); MOCK_METHOD2_T(sendMessage, void(const Protobuf::Message& request, bool end_stream)); - MOCK_METHOD0_T(closeStream, void()); - MOCK_METHOD0_T(resetStream, void()); + MOCK_METHOD2_T(sendRawMessage, void(Buffer::InstancePtr request, bool end_stream)); + MOCK_METHOD0_T(closeStream, void()); MOCK_METHOD0_T(resetStream, void()); }; template @@ -69,8 +69,14 @@ class MockAsyncClient : public AsyncClient { const Protobuf::Message& request, AsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, const absl::optional& timeout)); + MOCK_METHOD6_T(sendRaw, AsyncRequest*(absl::string_view service_full_name, absl::string_view method_name, + Buffer::InstancePtr request, + RawAsyncRequestCallbacks& callbacks, Tracing::Span& parent_span, + const absl::optional& timeout)); MOCK_METHOD2_T(start, AsyncStream*(const Protobuf::MethodDescriptor& service_method, AsyncStreamCallbacks& callbacks)); + MOCK_METHOD3_T(startRaw, AsyncStream*(absl::string_view service_full_name, absl::string_view method_name, + RawAsyncStreamCallbacks& callbacks)); }; class MockAsyncClientFactory : public AsyncClientFactory {