diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 05264d60abac1..e75efa1c07609 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -14,8 +14,8 @@ envoy_cc_library( hdrs = ["app_exception_impl.h"], deps = [ ":protocol_interface", + ":thrift_lib", "//include/envoy/buffer:buffer_interface", - "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) @@ -35,6 +35,7 @@ envoy_cc_library( srcs = ["config.cc"], hdrs = ["config.h"], deps = [ + ":app_exception_lib", ":conn_manager_lib", ":decoder_lib", ":protocol_lib", @@ -88,6 +89,17 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "metadata_lib", + srcs = ["metadata.cc"], + hdrs = ["metadata.h"], + external_deps = ["abseil_optional"], + deps = [ + ":thrift_lib", + "//source/common/common:macros", + ], +) + envoy_cc_library( name = "protocol_converter_lib", hdrs = [ @@ -107,6 +119,8 @@ envoy_cc_library( ], external_deps = ["abseil_optional"], deps = [ + ":metadata_lib", + ":thrift_lib", "//include/envoy/buffer:buffer_interface", "//include/envoy/registry", "//source/common/common:assert_lib", @@ -149,6 +163,9 @@ envoy_cc_library( hdrs = ["transport.h"], external_deps = ["abseil_optional"], deps = [ + ":buffer_helper_lib", + ":metadata_lib", + ":thrift_lib", "//include/envoy/buffer:buffer_interface", "//include/envoy/registry", "//source/common/common:assert_lib", @@ -157,6 +174,15 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "thrift_lib", + hdrs = ["thrift.h"], + deps = [ + "//source/common/common:assert_lib", + "//source/common/singleton:const_singleton", + ], +) + envoy_cc_library( name = "transport_lib", srcs = [ @@ -170,7 +196,9 @@ envoy_cc_library( "unframed_transport_impl.h", ], deps = [ + ":app_exception_lib", ":buffer_helper_lib", + ":metadata_lib", ":protocol_lib", ":transport_interface", "//source/common/common:assert_lib", diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc index 65455c12b3609..a7d73742eb81e 100644 --- a/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc @@ -10,19 +10,31 @@ static const std::string MessageField = "message"; static const std::string TypeField = "type"; static const std::string StopField = ""; -void AppException::encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) { - proto.writeMessageBegin(buffer, method_name_, ThriftProxy::MessageType::Exception, seq_id_); +void AppException::encode(MessageMetadata& metadata, ThriftProxy::Protocol& proto, + Buffer::Instance& buffer) const { + // Handle cases where the exception occurs before the message name (e.g. some header transport + // errors). + if (!metadata.hasMethodName()) { + metadata.setMethodName(""); + } + if (!metadata.hasSequenceId()) { + metadata.setSequenceId(0); + } + + metadata.setMessageType(MessageType::Exception); + + proto.writeMessageBegin(buffer, metadata); proto.writeStructBegin(buffer, TApplicationException); - proto.writeFieldBegin(buffer, MessageField, ThriftProxy::FieldType::String, 1); - proto.writeString(buffer, error_message_); + proto.writeFieldBegin(buffer, MessageField, FieldType::String, 1); + proto.writeString(buffer, std::string(what())); proto.writeFieldEnd(buffer); - proto.writeFieldBegin(buffer, TypeField, ThriftProxy::FieldType::I32, 2); + proto.writeFieldBegin(buffer, TypeField, FieldType::I32, 2); proto.writeInt32(buffer, static_cast(type_)); proto.writeFieldEnd(buffer); - proto.writeFieldBegin(buffer, StopField, ThriftProxy::FieldType::Stop, 0); + proto.writeFieldBegin(buffer, StopField, FieldType::Stop, 0); proto.writeStructEnd(buffer); proto.writeMessageEnd(buffer); diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.h b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h index 4a0335704100a..b31a91e8e6891 100644 --- a/source/extensions/filters/network/thrift_proxy/app_exception_impl.h +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h @@ -1,41 +1,24 @@ #pragma once -#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "envoy/common/exception.h" + +#include "extensions/filters/network/thrift_proxy/metadata.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/thrift.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/** - * Thrift Application Exception types. - * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md - */ -enum class AppExceptionType { - Unknown = 0, - UnknownMethod = 1, - InvalidMessageType = 2, - WrongMethodName = 3, - BadSequenceId = 4, - MissingResult = 5, - InternalError = 6, - ProtocolError = 7, - InvalidTransform = 8, - InvalidProtocol = 9, - UnsupportedClientType = 10, -}; - -struct AppException : public ThriftFilters::DirectResponse { - AppException(const absl::string_view method_name, int32_t seq_id, AppExceptionType type, - const std::string& error_message) - : method_name_(method_name), seq_id_(seq_id), type_(type), error_message_(error_message) {} +struct AppException : public EnvoyException, public DirectResponse { + AppException(AppExceptionType type, const std::string& what) + : EnvoyException(what), type_(type) {} + AppException(const AppException& ex) : EnvoyException(ex.what()), type_(ex.type_) {} - void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) override; + void encode(MessageMetadata& metadata, Protocol& proto, Buffer::Instance& buffer) const override; - const std::string method_name_; - const int32_t seq_id_; const AppExceptionType type_; - const std::string error_message_; }; } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc index ee5a734eda209..735200636d2d3 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc @@ -17,8 +17,7 @@ namespace ThriftProxy { const uint16_t BinaryProtocolImpl::Magic = 0x8001; -bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name, - MessageType& msg_type, int32_t& seq_id) { +bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) { // Minimum message length: // version: 2 bytes + // unused: 1 byte + @@ -52,13 +51,14 @@ bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& buffer.drain(8); if (name_len > 0) { - name.assign(std::string(static_cast(buffer.linearize(name_len)), name_len)); + metadata.setMethodName( + std::string(static_cast(buffer.linearize(name_len)), name_len)); buffer.drain(name_len); } else { - name.clear(); + metadata.setMethodName(""); } - msg_type = type; - seq_id = BufferHelper::drainI32(buffer); + metadata.setMessageType(type); + metadata.setSequenceId(BufferHelper::drainI32(buffer)); return true; } @@ -253,7 +253,7 @@ bool BinaryProtocolImpl::readString(Buffer::Instance& buffer, std::string& value } buffer.drain(4); - value.assign(static_cast(buffer.linearize(str_len)), str_len); + value.assign(static_cast(buffer.linearize(str_len)), str_len); buffer.drain(str_len); return true; } @@ -262,12 +262,12 @@ bool BinaryProtocolImpl::readBinary(Buffer::Instance& buffer, std::string& value return readString(buffer, value); } -void BinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name, - MessageType msg_type, int32_t seq_id) { +void BinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, + const MessageMetadata& metadata) { BufferHelper::writeU16(buffer, Magic); - BufferHelper::writeU16(buffer, static_cast(msg_type)); - writeString(buffer, name); - BufferHelper::writeI32(buffer, seq_id); + BufferHelper::writeU16(buffer, static_cast(metadata.messageType())); + writeString(buffer, metadata.methodName()); + BufferHelper::writeI32(buffer, metadata.sequenceId()); } void BinaryProtocolImpl::writeMessageEnd(Buffer::Instance& buffer) { @@ -362,8 +362,7 @@ void BinaryProtocolImpl::writeBinary(Buffer::Instance& buffer, const std::string writeString(buffer, value); } -bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name, - MessageType& msg_type, int32_t& seq_id) { +bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) { // Minimum message length: // name len: 4 bytes + // name: 0 bytes + @@ -387,24 +386,25 @@ bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::stri buffer.drain(4); if (name_len > 0) { - name.assign(std::string(static_cast(buffer.linearize(name_len)), name_len)); + metadata.setMethodName( + std::string(static_cast(buffer.linearize(name_len)), name_len)); buffer.drain(name_len); } else { - name.clear(); + metadata.setMethodName(""); } - msg_type = type; - seq_id = BufferHelper::peekI32(buffer, 1); + metadata.setMessageType(type); + metadata.setSequenceId(BufferHelper::peekI32(buffer, 1)); buffer.drain(5); return true; } -void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name, - MessageType msg_type, int32_t seq_id) { - writeString(buffer, name); - BufferHelper::writeI8(buffer, static_cast(msg_type)); - BufferHelper::writeI32(buffer, seq_id); +void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, + const MessageMetadata& metadata) { + writeString(buffer, metadata.methodName()); + BufferHelper::writeI8(buffer, static_cast(metadata.messageType())); + BufferHelper::writeI32(buffer, metadata.sequenceId()); } class BinaryProtocolConfigFactory : public ProtocolFactoryBase { diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h index e292d6cd036b9..75f5f2574c4a8 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h @@ -23,8 +23,7 @@ class BinaryProtocolImpl : public Protocol { // Protocol const std::string& name() const override { return ProtocolNames::get().BINARY; } ProtocolType type() const override { return ProtocolType::Binary; } - bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, - int32_t& seq_id) override; + bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override; bool readMessageEnd(Buffer::Instance& buffer) override; bool readStructBegin(Buffer::Instance& buffer, std::string& name) override; bool readStructEnd(Buffer::Instance& buffer) override; @@ -46,8 +45,7 @@ class BinaryProtocolImpl : public Protocol { bool readDouble(Buffer::Instance& buffer, double& value) override; bool readString(Buffer::Instance& buffer, std::string& value) override; bool readBinary(Buffer::Instance& buffer, std::string& value) override; - void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, - int32_t seq_id) override; + void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) override; void writeMessageEnd(Buffer::Instance& buffer) override; void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override; void writeStructEnd(Buffer::Instance& buffer) override; @@ -86,10 +84,8 @@ class LaxBinaryProtocolImpl : public BinaryProtocolImpl { const std::string& name() const override { return ProtocolNames::get().LAX_BINARY; } - bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, - int32_t& seq_id) override; - void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, - int32_t seq_id) override; + bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override; + void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) override; }; } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc index 417a80d8b6197..33421e83a7e51 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc @@ -18,8 +18,7 @@ namespace ThriftProxy { const uint16_t CompactProtocolImpl::Magic = 0x8201; const uint16_t CompactProtocolImpl::MagicMask = 0xFF1F; -bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name, - MessageType& msg_type, int32_t& seq_id) { +bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) { // Minimum message length: // protocol, message type, and version: 2 bytes + // seq id (var int): 1 byte + @@ -64,13 +63,14 @@ bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string buffer.drain(id_size + name_len_size + 2); if (name_len > 0) { - name.assign(std::string(static_cast(buffer.linearize(name_len)), name_len)); + metadata.setMethodName( + std::string(static_cast(buffer.linearize(name_len)), name_len)); buffer.drain(name_len); } else { - name.clear(); + metadata.setMethodName(""); } - msg_type = type; - seq_id = id; + metadata.setMessageType(type); + metadata.setSequenceId(id); return true; } @@ -373,7 +373,7 @@ bool CompactProtocolImpl::readString(Buffer::Instance& buffer, std::string& valu } buffer.drain(len_size); - value.assign(static_cast(buffer.linearize(str_len)), str_len); + value.assign(static_cast(buffer.linearize(str_len)), str_len); buffer.drain(str_len); return true; } @@ -382,17 +382,17 @@ bool CompactProtocolImpl::readBinary(Buffer::Instance& buffer, std::string& valu return readString(buffer, value); } -void CompactProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name, - MessageType msg_type, int32_t seq_id) { - UNREFERENCED_PARAMETER(name); +void CompactProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, + const MessageMetadata& metadata) { + MessageType msg_type = metadata.messageType(); uint16_t ptv = (Magic & MagicMask) | (static_cast(msg_type) << 5); ASSERT((ptv & MagicMask) == Magic); ASSERT((ptv & ~MagicMask) >> 5 == static_cast(msg_type)); BufferHelper::writeU16(buffer, ptv); - BufferHelper::writeVarIntI32(buffer, seq_id); - writeString(buffer, name); + BufferHelper::writeVarIntI32(buffer, metadata.sequenceId()); + writeString(buffer, metadata.methodName()); } void CompactProtocolImpl::writeMessageEnd(Buffer::Instance& buffer) { diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h index 322d03a3a83da..b099290e3c229 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h @@ -26,8 +26,7 @@ class CompactProtocolImpl : public Protocol { // Protocol const std::string& name() const override { return ProtocolNames::get().COMPACT; } ProtocolType type() const override { return ProtocolType::Compact; } - bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, - int32_t& seq_id) override; + bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override; bool readMessageEnd(Buffer::Instance& buffer) override; bool readStructBegin(Buffer::Instance& buffer, std::string& name) override; bool readStructEnd(Buffer::Instance& buffer) override; @@ -49,8 +48,7 @@ class CompactProtocolImpl : public Protocol { bool readDouble(Buffer::Instance& buffer, double& value) override; bool readString(Buffer::Instance& buffer, std::string& value) override; bool readBinary(Buffer::Instance& buffer, std::string& value) override; - void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, - int32_t seq_id) override; + void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) override; void writeMessageEnd(Buffer::Instance& buffer) override; void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override; void writeStructEnd(Buffer::Instance& buffer) override; diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 630edfdf00ce6..01ec9a212afef 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -45,8 +45,8 @@ class ConfigImpl : public Config, void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override; // Router::Config - Router::RouteConstSharedPtr route(const std::string& method_name) const override { - return route_matcher_->route(method_name); + Router::RouteConstSharedPtr route(const MessageMetadata& metadata) const override { + return route_matcher_->route(metadata); } // Config diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index c94bbeefcec36..b5138c22cc98f 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -7,6 +7,7 @@ #include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" #include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" namespace Envoy { @@ -30,7 +31,7 @@ Network::FilterStatus ConnectionManager::onData(Buffer::Instance& data, bool end void ConnectionManager::dispatch() { if (stopped_) { - ENVOY_LOG(error, "thrift filter stopped"); + ENVOY_LOG(debug, "thrift filter stopped"); return; } @@ -43,16 +44,47 @@ void ConnectionManager::dispatch() { break; } } + + return; + } catch (const AppException& ex) { + ENVOY_LOG(error, "thrift application exception: {}", ex.what()); + if (rpcs_.empty()) { + MessageMetadata metadata; + sendLocalReply(metadata, ex); + } else { + sendLocalReply(*(*rpcs_.begin())->metadata_, ex); + } } catch (const EnvoyException& ex) { ENVOY_LOG(error, "thrift error: {}", ex.what()); - stats_.request_decoding_error_.inc(); // Use the current rpc to send an error downstream, if possible. rpcs_.front()->onError(ex.what()); - - resetAllRpcs(); - read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); } + + stats_.request_decoding_error_.inc(); + resetAllRpcs(); + read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); +} + +void ConnectionManager::sendLocalReply(MessageMetadata& metadata, const DirectResponse& response) { + // Use the factory to get the concrete protocol from the decoder protocol (as opposed to + // potentially pre-detection auto protocol). + ProtocolType proto_type = decoder_->protocolType(); + ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(proto_type).createProtocol(); + Buffer::OwnedImpl buffer; + + response.encode(metadata, *proto, buffer); + + // Same logic as protocol above. + TransportPtr transport = + NamedTransportConfigFactory::getFactory(decoder_->transportType()).createTransport(); + + Buffer::OwnedImpl response_buffer; + + metadata.setProtocol(proto_type); + transport->encodeFrame(response_buffer, metadata, buffer); + + read_callbacks_->connection().write(response_buffer, false); } void ConnectionManager::continueDecoding() { @@ -105,12 +137,12 @@ bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) { return complete_; } -ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::messageBegin(absl::string_view name, - MessageType msg_type, - int32_t seq_id) { - reply_.emplace(std::string(name), msg_type, seq_id); - first_reply_field_ = (msg_type == MessageType::Reply); - return ProtocolConverter::messageBegin(name, msg_type, seq_id); +ThriftFilters::FilterStatus +ConnectionManager::ResponseDecoder::messageBegin(MessageMetadataSharedPtr metadata) { + metadata_ = metadata; + first_reply_field_ = + (metadata->hasMessageType() && metadata->messageType() == MessageType::Reply); + return ProtocolConverter::messageBegin(metadata); } ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::fieldBegin(absl::string_view name, @@ -120,8 +152,7 @@ ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::fieldBegin(absl: // Reply messages contain a struct where field 0 is the call result and fields 1+ are // exceptions, if defined. At most one field may be set. Therefore, the very first field we // encounter in a reply is either field 0 (success) or not (IDL exception returned). - ASSERT(reply_.has_value()); - reply_.value().success_ = field_id == 0 && field_type != FieldType::Stop; + success_ = field_id == 0 && field_type != FieldType::Stop; first_reply_field_ = false; } @@ -129,6 +160,8 @@ ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::fieldBegin(absl: } ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { + ASSERT(metadata_ != nullptr); + ConnectionManager& cm = parent_.parent_; Buffer::OwnedImpl buffer; @@ -138,18 +171,20 @@ ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { TransportPtr transport = NamedTransportConfigFactory::getFactory(parent_.parent_.decoder_->transportType()) .createTransport(); - transport->encodeFrame(buffer, parent_.response_buffer_); + + metadata_->setProtocol(parent_.parent_.decoder_->protocolType()); + metadata_->setSequenceId(parent_.metadata_->sequenceId()); + transport->encodeFrame(buffer, *metadata_, parent_.response_buffer_); complete_ = true; cm.read_callbacks_->connection().write(buffer, false); cm.stats_.response_.inc(); - ASSERT(reply_.has_value()); - switch (reply_.value().msg_type_) { + switch (metadata_->messageType()) { case MessageType::Reply: cm.stats_.response_reply_.inc(); - if (reply_.value().success_.value_or(false)) { + if (success_.value_or(false)) { cm.stats_.response_success_.inc(); } else { cm.stats_.response_error_.inc(); @@ -170,11 +205,11 @@ ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { } ThriftFilters::FilterStatus ConnectionManager::ActiveRpc::transportEnd() { - ASSERT(call_.has_value()); + ASSERT(metadata_ != nullptr && metadata_->hasMessageType()); parent_.stats_.request_.inc(); - switch (call_.value().msg_type_) { + switch (metadata_->messageType()) { case MessageType::Call: parent_.stats_.request_call_.inc(); break; @@ -204,10 +239,8 @@ void ConnectionManager::ActiveRpc::onReset() { } void ConnectionManager::ActiveRpc::onError(const std::string& what) { - if (call_.has_value()) { - const Message& msg = call_.value(); - sendLocalReply(std::make_unique(msg.method_name_, msg.seq_id_, - AppExceptionType::ProtocolError, what)); + if (metadata_) { + sendLocalReply(AppException(AppExceptionType::ProtocolError, what)); return; } @@ -223,9 +256,8 @@ void ConnectionManager::ActiveRpc::continueDecoding() { parent_.continueDecoding Router::RouteConstSharedPtr ConnectionManager::ActiveRpc::route() { if (!cached_route_) { - if (call_.has_value()) { - Router::RouteConstSharedPtr route = - parent_.config_.routerConfig().route(call_.value().method_name_); + if (metadata_ != nullptr) { + Router::RouteConstSharedPtr route = parent_.config_.routerConfig().route(*metadata_); cached_route_ = std::move(route); } else { cached_route_ = nullptr; @@ -235,21 +267,8 @@ Router::RouteConstSharedPtr ConnectionManager::ActiveRpc::route() { return cached_route_.value(); } -void ConnectionManager::ActiveRpc::sendLocalReply(ThriftFilters::DirectResponsePtr&& response) { - // Use the factory to get the concrete protocol from the decoder protocol (as opposed to - // potentially pre-detection auto protocol). - ProtocolPtr proto = - NamedProtocolConfigFactory::getFactory(parent_.decoder_->protocolType()).createProtocol(); - Buffer::OwnedImpl buffer; - - response->encode(*proto, buffer); - - // Same logic as protocol above. - TransportPtr transport = - NamedTransportConfigFactory::getFactory(parent_.decoder_->transportType()).createTransport(); - transport->encodeFrame(response_buffer_, buffer); - - parent_.read_callbacks_->connection().write(response_buffer_, false); +void ConnectionManager::ActiveRpc::sendLocalReply(const DirectResponse& response) { + parent_.sendLocalReply(*metadata_, response); parent_.doDeferredRpcDestroy(*this); } @@ -269,6 +288,13 @@ bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) { parent_.doDeferredRpcDestroy(*this); } return complete; + } catch (const AppException& ex) { + ENVOY_LOG(error, "thrift response application error: {}", ex.what()); + parent_.stats_.response_decoding_error_.inc(); + + sendLocalReply(ex); + decoder_filter_->resetUpstreamConnection(); + return true; } catch (const EnvoyException& ex) { ENVOY_LOG(error, "thrift response error: {}", ex.what()); parent_.stats_.response_decoding_error_.inc(); diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index c366a40c0f2a5..5c4648b7bb5a0 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -61,17 +61,6 @@ class ConnectionManager : public Network::ReadFilter, ThriftFilters::DecoderFilter& newDecoderFilter() override; private: - class Message { - public: - Message(const std::string& method_name, MessageType msg_type, int32_t seq_id) - : method_name_(method_name), msg_type_(msg_type), seq_id_(seq_id) {} - - const std::string method_name_; - const MessageType msg_type_; - const int32_t seq_id_; - absl::optional success_; - }; - struct ActiveRpc; struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { @@ -92,12 +81,11 @@ class ConnectionManager : public Network::ReadFilter, bool onData(Buffer::Instance& data); // ProtocolConverter - ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, - int32_t seq_id) override; + ThriftFilters::FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override; ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, int16_t field_id) override; - ThriftFilters::FilterStatus transportBegin(absl::optional size) override { - UNREFERENCED_PARAMETER(size); + ThriftFilters::FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override { + UNREFERENCED_PARAMETER(metadata); return ThriftFilters::FilterStatus::Continue; } ThriftFilters::FilterStatus transportEnd() override; @@ -108,7 +96,8 @@ class ConnectionManager : public Network::ReadFilter, ActiveRpc& parent_; DecoderPtr decoder_; Buffer::OwnedImpl upstream_buffer_; - absl::optional reply_; + MessageMetadataSharedPtr metadata_; + absl::optional success_; bool complete_ : 1; bool first_reply_field_ : 1; }; @@ -140,14 +129,13 @@ class ConnectionManager : public Network::ReadFilter, NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - ThriftFilters::FilterStatus transportBegin(absl::optional size) override { - return decoder_filter_->transportBegin(size); + ThriftFilters::FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override { + return decoder_filter_->transportBegin(metadata); } ThriftFilters::FilterStatus transportEnd() override; - ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, - int32_t seq_id) override { - call_.emplace(std::string(name), msg_type, seq_id); - return decoder_filter_->messageBegin(name, msg_type, seq_id); + ThriftFilters::FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override { + metadata_ = metadata; + return decoder_filter_->messageBegin(metadata); } ThriftFilters::FilterStatus messageEnd() override { return decoder_filter_->messageEnd(); } ThriftFilters::FilterStatus structBegin(absl::string_view name) override { @@ -205,7 +193,7 @@ class ConnectionManager : public Network::ReadFilter, ProtocolType downstreamProtocolType() const override { return parent_.decoder_->protocolType(); } - void sendLocalReply(ThriftFilters::DirectResponsePtr&& response) override; + void sendLocalReply(const DirectResponse& response) override; void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; bool upstreamData(Buffer::Instance& buffer) override; void resetDownstreamConnection() override; @@ -224,10 +212,10 @@ class ConnectionManager : public Network::ReadFilter, ConnectionManager& parent_; Stats::TimespanPtr request_timer_; uint64_t stream_id_; + MessageMetadataSharedPtr metadata_; ThriftFilters::DecoderFilterSharedPtr decoder_filter_; ResponseDecoderPtr response_decoder_; absl::optional cached_route_; - absl::optional call_; Buffer::OwnedImpl response_buffer_; }; @@ -235,6 +223,7 @@ class ConnectionManager : public Network::ReadFilter, void continueDecoding(); void dispatch(); + void sendLocalReply(MessageMetadata& metadata, const DirectResponse& reponse); void doDeferredRpcDestroy(ActiveRpc& rpc); void resetAllRpcs(); diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index 39aa8af6a2836..d3d69b8e09792 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -7,6 +7,9 @@ #include "common/common/assert.h" #include "common/common/macros.h" +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -14,18 +17,14 @@ namespace ThriftProxy { // MessageBegin -> StructBegin DecoderStateMachine::DecoderStatus DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { - std::string message_name; - MessageType msg_type; - int32_t seq_id; - if (!proto_.readMessageBegin(buffer, message_name, msg_type, seq_id)) { + if (!proto_.readMessageBegin(buffer, *metadata_)) { return DecoderStatus(ProtocolState::WaitForData); } stack_.clear(); stack_.emplace_back(Frame(ProtocolState::MessageEnd)); - return DecoderStatus(ProtocolState::StructBegin, - filter_.messageBegin(absl::string_view(message_name), msg_type, seq_id)); + return DecoderStatus(ProtocolState::StructBegin, filter_.messageBegin(metadata_)); } // MessageEnd -> Done @@ -375,8 +374,11 @@ ThriftFilters::FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer if (!frame_started_) { // Look for start of next frame. - absl::optional size{}; - if (!transport_->decodeFrameStart(data, size)) { + if (!metadata_) { + metadata_ = std::make_shared(); + } + + if (!transport_->decodeFrameStart(data, *metadata_)) { ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); buffer_underflow = true; return ThriftFilters::FilterStatus::Continue; @@ -385,9 +387,10 @@ ThriftFilters::FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer request_ = std::make_unique(callbacks_.newDecoderFilter()); frame_started_ = true; - state_machine_ = std::make_unique(*protocol_, request_->filter_); + state_machine_ = + std::make_unique(*protocol_, metadata_, request_->filter_); - if (request_->filter_.transportBegin(size) == ThriftFilters::FilterStatus::StopIteration) { + if (request_->filter_.transportBegin(metadata_) == ThriftFilters::FilterStatus::StopIteration) { return ThriftFilters::FilterStatus::StopIteration; } } @@ -417,6 +420,7 @@ ThriftFilters::FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer } frame_ended_ = true; + metadata_.reset(); ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); if (request_->filter_.transportEnd() == ThriftFilters::FilterStatus::StopIteration) { diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index 26068858beb6d..660e375e98b79 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -63,8 +63,9 @@ class ProtocolStateNameValues { */ class DecoderStateMachine { public: - DecoderStateMachine(Protocol& proto, ThriftFilters::DecoderFilter& filter) - : proto_(proto), filter_(filter), state_(ProtocolState::MessageBegin) {} + DecoderStateMachine(Protocol& proto, MessageMetadataSharedPtr& metadata, + ThriftFilters::DecoderFilter& filter) + : proto_(proto), metadata_(metadata), filter_(filter), state_(ProtocolState::MessageBegin) {} /** * Consumes as much data from the configured Buffer as possible and executes the decoding state @@ -162,6 +163,7 @@ class DecoderStateMachine { ProtocolState popReturnState(); Protocol& proto_; + MessageMetadataSharedPtr metadata_; ThriftFilters::DecoderFilter& filter_; ProtocolState state_; std::vector stack_; @@ -215,6 +217,7 @@ class Decoder : public Logger::Loggable { ProtocolPtr protocol_; DecoderCallbacks& callbacks_; ActiveRequestPtr request_; + MessageMetadataSharedPtr metadata_; DecoderStateMachinePtr state_machine_; bool frame_started_{false}; bool frame_ended_{false}; diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index 969ffcadfc46c..476e02cadf0bd 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -20,20 +20,6 @@ namespace NetworkFilters { namespace ThriftProxy { namespace ThriftFilters { -class DirectResponse { -public: - virtual ~DirectResponse() {} - - /** - * Encodes the response via the given Protocol. - * @param proto the Protocol to be used for message encoding - * @param buffer the Buffer into which the message should be encoded - */ - virtual void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) PURE; -}; - -typedef std::unique_ptr DirectResponsePtr; - /** * Decoder filter callbacks add additional callbacks. */ @@ -77,9 +63,9 @@ class DecoderFilterCallbacks { /** * Create a locally generated response using the provided response object. - * @param response DirectResponsePtr the response to send to the downstream client + * @param response DirectResponse the response to send to the downstream client */ - virtual void sendLocalReply(DirectResponsePtr&& response) PURE; + virtual void sendLocalReply(const ThriftProxy::DirectResponse& response) PURE; /** * Indicates the start of an upstream response. May only be called once. @@ -142,9 +128,10 @@ class DecoderFilter { /** * Indicates the start of a Thrift transport frame was detected. Unframed transports generate * simulated start messages. - * @param size the size of the message, if available to the transport + * @param metadata MessageMetadataSharedPtr describing as much as is currently known about the + * message */ - virtual FilterStatus transportBegin(absl::optional size) PURE; + virtual FilterStatus transportBegin(MessageMetadataSharedPtr metadata) PURE; /** * Indicates the end of a Thrift transport frame was detected. Unframed transport generate @@ -154,13 +141,10 @@ class DecoderFilter { /** * Indicates that the start of a Thrift protocol message was detected. - * @param name the name of the message, if available - * @param msg_type the type of the message - * @param seq_id the message sequence id + * @param metadata MessageMetadataSharedPtr describing the message * @return FilterStatus to indicate if filter chain iteration should continue */ - virtual FilterStatus messageBegin(absl::string_view name, MessageType msg_type, - int32_t seq_id) PURE; + virtual FilterStatus messageBegin(MessageMetadataSharedPtr metadata) PURE; /** * Indicates that the end of a Thrift protocol message was detected. diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc index a45861a349e4b..93b528a5dc56d 100644 --- a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc @@ -10,8 +10,9 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer, - absl::optional& size) { +bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) { + UNREFERENCED_PARAMETER(metadata); + if (buffer.length() < 4) { return false; } @@ -24,13 +25,16 @@ bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer, buffer.drain(4); - size = static_cast(thrift_size); + metadata.setFrameSize(static_cast(thrift_size)); return true; } bool FramedTransportImpl::decodeFrameEnd(Buffer::Instance&) { return true; } -void FramedTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) { +void FramedTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) { + UNREFERENCED_PARAMETER(metadata); + uint64_t size = message.length(); if (size == 0 || size > MaxFrameSize) { throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", size)); diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h index 4c7569487ea38..16d51c57093cb 100644 --- a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h @@ -24,9 +24,10 @@ class FramedTransportImpl : public Transport { // Transport const std::string& name() const override { return TransportNames::get().FRAMED; } TransportType type() const override { return TransportType::Framed; } - bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; + bool decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) override; bool decodeFrameEnd(Buffer::Instance& buffer) override; - void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; + void encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) override; static const int32_t MaxFrameSize = 0xFA0000; }; diff --git a/source/extensions/filters/network/thrift_proxy/metadata.cc b/source/extensions/filters/network/thrift_proxy/metadata.cc new file mode 100644 index 0000000000000..1176773478840 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/metadata.cc @@ -0,0 +1,47 @@ +#include "extensions/filters/network/thrift_proxy/metadata.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +HeaderMap::HeaderMap(const std::initializer_list>& values) { + for (auto& value : values) { + headers_.emplace_back(Header(value.first, value.second)); + } +} + +HeaderMap::HeaderMap(const HeaderMap& rhs) { + for (const auto& header : rhs.headers_) { + headers_.emplace_back(header); + } +} + +bool HeaderMap::operator==(const HeaderMap& rhs) const { + if (headers_.size() != rhs.headers_.size()) { + return false; + } + + for (auto i = headers_.begin(), j = rhs.headers_.begin(); i != headers_.end(); ++i, ++j) { + if (i->key() != j->key() || i->value() != j->value()) { + return false; + } + } + + return true; +} + +Header* HeaderMap::get(const std::string& key) { + for (Header& header : headers_) { + if (header.key() == key) { + return &header; + } + } + + return nullptr; +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/metadata.h b/source/extensions/filters/network/thrift_proxy/metadata.h new file mode 100644 index 0000000000000..a7025739df36d --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/metadata.h @@ -0,0 +1,156 @@ +#pragma once + +#include + +#include +#include +#include + +#include "common/common/macros.h" + +#include "extensions/filters/network/thrift_proxy/thrift.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * Header is a name-value pair in Thrift transport or protocol headers. + */ +class Header { +public: + Header(const std::string key, const std::string value) : key_(key), value_(value) {} + Header(const Header& rhs) : key_(rhs.key_), value_(rhs.value_) {} + + const std::string& key() const { return key_; } + const std::string& value() const { return value_; } + +private: + std::string key_; + std::string value_; +}; + +// TODO(zuercher): replace this with Http::HeaderMap[Impl] +/* + * HeaderMap contains Thrift transport and/or protocol-level headers. + */ +class HeaderMap { +public: + HeaderMap() {} + HeaderMap(const std::initializer_list>& values); + HeaderMap(const HeaderMap& rhs); + + /** + * @return true if the HeaderMap is empty + */ + bool empty() const { return headers_.empty(); } + + /** + * @return uint32_t the number of headers in the map + */ + uint32_t size() const { return headers_.size(); } + + /** + * @param header Header to move into the HeaderMap + */ + void add(Header&& header) { headers_.emplace_back(std::move(header)); } + + /** + * Clears all Headers from the HeaderMap. + */ + void clear() { headers_.clear(); } + + /** + * Retrieves a Header from the HeaderMap. + * @param key std::string containing the key to lookup + * @return Header* corresponding to key or nullptr if not found. + */ + Header* get(const std::string& key); + + /** + * Const iterators for the HeaderMap. + */ + std::list
::const_iterator begin() const noexcept { return headers_.begin(); } + std::list
::const_iterator end() const noexcept { return headers_.end(); } + std::list
::const_iterator cbegin() const noexcept { return headers_.cbegin(); } + std::list
::const_iterator cend() const noexcept { return headers_.cend(); } + + /** + * For testing. Equality is based on equality of the backing list. This is an exact match + * comparison (order matters). + */ + bool operator==(const HeaderMap& rhs) const; + + /** + * @return an empty HeaderMap + */ + static const HeaderMap& emptyHeaderMap() { CONSTRUCT_ON_FIRST_USE(HeaderMap, HeaderMap({})); } + +private: + std::list
headers_; +}; + +/** + * MessageMetadata encapsulates metadata about Thrift messages. The various fields are considered + * optional since they may come from either the transport or protocol in some cases. Unless + * otherwise noted, accessor methods throw absl::bad_optional_access if the corresponding value has + * not been set. + */ +class MessageMetadata { +public: + MessageMetadata() {} + + bool hasFrameSize() const { return frame_size_.has_value(); } + uint32_t frameSize() const { return frame_size_.value(); } + void setFrameSize(uint32_t size) { frame_size_ = size; } + + bool hasProtocol() const { return proto_.has_value(); } + ProtocolType protocol() const { return proto_.value(); } + void setProtocol(ProtocolType proto) { proto_ = proto; } + + bool hasMethodName() const { return method_name_.has_value(); } + const std::string& methodName() const { return method_name_.value(); } + void setMethodName(const std::string& method_name) { method_name_ = method_name; } + + bool hasSequenceId() const { return seq_id_.has_value(); } + int32_t sequenceId() const { return seq_id_.value(); } + void setSequenceId(int32_t seq_id) { seq_id_ = seq_id; } + + bool hasMessageType() const { return msg_type_.has_value(); } + MessageType messageType() const { return msg_type_.value(); } + void setMessageType(MessageType msg_type) { msg_type_ = msg_type; } + + void addHeader(Header&& header) { headers_.add(std::move(header)); } + /** + * @return HeaderMap of current headers (never throws) + */ + const HeaderMap& headers() const { return headers_; } + + bool hasAppException() const { return app_ex_type_.has_value(); } + void setAppException(AppExceptionType app_ex_type, const std::string& message) { + app_ex_type_ = app_ex_type; + app_ex_msg_ = message; + } + AppExceptionType appExceptionType() const { return app_ex_type_.value(); } + const std::string& appExceptionMessage() const { return app_ex_msg_.value(); } + +private: + absl::optional frame_size_{}; + absl::optional proto_{}; + absl::optional method_name_{}; + absl::optional seq_id_{}; + absl::optional msg_type_{}; + HeaderMap headers_; + absl::optional app_ex_type_; + absl::optional app_ex_msg_; +}; + +typedef std::shared_ptr MessageMetadataSharedPtr; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 02f2808427e08..7d5baa3b9d9f6 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -11,6 +11,9 @@ #include "common/config/utility.h" #include "common/singleton/const_singleton.h" +#include "extensions/filters/network/thrift_proxy/metadata.h" +#include "extensions/filters/network/thrift_proxy/thrift.h" + #include "absl/strings/string_view.h" namespace Envoy { @@ -18,88 +21,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -enum class ProtocolType { - Binary, - LaxBinary, - Compact, - Auto, - - // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST PROTOCOL TYPE - LastProtocolType = Auto, -}; - -/** - * Names of available Protocol implementations. - */ -class ProtocolNameValues { -public: - // Binary protocol - const std::string BINARY = "binary"; - - // Lax Binary protocol - const std::string LAX_BINARY = "binary/non-strict"; - - // Compact protocol - const std::string COMPACT = "compact"; - - // Auto-detection protocol - const std::string AUTO = "auto"; - - const std::string& fromType(ProtocolType type) const { - switch (type) { - case ProtocolType::Binary: - return BINARY; - case ProtocolType::LaxBinary: - return LAX_BINARY; - case ProtocolType::Compact: - return COMPACT; - case ProtocolType::Auto: - return AUTO; - default: - NOT_REACHED_GCOVR_EXCL_LINE; - } - } -}; - -typedef ConstSingleton ProtocolNames; - -/** - * Thrift protocol message types. - * See https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TProtocol.h - */ -enum class MessageType { - Call = 1, - Reply = 2, - Exception = 3, - Oneway = 4, - - // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST MESSAGE TYPE - LastMessageType = Oneway, -}; - -/** - * Thrift protocol struct field types. - * See https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TProtocol.h - */ -enum class FieldType { - Stop = 0, - Void = 1, - Bool = 2, - Byte = 3, - Double = 4, - I16 = 6, - I32 = 8, - I64 = 10, - String = 11, - Struct = 12, - Map = 13, - Set = 14, - List = 15, - - // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST FIELD TYPE - LastFieldType = List, -}; - /** * Protocol represents the operations necessary to implement the a generic Thrift protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-protocol-spec.md @@ -119,18 +40,15 @@ class Protocol { virtual ProtocolType type() const PURE; /** - * Reads the start of a Thrift protocol message from the buffer and updates the name, msg_type, - * and seq_id parameters with values from the message header. If successful, the message header - * is removed from the buffer. + * Reads the start of a Thrift protocol message from the buffer and updates the metadata + * parameter with values from the message header. If successful, the message header is removed + * from the buffer. * @param buffer the buffer to read from - * @param name updated with the message name on success only - * @param msg_type updated with the MessageType on success only - * @param seq_id updated with the message sequence ID on success only + * @param metadata MessageMetadata to be updated with name, message type, and sequence id. * @return true if a message header was sucessfully read, false if more data is required * @throw EnvoyException if the data is not a valid message header */ - virtual bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, - int32_t& seq_id) PURE; + virtual bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) PURE; /** * Reads the end of a Thrift protocol message from the buffer. If successful, the message footer @@ -331,12 +249,9 @@ class Protocol { /** * Writes the start of a Thrift protocol message to the buffer. * @param buffer Buffer::Instance to modify - * @param name the message name - * @param msg_type the message's MessageType - * @param seq_id the message sequende ID + * @param metadata MessageMetadata for the message to write. */ - virtual void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, - MessageType msg_type, int32_t seq_id) PURE; + virtual void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) PURE; /** * Writes the end of a Thrift protocol message to the buffer. @@ -522,6 +437,25 @@ template class ProtocolFactoryBase : public NamedProtocolCo const std::string name_; }; +/** + * A DirectResponse manipulates a Protocol to directly create a Thrift response message. + */ +class DirectResponse { +public: + virtual ~DirectResponse() {} + + /** + * Encodes the response via the given Protocol. + * @param metadata the MessageMetadata for the request that generated this response + * @param proto the Protocol to be used for message encoding + * @param buffer the Buffer into which the message should be encoded + */ + virtual void encode(MessageMetadata& metadata, Protocol& proto, + Buffer::Instance& buffer) const PURE; +}; + +typedef std::unique_ptr DirectResponsePtr; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h index af7b6dfa2af3d..0695bb9475a4f 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_converter.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -31,9 +31,8 @@ class ProtocolConverter : public ThriftFilters::DecoderFilter { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, - int32_t seq_id) override { - proto_->writeMessageBegin(*buffer_, std::string(name), msg_type, seq_id); + ThriftFilters::FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override { + proto_->writeMessageBegin(*buffer_, *metadata); return ThriftFilters::FilterStatus::Continue; } diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc index 46636099f5430..d0780b20d9591 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc @@ -17,8 +17,7 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name, - MessageType& msg_type, int32_t& seq_id) { +bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) { if (protocol_ == nullptr) { if (buffer.length() < 2) { return false; @@ -29,15 +28,15 @@ bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& n setProtocol(std::make_unique()); } else if (CompactProtocolImpl::isMagic(version)) { setProtocol(std::make_unique()); - } else { + } + + if (!protocol_) { throw EnvoyException( fmt::format("unknown thrift auto protocol message start {:04x}", version)); } - - ASSERT(protocol_ != nullptr); } - return protocol_->readMessageBegin(buffer, name, msg_type, seq_id); + return protocol_->readMessageBegin(buffer, metadata); } bool AutoProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.h b/source/extensions/filters/network/thrift_proxy/protocol_impl.h index 0eb374e9f9869..227941ca5811a 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.h @@ -31,8 +31,7 @@ class AutoProtocolImpl : public Protocol { return ProtocolType::Auto; } - bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, - int32_t& seq_id) override; + bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override; bool readMessageEnd(Buffer::Instance& buffer) override; bool readStructBegin(Buffer::Instance& buffer, std::string& name) override { return protocol_->readStructBegin(buffer, name); @@ -80,9 +79,8 @@ class AutoProtocolImpl : public Protocol { bool readBinary(Buffer::Instance& buffer, std::string& value) override { return protocol_->readBinary(buffer, value); } - void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type, - int32_t seq_id) override { - protocol_->writeMessageBegin(buffer, name, msg_type, seq_id); + void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) override { + protocol_->writeMessageBegin(buffer, metadata); } void writeMessageEnd(Buffer::Instance& buffer) override { protocol_->writeMessageEnd(buffer); } void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override { diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h index 32d717de52351..7f995e5e25a1d 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router.h +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -47,10 +47,10 @@ class Config { /** * Based on the incoming Thrift request transport and/or protocol data, determine the target * route for the request. - * @param method supplies the thrift method name + * @param metadata MessageMetadata for the message to route * @return the route or nullptr if there is no matching route for the request. */ - virtual RouteConstSharedPtr route(const std::string& method) const PURE; + virtual RouteConstSharedPtr route(const MessageMetadata& metadata) const PURE; }; typedef std::shared_ptr ConfigConstSharedPtr; diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index 2b4fb77946c35..4808cdb175af2 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -26,8 +26,12 @@ MethodNameRouteEntryImpl::MethodNameRouteEntryImpl( const envoy::config::filter::network::thrift_proxy::v2alpha1::Route& route) : RouteEntryImplBase(route), method_name_(route.match().method()) {} -RouteConstSharedPtr MethodNameRouteEntryImpl::matches(const std::string& method_name) const { - if (method_name_.empty() || method_name_ == method_name) { +RouteConstSharedPtr MethodNameRouteEntryImpl::matches(const MessageMetadata& metadata) const { + if (method_name_.empty()) { + return clusterEntry(); + } + + if (metadata.hasMethodName() && metadata.methodName() == method_name_) { return clusterEntry(); } @@ -41,9 +45,9 @@ RouteMatcher::RouteMatcher( } } -RouteConstSharedPtr RouteMatcher::route(const std::string& method_name) const { +RouteConstSharedPtr RouteMatcher::route(const MessageMetadata& metadata) const { for (const auto& route : routes_) { - RouteConstSharedPtr route_entry = route->matches(method_name); + RouteConstSharedPtr route_entry = route->matches(metadata); if (nullptr != route_entry) { return route_entry; } @@ -71,13 +75,13 @@ void Router::resetUpstreamConnection() { } } -ThriftFilters::FilterStatus Router::transportBegin(absl::optional size) { - UNREFERENCED_PARAMETER(size); +ThriftFilters::FilterStatus Router::transportBegin(MessageMetadataSharedPtr metadata) { + UNREFERENCED_PARAMETER(metadata); return ThriftFilters::FilterStatus::Continue; } ThriftFilters::FilterStatus Router::transportEnd() { - if (upstream_request_->msg_type_ == MessageType::Oneway) { + if (upstream_request_->metadata_->messageType() == MessageType::Oneway) { // No response expected upstream_request_->onResponseComplete(); cleanup(); @@ -85,17 +89,17 @@ ThriftFilters::FilterStatus Router::transportEnd() { return ThriftFilters::FilterStatus::Continue; } -ThriftFilters::FilterStatus Router::messageBegin(absl::string_view name, MessageType msg_type, - int32_t seq_id) { +ThriftFilters::FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { // TODO(zuercher): route stats (e.g., no_route, no_cluster, upstream_rq_maintenance_mode, no // healtthy upstream) route_ = callbacks_->route(); if (!route_) { - ENVOY_STREAM_LOG(debug, "no cluster match for method '{}'", *callbacks_, name); - callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ - new AppException(name, seq_id, AppExceptionType::UnknownMethod, - fmt::format("no route for method '{}'", name))}); + ENVOY_STREAM_LOG(debug, "no cluster match for method '{}'", *callbacks_, + metadata->methodName()); + callbacks_->sendLocalReply( + AppException(AppExceptionType::UnknownMethod, + fmt::format("no route for method '{}'", metadata->methodName()))); return ThriftFilters::FilterStatus::StopIteration; } @@ -104,35 +108,35 @@ ThriftFilters::FilterStatus Router::messageBegin(absl::string_view name, Message Upstream::ThreadLocalCluster* cluster = cluster_manager_.get(route_entry_->clusterName()); if (!cluster) { ENVOY_STREAM_LOG(debug, "unknown cluster '{}'", *callbacks_, route_entry_->clusterName()); - callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ - new AppException(name, seq_id, AppExceptionType::InternalError, - fmt::format("unknown cluster '{}'", route_entry_->clusterName()))}); + callbacks_->sendLocalReply( + AppException(AppExceptionType::InternalError, + fmt::format("unknown cluster '{}'", route_entry_->clusterName()))); return ThriftFilters::FilterStatus::StopIteration; } cluster_ = cluster->info(); ENVOY_STREAM_LOG(debug, "cluster '{}' match for method '{}'", *callbacks_, - route_entry_->clusterName(), name); + route_entry_->clusterName(), metadata->methodName()); if (cluster_->maintenanceMode()) { - callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( - name, seq_id, AppExceptionType::InternalError, - fmt::format("maintenance mode for cluster '{}'", route_entry_->clusterName()))}); + callbacks_->sendLocalReply(AppException( + AppExceptionType::InternalError, + fmt::format("maintenance mode for cluster '{}'", route_entry_->clusterName()))); return ThriftFilters::FilterStatus::StopIteration; } Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( route_entry_->clusterName(), Upstream::ResourcePriority::Default, this); if (!conn_pool) { - callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( - name, seq_id, AppExceptionType::InternalError, - fmt::format("no healthy upstream for '{}'", route_entry_->clusterName()))}); + callbacks_->sendLocalReply( + AppException(AppExceptionType::InternalError, + fmt::format("no healthy upstream for '{}'", route_entry_->clusterName()))); return ThriftFilters::FilterStatus::StopIteration; } ENVOY_STREAM_LOG(debug, "router decoding request", *callbacks_); - upstream_request_.reset(new UpstreamRequest(*this, *conn_pool, name, msg_type, seq_id)); + upstream_request_.reset(new UpstreamRequest(*this, *conn_pool, metadata)); upstream_request_->start(); return ThriftFilters::FilterStatus::StopIteration; } @@ -141,7 +145,9 @@ ThriftFilters::FilterStatus Router::messageEnd() { ProtocolConverter::messageEnd(); Buffer::OwnedImpl transport_buffer; - upstream_request_->transport_->encodeFrame(transport_buffer, upstream_request_buffer_); + + upstream_request_->transport_->encodeFrame(transport_buffer, *upstream_request_->metadata_, + upstream_request_buffer_); upstream_request_->conn_data_->connection().write(transport_buffer, false); upstream_request_->onRequestComplete(); return ThriftFilters::FilterStatus::Continue; @@ -171,10 +177,7 @@ void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { } void Router::onEvent(Network::ConnectionEvent event) { - if (!upstream_request_ || upstream_request_->response_complete_) { - // Client closed connection after completing response. - return; - } + ASSERT(upstream_request_ && !upstream_request_->response_complete_); switch (event) { case Network::ConnectionEvent::RemoteClose: @@ -199,18 +202,16 @@ const Network::Connection* Router::downstreamConnection() const { return nullptr; } -void Router::convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id) { - ProtocolConverter::messageBegin(absl::string_view(name), msg_type, seq_id); +void Router::convertMessageBegin(MessageMetadataSharedPtr metadata) { + ProtocolConverter::messageBegin(metadata); } void Router::cleanup() { upstream_request_.reset(); } Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, - absl::string_view method_name, MessageType msg_type, - int32_t seq_id) - : parent_(parent), conn_pool_(pool), method_name_(std::string(method_name)), - msg_type_(msg_type), seq_id_(seq_id), request_complete_(false), response_started_(false), - response_complete_(false) {} + MessageMetadataSharedPtr& metadata) + : parent_(parent), conn_pool_(pool), metadata_(metadata), request_complete_(false), + response_started_(false), response_complete_(false) {} Router::UpstreamRequest::~UpstreamRequest() {} @@ -254,7 +255,7 @@ void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionData& c parent_.upstream_request_buffer_); // TODO(zuercher): need to use an upstream-connection-specific sequence id - parent_.convertMessageBegin(method_name_, msg_type_, seq_id_); + parent_.convertMessageBegin(metadata_); parent_.callbacks_->continueDecoding(); } @@ -276,18 +277,18 @@ void Router::UpstreamRequest::onUpstreamHostSelected(Upstream::HostDescriptionCo void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReason reason) { switch (reason) { case Tcp::ConnectionPool::PoolFailureReason::Overflow: - parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( - method_name_, seq_id_, AppExceptionType::InternalError, - fmt::format("too many connections to '{}'", upstream_host_->address()->asString()))}); + parent_.callbacks_->sendLocalReply(AppException( + AppExceptionType::InternalError, + fmt::format("too many connections to '{}'", upstream_host_->address()->asString()))); break; case Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure: case Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure: case Tcp::ConnectionPool::PoolFailureReason::Timeout: // TODO(zuercher): distinguish between these cases where appropriate (particularly timeout) if (!response_started_) { - parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( - method_name_, seq_id_, AppExceptionType::InternalError, - fmt::format("connection failure '{}'", upstream_host_->address()->asString()))}); + parent_.callbacks_->sendLocalReply(AppException( + AppExceptionType::InternalError, + fmt::format("connection failure '{}'", upstream_host_->address()->asString()))); return; } diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h index 117c99ca2635b..d5c82ecb1aeae 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -34,7 +34,7 @@ class RouteEntryImplBase : public RouteEntry, // Router::Route const RouteEntry* routeEntry() const override; - virtual RouteConstSharedPtr matches(const std::string& method_name) const PURE; + virtual RouteConstSharedPtr matches(const MessageMetadata& metadata) const PURE; protected: RouteConstSharedPtr clusterEntry() const; @@ -53,7 +53,7 @@ class MethodNameRouteEntryImpl : public RouteEntryImplBase { const std::string& methodName() const { return method_name_; } // RoutEntryImplBase - RouteConstSharedPtr matches(const std::string& method_name) const override; + RouteConstSharedPtr matches(const MessageMetadata& metadata) const override; private: const std::string method_name_; @@ -63,7 +63,7 @@ class RouteMatcher { public: RouteMatcher(const envoy::config::filter::network::thrift_proxy::v2alpha1::RouteConfiguration&); - RouteConstSharedPtr route(const std::string& method_name) const; + RouteConstSharedPtr route(const MessageMetadata& metadata) const; private: std::vector routes_; @@ -82,10 +82,9 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, void onDestroy() override; void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) override; void resetUpstreamConnection() override; - ThriftFilters::FilterStatus transportBegin(absl::optional size) override; + ThriftFilters::FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override; ThriftFilters::FilterStatus transportEnd() override; - ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, - int32_t seq_id) override; + ThriftFilters::FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override; ThriftFilters::FilterStatus messageEnd() override; // Upstream::LoadBalancerContext @@ -103,7 +102,7 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, private: struct UpstreamRequest : public Tcp::ConnectionPool::Callbacks { UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, - absl::string_view method_name, MessageType msg_type, int32_t seq_id); + MessageMetadataSharedPtr& metadata); ~UpstreamRequest(); void start(); @@ -122,9 +121,7 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, Router& parent_; Tcp::ConnectionPool::Instance& conn_pool_; - const std::string method_name_; - const MessageType msg_type_; - const int32_t seq_id_; + MessageMetadataSharedPtr metadata_; Tcp::ConnectionPool::Cancellable* conn_pool_handle_{}; Tcp::ConnectionPool::ConnectionData* conn_data_{}; @@ -137,7 +134,7 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, bool response_complete_ : 1; }; - void convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id); + void convertMessageBegin(MessageMetadataSharedPtr metadata); void cleanup(); Upstream::ClusterManager& cluster_manager_; diff --git a/source/extensions/filters/network/thrift_proxy/thrift.h b/source/extensions/filters/network/thrift_proxy/thrift.h new file mode 100644 index 0000000000000..fe3d7022a59e4 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift.h @@ -0,0 +1,154 @@ +#pragma once + +#include "common/common/assert.h" +#include "common/singleton/const_singleton.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +enum class TransportType { + Framed, + Unframed, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST TRANSPORT TYPE + LastTransportType = Auto, + +}; + +/** + * Names of available Transport implementations. + */ +class TransportNameValues { +public: + // Framed transport + const std::string FRAMED = "framed"; + + // Unframed transport + const std::string UNFRAMED = "unframed"; + + // Auto-detection transport + const std::string AUTO = "auto"; + + const std::string& fromType(TransportType type) const { + switch (type) { + case TransportType::Framed: + return FRAMED; + case TransportType::Unframed: + return UNFRAMED; + case TransportType::Auto: + return AUTO; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } +}; + +typedef ConstSingleton TransportNames; + +enum class ProtocolType { + Binary, + LaxBinary, + Compact, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST PROTOCOL TYPE + LastProtocolType = Auto, +}; + +/** + * Names of available Protocol implementations. + */ +class ProtocolNameValues { +public: + // Binary protocol + const std::string BINARY = "binary"; + + // Lax Binary protocol + const std::string LAX_BINARY = "binary/non-strict"; + + // Compact protocol + const std::string COMPACT = "compact"; + + // Auto-detection protocol + const std::string AUTO = "auto"; + + const std::string& fromType(ProtocolType type) const { + switch (type) { + case ProtocolType::Binary: + return BINARY; + case ProtocolType::LaxBinary: + return LAX_BINARY; + case ProtocolType::Compact: + return COMPACT; + case ProtocolType::Auto: + return AUTO; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } +}; + +typedef ConstSingleton ProtocolNames; + +/** + * Thrift protocol message types. + * See https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TProtocol.h + */ +enum class MessageType { + Call = 1, + Reply = 2, + Exception = 3, + Oneway = 4, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST MESSAGE TYPE + LastMessageType = Oneway, +}; + +/** + * Thrift protocol struct field types. + * See https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TProtocol.h + */ +enum class FieldType { + Stop = 0, + Void = 1, + Bool = 2, + Byte = 3, + Double = 4, + I16 = 6, + I32 = 8, + I64 = 10, + String = 11, + Struct = 12, + Map = 13, + Set = 14, + List = 15, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST FIELD TYPE + LastFieldType = List, +}; + +/** + * Thrift Application Exception types. + * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md + */ +enum class AppExceptionType { + Unknown = 0, + UnknownMethod = 1, + InvalidMessageType = 2, + WrongMethodName = 3, + BadSequenceId = 4, + MissingResult = 5, + InternalError = 6, + ProtocolError = 7, + InvalidTransform = 8, + InvalidProtocol = 9, + UnsupportedClientType = 10, +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/transport.h b/source/extensions/filters/network/thrift_proxy/transport.h index 1bda083e6bc3e..915e2cbaa9440 100644 --- a/source/extensions/filters/network/thrift_proxy/transport.h +++ b/source/extensions/filters/network/thrift_proxy/transport.h @@ -10,6 +10,9 @@ #include "common/config/utility.h" #include "common/singleton/const_singleton.h" +#include "extensions/filters/network/thrift_proxy/metadata.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" + #include "absl/types/optional.h" namespace Envoy { @@ -17,46 +20,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -enum class TransportType { - Framed, - Unframed, - Auto, - - // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST TRANSPORT TYPE - LastTransportType = Auto, - -}; - -/** - * Names of available Transport implementations. - */ -class TransportNameValues { -public: - // Framed transport - const std::string FRAMED = "framed"; - - // Unframed transport - const std::string UNFRAMED = "unframed"; - - // Auto-detection transport - const std::string AUTO = "auto"; - - const std::string& fromType(TransportType type) const { - switch (type) { - case TransportType::Framed: - return FRAMED; - case TransportType::Unframed: - return UNFRAMED; - case TransportType::Auto: - return AUTO; - default: - NOT_REACHED_GCOVR_EXCL_LINE; - } - } -}; - -typedef ConstSingleton TransportNames; - /** * Transport represents a Thrift transport. The Thrift transport is nominally a generic, * bi-directional byte stream. In Envoy we assume it always represents a network byte stream and @@ -80,16 +43,18 @@ class Transport { /* * Decodes the start of a transport message. If successful, the start of the frame is removed - * from the buffer. + * from the buffer. Transports should not modify the buffer, headers, protocol type, or size if + * more data is required to decode the frame's start. If the full frame start can be decoded, the + * Transport must drain the frame start data from the buffer. The request metadata should be + * modified with any data available to the transport. * * @param buffer the currently buffered thrift data. - * @param size updated with the frame size on success. If frame size is not encoded, the size - * is cleared on success. + * @param metadata MessageMetadata to be modified if transport supports additional information * @return bool true if a complete frame header was successfully consumed, false if more data * is required. * @throws EnvoyException if the data is not valid for this transport. */ - virtual bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) PURE; + virtual bool decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) PURE; /* * Decodes the end of a transport message. If successful, the end of the frame is removed from @@ -106,10 +71,12 @@ class Transport { * Wraps the given message buffer with the transport's header and trailer (if any). After * encoding, message will be empty. * @param buffer is the output buffer + * @param metadata MessageMetadata for the message * @param message a protocol-encoded message * @throws EnvoyException if the message is too large for the transport */ - virtual void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) PURE; + virtual void encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) PURE; }; typedef std::unique_ptr TransportPtr; diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.cc b/source/extensions/filters/network/thrift_proxy/transport_impl.cc index ff177884d94c5..0d5971f03e7e3 100644 --- a/source/extensions/filters/network/thrift_proxy/transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.cc @@ -15,7 +15,7 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) { +bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) { if (transport_ == nullptr) { // Not enough data to select a transport. if (buffer.length() < 8) { @@ -27,8 +27,7 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, absl::optiona if (size > 0 && size <= FramedTransportImpl::MaxFrameSize) { // TODO(zuercher): Spec says max size is 16,384,000 (0xFA0000). Apache C++ TFramedTransport - // is configurable, but defaults to 256 MB (0x1000000). THeaderTransport will take up to ~1GB - // (0x3FFFFFFF) when it falls back to framed mode. + // is configurable, but defaults to 256 MB (0x1000000). if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { setTransport(std::make_unique()); } @@ -51,7 +50,7 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, absl::optiona } } - return transport_->decodeFrameStart(buffer, size); + return transport_->decodeFrameStart(buffer, metadata); } bool AutoTransportImpl::decodeFrameEnd(Buffer::Instance& buffer) { @@ -59,9 +58,10 @@ bool AutoTransportImpl::decodeFrameEnd(Buffer::Instance& buffer) { return transport_->decodeFrameEnd(buffer); } -void AutoTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) { - RELEASE_ASSERT(transport_ != nullptr, ""); - transport_->encodeFrame(buffer, message); +void AutoTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) { + RELEASE_ASSERT(transport_ != nullptr, "auto transport cannot encode before transport detection"); + transport_->encodeFrame(buffer, metadata, message); } class AutoTransportConfigFactory : public TransportFactoryBase { diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.h b/source/extensions/filters/network/thrift_proxy/transport_impl.h index 08281c53c6d16..56eb9a522dc0a 100644 --- a/source/extensions/filters/network/thrift_proxy/transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.h @@ -33,9 +33,10 @@ class AutoTransportImpl : public Transport { return TransportType::Auto; } - bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; + bool decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) override; bool decodeFrameEnd(Buffer::Instance& buffer) override; - void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; + void encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) override; /* * Explicitly set the transport. Public to simplify testing. diff --git a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h index 29992dfc0812f..92dd78af0ad16 100644 --- a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h @@ -24,12 +24,14 @@ class UnframedTransportImpl : public Transport { // Transport const std::string& name() const override { return TransportNames::get().UNFRAMED; } TransportType type() const override { return TransportType::Unframed; } - bool decodeFrameStart(Buffer::Instance&, absl::optional& size) override { - size.reset(); + bool decodeFrameStart(Buffer::Instance&, MessageMetadata& metadata) override { + UNREFERENCED_PARAMETER(metadata); return true; } bool decodeFrameEnd(Buffer::Instance&) override { return true; } - void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override { + void encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) override { + UNREFERENCED_PARAMETER(metadata); buffer.move(message); } }; diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index 1eea2452232bb..e07767ba0963f 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -40,6 +40,18 @@ envoy_extension_cc_test_library( ], ) +envoy_extension_cc_test( + name = "app_exception_impl_test", + srcs = ["app_exception_impl_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//test/test_common:printers_lib", + "//test/test_common:utility_lib", + ], +) + envoy_extension_cc_test( name = "binary_protocol_impl_test", srcs = ["binary_protocol_impl_test.cc"], @@ -115,6 +127,7 @@ envoy_extension_cc_test( deps = [ ":mocks", ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", "//source/extensions/filters/network/thrift_proxy:decoder_lib", "//test/test_common:printers_lib", "//test/test_common:utility_lib", @@ -134,6 +147,17 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "metadata_test", + srcs = ["metadata_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + "//source/extensions/filters/network/thrift_proxy:metadata_lib", + "//test/test_common:printers_lib", + "//test/test_common:utility_lib", + ], +) + envoy_extension_cc_test( name = "protocol_impl_test", srcs = ["protocol_impl_test.cc"], diff --git a/test/extensions/filters/network/thrift_proxy/app_exception_impl_test.cc b/test/extensions/filters/network/thrift_proxy/app_exception_impl_test.cc new file mode 100644 index 0000000000000..df8592e7776a0 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/app_exception_impl_test.cc @@ -0,0 +1,90 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::InSequence; +using testing::Ref; +using testing::StrictMock; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +TEST(AppExceptionImplTest, CopyConstructor) { + AppException app_ex(AppExceptionType::InternalError, "msg"); + AppException copy(app_ex); + + EXPECT_EQ(app_ex.type_, copy.type_); + EXPECT_STREQ("msg", copy.what()); +} + +TEST(AppExceptionImplTest, TestEncode) { + AppException app_ex(AppExceptionType::InternalError, "msg"); + + MessageMetadata metadata; + metadata.setMethodName("method"); + metadata.setSequenceId(99); + metadata.setMessageType(MessageType::Call); + + StrictMock proto; + Buffer::OwnedImpl buffer; + + InSequence dummy; + EXPECT_CALL(proto, writeMessageBegin(Ref(buffer), Ref(metadata))) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ("method", metadata.methodName()); + EXPECT_EQ(99, metadata.sequenceId()); + EXPECT_EQ(MessageType::Exception, metadata.messageType()); + })); + EXPECT_CALL(proto, writeStructBegin(Ref(buffer), "TApplicationException")); + EXPECT_CALL(proto, writeFieldBegin(Ref(buffer), "message", FieldType::String, 1)); + EXPECT_CALL(proto, writeString(Ref(buffer), "msg")); + EXPECT_CALL(proto, writeFieldEnd(Ref(buffer))); + EXPECT_CALL(proto, writeFieldBegin(Ref(buffer), "type", FieldType::I32, 2)); + EXPECT_CALL(proto, writeInt32(Ref(buffer), static_cast(AppExceptionType::InternalError))); + EXPECT_CALL(proto, writeFieldEnd(Ref(buffer))); + EXPECT_CALL(proto, writeFieldBegin(Ref(buffer), "", FieldType::Stop, 0)); + EXPECT_CALL(proto, writeStructEnd(Ref(buffer))); + EXPECT_CALL(proto, writeMessageEnd(Ref(buffer))); + + app_ex.encode(metadata, proto, buffer); +} + +TEST(AppExceptionImplTest, TestEncodeEmptyMetadata) { + AppException app_ex(AppExceptionType::InternalError, "msg"); + + MessageMetadata metadata; + StrictMock proto; + Buffer::OwnedImpl buffer; + + InSequence dummy; + EXPECT_CALL(proto, writeMessageBegin(Ref(buffer), Ref(metadata))) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ("", metadata.methodName()); + EXPECT_EQ(0, metadata.sequenceId()); + EXPECT_EQ(MessageType::Exception, metadata.messageType()); + })); + EXPECT_CALL(proto, writeStructBegin(Ref(buffer), "TApplicationException")); + EXPECT_CALL(proto, writeFieldBegin(Ref(buffer), "message", FieldType::String, 1)); + EXPECT_CALL(proto, writeString(Ref(buffer), "msg")); + EXPECT_CALL(proto, writeFieldEnd(Ref(buffer))); + EXPECT_CALL(proto, writeFieldBegin(Ref(buffer), "type", FieldType::I32, 2)); + EXPECT_CALL(proto, writeInt32(Ref(buffer), static_cast(AppExceptionType::InternalError))); + EXPECT_CALL(proto, writeFieldEnd(Ref(buffer))); + EXPECT_CALL(proto, writeFieldBegin(Ref(buffer), "", FieldType::Stop, 0)); + EXPECT_CALL(proto, writeStructEnd(Ref(buffer))); + EXPECT_CALL(proto, writeMessageEnd(Ref(buffer))); + + app_ex.encode(metadata, proto, buffer); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc index 3db9d588d5338..dea7c694c5f1f 100644 --- a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc @@ -15,76 +15,92 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -TEST(BinaryProtocolTest, Name) { +class BinaryProtocolTest : public testing::Test { +public: + void resetMetadata() { + metadata_.setMethodName("-"); + metadata_.setMessageType(MessageType::Oneway); + metadata_.setSequenceId(1); + } + + void expectMetadata(const std::string& name, MessageType msg_type, int32_t seq_id) { + EXPECT_TRUE(metadata_.hasMethodName()); + EXPECT_EQ(name, metadata_.methodName()); + + EXPECT_TRUE(metadata_.hasMessageType()); + EXPECT_EQ(msg_type, metadata_.messageType()); + + EXPECT_TRUE(metadata_.hasSequenceId()); + EXPECT_EQ(seq_id, metadata_.sequenceId()); + + EXPECT_FALSE(metadata_.hasFrameSize()); + EXPECT_FALSE(metadata_.hasProtocol()); + EXPECT_FALSE(metadata_.hasAppException()); + EXPECT_TRUE(metadata_.headers().empty()); + } + + void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, 1); } + + MessageMetadata metadata_; +}; + +class LaxBinaryProtocolTest : public BinaryProtocolTest {}; + +TEST_F(BinaryProtocolTest, Name) { BinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary"); } -TEST(BinaryProtocolTest, ReadMessageBegin) { +TEST_F(BinaryProtocolTest, ReadMessageBegin) { BinaryProtocolImpl proto; // Insufficient data { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addRepeated(buffer, 11, 'x'); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 11); } // Wrong protocol version { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x0102); addRepeated(buffer, 10, 'x'); - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), - EnvoyException, "invalid binary protocol version 0x0102 != 0x8001"); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, + "invalid binary protocol version 0x0102 != 0x8001"); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 12); } // Invalid message type { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8001); addInt8(buffer, 'x'); addInt8(buffer, static_cast(MessageType::LastMessageType) + 1); addRepeated(buffer, 8, 'x'); - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), - EnvoyException, + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, fmt::format("invalid binary protocol message type {}", static_cast(MessageType::LastMessageType) + 1)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 12); } // Empty name { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8001); addInt8(buffer, 'x'); @@ -92,19 +108,15 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addInt32(buffer, 0); addInt32(buffer, 1234); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, ""); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 1234); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("", MessageType::Call, 1234); EXPECT_EQ(buffer.length(), 0); } // Insufficient data after checking name length { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8001); addInt8(buffer, 'x'); @@ -112,19 +124,15 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addInt32(buffer, 4); // name length addString(buffer, "abcd"); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 12); } // Named message { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8001); addInt8(buffer, 0); @@ -133,22 +141,20 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addString(buffer, "the_name"); addInt32(buffer, 5678); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "the_name"); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 5678); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("the_name", MessageType::Call, 5678); EXPECT_EQ(buffer.length(), 0); } } -TEST(BinaryProtocolTest, ReadMessageEnd) { +TEST_F(BinaryProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; BinaryProtocolImpl proto; EXPECT_TRUE(proto.readMessageEnd(buffer)); } -TEST(BinaryProtocolTest, ReadStructBegin) { +TEST_F(BinaryProtocolTest, ReadStructBegin) { Buffer::OwnedImpl buffer; BinaryProtocolImpl proto; std::string name = "-"; @@ -157,14 +163,14 @@ TEST(BinaryProtocolTest, ReadStructBegin) { EXPECT_EQ(name, ""); } -TEST(BinaryProtocolTest, ReadStructEnd) { +TEST_F(BinaryProtocolTest, ReadStructEnd) { Buffer::OwnedImpl buffer; BinaryProtocolImpl proto; EXPECT_TRUE(proto.readStructEnd(buffer)); } -TEST(BinaryProtocolTest, ReadFieldBegin) { +TEST_F(BinaryProtocolTest, ReadFieldBegin) { BinaryProtocolImpl proto; // Insufficient data @@ -247,13 +253,13 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { } } -TEST(BinaryProtocolTest, ReadFieldEnd) { +TEST_F(BinaryProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; BinaryProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } -TEST(BinaryProtocolTest, ReadMapBegin) { +TEST_F(BinaryProtocolTest, ReadMapBegin) { BinaryProtocolImpl proto; // Insufficient data @@ -310,13 +316,13 @@ TEST(BinaryProtocolTest, ReadMapBegin) { } } -TEST(BinaryProtocolTest, ReadMapEnd) { +TEST_F(BinaryProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; BinaryProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } -TEST(BinaryProtocolTest, ReadListBegin) { +TEST_F(BinaryProtocolTest, ReadListBegin) { BinaryProtocolImpl proto; // Insufficient data @@ -365,13 +371,13 @@ TEST(BinaryProtocolTest, ReadListBegin) { } } -TEST(BinaryProtocolTest, ReadListEnd) { +TEST_F(BinaryProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; BinaryProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } -TEST(BinaryProtocolTest, ReadSetBegin) { +TEST_F(BinaryProtocolTest, ReadSetBegin) { BinaryProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() @@ -388,13 +394,13 @@ TEST(BinaryProtocolTest, ReadSetBegin) { EXPECT_EQ(buffer.length(), 0); } -TEST(BinaryProtocolTest, ReadSetEnd) { +TEST_F(BinaryProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; BinaryProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } -TEST(BinaryProtocolTest, ReadIntegerTypes) { +TEST_F(BinaryProtocolTest, ReadIntegerTypes) { BinaryProtocolImpl proto; // Bool @@ -512,7 +518,7 @@ TEST(BinaryProtocolTest, ReadIntegerTypes) { } } -TEST(BinaryProtocolTest, ReadDouble) { +TEST_F(BinaryProtocolTest, ReadDouble) { BinaryProtocolImpl proto; // Insufficient data @@ -540,7 +546,7 @@ TEST(BinaryProtocolTest, ReadDouble) { } } -TEST(BinaryProtocolTest, ReadString) { +TEST_F(BinaryProtocolTest, ReadString) { BinaryProtocolImpl proto; // Insufficient data to read length @@ -606,7 +612,7 @@ TEST(BinaryProtocolTest, ReadString) { } } -TEST(BinaryProtocolTest, ReadBinary) { +TEST_F(BinaryProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; @@ -620,46 +626,54 @@ TEST(BinaryProtocolTest, ReadBinary) { EXPECT_EQ(buffer.length(), 0); } -TEST(BinaryProtocolTest, WriteMessageBegin) { +TEST_F(BinaryProtocolTest, WriteMessageBegin) { BinaryProtocolImpl proto; // Named call { + metadata_.setMethodName("message"); + metadata_.setMessageType(MessageType::Call); + metadata_.setSequenceId(1); + Buffer::OwnedImpl buffer; - proto.writeMessageBegin(buffer, "message", MessageType::Call, 1); + proto.writeMessageBegin(buffer, metadata_); EXPECT_EQ(std::string("\x80\x1\0\x1\0\0\0\x7message\0\0\0\x1", 19), buffer.toString()); } // Unnamed oneway { + metadata_.setMethodName(""); + metadata_.setMessageType(MessageType::Oneway); + metadata_.setSequenceId(2); + Buffer::OwnedImpl buffer; - proto.writeMessageBegin(buffer, "", MessageType::Oneway, 2); + proto.writeMessageBegin(buffer, metadata_); EXPECT_EQ(std::string("\x80\x1\0\x4\0\0\0\0\0\0\0\x2", 12), buffer.toString()); } } -TEST(BinaryProtocolTest, WriteMessageEnd) { +TEST_F(BinaryProtocolTest, WriteMessageEnd) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMessageEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(BinaryProtocolTest, WriteStructBegin) { +TEST_F(BinaryProtocolTest, WriteStructBegin) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructBegin(buffer, "unused"); EXPECT_EQ(0, buffer.length()); } -TEST(BinaryProtocolTest, WriteStructEnd) { +TEST_F(BinaryProtocolTest, WriteStructEnd) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(BinaryProtocolTest, WriteFieldBegin) { +TEST_F(BinaryProtocolTest, WriteFieldBegin) { BinaryProtocolImpl proto; // Stop field @@ -677,14 +691,14 @@ TEST(BinaryProtocolTest, WriteFieldBegin) { } } -TEST(BinaryProtocolTest, WriteFieldEnd) { +TEST_F(BinaryProtocolTest, WriteFieldEnd) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(BinaryProtocolTest, WriteMapBegin) { +TEST_F(BinaryProtocolTest, WriteMapBegin) { BinaryProtocolImpl proto; // Non-empty map @@ -710,14 +724,14 @@ TEST(BinaryProtocolTest, WriteMapBegin) { } } -TEST(BinaryProtocolTest, WriteMapEnd) { +TEST_F(BinaryProtocolTest, WriteMapEnd) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMapEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(BinaryProtocolTest, WriteListBegin) { +TEST_F(BinaryProtocolTest, WriteListBegin) { BinaryProtocolImpl proto; // Non-empty list @@ -742,14 +756,14 @@ TEST(BinaryProtocolTest, WriteListBegin) { } } -TEST(BinaryProtocolTest, WriteListEnd) { +TEST_F(BinaryProtocolTest, WriteListEnd) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeListEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(BinaryProtocolTest, WriteSetBegin) { +TEST_F(BinaryProtocolTest, WriteSetBegin) { BinaryProtocolImpl proto; // Only test the happy path, as this shares an implementation with writeListBegin @@ -759,14 +773,14 @@ TEST(BinaryProtocolTest, WriteSetBegin) { EXPECT_EQ(std::string("\xb\0\0\0\x3", 5), buffer.toString()); } -TEST(BinaryProtocolTest, WriteSetEnd) { +TEST_F(BinaryProtocolTest, WriteSetEnd) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeSetEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(BinaryProtocolTest, WriteBool) { +TEST_F(BinaryProtocolTest, WriteBool) { BinaryProtocolImpl proto; // True @@ -784,7 +798,7 @@ TEST(BinaryProtocolTest, WriteBool) { } } -TEST(BinaryProtocolTest, WriteByte) { +TEST_F(BinaryProtocolTest, WriteByte) { BinaryProtocolImpl proto; { @@ -800,7 +814,7 @@ TEST(BinaryProtocolTest, WriteByte) { } } -TEST(BinaryProtocolTest, WriteInt16) { +TEST_F(BinaryProtocolTest, WriteInt16) { BinaryProtocolImpl proto; { @@ -816,7 +830,7 @@ TEST(BinaryProtocolTest, WriteInt16) { } } -TEST(BinaryProtocolTest, WriteInt32) { +TEST_F(BinaryProtocolTest, WriteInt32) { BinaryProtocolImpl proto; { @@ -832,7 +846,7 @@ TEST(BinaryProtocolTest, WriteInt32) { } } -TEST(BinaryProtocolTest, WriteInt64) { +TEST_F(BinaryProtocolTest, WriteInt64) { BinaryProtocolImpl proto; { @@ -848,14 +862,14 @@ TEST(BinaryProtocolTest, WriteInt64) { } } -TEST(BinaryProtocolTest, WriteDouble) { +TEST_F(BinaryProtocolTest, WriteDouble) { BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeDouble(buffer, 3.0); EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); } -TEST(BinaryProtocolTest, WriteString) { +TEST_F(BinaryProtocolTest, WriteString) { BinaryProtocolImpl proto; { @@ -874,7 +888,7 @@ TEST(BinaryProtocolTest, WriteString) { } } -TEST(BinaryProtocolTest, WriteBinary) { +TEST_F(BinaryProtocolTest, WriteBinary) { BinaryProtocolImpl proto; // Happy path only, since this is just a synonym for writeString @@ -886,121 +900,108 @@ TEST(BinaryProtocolTest, WriteBinary) { buffer.toString()); } -TEST(LaxBinaryProtocolTest, Name) { +TEST_F(LaxBinaryProtocolTest, Name) { LaxBinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary/non-strict"); } -TEST(LaxBinaryProtocolTest, ReadMessageBegin) { +TEST_F(LaxBinaryProtocolTest, ReadMessageBegin) { LaxBinaryProtocolImpl proto; // Insufficient data { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addRepeated(buffer, 8, 'x'); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 8); } // Invalid message type { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt32(buffer, 0); addInt8(buffer, static_cast(MessageType::LastMessageType) + 1); addRepeated(buffer, 4, 'x'); - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), - EnvoyException, + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, fmt::format("invalid (lax) binary protocol message type {}", static_cast(MessageType::LastMessageType) + 1)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 9); } // Empty name { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt32(buffer, 0); addInt8(buffer, MessageType::Call); addInt32(buffer, 1234); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, ""); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 1234); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("", MessageType::Call, 1234); EXPECT_EQ(buffer.length(), 0); } // Insufficient data after checking name length { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt32(buffer, 1); // name length addInt8(buffer, MessageType::Call); addInt32(buffer, 1234); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 9); } // Named message { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt32(buffer, 8); addString(buffer, "the_name"); addInt8(buffer, MessageType::Call); addInt32(buffer, 5678); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "the_name"); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 5678); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("the_name", MessageType::Call, 5678); EXPECT_EQ(buffer.length(), 0); } } -TEST(LaxBinaryProtocolTest, WriteMessageBegin) { +TEST_F(LaxBinaryProtocolTest, WriteMessageBegin) { LaxBinaryProtocolImpl proto; // Named call { + metadata_.setMethodName("message"); + metadata_.setMessageType(MessageType::Call); + metadata_.setSequenceId(1); + Buffer::OwnedImpl buffer; - proto.writeMessageBegin(buffer, "message", MessageType::Call, 1); + proto.writeMessageBegin(buffer, metadata_); EXPECT_EQ(std::string("\0\0\0\x7message\x1\0\0\0\x1", 16), buffer.toString()); } // Unnamed oneway { + metadata_.setMethodName(""); + metadata_.setMessageType(MessageType::Oneway); + metadata_.setSequenceId(2); + Buffer::OwnedImpl buffer; - proto.writeMessageBegin(buffer, "", MessageType::Oneway, 2); + proto.writeMessageBegin(buffer, metadata_); EXPECT_EQ(std::string("\0\0\0\0\x4\0\0\0\x2", 9), buffer.toString()); } } diff --git a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc index 79187821def52..70e476a5ad9a9 100644 --- a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc @@ -18,54 +18,73 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -TEST(CompactProtocolTest, Name) { +class CompactProtocolTest : public testing::Test { +public: + void resetMetadata() { + metadata_.setMethodName("-"); + metadata_.setMessageType(MessageType::Oneway); + metadata_.setSequenceId(1); + } + + void expectMetadata(const std::string& name, MessageType msg_type, int32_t seq_id) { + EXPECT_TRUE(metadata_.hasMethodName()); + EXPECT_EQ(name, metadata_.methodName()); + + EXPECT_TRUE(metadata_.hasMessageType()); + EXPECT_EQ(msg_type, metadata_.messageType()); + + EXPECT_TRUE(metadata_.hasSequenceId()); + EXPECT_EQ(seq_id, metadata_.sequenceId()); + + EXPECT_FALSE(metadata_.hasFrameSize()); + EXPECT_FALSE(metadata_.hasProtocol()); + EXPECT_FALSE(metadata_.hasAppException()); + EXPECT_TRUE(metadata_.headers().empty()); + } + + void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, 1); } + + MessageMetadata metadata_; +}; + +TEST_F(CompactProtocolTest, Name) { CompactProtocolImpl proto; EXPECT_EQ(proto.name(), "compact"); } -TEST(CompactProtocolTest, ReadMessageBegin) { +TEST_F(CompactProtocolTest, ReadMessageBegin) { CompactProtocolImpl proto; // Insufficient data { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addRepeated(buffer, 3, 'x'); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 3); } // Wrong protocol version { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x0102); addRepeated(buffer, 2, 'x'); - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), - EnvoyException, "invalid compact protocol version 0x0102 != 0x8201"); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, + "invalid compact protocol version 0x0102 != 0x8201"); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 4); } // Invalid message type { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); // Message type is encoded in the 3 highest order bits of the second byte. int8_t invalid_msg_type = static_cast(MessageType::LastMessageType) + 1; @@ -73,171 +92,137 @@ TEST(CompactProtocolTest, ReadMessageBegin) { addRepeated(buffer, 2, 'x'); EXPECT_THROW_WITH_MESSAGE( - proto.readMessageBegin(buffer, name, msg_type, seq_id), EnvoyException, + proto.readMessageBegin(buffer, metadata_), EnvoyException, fmt::format("invalid compact protocol message type {}", invalid_msg_type)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 4); } // Insufficient data to read message id { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addRepeated(buffer, 2, 0x81); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 4); } // Invalid sequence id encoding { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addSeq(buffer, {0x81, 0x81, 0x81, 0x81, 0x81, 0}); // > 32 bit varint addInt8(buffer, 0); - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), - EnvoyException, "invalid compact protocol varint i32"); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, + "invalid compact protocol varint i32"); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 9); } // Insufficient data to read message name length { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addInt8(buffer, 32); addInt8(buffer, 0x81); // unterminated varint - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 4); } // Insufficient data to read message name { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addInt8(buffer, 32); addInt8(buffer, 10); addString(buffer, "partial"); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 11); } // Empty name { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addInt8(buffer, 32); addInt8(buffer, 0); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, ""); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 32); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("", MessageType::Call, 32); EXPECT_EQ(buffer.length(), 0); } // Invalid name length encoding { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addInt8(buffer, 32); addSeq(buffer, {0x81, 0x81, 0x81, 0x81, 0x81, 0}); // > 32 bit varint - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), - EnvoyException, "invalid compact protocol varint i32"); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, + "invalid compact protocol varint i32"); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 9); } // Invalid name length { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addInt8(buffer, 32); addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x1F}); // -1 - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), - EnvoyException, "negative compact protocol message name length -1"); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, 1); + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, + "negative compact protocol message name length -1"); + expectDefaultMetadata(); EXPECT_EQ(buffer.length(), 8); } // Named message { Buffer::OwnedImpl buffer; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); addInt16(buffer, 0x8221); addInt16(buffer, 0x8202); // 0x0102 addInt8(buffer, 8); addString(buffer, "the_name"); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "the_name"); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 0x0102); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("the_name", MessageType::Call, 0x102); EXPECT_EQ(buffer.length(), 0); } } -TEST(CompactProtocolTest, ReadMessageEnd) { +TEST_F(CompactProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; CompactProtocolImpl proto; EXPECT_TRUE(proto.readMessageEnd(buffer)); } -TEST(CompactProtocolTest, ReadStruct) { +TEST_F(CompactProtocolTest, ReadStruct) { Buffer::OwnedImpl buffer; CompactProtocolImpl proto; std::string name = "-"; @@ -251,7 +236,7 @@ TEST(CompactProtocolTest, ReadStruct) { "invalid check for compact protocol struct end") } -TEST(CompactProtocolTest, ReadFieldBegin) { +TEST_F(CompactProtocolTest, ReadFieldBegin) { CompactProtocolImpl proto; // Insufficient data @@ -412,13 +397,13 @@ TEST(CompactProtocolTest, ReadFieldBegin) { } } -TEST(CompactProtocolTest, ReadFieldEnd) { +TEST_F(CompactProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; CompactProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } -TEST(CompactProtocolTest, ReadMapBegin) { +TEST_F(CompactProtocolTest, ReadMapBegin) { CompactProtocolImpl proto; // Insufficient data @@ -557,13 +542,13 @@ TEST(CompactProtocolTest, ReadMapBegin) { } } -TEST(CompactProtocolTest, ReadMapEnd) { +TEST_F(CompactProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; CompactProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } -TEST(CompactProtocolTest, ReadListBegin) { +TEST_F(CompactProtocolTest, ReadListBegin) { CompactProtocolImpl proto; // Insufficient data @@ -670,13 +655,13 @@ TEST(CompactProtocolTest, ReadListBegin) { } } -TEST(CompactProtocolTest, ReadListEnd) { +TEST_F(CompactProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; CompactProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } -TEST(CompactProtocolTest, ReadSetBegin) { +TEST_F(CompactProtocolTest, ReadSetBegin) { CompactProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() @@ -692,13 +677,13 @@ TEST(CompactProtocolTest, ReadSetBegin) { EXPECT_EQ(buffer.length(), 0); } -TEST(CompactProtocolTest, ReadSetEnd) { +TEST_F(CompactProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; CompactProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } -TEST(CompactProtocolTest, ReadBool) { +TEST_F(CompactProtocolTest, ReadBool) { CompactProtocolImpl proto; // Bool field values are encoded in the field type @@ -762,7 +747,7 @@ TEST(CompactProtocolTest, ReadBool) { } } -TEST(CompactProtocolTest, ReadIntegerTypes) { +TEST_F(CompactProtocolTest, ReadIntegerTypes) { CompactProtocolImpl proto; // Byte @@ -893,7 +878,7 @@ TEST(CompactProtocolTest, ReadIntegerTypes) { } } -TEST(CompactProtocolTest, ReadDouble) { +TEST_F(CompactProtocolTest, ReadDouble) { CompactProtocolImpl proto; // Insufficient data @@ -923,7 +908,7 @@ TEST(CompactProtocolTest, ReadDouble) { } } -TEST(CompactProtocolTest, ReadString) { +TEST_F(CompactProtocolTest, ReadString) { CompactProtocolImpl proto; // Insufficient data @@ -999,7 +984,7 @@ TEST(CompactProtocolTest, ReadString) { } } -TEST(CompactProtocolTest, ReadBinary) { +TEST_F(CompactProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() CompactProtocolImpl proto; Buffer::OwnedImpl buffer; @@ -1050,32 +1035,40 @@ TEST_P(CompactProtocolFieldTypeTest, ConvertsToFieldType) { INSTANTIATE_TEST_CASE_P(CompactFieldTypes, CompactProtocolFieldTypeTest, Values(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)); -TEST(CompactProtocolTest, WriteMessageBegin) { +TEST_F(CompactProtocolTest, WriteMessageBegin) { CompactProtocolImpl proto; // Named call { + metadata_.setMethodName("message"); + metadata_.setMessageType(MessageType::Call); + metadata_.setSequenceId(1); + Buffer::OwnedImpl buffer; - proto.writeMessageBegin(buffer, "message", MessageType::Call, 1); + proto.writeMessageBegin(buffer, metadata_); EXPECT_EQ(std::string("\x82\x21\x1\x7message", 11), buffer.toString()); } // Unnamed oneway { + metadata_.setMethodName(""); + metadata_.setMessageType(MessageType::Oneway); + metadata_.setSequenceId(2); + Buffer::OwnedImpl buffer; - proto.writeMessageBegin(buffer, "", MessageType::Oneway, 2); + proto.writeMessageBegin(buffer, metadata_); EXPECT_EQ(std::string("\x82\x81\x2\0", 4), buffer.toString()); } } -TEST(CompactProtocolTest, WriteMessageEnd) { +TEST_F(CompactProtocolTest, WriteMessageEnd) { CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMessageEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(CompactProtocolTest, WriteStruct) { +TEST_F(CompactProtocolTest, WriteStruct) { CompactProtocolImpl proto; Buffer::OwnedImpl buffer; @@ -1088,7 +1081,7 @@ TEST(CompactProtocolTest, WriteStruct) { "invalid write of compact protocol struct end") } -TEST(CompactProtocolTest, WriteFieldBegin) { +TEST_F(CompactProtocolTest, WriteFieldBegin) { // Stop field { CompactProtocolImpl proto; @@ -1175,14 +1168,14 @@ TEST(CompactProtocolTest, WriteFieldBegin) { } } -TEST(CompactProtocolTest, WriteFieldEnd) { +TEST_F(CompactProtocolTest, WriteFieldEnd) { CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(CompactProtocolTest, WriteBoolField) { +TEST_F(CompactProtocolTest, WriteBoolField) { // Boolean struct fields are encoded with custom types to save a byte // Short form field @@ -1227,7 +1220,7 @@ TEST(CompactProtocolTest, WriteBoolField) { } } -TEST(CompactProtocolTest, WriteMapBegin) { +TEST_F(CompactProtocolTest, WriteMapBegin) { CompactProtocolImpl proto; // Empty map @@ -1253,14 +1246,14 @@ TEST(CompactProtocolTest, WriteMapBegin) { } } -TEST(CompactProtocolTest, WriteMapEnd) { +TEST_F(CompactProtocolTest, WriteMapEnd) { CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMapEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(CompactProtocolTest, WriteListBegin) { +TEST_F(CompactProtocolTest, WriteListBegin) { CompactProtocolImpl proto; // Empty list @@ -1292,14 +1285,14 @@ TEST(CompactProtocolTest, WriteListBegin) { } } -TEST(CompactProtocolTest, WriteListEnd) { +TEST_F(CompactProtocolTest, WriteListEnd) { CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeListEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(CompactProtocolTest, WriteSetBegin) { +TEST_F(CompactProtocolTest, WriteSetBegin) { CompactProtocolImpl proto; // Empty set only, as writeSetBegin delegates to writeListBegin. @@ -1308,14 +1301,14 @@ TEST(CompactProtocolTest, WriteSetBegin) { EXPECT_EQ("\x5", buffer.toString()); } -TEST(CompactProtocolTest, WriteSetEnd) { +TEST_F(CompactProtocolTest, WriteSetEnd) { CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeSetEnd(buffer); EXPECT_EQ(0, buffer.length()); } -TEST(CompactProtocolTest, WriteBool) { +TEST_F(CompactProtocolTest, WriteBool) { CompactProtocolImpl proto; // Non-field bools (see WriteBoolField test) @@ -1332,7 +1325,7 @@ TEST(CompactProtocolTest, WriteBool) { } } -TEST(CompactProtocolTest, WriteByte) { +TEST_F(CompactProtocolTest, WriteByte) { CompactProtocolImpl proto; { @@ -1348,7 +1341,7 @@ TEST(CompactProtocolTest, WriteByte) { } } -TEST(CompactProtocolTest, WriteInt16) { +TEST_F(CompactProtocolTest, WriteInt16) { CompactProtocolImpl proto; // zigzag(1) = 2 @@ -1387,7 +1380,7 @@ TEST(CompactProtocolTest, WriteInt16) { } } -TEST(CompactProtocolTest, WriteInt32) { +TEST_F(CompactProtocolTest, WriteInt32) { CompactProtocolImpl proto; // zigzag(1) = 2 @@ -1426,7 +1419,7 @@ TEST(CompactProtocolTest, WriteInt32) { } } -TEST(CompactProtocolTest, WriteInt64) { +TEST_F(CompactProtocolTest, WriteInt64) { CompactProtocolImpl proto; // zigzag(1) = 2 @@ -1465,14 +1458,14 @@ TEST(CompactProtocolTest, WriteInt64) { } } -TEST(CompactProtocolTest, WriteDouble) { +TEST_F(CompactProtocolTest, WriteDouble) { CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeDouble(buffer, 3.0); EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); } -TEST(CompactProtocolTest, WriteString) { +TEST_F(CompactProtocolTest, WriteString) { CompactProtocolImpl proto; { @@ -1498,7 +1491,7 @@ TEST(CompactProtocolTest, WriteString) { } } -TEST(CompactProtocolTest, WriteBinary) { +TEST_F(CompactProtocolTest, WriteBinary) { CompactProtocolImpl proto; // writeBinary is an alias for writeString diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 7e7cff3df58bc..798c156c75314 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -84,11 +84,17 @@ class ThriftConnectionManagerTest : public testing::Test { filter_->onBelowWriteBufferLowWatermark(); } - void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { + void writeMessage(Buffer::Instance& buffer, TransportType transport_type, + ProtocolType protocol_type, MessageType msg_type, int32_t seq_id) { Buffer::OwnedImpl msg; - ProtocolPtr proto = - NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); - proto->writeMessageBegin(msg, "name", msg_type, seq_id); + ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol(); + MessageMetadata metadata; + metadata.setProtocol(protocol_type); + metadata.setMethodName("name"); + metadata.setMessageType(msg_type); + metadata.setSequenceId(seq_id); + + proto->writeMessageBegin(msg, metadata); proto->writeStructBegin(msg, "response"); proto->writeFieldBegin(msg, "success", FieldType::String, 0); proto->writeString(msg, "field"); @@ -98,8 +104,12 @@ class ThriftConnectionManagerTest : public testing::Test { proto->writeMessageEnd(msg); TransportPtr transport = - NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); - transport->encodeFrame(buffer, msg); + NamedTransportConfigFactory::getFactory(transport_type).createTransport(); + transport->encodeFrame(buffer, metadata, msg); + } + + void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { + writeMessage(buffer, TransportType::Framed, ProtocolType::Binary, msg_type, seq_id); } void writeComplexFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, @@ -107,7 +117,12 @@ class ThriftConnectionManagerTest : public testing::Test { Buffer::OwnedImpl msg; ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); - proto->writeMessageBegin(msg, "name", msg_type, seq_id); + MessageMetadata metadata; + metadata.setMethodName("name"); + metadata.setMessageType(msg_type); + metadata.setSequenceId(seq_id); + + proto->writeMessageBegin(msg, metadata); proto->writeStructBegin(msg, "wrapper"); // call args struct or response struct proto->writeFieldBegin(msg, "wrapper_field", FieldType::Struct, 0); // call arg/response success @@ -169,7 +184,7 @@ class ThriftConnectionManagerTest : public testing::Test { TransportPtr transport = NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); - transport->encodeFrame(buffer, msg); + transport->encodeFrame(buffer, metadata, msg); } void writePartialFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, @@ -189,7 +204,12 @@ class ThriftConnectionManagerTest : public testing::Test { Buffer::OwnedImpl msg; ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); - proto->writeMessageBegin(msg, "name", MessageType::Exception, seq_id); + MessageMetadata metadata; + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Exception); + metadata.setSequenceId(seq_id); + + proto->writeMessageBegin(msg, metadata); proto->writeStructBegin(msg, ""); proto->writeFieldBegin(msg, "", FieldType::String, 1); proto->writeString(msg, "error"); @@ -203,14 +223,19 @@ class ThriftConnectionManagerTest : public testing::Test { TransportPtr transport = NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); - transport->encodeFrame(buffer, msg); + transport->encodeFrame(buffer, metadata, msg); } void writeFramedBinaryIDLException(Buffer::Instance& buffer, int32_t seq_id) { Buffer::OwnedImpl msg; ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); - proto->writeMessageBegin(msg, "name", MessageType::Reply, seq_id); + MessageMetadata metadata; + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Reply); + metadata.setSequenceId(seq_id); + + proto->writeMessageBegin(msg, metadata); proto->writeStructBegin(msg, ""); proto->writeFieldBegin(msg, "", FieldType::Struct, 2); @@ -228,7 +253,7 @@ class ThriftConnectionManagerTest : public testing::Test { TransportPtr transport = NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); - transport->encodeFrame(buffer, msg); + transport->encodeFrame(buffer, metadata, msg); } NiceMock context_; @@ -285,7 +310,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataHandlesStopIterationAndResume) { EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); - EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + EXPECT_CALL(*decoder_filter_, messageBegin(_)) .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); @@ -506,7 +531,7 @@ stat_prefix: test EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); - EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + EXPECT_CALL(*decoder_filter_, messageBegin(_)) .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 9762fddfaf007..f18b0a7f6e435 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -1,5 +1,6 @@ #include "common/buffer/buffer_impl.h" +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" #include "extensions/filters/network/thrift_proxy/decoder.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" @@ -17,12 +18,14 @@ using testing::DoAll; using testing::Expectation; using testing::ExpectationSet; using testing::InSequence; +using testing::Invoke; using testing::NiceMock; using testing::Ref; using testing::Return; using testing::ReturnRef; using testing::SetArgReferee; using testing::StrictMock; +using testing::Test; using testing::TestParamInfo; using testing::TestWithParam; using testing::Values; @@ -163,7 +166,18 @@ ExpectationSet expectContainerEnd(MockProtocol& proto, ThriftFilters::MockDecode } // end namespace -class DecoderStateMachineNonValueTest : public TestWithParam {}; +class DecoderStateMachineTestBase { +public: + DecoderStateMachineTestBase() : metadata_(std::make_shared()) {} + virtual ~DecoderStateMachineTestBase() {} + + NiceMock proto_; + MessageMetadataSharedPtr metadata_; + NiceMock filter_; +}; + +class DecoderStateMachineNonValueTest : public DecoderStateMachineTestBase, + public TestWithParam {}; static std::string protoStateParamToString(const TestParamInfo& params) { return ProtocolStateNameValues::name(params.param); @@ -178,7 +192,10 @@ INSTANTIATE_TEST_CASE_P(NonValueProtocolStates, DecoderStateMachineNonValueTest, ProtocolState::SetBegin, ProtocolState::SetEnd), protoStateParamToString); -class DecoderStateMachineValueTest : public TestWithParam {}; +class DecoderStateMachineTest : public DecoderStateMachineTestBase, public Test {}; + +class DecoderStateMachineValueTest : public DecoderStateMachineTestBase, + public TestWithParam {}; INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, DecoderStateMachineValueTest, Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, @@ -186,7 +203,8 @@ INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, DecoderStateMachineValueTest, fieldTypeParamToString); class DecoderStateMachineNestingTest - : public TestWithParam> {}; + : public DecoderStateMachineTestBase, + public TestWithParam> {}; static std::string nestedFieldTypesParamToString( const TestParamInfo>& params) { @@ -207,9 +225,8 @@ INSTANTIATE_TEST_CASE_P( TEST_P(DecoderStateMachineNonValueTest, NoData) { ProtocolState state = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - StrictMock filter; - DecoderStateMachine dsm(proto, filter); + + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(state); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), state); @@ -219,19 +236,17 @@ TEST_P(DecoderStateMachineValueTest, NoFieldValueData) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, filter, field_type, false); - expectValue(proto, filter, field_type, true); - EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); + expectValue(proto_, filter_, field_type, false); + expectValue(proto_, filter_, field_type, true); + EXPECT_CALL(proto_, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -244,54 +259,48 @@ TEST_P(DecoderStateMachineValueTest, NoFieldValueData) { TEST_P(DecoderStateMachineValueTest, FieldValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, filter, field_type); + expectValue(proto_, filter_, field_type); - EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); + EXPECT_CALL(proto_, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::FieldBegin); } -TEST(DecoderStateMachineTest, NoListValueData) { +TEST_F(DecoderStateMachineTest, NoListValueData) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); - EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::ListValue); } -TEST(DecoderStateMachineTest, EmptyList) { +TEST_F(DecoderStateMachineTest, EmptyList) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); - EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -301,18 +310,16 @@ TEST(DecoderStateMachineTest, EmptyList) { TEST_P(DecoderStateMachineValueTest, ListValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, filter, field_type); + expectValue(proto_, filter_, field_type); - EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -322,75 +329,67 @@ TEST_P(DecoderStateMachineValueTest, ListValue) { TEST_P(DecoderStateMachineValueTest, MultipleListValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, filter, field_type); + expectValue(proto_, filter_, field_type); } - EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::ListEnd); } -TEST(DecoderStateMachineTest, NoMapKeyData) { +TEST_F(DecoderStateMachineTest, NoMapKeyData) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(FieldType::String), SetArgReferee<3>(1), Return(true))); - EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::MapKey); } -TEST(DecoderStateMachineTest, NoMapValueData) { +TEST_F(DecoderStateMachineTest, NoMapValueData) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(FieldType::String), SetArgReferee<3>(1), Return(true))); - EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(true)); - EXPECT_CALL(proto, readString(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto_, readString(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::MapValue); } -TEST(DecoderStateMachineTest, EmptyMap) { +TEST_F(DecoderStateMachineTest, EmptyMap) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(FieldType::String), SetArgReferee<3>(0), Return(true))); - EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -400,20 +399,18 @@ TEST(DecoderStateMachineTest, EmptyMap) { TEST_P(DecoderStateMachineValueTest, MapKeyValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(FieldType::String), SetArgReferee<3>(1), Return(true))); - expectValue(proto, filter, field_type); // key - expectValue(proto, filter, FieldType::String); // value + expectValue(proto_, filter_, field_type); // key + expectValue(proto_, filter_, FieldType::String); // value - EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -423,20 +420,18 @@ TEST_P(DecoderStateMachineValueTest, MapKeyValue) { TEST_P(DecoderStateMachineValueTest, MapValueValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, filter, FieldType::I32); // key - expectValue(proto, filter, field_type); // value + expectValue(proto_, filter_, FieldType::I32); // key + expectValue(proto_, filter_, field_type); // value - EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -446,56 +441,50 @@ TEST_P(DecoderStateMachineValueTest, MapValueValue) { TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(field_type), SetArgReferee<3>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, filter, FieldType::I32); // key - expectValue(proto, filter, field_type); // value + expectValue(proto_, filter_, FieldType::I32); // key + expectValue(proto_, filter_, field_type); // value } - EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::MapEnd); } -TEST(DecoderStateMachineTest, NoSetValueData) { +TEST_F(DecoderStateMachineTest, NoSetValueData) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); - EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_CALL(proto_, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::SetValue); } -TEST(DecoderStateMachineTest, EmptySet) { +TEST_F(DecoderStateMachineTest, EmptySet) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); - EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -505,18 +494,16 @@ TEST(DecoderStateMachineTest, EmptySet) { TEST_P(DecoderStateMachineValueTest, SetValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, filter, field_type); + expectValue(proto_, filter_, field_type); - EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -526,42 +513,42 @@ TEST_P(DecoderStateMachineValueTest, SetValue) { TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) + EXPECT_CALL(proto_, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, filter, field_type); + expectValue(proto_, filter_, field_type); } - EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(proto_, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), ProtocolState::SetEnd); } -TEST(DecoderStateMachineTest, EmptyStruct) { +TEST_F(DecoderStateMachineTest, EmptyStruct) { Buffer::OwnedImpl buffer; - NiceMock proto; - NiceMock filter; InSequence dummy; - EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); - EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(proto_, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto_, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -570,88 +557,108 @@ TEST(DecoderStateMachineTest, EmptyStruct) { TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - StrictMock filter; InSequence dummy; - EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); - EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + EXPECT_CALL(proto_, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(filter_, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return ThriftFilters::FilterStatus::Continue; + })); + + EXPECT_CALL(proto_, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter_, structBegin(absl::string_view())) .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); - EXPECT_CALL(filter, structBegin(absl::string_view())) - .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, 1)) + EXPECT_CALL(filter_, fieldBegin(absl::string_view(), field_type, 1)) .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - expectValue(proto, filter, field_type); + expectValue(proto_, filter_, field_type); - EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto_, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter_, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto_, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter_, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter_, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); } -TEST(DecoderStateMachineTest, MultiFieldStruct) { +TEST_F(DecoderStateMachineTest, MultiFieldStruct) { Buffer::OwnedImpl buffer; - NiceMock proto; - StrictMock filter; InSequence dummy; std::vector field_types = {FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, FieldType::I32, FieldType::I64, FieldType::String}; - EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); - EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) - .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - - EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); - EXPECT_CALL(filter, structBegin(absl::string_view())) + EXPECT_CALL(proto_, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(filter_, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return ThriftFilters::FilterStatus::Continue; + })); + + EXPECT_CALL(proto_, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter_, structBegin(absl::string_view())) .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); int16_t field_id = 1; for (FieldType field_type : field_types) { - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(field_id), Return(true))); - EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, field_id)) + EXPECT_CALL(filter_, fieldBegin(absl::string_view(), field_type, field_id)) .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); field_id++; - expectValue(proto, filter, field_type); + expectValue(proto_, filter_, field_type); - EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto_, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter_, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); } - EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto_, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto_, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter_, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter_, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -662,42 +669,52 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { std::tie(outer_field_type, inner_type, value_type) = GetParam(); Buffer::OwnedImpl buffer; - NiceMock proto; - StrictMock filter; InSequence dummy; // start of message and outermost struct - EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); - EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) - .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - - expectContainerStart(proto, filter, FieldType::Struct, outer_field_type); - - expectContainerStart(proto, filter, outer_field_type, inner_type); + EXPECT_CALL(proto_, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(filter_, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return ThriftFilters::FilterStatus::Continue; + })); + + expectContainerStart(proto_, filter_, FieldType::Struct, outer_field_type); + + expectContainerStart(proto_, filter_, outer_field_type, inner_type); int outer_reps = outer_field_type == FieldType::Map ? 2 : 1; for (int i = 0; i < outer_reps; i++) { - expectContainerStart(proto, filter, inner_type, value_type); + expectContainerStart(proto_, filter_, inner_type, value_type); int inner_reps = inner_type == FieldType::Map ? 2 : 1; for (int j = 0; j < inner_reps; j++) { - expectValue(proto, filter, value_type); + expectValue(proto_, filter_, value_type); } - expectContainerEnd(proto, filter, inner_type); + expectContainerEnd(proto_, filter_, inner_type); } - expectContainerEnd(proto, filter, outer_field_type); + expectContainerEnd(proto_, filter_, outer_field_type); // end of message and outermost struct - expectContainerEnd(proto, filter, FieldType::Struct); + expectContainerEnd(proto_, filter_, FieldType::Struct); - EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto_, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter_, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto, filter); + DecoderStateMachine dsm(proto_, metadata_, filter_); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -715,15 +732,34 @@ TEST(DecoderTest, OnData) { Buffer::OwnedImpl buffer; EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) - .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); - EXPECT_CALL(filter, transportBegin(absl::optional(100))) - .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); - EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) - .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(filter, transportBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasFrameSize()); + EXPECT_EQ(100U, metadata->frameSize()); + return ThriftFilters::FilterStatus::Continue; + })); + + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(filter, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return ThriftFilters::FilterStatus::Continue; + })); EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(filter, structBegin(absl::string_view())) @@ -759,10 +795,17 @@ TEST(DecoderTest, OnDataResumes) { buffer.add("x"); EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) - .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); - EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(*proto, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(false)); bool underflow = false; @@ -796,16 +839,22 @@ TEST(DecoderTest, OnDataResumesTransportFrameStart) { Buffer::OwnedImpl buffer; bool underflow = false; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) - .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(false))); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) - .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); - EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(*proto, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); @@ -834,10 +883,17 @@ TEST(DecoderTest, OnDataResumesTransportFrameEnd) { Buffer::OwnedImpl buffer; EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) - .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); - EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(*proto, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); @@ -871,18 +927,40 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { Buffer::OwnedImpl buffer; bool underflow = true; + HeaderMap headers{{"test", "header"}}; + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) - .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); - EXPECT_CALL(filter, transportBegin(absl::optional(100))) - .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + return true; + })); + EXPECT_CALL(filter, transportBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasFrameSize()); + EXPECT_EQ(100U, metadata->frameSize()); + + return ThriftFilters::FilterStatus::StopIteration; + })); EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), - SetArgReferee<3>(100), Return(true))); - EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) - .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(filter, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return ThriftFilters::FilterStatus::StopIteration; + })); EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); diff --git a/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc index 2999de7edbf53..6c73d82c29402 100644 --- a/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc @@ -28,15 +28,15 @@ TEST(FramedTransportTest, Type) { TEST(FramedTransportTest, NotEnoughData) { Buffer::OwnedImpl buffer; FramedTransportImpl transport; - absl::optional size = 1; + MessageMetadata metadata; - EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); - EXPECT_EQ(absl::optional(1), size); + EXPECT_FALSE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); addRepeated(buffer, 3, 0); - EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); - EXPECT_EQ(absl::optional(1), size); + EXPECT_FALSE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); } TEST(FramedTransportTest, InvalidFrameSize) { @@ -46,20 +46,20 @@ TEST(FramedTransportTest, InvalidFrameSize) { Buffer::OwnedImpl buffer; addInt32(buffer, -1); - absl::optional size = 1; - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + MessageMetadata metadata; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, "invalid thrift framed transport frame size -1"); - EXPECT_EQ(absl::optional(1), size); + EXPECT_THAT(metadata, IsEmptyMetadata()); } { Buffer::OwnedImpl buffer; addInt32(buffer, 0x7fffffff); - absl::optional size = 1; - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + MessageMetadata metadata; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, "invalid thrift framed transport frame size 2147483647"); - EXPECT_EQ(absl::optional(1), size); + EXPECT_THAT(metadata, IsEmptyMetadata()); } } @@ -70,9 +70,9 @@ TEST(FramedTransportTest, DecodeFrameStart) { addInt32(buffer, 100); EXPECT_EQ(buffer.length(), 4); - absl::optional size; - EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); - EXPECT_EQ(absl::optional(100U), size); + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasOnlyFrameSize(100U)); EXPECT_EQ(buffer.length(), 0); } @@ -88,11 +88,12 @@ TEST(FramedTransportTest, EncodeFrame) { FramedTransportImpl transport; { + MessageMetadata metadata; Buffer::OwnedImpl message; message.add("fake message"); Buffer::OwnedImpl buffer; - transport.encodeFrame(buffer, message); + transport.encodeFrame(buffer, metadata, message); EXPECT_EQ(0, message.length()); EXPECT_EQ(std::string("\0\0\0\xC" @@ -102,9 +103,10 @@ TEST(FramedTransportTest, EncodeFrame) { } { + MessageMetadata metadata; Buffer::OwnedImpl message; Buffer::OwnedImpl buffer; - EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, message), EnvoyException, + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, message), EnvoyException, "invalid thrift framed transport frame size 0"); } } diff --git a/test/extensions/filters/network/thrift_proxy/metadata_test.cc b/test/extensions/filters/network/thrift_proxy/metadata_test.cc new file mode 100644 index 0000000000000..d06fa46e6b81f --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/metadata_test.cc @@ -0,0 +1,173 @@ +#include "extensions/filters/network/thrift_proxy/metadata.h" + +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +TEST(HeaderTest, HeaderKeyIsNotTransformed) { + Header hdr("KEY", "VALUE"); + EXPECT_EQ(hdr.key(), "KEY"); + EXPECT_EQ(hdr.value(), "VALUE"); +} + +TEST(HeaderTest, HeaderIsCopyable) { + Header hdr("KEY", "VALUE"); + Header hdrCopy(hdr); + EXPECT_EQ(hdrCopy.key(), "KEY"); + EXPECT_EQ(hdrCopy.value(), "VALUE"); +} + +TEST(HeaderMapTest, AddHeaders) { + HeaderMap headers; + headers.add(Header("k", "v")); + + Header* hdr = headers.get("k"); + EXPECT_NE(hdr, nullptr); + EXPECT_EQ(hdr->key(), "k"); + EXPECT_EQ(hdr->value(), "v"); +} + +TEST(HeaderMapTest, GetHeaders) { + HeaderMap headers({ + {"a", "b"}, + {"c", "d"}, + {"e", "f"}, + }); + + EXPECT_EQ(headers.get("a")->value(), "b"); + EXPECT_EQ(headers.get("c")->value(), "d"); + EXPECT_EQ(headers.get("e")->value(), "f"); +} + +TEST(HeaderMapTest, Clear) { + HeaderMap headers({ + {"a", "b"}, + {"c", "d"}, + {"e", "f"}, + }); + + headers.clear(); + EXPECT_EQ(headers.get("a"), nullptr); + EXPECT_EQ(headers.get("c"), nullptr); + EXPECT_EQ(headers.get("e"), nullptr); +} + +TEST(HeaderMapTest, Size) { + HeaderMap headers({ + {"a", "b"}, + {"c", "d"}, + {"e", "f"}, + }); + + EXPECT_EQ(3U, headers.size()); +} + +TEST(HeaderMapTest, Equality) { + HeaderMap headers1({{"FIRST", "1"}, {"Second", "2"}}); + HeaderMap headers2({{"FIRST", "1"}, {"Second", "2"}}); + HeaderMap headers3({{"FIRST", "1"}}); + HeaderMap headers4({{"FIRST", "_"}, {"Second", "2"}}); + HeaderMap headers5({{"First", "1"}, {"Second", "2"}}); + + EXPECT_EQ(headers1, headers2); + EXPECT_EQ(headers2, headers1); + + EXPECT_FALSE(headers1 == headers3); + EXPECT_FALSE(headers3 == headers1); + + EXPECT_FALSE(headers1 == headers4); + EXPECT_FALSE(headers4 == headers1); + + EXPECT_FALSE(headers1 == headers5); + EXPECT_FALSE(headers5 == headers1); +} + +TEST(HeaderMapTest, Iteration) { + HeaderMap headers({{"first", "1"}, {"second", "2"}}); + + int i = 0; + for (const Header& header : headers) { + switch (i) { + case 0: + EXPECT_EQ("first", header.key()); + EXPECT_EQ("1", header.value()); + break; + case 1: + EXPECT_EQ("second", header.key()); + EXPECT_EQ("2", header.value()); + break; + default: + ASSERT(false); + break; + } + i++; + } +} + +TEST(HeaderMapTest, CopyConstructor) { + HeaderMap headers({{"first", "1"}, {"second", "2"}, {"third", "3"}}); + HeaderMap copy(headers); + EXPECT_EQ(copy, headers); +} + +TEST(MessageMetadataTest, Fields) { + MessageMetadata metadata; + + EXPECT_FALSE(metadata.hasFrameSize()); + EXPECT_THROW(metadata.frameSize(), absl::bad_optional_access); + metadata.setFrameSize(100); + EXPECT_TRUE(metadata.hasFrameSize()); + EXPECT_EQ(100, metadata.frameSize()); + + EXPECT_FALSE(metadata.hasProtocol()); + EXPECT_THROW(metadata.protocol(), absl::bad_optional_access); + metadata.setProtocol(ProtocolType::Binary); + EXPECT_TRUE(metadata.hasProtocol()); + EXPECT_EQ(ProtocolType::Binary, metadata.protocol()); + + EXPECT_FALSE(metadata.hasMethodName()); + EXPECT_THROW(metadata.methodName(), absl::bad_optional_access); + metadata.setMethodName("method"); + EXPECT_TRUE(metadata.hasMethodName()); + EXPECT_EQ("method", metadata.methodName()); + + EXPECT_FALSE(metadata.hasMessageType()); + EXPECT_THROW(metadata.messageType(), absl::bad_optional_access); + metadata.setMessageType(MessageType::Call); + EXPECT_TRUE(metadata.hasMessageType()); + EXPECT_EQ(MessageType::Call, metadata.messageType()); + + EXPECT_FALSE(metadata.hasSequenceId()); + EXPECT_THROW(metadata.sequenceId(), absl::bad_optional_access); + metadata.setSequenceId(101); + EXPECT_TRUE(metadata.hasSequenceId()); + EXPECT_EQ(101, metadata.sequenceId()); + + EXPECT_FALSE(metadata.hasAppException()); + EXPECT_THROW(metadata.appExceptionType(), absl::bad_optional_access); + EXPECT_THROW(metadata.appExceptionMessage(), absl::bad_optional_access); + metadata.setAppException(AppExceptionType::InternalError, "oops"); + EXPECT_TRUE(metadata.hasAppException()); + EXPECT_EQ(AppExceptionType::InternalError, metadata.appExceptionType()); + EXPECT_EQ("oops", metadata.appExceptionMessage()); +} + +TEST(MessageMetadataTest, Headers) { + MessageMetadata metadata; + + EXPECT_TRUE(metadata.headers().empty()); + + metadata.addHeader(Header("k", "v")); + EXPECT_FALSE(metadata.headers().empty()); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index caa93654233e8..85cdd7377a32b 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -2,6 +2,7 @@ #include "gtest/gtest.h" +using testing::Invoke; using testing::Return; using testing::ReturnRef; using testing::_; @@ -23,6 +24,9 @@ MockTransport::~MockTransport() {} MockProtocol::MockProtocol() { ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); ON_CALL(*this, type()).WillByDefault(Return(type_)); + ON_CALL(*this, setType(_)).WillByDefault(Invoke([&](ProtocolType type) -> void { + type_ = type; + })); } MockProtocol::~MockProtocol() {} @@ -34,7 +38,7 @@ namespace ThriftFilters { MockDecoderFilter::MockDecoderFilter() { ON_CALL(*this, transportBegin(_)).WillByDefault(Return(FilterStatus::Continue)); ON_CALL(*this, transportEnd()).WillByDefault(Return(FilterStatus::Continue)); - ON_CALL(*this, messageBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageBegin(_)).WillByDefault(Return(FilterStatus::Continue)); ON_CALL(*this, messageEnd()).WillByDefault(Return(FilterStatus::Continue)); ON_CALL(*this, structBegin(_)).WillByDefault(Return(FilterStatus::Continue)); ON_CALL(*this, structEnd()).WillByDefault(Return(FilterStatus::Continue)); diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index f932bc808d418..d60a40cd9b759 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -2,6 +2,7 @@ #include "extensions/filters/network/thrift_proxy/conn_manager.h" #include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/metadata.h" #include "extensions/filters/network/thrift_proxy/protocol.h" #include "extensions/filters/network/thrift_proxy/router/router.h" #include "extensions/filters/network/thrift_proxy/transport.h" @@ -38,9 +39,9 @@ class MockTransport : public Transport { // ThriftProxy::Transport MOCK_CONST_METHOD0(name, const std::string&()); MOCK_CONST_METHOD0(type, TransportType()); - MOCK_METHOD2(decodeFrameStart, bool(Buffer::Instance&, absl::optional&)); + MOCK_METHOD2(decodeFrameStart, bool(Buffer::Instance&, MessageMetadata&)); MOCK_METHOD1(decodeFrameEnd, bool(Buffer::Instance&)); - MOCK_METHOD2(encodeFrame, void(Buffer::Instance&, Buffer::Instance&)); + MOCK_METHOD3(encodeFrame, void(Buffer::Instance&, const MessageMetadata&, Buffer::Instance&)); std::string name_{"mock"}; TransportType type_{TransportType::Auto}; @@ -54,8 +55,8 @@ class MockProtocol : public Protocol { // ThriftProxy::Protocol MOCK_CONST_METHOD0(name, const std::string&()); MOCK_CONST_METHOD0(type, ProtocolType()); - MOCK_METHOD4(readMessageBegin, bool(Buffer::Instance& buffer, std::string& name, - MessageType& msg_type, int32_t& seq_id)); + MOCK_METHOD1(setType, void(ProtocolType)); + MOCK_METHOD2(readMessageBegin, bool(Buffer::Instance& buffer, MessageMetadata& metadata)); MOCK_METHOD1(readMessageEnd, bool(Buffer::Instance& buffer)); MOCK_METHOD2(readStructBegin, bool(Buffer::Instance& buffer, std::string& name)); MOCK_METHOD1(readStructEnd, bool(Buffer::Instance& buffer)); @@ -78,8 +79,7 @@ class MockProtocol : public Protocol { MOCK_METHOD2(readString, bool(Buffer::Instance& buffer, std::string& value)); MOCK_METHOD2(readBinary, bool(Buffer::Instance& buffer, std::string& value)); - MOCK_METHOD4(writeMessageBegin, void(Buffer::Instance& buffer, const std::string& name, - MessageType msg_type, int32_t seq_id)); + MOCK_METHOD2(writeMessageBegin, void(Buffer::Instance& buffer, const MessageMetadata& metadata)); MOCK_METHOD1(writeMessageEnd, void(Buffer::Instance& buffer)); MOCK_METHOD2(writeStructBegin, void(Buffer::Instance& buffer, const std::string& name)); MOCK_METHOD1(writeStructEnd, void(Buffer::Instance& buffer)); @@ -126,10 +126,9 @@ class MockDecoderFilter : public DecoderFilter { MOCK_METHOD0(onDestroy, void()); MOCK_METHOD1(setDecoderFilterCallbacks, void(DecoderFilterCallbacks& callbacks)); MOCK_METHOD0(resetUpstreamConnection, void()); - MOCK_METHOD1(transportBegin, FilterStatus(absl::optional size)); + MOCK_METHOD1(transportBegin, FilterStatus(MessageMetadataSharedPtr metadata)); MOCK_METHOD0(transportEnd, FilterStatus()); - MOCK_METHOD3(messageBegin, - FilterStatus(const absl::string_view name, MessageType msg_type, int32_t seq_id)); + MOCK_METHOD1(messageBegin, FilterStatus(MessageMetadataSharedPtr metadata)); MOCK_METHOD0(messageEnd, FilterStatus()); MOCK_METHOD1(structBegin, FilterStatus(const absl::string_view name)); MOCK_METHOD0(structEnd, FilterStatus()); @@ -163,13 +162,11 @@ class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { MOCK_METHOD0(route, Router::RouteConstSharedPtr()); MOCK_CONST_METHOD0(downstreamTransportType, TransportType()); MOCK_CONST_METHOD0(downstreamProtocolType, ProtocolType()); - void sendLocalReply(DirectResponsePtr&& response) override { sendLocalReply_(response); } + MOCK_METHOD1(sendLocalReply, void(const DirectResponse&)); MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); MOCK_METHOD0(resetDownstreamConnection, void()); - MOCK_METHOD1(sendLocalReply_, void(DirectResponsePtr&)); - uint64_t stream_id_{1}; NiceMock connection_; }; diff --git a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc index 7a8fef74a1490..da9f0f90ed93f 100644 --- a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc @@ -24,6 +24,35 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +class AutoProtocolTest : public testing::Test { +public: + void resetMetadata() { + metadata_.setMethodName("-"); + metadata_.setMessageType(MessageType::Oneway); + metadata_.setSequenceId(-1); + } + + void expectMetadata(const std::string& name, MessageType msg_type, int32_t seq_id) { + EXPECT_TRUE(metadata_.hasMethodName()); + EXPECT_EQ(name, metadata_.methodName()); + + EXPECT_TRUE(metadata_.hasMessageType()); + EXPECT_EQ(msg_type, metadata_.messageType()); + + EXPECT_TRUE(metadata_.hasSequenceId()); + EXPECT_EQ(seq_id, metadata_.sequenceId()); + + EXPECT_FALSE(metadata_.hasFrameSize()); + EXPECT_FALSE(metadata_.hasProtocol()); + EXPECT_FALSE(metadata_.hasAppException()); + EXPECT_TRUE(metadata_.headers().empty()); + } + + void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, -1); } + + MessageMetadata metadata_; +}; + TEST(ProtocolNames, FromType) { for (int i = 0; i <= static_cast(ProtocolType::LastProtocolType); i++) { ProtocolType type = static_cast(i); @@ -31,43 +60,33 @@ TEST(ProtocolNames, FromType) { } } -TEST(AutoProtocolTest, NotEnoughData) { +TEST_F(AutoProtocolTest, NotEnoughData) { Buffer::OwnedImpl buffer; AutoProtocolImpl proto; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = -1; + resetMetadata(); addInt8(buffer, 0); - EXPECT_FALSE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, -1); + EXPECT_FALSE(proto.readMessageBegin(buffer, metadata_)); + expectDefaultMetadata(); } -TEST(AutoProtocolTest, UnknownProtocol) { +TEST_F(AutoProtocolTest, UnknownProtocol) { Buffer::OwnedImpl buffer; AutoProtocolImpl proto; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = -1; + resetMetadata(); addInt16(buffer, 0x0102); - EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, name, msg_type, seq_id), EnvoyException, + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, "unknown thrift auto protocol message start 0102"); - EXPECT_EQ(name, "-"); - EXPECT_EQ(msg_type, MessageType::Oneway); - EXPECT_EQ(seq_id, -1); + expectDefaultMetadata(); } -TEST(AutoProtocolTest, ReadMessageBegin) { +TEST_F(AutoProtocolTest, ReadMessageBegin) { // Binary Protocol { AutoProtocolImpl proto; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = -1; + resetMetadata(); Buffer::OwnedImpl buffer; addInt16(buffer, 0x8001); @@ -77,10 +96,8 @@ TEST(AutoProtocolTest, ReadMessageBegin) { addString(buffer, "the_name"); addInt32(buffer, 1); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "the_name"); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 1); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("the_name", MessageType::Call, 1); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "binary(auto)"); EXPECT_EQ(proto.type(), ProtocolType::Binary); @@ -89,9 +106,7 @@ TEST(AutoProtocolTest, ReadMessageBegin) { // Compact protocol { AutoProtocolImpl proto; - std::string name = "-"; - MessageType msg_type = MessageType::Oneway; - int32_t seq_id = 1; + resetMetadata(); Buffer::OwnedImpl buffer; addInt16(buffer, 0x8221); @@ -99,36 +114,32 @@ TEST(AutoProtocolTest, ReadMessageBegin) { addInt8(buffer, 8); addString(buffer, "the_name"); - EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); - EXPECT_EQ(name, "the_name"); - EXPECT_EQ(msg_type, MessageType::Call); - EXPECT_EQ(seq_id, 0x0102); + EXPECT_TRUE(proto.readMessageBegin(buffer, metadata_)); + expectMetadata("the_name", MessageType::Call, 0x0102); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "compact(auto)"); EXPECT_EQ(proto.type(), ProtocolType::Compact); } } -TEST(AutoProtocolTest, ReadDelegation) { +TEST_F(AutoProtocolTest, ReadDelegation) { NiceMock* proto = new NiceMock(); AutoProtocolImpl auto_proto; auto_proto.setProtocol(ProtocolPtr{proto}); // readMessageBegin Buffer::OwnedImpl buffer; - std::string name = "x"; - MessageType msg_type = MessageType::Call; - int32_t seq_id = 1; + resetMetadata(); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), Ref(name), Ref(msg_type), Ref(seq_id))) - .WillOnce(Return(true)); - EXPECT_TRUE(auto_proto.readMessageBegin(buffer, name, msg_type, seq_id)); + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), Ref(metadata_))).WillOnce(Return(true)); + EXPECT_TRUE(auto_proto.readMessageBegin(buffer, metadata_)); // readMessageEnd EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_TRUE(auto_proto.readMessageEnd(buffer)); // readStructBegin + std::string name; EXPECT_CALL(*proto, readStructBegin(Ref(buffer), Ref(name))).WillOnce(Return(true)); EXPECT_TRUE(auto_proto.readStructBegin(buffer, name)); @@ -234,15 +245,15 @@ TEST(AutoProtocolTest, ReadDelegation) { } } -TEST(AutoProtocolTest, WriteDelegation) { +TEST_F(AutoProtocolTest, WriteDelegation) { NiceMock* proto = new NiceMock(); AutoProtocolImpl auto_proto; auto_proto.setProtocol(ProtocolPtr{proto}); // writeMessageBegin Buffer::OwnedImpl buffer; - EXPECT_CALL(*proto, writeMessageBegin(Ref(buffer), "name", MessageType::Call, 100)); - auto_proto.writeMessageBegin(buffer, "name", MessageType::Call, 100); + EXPECT_CALL(*proto, writeMessageBegin(Ref(buffer), Ref(metadata_))); + auto_proto.writeMessageBegin(buffer, metadata_); // writeMessageEnd EXPECT_CALL(*proto, writeMessageEnd(Ref(buffer))); @@ -321,12 +332,12 @@ TEST(AutoProtocolTest, WriteDelegation) { auto_proto.writeBinary(buffer, "binary"); } -TEST(AutoProtocolTest, Name) { +TEST_F(AutoProtocolTest, Name) { AutoProtocolImpl proto; EXPECT_EQ(proto.name(), "auto"); } -TEST(AutoProtocolTest, Type) { +TEST_F(AutoProtocolTest, Type) { AutoProtocolImpl proto; EXPECT_EQ(proto.type(), ProtocolType::Auto); } diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 94afe87c750d1..e0b92b515016c 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -89,15 +89,24 @@ class ThriftRouterTestBase { router_->setDecoderFilterCallbacks(callbacks_); } - void startRequest(MessageType msg_type) { + void initializeMetadata(MessageType msg_type) { msg_type_ = msg_type; - EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportBegin({})); + metadata_.reset(new MessageMetadata()); + metadata_->setMethodName("method"); + metadata_->setMessageType(msg_type_); + metadata_->setSequenceId(1); + } + + void startRequest(MessageType msg_type) { + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportBegin(metadata_)); EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + initializeMetadata(msg_type); + EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, newConnection(_)) .WillOnce( Invoke([&](Tcp::ConnectionPool::Callbacks& cb) -> Tcp::ConnectionPool::Cancellable* { @@ -105,8 +114,7 @@ class ThriftRouterTestBase { return &handle_; })); - EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, - router_->messageBegin(method_name_, msg_type_, seq_id_)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, router_->messageBegin(metadata_)); EXPECT_NE(nullptr, conn_pool_callbacks_); NiceMock connection; @@ -132,7 +140,12 @@ class ThriftRouterTestBase { EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); protocol_ = new NiceMock(); ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); - EXPECT_CALL(*protocol_, writeMessageBegin(_, method_name_, msg_type_, seq_id_)); + EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); EXPECT_CALL(callbacks_, continueDecoding()); conn_pool_callbacks_->onPoolReady(conn_data_, host_ptr_); @@ -193,7 +206,7 @@ class ThriftRouterTestBase { void completeRequest() { EXPECT_CALL(*protocol_, writeMessageEnd(_)); - EXPECT_CALL(*transport_, encodeFrame(_, _)); + EXPECT_CALL(*transport_, encodeFrame(_, _, _)); EXPECT_CALL(conn_data_.connection_, write(_, false)); if (msg_type_ == MessageType::Oneway) { @@ -242,9 +255,8 @@ class ThriftRouterTestBase { std::string cluster_name_{"cluster"}; - std::string method_name_{"method"}; MessageType msg_type_{MessageType::Call}; - int32_t seq_id_{1}; + MessageMetadataSharedPtr metadata_; NiceMock handle_; NiceMock conn_data_; @@ -281,14 +293,11 @@ TEST_F(ThriftRouterTest, PoolRemoteConnectionFailure) { startRequest(MessageType::Call); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*")); })); conn_pool_callbacks_->onPoolFailure( Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure, host_ptr_); @@ -299,14 +308,11 @@ TEST_F(ThriftRouterTest, PoolLocalConnectionFailure) { startRequest(MessageType::Call); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*")); })); conn_pool_callbacks_->onPoolFailure( Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure, host_ptr_); @@ -317,14 +323,11 @@ TEST_F(ThriftRouterTest, PoolTimeout) { startRequest(MessageType::Call); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*")); })); conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::Timeout, host_ptr_); } @@ -334,59 +337,49 @@ TEST_F(ThriftRouterTest, PoolOverflowFailure) { startRequest(MessageType::Call); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*too many connections.*")); + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*too many connections.*")); })); conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::Overflow, host_ptr_); } TEST_F(ThriftRouterTest, NoRoute) { initializeRouter(); + initializeMetadata(MessageType::Call); EXPECT_CALL(callbacks_, route()).WillOnce(Return(nullptr)); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - if (app_ex != nullptr) { - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::UnknownMethod, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no route.*")); - } + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::UnknownMethod, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*no route.*")); })); - EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, - router_->messageBegin(method_name_, MessageType::Call, seq_id_)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, router_->messageBegin(metadata_)); } TEST_F(ThriftRouterTest, NoCluster) { initializeRouter(); + initializeMetadata(MessageType::Call); EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); EXPECT_CALL(context_.cluster_manager_, get(cluster_name_)).WillOnce(Return(nullptr)); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*unknown cluster.*")); + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*unknown cluster.*")); })); - EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, - router_->messageBegin(method_name_, MessageType::Call, seq_id_)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, router_->messageBegin(metadata_)); } TEST_F(ThriftRouterTest, ClusterMaintenanceMode) { initializeRouter(); + initializeMetadata(MessageType::Call); EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); @@ -394,21 +387,18 @@ TEST_F(ThriftRouterTest, ClusterMaintenanceMode) { EXPECT_CALL(*context_.cluster_manager_.thread_local_cluster_.cluster_.info_, maintenanceMode()) .WillOnce(Return(true)); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*maintenance mode.*")); + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*maintenance mode.*")); })); - EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, - router_->messageBegin(method_name_, MessageType::Call, seq_id_)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, router_->messageBegin(metadata_)); } TEST_F(ThriftRouterTest, NoHealthyHosts) { initializeRouter(); + initializeMetadata(MessageType::Call); EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); @@ -416,17 +406,14 @@ TEST_F(ThriftRouterTest, NoHealthyHosts) { EXPECT_CALL(context_.cluster_manager_, tcpConnPoolForCluster(cluster_name_, _, _)) .WillOnce(Return(nullptr)); - EXPECT_CALL(callbacks_, sendLocalReply_(_)) - .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { - auto* app_ex = dynamic_cast(response.get()); - EXPECT_NE(nullptr, app_ex); - EXPECT_EQ(method_name_, app_ex->method_name_); - EXPECT_EQ(seq_id_, app_ex->seq_id_); - EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); - EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no healthy upstream.*")); + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*no healthy upstream.*")); })); - EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, - router_->messageBegin(method_name_, MessageType::Call, seq_id_)); + + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, router_->messageBegin(metadata_)); } TEST_F(ThriftRouterTest, TruncatedResponse) { @@ -468,6 +455,36 @@ TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) { destroyRouter(); } +TEST_F(ThriftRouterTest, UnexpectedUpstreamRemoteClose) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*")); + })); + router_->onEvent(Network::ConnectionEvent::RemoteClose); +} + +TEST_F(ThriftRouterTest, UnexpectedUpstreamLocalClose) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + + EXPECT_CALL(callbacks_, sendLocalReply(_)) + .WillOnce(Invoke([&](const DirectResponse& response) -> void { + auto& app_ex = dynamic_cast(response); + EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_); + EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*")); + })); + router_->onEvent(Network::ConnectionEvent::RemoteClose); +} + TEST_F(ThriftRouterTest, UnexpectedRouterDestroyBeforeUpstreamConnect) { initializeRouter(); startRequest(MessageType::Call); @@ -586,14 +603,20 @@ name: config parseRouteConfigurationFromV2Yaml(yaml); RouteMatcher matcher(config); - EXPECT_EQ(nullptr, matcher.route("unknown")); - EXPECT_EQ(nullptr, matcher.route("METHOD1")); - - RouteConstSharedPtr route = matcher.route("method1"); + MessageMetadata metadata; + EXPECT_EQ(nullptr, matcher.route(metadata)); + metadata.setMethodName("unknown"); + EXPECT_EQ(nullptr, matcher.route(metadata)); + metadata.setMethodName("METHOD1"); + EXPECT_EQ(nullptr, matcher.route(metadata)); + + metadata.setMethodName("method1"); + RouteConstSharedPtr route = matcher.route(metadata); EXPECT_NE(nullptr, route); EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); - RouteConstSharedPtr route2 = matcher.route("method2"); + metadata.setMethodName("method2"); + RouteConstSharedPtr route2 = matcher.route(metadata); EXPECT_NE(nullptr, route2); EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); } @@ -615,13 +638,26 @@ name: config parseRouteConfigurationFromV2Yaml(yaml); RouteMatcher matcher(config); - RouteConstSharedPtr route = matcher.route("method1"); - EXPECT_NE(nullptr, route); - EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); - RouteConstSharedPtr route2 = matcher.route("anything"); - EXPECT_NE(nullptr, route2); - EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); + { + MessageMetadata metadata; + metadata.setMethodName("method1"); + RouteConstSharedPtr route = matcher.route(metadata); + EXPECT_NE(nullptr, route); + EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); + + metadata.setMethodName("anything"); + RouteConstSharedPtr route2 = matcher.route(metadata); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); + } + + { + MessageMetadata metadata; + RouteConstSharedPtr route2 = matcher.route(metadata); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); + } } } // namespace Router diff --git a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc index 64ab3cce319be..93ddab8928a63 100644 --- a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc @@ -30,15 +30,15 @@ TEST(TransportNames, FromType) { TEST(AutoTransportTest, NotEnoughData) { Buffer::OwnedImpl buffer; AutoTransportImpl transport; - absl::optional size = 100; + MessageMetadata metadata; - EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); - EXPECT_EQ(absl::optional(100), size); + EXPECT_FALSE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); addRepeated(buffer, 7, 0); - EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); - EXPECT_EQ(absl::optional(100), size); + EXPECT_FALSE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); } TEST(AutoTransportTest, UnknownTransport) { @@ -50,10 +50,10 @@ TEST(AutoTransportTest, UnknownTransport) { addInt32(buffer, 0); addInt32(buffer, 0); - absl::optional size = 100; - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + MessageMetadata metadata; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, "unknown thrift auto transport frame start 00 00 00 00 00 00 00 00"); - EXPECT_EQ(absl::optional(100), size); + EXPECT_THAT(metadata, IsEmptyMetadata()); } // Looks like framed, but fails protocol check. @@ -62,10 +62,10 @@ TEST(AutoTransportTest, UnknownTransport) { addInt32(buffer, 0xFF); addInt32(buffer, 0); - absl::optional size = 100; - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, + MessageMetadata metadata; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, "unknown thrift auto transport frame start 00 00 00 ff 00 00 00 00"); - EXPECT_EQ(absl::optional(100), size); + EXPECT_THAT(metadata, IsEmptyMetadata()); } } @@ -78,9 +78,9 @@ TEST(AutoTransportTest, DecodeFrameStart) { addInt16(buffer, 0x8001); addInt16(buffer, 0); - absl::optional size; - EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); - EXPECT_EQ(absl::optional(255), size); + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasOnlyFrameSize(255U)); EXPECT_EQ(transport.name(), "framed(auto)"); EXPECT_EQ(transport.type(), TransportType::Framed); EXPECT_EQ(buffer.length(), 4); @@ -94,9 +94,9 @@ TEST(AutoTransportTest, DecodeFrameStart) { addInt16(buffer, 0x8201); addInt16(buffer, 0); - absl::optional size; - EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); - EXPECT_EQ(absl::optional(4095), size); + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasOnlyFrameSize(4095U)); EXPECT_EQ(transport.name(), "framed(auto)"); EXPECT_EQ(transport.type(), TransportType::Framed); EXPECT_EQ(buffer.length(), 4); @@ -109,9 +109,9 @@ TEST(AutoTransportTest, DecodeFrameStart) { addInt16(buffer, 0x8001); addRepeated(buffer, 6, 0); - absl::optional size = 1; - EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); - EXPECT_FALSE(size.has_value()); + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); EXPECT_EQ(transport.name(), "unframed(auto)"); EXPECT_EQ(transport.type(), TransportType::Unframed); EXPECT_EQ(buffer.length(), 8); @@ -124,9 +124,9 @@ TEST(AutoTransportTest, DecodeFrameStart) { addInt16(buffer, 0x8201); addRepeated(buffer, 6, 0); - absl::optional size = 1; - EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); - EXPECT_FALSE(size.has_value()); + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); EXPECT_EQ(transport.name(), "unframed(auto)"); EXPECT_EQ(transport.type(), TransportType::Unframed); EXPECT_EQ(buffer.length(), 8); @@ -140,8 +140,9 @@ TEST(AutoTransportTest, DecodeFrameEnd) { addInt16(buffer, 0x8001); addInt16(buffer, 0); - absl::optional size; - EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_EQ(buffer.length(), 4); EXPECT_TRUE(transport.decodeFrameEnd(buffer)); @@ -153,11 +154,12 @@ TEST(AutoTransportTest, EncodeFrame) { AutoTransportImpl transport; transport.setTransport(TransportPtr{mock_transport}); + MessageMetadata metadata; Buffer::OwnedImpl buffer; Buffer::OwnedImpl message; - EXPECT_CALL(*mock_transport, encodeFrame(Ref(buffer), Ref(message))); - transport.encodeFrame(buffer, message); + EXPECT_CALL(*mock_transport, encodeFrame(Ref(buffer), Ref(metadata), Ref(message))); + transport.encodeFrame(buffer, metadata, message); } TEST(AutoTransportTest, Name) { diff --git a/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc index f83119ffaf383..7c3c36b49a9b4 100644 --- a/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc @@ -30,9 +30,9 @@ TEST(UnframedTransportTest, DecodeFrameStart) { addInt32(buffer, 0xDEADBEEF); EXPECT_EQ(buffer.length(), 4); - absl::optional size = 1; - EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); - EXPECT_FALSE(size.has_value()); + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); EXPECT_EQ(buffer.length(), 4); } @@ -46,11 +46,13 @@ TEST(UnframedTransportTest, DecodeFrameEnd) { TEST(UnframedTransportTest, EncodeFrame) { UnframedTransportImpl transport; + MessageMetadata metadata; + Buffer::OwnedImpl message; message.add("fake message"); Buffer::OwnedImpl buffer; - transport.encodeFrame(buffer, message); + transport.encodeFrame(buffer, metadata, message); EXPECT_EQ(0, message.length()); EXPECT_EQ("fake message", buffer.toString()); diff --git a/test/extensions/filters/network/thrift_proxy/utility.h b/test/extensions/filters/network/thrift_proxy/utility.h index 8058f9b337999..21585f704c0b5 100644 --- a/test/extensions/filters/network/thrift_proxy/utility.h +++ b/test/extensions/filters/network/thrift_proxy/utility.h @@ -7,6 +7,7 @@ #include "extensions/filters/network/thrift_proxy/protocol.h" +#include "gmock/gmock.h" #include "gtest/gtest.h" using testing::TestParamInfo; @@ -93,6 +94,76 @@ inline std::string fieldTypeParamToString(const TestParamInfo& params return fieldTypeToString(params.param); } +MATCHER(IsEmptyMetadata, "") { + if (arg.hasFrameSize()) { + *result_listener << "has a frame size of " << arg.frameSize(); + return false; + } + if (arg.hasProtocol()) { + *result_listener << "has a protocol of " << ProtocolNames::get().fromType(arg.protocol()); + return false; + } + if (arg.hasMethodName()) { + *result_listener << "has a method name of " << arg.methodName(); + return false; + } + if (arg.hasSequenceId()) { + *result_listener << "has a sequence id " << arg.sequenceId(); + return false; + } + if (arg.hasMessageType()) { + *result_listener << "has a message type of " << static_cast(arg.messageType()); + return false; + } + if (!arg.headers().empty()) { + *result_listener << "has " << arg.headers().size() << " headers"; + return false; + } + if (arg.hasAppException()) { + *result_listener << "has an app exception"; + return false; + } + return true; +} + +MATCHER_P(HasOnlyFrameSize, n, "") { + return arg.hasFrameSize() && arg.frameSize() == n && !arg.hasProtocol() && !arg.hasMethodName() && + !arg.hasSequenceId() && !arg.hasMessageType() && arg.headers().empty() && + !arg.hasAppException(); +} + +MATCHER_P(HasFrameSize, n, "") { + if (!arg.hasFrameSize()) { + *result_listener << "has no frame size"; + return false; + } + *result_listener << "has frame size = " << arg.frameSize(); + return arg.frameSize() == n; +} + +MATCHER_P(HasProtocol, p, "") { return arg.hasProtocol() && arg.protocol() == p; } +MATCHER_P(HasSequenceId, id, "") { return arg.hasSequenceId() && arg.sequenceId() == id; } +MATCHER(HasNoHeaders, "") { return arg.headers().empty(); } + +MATCHER_P2(HasAppException, t, m, "") { + if (!arg.hasAppException()) { + *result_listener << "has no exception"; + return false; + } + + if (arg.appExceptionType() != t) { + *result_listener << "has exception with type " << static_cast(arg.appExceptionType()); + return false; + } + + if (std::string(m) != arg.appExceptionMessage()) { + *result_listener << "has exception with message " << arg.appExceptionMessage(); + return false; + } + + return true; +} + } // namespace } // namespace ThriftProxy } // namespace NetworkFilters