diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 25c6dfda0dda7..e2d1789976b63 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -105,6 +105,7 @@ envoy_cc_library( external_deps = ["abseil_optional"], deps = [ ":thrift_lib", + "//include/envoy/buffer:buffer_interface", "//source/common/common:macros", ], ) @@ -128,14 +129,18 @@ envoy_cc_library( ], external_deps = ["abseil_optional"], deps = [ + ":conn_state_lib", ":decoder_events_lib", ":metadata_lib", ":thrift_lib", + ":thrift_object_interface", + ":transport_interface", "//include/envoy/buffer:buffer_interface", "//include/envoy/registry", "//source/common/common:assert_lib", "//source/common/config:utility_lib", "//source/common/singleton:const_singleton", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) @@ -221,6 +226,14 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "conn_state_lib", + hdrs = ["conn_state.h"], + deps = [ + "//include/envoy/tcp:conn_pool_interface", + ], +) + envoy_cc_library( name = "thrift_lib", hdrs = ["thrift.h"], @@ -230,6 +243,27 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "thrift_object_interface", + hdrs = ["thrift_object.h"], + deps = [ + "//include/envoy/buffer:buffer_interface", + ], +) + +envoy_cc_library( + name = "thrift_object_lib", + srcs = ["thrift_object_impl.cc"], + hdrs = ["thrift_object_impl.h"], + deps = [ + ":decoder_lib", + ":thrift_lib", + ":thrift_object_interface", + ":unframed_transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + envoy_cc_library( name = "auto_transport_lib", srcs = [ diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index bba8a035a25bf..359cf7f77ed5c 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -142,10 +142,6 @@ void ConfigImpl::createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& c } } -DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { - return std::make_unique(createTransport(), createProtocol(), callbacks); -} - TransportPtr ConfigImpl::createTransport() { return NamedTransportConfigFactory::getFactory(transport_).createTransport(); } diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 0b32e7964c4ea..71c2d1c580c91 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -76,13 +76,11 @@ class ConfigImpl : public Config, // Config ThriftFilterStats& stats() override { return stats_; } ThriftFilters::FilterChainFactory& filterFactory() override { return *this; } - DecoderPtr createDecoder(DecoderCallbacks& callbacks) override; + TransportPtr createTransport() override; + ProtocolPtr createProtocol() override; Router::Config& routerConfig() override { return *this; } private: - TransportPtr createTransport(); - ProtocolPtr createProtocol(); - Server::Configuration::FactoryContext& context_; const std::string stats_prefix_; ThriftFilterStats stats_; diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index 6b83435351b9d..de34169fc385f 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -16,7 +16,9 @@ namespace NetworkFilters { namespace ThriftProxy { ConnectionManager::ConnectionManager(Config& config) - : config_(config), stats_(config_.stats()), decoder_(config_.createDecoder(*this)) {} + : config_(config), stats_(config_.stats()), transport_(config.createTransport()), + protocol_(config.createProtocol()), + decoder_(std::make_unique(*transport_, *protocol_, *this)) {} ConnectionManager::~ConnectionManager() {} @@ -67,22 +69,14 @@ void ConnectionManager::dispatch() { } void ConnectionManager::sendLocalReply(MessageMetadata& metadata, const DirectResponse& response) { - // Use the factory to get the concrete protocol from the decoder protocol (as opposed to - // potentially pre-detection auto protocol). - ProtocolType proto_type = decoder_->protocolType(); - ProtocolPtr proto = NamedProtocolConfigFactory::getFactory(proto_type).createProtocol(); Buffer::OwnedImpl buffer; - response.encode(metadata, *proto, buffer); - - // Same logic as protocol above. - TransportPtr transport = - NamedTransportConfigFactory::getFactory(decoder_->transportType()).createTransport(); + response.encode(metadata, *protocol_, buffer); Buffer::OwnedImpl response_buffer; - metadata.setProtocol(proto_type); - transport->encodeFrame(response_buffer, metadata, buffer); + metadata.setProtocol(protocol_->type()); + transport_->encodeFrame(response_buffer, metadata, buffer); read_callbacks_->connection().write(response_buffer, false); } @@ -230,12 +224,30 @@ FilterStatus ConnectionManager::ActiveRpc::transportEnd() { break; } - return decoder_filter_->transportEnd(); + FilterStatus status = event_handler_->transportEnd(); + + if (metadata_->isProtocolUpgradeMessage()) { + ENVOY_CONN_LOG(error, "thrift: sending protocol upgrade response", + parent_.read_callbacks_->connection()); + sendLocalReply(*parent_.protocol_->upgradeResponse(*upgrade_handler_)); + } + + return status; } FilterStatus ConnectionManager::ActiveRpc::messageBegin(MessageMetadataSharedPtr metadata) { metadata_ = metadata; + if (metadata_->isProtocolUpgradeMessage()) { + ASSERT(parent_.protocol_->supportsUpgrade()); + + ENVOY_CONN_LOG(error, "thrift: decoding protocol upgrade request", + parent_.read_callbacks_->connection()); + upgrade_handler_ = parent_.protocol_->upgradeRequestDecoder(); + ASSERT(upgrade_handler_ != nullptr); + event_handler_ = upgrade_handler_.get(); + } + return event_handler_->messageBegin(metadata); } @@ -282,11 +294,10 @@ void ConnectionManager::ActiveRpc::sendLocalReply(const DirectResponse& response parent_.doDeferredRpcDestroy(*this); } -void ConnectionManager::ActiveRpc::startUpstreamResponse(TransportType transport_type, - ProtocolType protocol_type) { +void ConnectionManager::ActiveRpc::startUpstreamResponse(Transport& transport, Protocol& protocol) { ASSERT(response_decoder_ == nullptr); - response_decoder_ = std::make_unique(*this, transport_type, protocol_type); + response_decoder_ = std::make_unique(*this, transport, protocol); } bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) { diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index 5d4f4f9cf1151..6432a51377960 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -31,7 +31,8 @@ class Config { virtual ThriftFilters::FilterChainFactory& filterFactory() PURE; virtual ThriftFilterStats& stats() PURE; - virtual DecoderPtr createDecoder(DecoderCallbacks& callbacks) PURE; + virtual TransportPtr createTransport() PURE; + virtual ProtocolPtr createProtocol() PURE; virtual Router::Config& routerConfig() PURE; }; @@ -74,18 +75,10 @@ class ConnectionManager : public Network::ReadFilter, struct ActiveRpc; struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { - ResponseDecoder(ActiveRpc& parent, TransportType transport_type, ProtocolType protocol_type) - : parent_(parent), - decoder_(std::make_unique( - NamedTransportConfigFactory::getFactory(transport_type).createTransport(), - NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol(), *this)), + ResponseDecoder(ActiveRpc& parent, Transport& transport, Protocol& protocol) + : parent_(parent), decoder_(std::make_unique(transport, protocol, *this)), complete_(false), first_reply_field_(false) { - // Use the factory to get the concrete protocol from the decoder protocol (as opposed to - // potentially pre-detection auto protocol). - initProtocolConverter( - NamedProtocolConfigFactory::getFactory(parent_.parent_.decoder_->protocolType()) - .createProtocol(), - parent_.response_buffer_); + initProtocolConverter(*parent_.parent_.protocol_, parent_.response_buffer_); } bool onData(Buffer::Instance& data); @@ -149,7 +142,7 @@ class ConnectionManager : public Network::ReadFilter, return parent_.decoder_->protocolType(); } void sendLocalReply(const DirectResponse& response) override; - void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; + void startUpstreamResponse(Transport& transport, Protocol& protocol) override; bool upstreamData(Buffer::Instance& buffer) override; void resetDownstreamConnection() override; @@ -170,6 +163,7 @@ class ConnectionManager : public Network::ReadFilter, uint64_t stream_id_; MessageMetadataSharedPtr metadata_; ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + DecoderEventHandlerSharedPtr upgrade_handler_; ResponseDecoderPtr response_decoder_; absl::optional cached_route_; Buffer::OwnedImpl response_buffer_; @@ -188,6 +182,8 @@ class ConnectionManager : public Network::ReadFilter, Network::ReadFilterCallbacks* read_callbacks_{}; + TransportPtr transport_; + ProtocolPtr protocol_; DecoderPtr decoder_; std::list rpcs_; Buffer::OwnedImpl request_buffer_; diff --git a/source/extensions/filters/network/thrift_proxy/conn_state.h b/source/extensions/filters/network/thrift_proxy/conn_state.h new file mode 100644 index 0000000000000..f5db30d458461 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_state.h @@ -0,0 +1,48 @@ +#pragma once + +#include "envoy/tcp/conn_pool.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ThriftConnectionState tracks thrift-related connection state for pooled connections. + */ +class ThriftConnectionState : public Tcp::ConnectionPool::ConnectionState { +public: + /** + * @return true if this upgrade has been attempted on this connection. + */ + bool upgradeAttempted() const { return upgrade_attempted_; } + /** + * @return true if this connection has been upgraded + */ + bool isUpgraded() const { return upgraded_; } + + /** + * Marks the connection as successfully upgraded. + */ + void markUpgraded() { + upgrade_attempted_ = true; + upgraded_ = true; + } + + /** + * Marks the connection as not upgraded. + */ + void markUpgradeFailed() { + upgrade_attempted_ = true; + upgraded_ = false; + } + +private: + bool upgrade_attempted_{false}; + bool upgraded_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index c6d4b470ea890..bc56bffea0b00 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -334,6 +334,9 @@ ProtocolState DecoderStateMachine::popReturnState() { ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { while (state_ != ProtocolState::Done) { + ENVOY_LOG(trace, "thrift: state {}, {} bytes available", ProtocolStateNameValues::name(state_), + buffer.length()); + DecoderStatus s = handleState(buffer); if (s.next_state_ == ProtocolState::WaitForData) { return ProtocolState::WaitForData; @@ -350,8 +353,8 @@ ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { return state_; } -Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks) - : transport_(std::move(transport)), protocol_(std::move(protocol)), callbacks_(callbacks) {} +Decoder::Decoder(Transport& transport, Protocol& protocol, DecoderCallbacks& callbacks) + : transport_(transport), protocol_(protocol), callbacks_(callbacks) {} void Decoder::complete() { request_.reset(); @@ -377,22 +380,22 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { metadata_ = std::make_shared(); } - if (!transport_->decodeFrameStart(data, *metadata_)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); + if (!transport_.decodeFrameStart(data, *metadata_)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_.name()); buffer_underflow = true; return FilterStatus::Continue; } - ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); + ENVOY_LOG(debug, "thrift: {} transport started", transport_.name()); if (metadata_->hasProtocol()) { - if (protocol_->type() == ProtocolType::Auto) { - protocol_->setType(metadata_->protocol()); - ENVOY_LOG(debug, "thrift: {} transport forced {} protocol", transport_->name(), - protocol_->name()); - } else if (metadata_->protocol() != protocol_->type()) { + if (protocol_.type() == ProtocolType::Auto) { + protocol_.setType(metadata_->protocol()); + ENVOY_LOG(debug, "thrift: {} transport forced {} protocol", transport_.name(), + protocol_.name()); + } else if (metadata_->protocol() != protocol_.type()) { throw EnvoyException(fmt::format("transport reports protocol {}, but configured for {}", ProtocolNames::get().fromType(metadata_->protocol()), - ProtocolNames::get().fromType(protocol_->type()))); + ProtocolNames::get().fromType(protocol_.type()))); } } if (metadata_->hasAppException()) { @@ -406,7 +409,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { request_ = std::make_unique(callbacks_.newDecoderEventHandler()); frame_started_ = true; state_machine_ = - std::make_unique(*protocol_, metadata_, request_->handler_); + std::make_unique(protocol_, metadata_, request_->handler_); if (request_->handler_.transportBegin(metadata_) == FilterStatus::StopIteration) { return FilterStatus::StopIteration; @@ -415,7 +418,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ASSERT(state_machine_ != nullptr); - ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), + ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_.name(), ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); ProtocolState rv = state_machine_->run(data); @@ -431,8 +434,8 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ASSERT(rv == ProtocolState::Done); // Message complete, decode end of frame. - if (!transport_->decodeFrameEnd(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); + if (!transport_.decodeFrameEnd(data)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_.name()); buffer_underflow = true; return FilterStatus::Continue; } @@ -440,7 +443,7 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { frame_ended_ = true; metadata_.reset(); - ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + ENVOY_LOG(debug, "thrift: {} transport ended", transport_.name()); if (request_->handler_.transportEnd() == FilterStatus::StopIteration) { return FilterStatus::StopIteration; } diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index cda5d75b4a8b7..e2886aedebd60 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -2,7 +2,6 @@ #include "envoy/buffer/buffer.h" -#include "common/buffer/buffer_impl.h" #include "common/common/assert.h" #include "common/common/logger.h" @@ -61,7 +60,7 @@ class ProtocolStateNameValues { * DecoderStateMachine is the Thrift message state machine as described in * source/extensions/filters/network/thrift_proxy/docs. */ -class DecoderStateMachine { +class DecoderStateMachine : public Logger::Loggable { public: DecoderStateMachine(Protocol& proto, MessageMetadataSharedPtr& metadata, DecoderEventHandler& handler) @@ -183,15 +182,15 @@ class DecoderCallbacks { }; /** - * Decoder encapsulates a configured TransportPtr and ProtocolPtr. + * Decoder encapsulates a configured Transport and Protocol and provides the ability to decode + * Thrift messages. */ class Decoder : public Logger::Loggable { public: - Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks); - Decoder(TransportType transport_type, ProtocolType protocol_type, DecoderCallbacks& callbacks); + Decoder(Transport& transport, Protocol& protocol, DecoderCallbacks& callbacks); /** - * Drains data from the given buffer while executing a DecoderStateMachine over the data. + * Drains data from the given buffer while executing a state machine over the data. * * @param data a Buffer containing Thrift protocol data * @param buffer_underflow bool set to true if more data is required to continue decoding @@ -201,8 +200,8 @@ class Decoder : public Logger::Loggable { */ FilterStatus onData(Buffer::Instance& data, bool& buffer_underflow); - TransportType transportType() { return transport_->type(); } - ProtocolType protocolType() { return protocol_->type(); } + TransportType transportType() { return transport_.type(); } + ProtocolType protocolType() { return protocol_.type(); } private: struct ActiveRequest { @@ -214,8 +213,8 @@ class Decoder : public Logger::Loggable { void complete(); - TransportPtr transport_; - ProtocolPtr protocol_; + Transport& transport_; + Protocol& protocol_; DecoderCallbacks& callbacks_; ActiveRequestPtr request_; MessageMetadataSharedPtr metadata_; diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index 304db348d9e29..4183514455b91 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -68,10 +68,10 @@ class DecoderFilterCallbacks { /** * Indicates the start of an upstream response. May only be called once. - * @param transport_type TransportType the upstream is using - * @param protocol_type ProtocolType the upstream is using + * @param transport the transport used by the upstream response + * @param protocol the protocol used by the upstream response */ - virtual void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) PURE; + virtual void startUpstreamResponse(Transport& transport, Protocol& protocol) PURE; /** * Called with upstream response data. diff --git a/source/extensions/filters/network/thrift_proxy/metadata.h b/source/extensions/filters/network/thrift_proxy/metadata.h index e9498d88eef43..bb659e5afdc47 100644 --- a/source/extensions/filters/network/thrift_proxy/metadata.h +++ b/source/extensions/filters/network/thrift_proxy/metadata.h @@ -4,8 +4,11 @@ #include #include +#include #include +#include "envoy/buffer/buffer.h" + #include "common/common/macros.h" #include "common/http/header_map_impl.h" @@ -62,6 +65,11 @@ class MessageMetadata { AppExceptionType appExceptionType() const { return app_ex_type_.value(); } const std::string& appExceptionMessage() const { return app_ex_msg_.value(); } + bool isProtocolUpgradeMessage() const { return protocol_upgrade_message_; } + void setProtocolUpgradeMessage(bool upgrade_message) { + protocol_upgrade_message_ = upgrade_message; + } + private: absl::optional frame_size_{}; absl::optional proto_{}; @@ -71,6 +79,7 @@ class MessageMetadata { Http::HeaderMapImpl headers_; absl::optional app_ex_type_; absl::optional app_ex_msg_; + bool protocol_upgrade_message_{false}; }; typedef std::shared_ptr MessageMetadataSharedPtr; diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 0e4066021ce9b..316ca0e79df3a 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -11,8 +11,12 @@ #include "common/config/utility.h" #include "common/singleton/const_singleton.h" +#include "extensions/filters/network/thrift_proxy/conn_state.h" +#include "extensions/filters/network/thrift_proxy/decoder_events.h" #include "extensions/filters/network/thrift_proxy/metadata.h" #include "extensions/filters/network/thrift_proxy/thrift.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" +#include "extensions/filters/network/thrift_proxy/transport.h" #include "absl/strings/string_view.h" @@ -21,6 +25,9 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +class DirectResponse; +typedef std::unique_ptr DirectResponsePtr; + /** * Protocol represents the operations necessary to implement the a generic Thrift protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-protocol-spec.md @@ -394,10 +401,86 @@ class Protocol { * @param value std::string to write */ virtual void writeBinary(Buffer::Instance& buffer, const std::string& value) PURE; + + /** + * Indicates whether a protocol uses start-of-connection messages to negotiate protocol options. + * If this method returns true, the Protocol must invoke setProtocolUpgradeMessage during + * readMessageBegin if it detects an upgrade request. + * + * @return true for protocols that exchange messages at the start of a connection to negotiate + * protocol upgrade (or options) + */ + virtual bool supportsUpgrade() { return false; } + + /** + * Creates an opaque DecoderEventHandlerSharedPtr that can decode a downstream client's upgrade + * request. When the request is complete, the decoder is passed back to writeUpgradeResponse + * to allow the Protocol to update its internal state and generate a response to the request. + * + * @return a DecoderEventHandlerSharedPtr that decodes a downstream client's upgrade request + */ + virtual DecoderEventHandlerSharedPtr upgradeRequestDecoder() { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + + /** + * Writes a response to a downstream client's upgrade request. + * @param decoder DecoderEventHandlerSharedPtr created by upgradeRequestDecoder + * @return DirectResponsePtr containing an upgrade response + */ + virtual DirectResponsePtr upgradeResponse(const DecoderEventHandler& decoder) { + UNREFERENCED_PARAMETER(decoder); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + + /** + * Checks whether a given upstream connection can be upgraded and generates an upgrade request + * message. If this method returns a ThriftObject it will be used to decode the upstream's next + * response. + * + * @param transport the Transport to use for decoding the response + * @param state ThriftConnectionState tracking whether upgrade has already been performed + * @param buffer Buffer::Instance to modify with an upgrade request + * @return a ThriftObject capable of decoding an upgrade response or nullptr if upgrade was + * already completed (successfully or not) + */ + virtual ThriftObjectPtr attemptUpgrade(Transport& transport, ThriftConnectionState& state, + Buffer::Instance& buffer) { + UNREFERENCED_PARAMETER(transport); + UNREFERENCED_PARAMETER(state); + UNREFERENCED_PARAMETER(buffer); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + + /** + * Completes an upgrade previously started via attemptUpgrade. + * @param response ThriftObject created by attemptUpgrade, after the response has completed + * decoding + */ + virtual void completeUpgrade(ThriftConnectionState& state, ThriftObject& response) { + UNREFERENCED_PARAMETER(state); + UNREFERENCED_PARAMETER(response); + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } }; typedef std::unique_ptr ProtocolPtr; +/** + * A DirectResponse manipulates a Protocol to directly create a Thrift response message. + */ +class DirectResponse { +public: + virtual ~DirectResponse() {} + + /** + * Encodes the response via the given Protocol. + * @param metadata the MessageMetadata for the request that generated this response + * @param proto the Protocol to be used for message encoding + * @param buffer the Buffer into which the message should be encoded + */ + virtual void encode(MessageMetadata& metadata, Protocol& proto, + Buffer::Instance& buffer) const PURE; +}; + /** * Implemented by each Thrift protocol and registered via Registry::registerFactory or the * convenience class RegisterFactory. @@ -444,25 +527,6 @@ template class ProtocolFactoryBase : public NamedProtocolCo const std::string name_; }; -/** - * A DirectResponse manipulates a Protocol to directly create a Thrift response message. - */ -class DirectResponse { -public: - virtual ~DirectResponse() {} - - /** - * Encodes the response via the given Protocol. - * @param metadata the MessageMetadata for the request that generated this response - * @param proto the Protocol to be used for message encoding - * @param buffer the Buffer into which the message should be encoded - */ - virtual void encode(MessageMetadata& metadata, Protocol& proto, - Buffer::Instance& buffer) const PURE; -}; - -typedef std::unique_ptr DirectResponsePtr; - } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h index 8a3fde7894505..22a868f1a40af 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_converter.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -19,8 +19,8 @@ class ProtocolConverter : public virtual DecoderEventHandler { ProtocolConverter() {} virtual ~ProtocolConverter() {} - void initProtocolConverter(ProtocolPtr&& proto, Buffer::Instance& buffer) { - proto_ = std::move(proto); + void initProtocolConverter(Protocol& proto, Buffer::Instance& buffer) { + proto_ = &proto; buffer_ = &buffer; } @@ -125,7 +125,7 @@ class ProtocolConverter : public virtual DecoderEventHandler { ProtocolType protocolType() const { return proto_->type(); } private: - ProtocolPtr proto_; + Protocol* proto_; Buffer::Instance* buffer_{}; }; diff --git a/source/extensions/filters/network/thrift_proxy/router/BUILD b/source/extensions/filters/network/thrift_proxy/router/BUILD index e5051335b04e9..ba6cf07cc045b 100644 --- a/source/extensions/filters/network/thrift_proxy/router/BUILD +++ b/source/extensions/filters/network/thrift_proxy/router/BUILD @@ -26,7 +26,9 @@ envoy_cc_library( name = "router_interface", hdrs = ["router.h"], external_deps = ["abseil_optional"], - deps = [], + deps = [ + "//source/extensions/filters/network/thrift_proxy:metadata_lib", + ], ) envoy_cc_library( @@ -46,8 +48,9 @@ envoy_cc_library( "//source/extensions/filters/network/thrift_proxy:app_exception_lib", "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", "//source/extensions/filters/network/thrift_proxy:protocol_converter_lib", - "//source/extensions/filters/network/thrift_proxy:protocol_lib", - "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_interface", + "//source/extensions/filters/network/thrift_proxy:thrift_object_interface", + "//source/extensions/filters/network/thrift_proxy:transport_interface", "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", "@envoy_api//envoy/config/filter/network/thrift_proxy/v2alpha1:thrift_proxy_cc", ], diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h index 7f995e5e25a1d..b456a24ae7da1 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router.h +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -3,6 +3,8 @@ #include #include +#include "extensions/filters/network/thrift_proxy/metadata.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc index 57e8be9f6caf1..a854b7432aa87 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -213,13 +213,12 @@ FilterStatus Router::messageBegin(MessageMetadataSharedPtr metadata) { FilterStatus Router::messageEnd() { ProtocolConverter::messageEnd(); - TransportPtr transport = - NamedTransportConfigFactory::getFactory(upstream_request_->transport_type_).createTransport(); Buffer::OwnedImpl transport_buffer; - upstream_request_->metadata_->setProtocol(upstream_request_->protocol_type_); + upstream_request_->metadata_->setProtocol(upstream_request_->protocol_->type()); - transport->encodeFrame(transport_buffer, *upstream_request_->metadata_, upstream_request_buffer_); + upstream_request_->transport_->encodeFrame(transport_buffer, *upstream_request_->metadata_, + upstream_request_buffer_); upstream_request_->conn_data_->connection().write(transport_buffer, false); upstream_request_->onRequestComplete(); return FilterStatus::Continue; @@ -228,16 +227,32 @@ FilterStatus Router::messageEnd() { void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { ASSERT(!upstream_request_->response_complete_); - if (!upstream_request_->response_started_) { - callbacks_->startUpstreamResponse(upstream_request_->transport_type_, - upstream_request_->protocol_type_); - upstream_request_->response_started_ = true; - } + if (upstream_request_->upgrade_response_ != nullptr) { + // Handle upgrade response. + if (!upstream_request_->upgrade_response_->onData(data)) { + // Wait for more data. + return; + } - if (callbacks_->upstreamData(data)) { - upstream_request_->onResponseComplete(); - cleanup(); - return; + upstream_request_->protocol_->completeUpgrade( + *upstream_request_->conn_data_->connectionStateTyped(), + *upstream_request_->upgrade_response_); + + upstream_request_->upgrade_response_.reset(); + upstream_request_->onRequestStart(true); + } else { + // Handle normal response. + if (!upstream_request_->response_started_) { + callbacks_->startUpstreamResponse(*upstream_request_->transport_, + *upstream_request_->protocol_); + upstream_request_->response_started_ = true; + } + + if (callbacks_->upstreamData(data)) { + upstream_request_->onResponseComplete(); + cleanup(); + return; + } } if (end_stream) { @@ -284,9 +299,10 @@ void Router::cleanup() { upstream_request_.reset(); } Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, MessageMetadataSharedPtr& metadata, TransportType transport_type, ProtocolType protocol_type) - : parent_(parent), conn_pool_(pool), metadata_(metadata), transport_type_(transport_type), - protocol_type_(protocol_type), request_complete_(false), response_started_(false), - response_complete_(false) {} + : parent_(parent), conn_pool_(pool), metadata_(metadata), + transport_(NamedTransportConfigFactory::getFactory(transport_type).createTransport()), + protocol_(NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol()), + request_complete_(false), response_started_(false), response_complete_(false) {} Router::UpstreamRequest::~UpstreamRequest() {} @@ -298,6 +314,11 @@ FilterStatus Router::UpstreamRequest::start() { return FilterStatus::StopIteration; } + if (upgrade_response_ != nullptr) { + // Pause while we wait for an upgrade response. + return FilterStatus::StopIteration; + } + return FilterStatus::Continue; } @@ -329,12 +350,28 @@ void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr onUpstreamHostSelected(host); conn_data_ = std::move(conn_data); conn_data_->addUpstreamCallbacks(parent_); - conn_pool_handle_ = nullptr; - parent_.initProtocolConverter( - NamedProtocolConfigFactory::getFactory(protocol_type_).createProtocol(), - parent_.upstream_request_buffer_); + ThriftConnectionState* state = conn_data_->connectionStateTyped(); + if (state == nullptr) { + conn_data_->setConnectionState(std::make_unique()); + state = conn_data_->connectionStateTyped(); + } + + if (protocol_->supportsUpgrade()) { + upgrade_response_ = + protocol_->attemptUpgrade(*transport_, *state, parent_.upstream_request_buffer_); + if (upgrade_response_ != nullptr) { + conn_data_->connection().write(parent_.upstream_request_buffer_, false); + return; + } + } + + onRequestStart(continue_decoding); +} + +void Router::UpstreamRequest::onRequestStart(bool continue_decoding) { + parent_.initProtocolConverter(*protocol_, parent_.upstream_request_buffer_); // TODO(zuercher): need to use an upstream-connection-specific sequence id parent_.convertMessageBegin(metadata_); diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h index 224b6fb09997c..8d6033bdb410a 100644 --- a/source/extensions/filters/network/thrift_proxy/router/router_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -15,6 +15,7 @@ #include "extensions/filters/network/thrift_proxy/conn_manager.h" #include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/router/router.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" #include "absl/types/optional.h" @@ -135,6 +136,7 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, void onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn, Upstream::HostDescriptionConstSharedPtr host) override; + void onRequestStart(bool continue_decoding); void onRequestComplete(); void onResponseComplete(); void onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host); @@ -147,8 +149,9 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks, Tcp::ConnectionPool::Cancellable* conn_pool_handle_{}; Tcp::ConnectionPool::ConnectionDataPtr conn_data_; Upstream::HostDescriptionConstSharedPtr upstream_host_; - TransportType transport_type_; - ProtocolType protocol_type_; + TransportPtr transport_; + ProtocolPtr protocol_; + ThriftObjectPtr upgrade_response_; bool request_complete_ : 1; bool response_started_ : 1; diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object.h b/source/extensions/filters/network/thrift_proxy/thrift_object.h new file mode 100644 index 0000000000000..321e6496674cd --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object.h @@ -0,0 +1,247 @@ +#pragma once + +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" + +#include "extensions/filters/network/thrift_proxy/thrift.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class ThriftBase; + +/** + * ThriftValue is a field or container (list, set, or map) element. + */ +class ThriftValue { +public: + virtual ~ThriftValue() {} + + /** + * @return FieldType the type of this value + */ + virtual FieldType type() const PURE; + + /** + * @return const T& pointer to the value, provided that it can be cast to the given type + * @throw EnvoyException if the type T does not match the type + */ + template const T& getValueTyped() const { + // Use the Traits template to determine what FieldType the value must have to be cast to T + // and throw if the value's type doesn't match. + FieldType expected_field_type = Traits::getFieldType(); + if (expected_field_type != type()) { + throw EnvoyException(fmt::format("expected field type {}, got {}", + static_cast(expected_field_type), + static_cast(type()))); + } + + return *static_cast(getValue()); + } + +protected: + /** + * @return void* pointing to the underlying value, to be dynamically cast in getValueTyped + */ + virtual const void* getValue() const PURE; + +private: + /** + * Traits allows getValueTyped() to enforce that the field type is being cast to the desired type. + */ + template class Traits { + public: + // Compilation failures where T does not have a member getFieldType typically mean that + // getValueTyped was called with a type T that is not used to encode Thrift values. + // The specializations below encode the valid types for Thrift primitive types. + static FieldType getFieldType() { return T::getFieldType(); } + }; +}; + +// Explicit specializations of ThriftValue::Types for primitive types. +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Bool; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Byte; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I16; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I32; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::I64; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::Double; } +}; + +template <> class ThriftValue::Traits { +public: + static FieldType getFieldType() { return FieldType::String; } +}; + +typedef std::unique_ptr ThriftValuePtr; +typedef std::list ThriftValuePtrList; +typedef std::list> ThriftValuePtrPairList; + +/** + * ThriftField is a field within a ThriftStruct. + */ +class ThriftField { +public: + virtual ~ThriftField() {} + + /** + * @return FieldType this field's type + */ + virtual FieldType fieldType() const PURE; + + /** + * @return int16_t the field's identifier + */ + virtual int16_t fieldId() const PURE; + + /** + * @return const ThriftValue& containing the field's value + */ + virtual const ThriftValue& getValue() const PURE; +}; + +typedef std::unique_ptr ThriftFieldPtr; +typedef std::list ThriftFieldPtrList; + +/** + * ThriftListValue is an ordered list of ThriftValues. + */ +class ThriftListValue { +public: + virtual ~ThriftListValue() {} + + /** + * @return const ThriftValuePtrList& containing the ThriftValues that comprise the list + */ + virtual const ThriftValuePtrList& elements() const PURE; + + /** + * @return FieldType of the underlying elements + */ + virtual FieldType elementType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::List; } +}; + +/** + * ThriftSetValue is a set of ThriftValues, maintained in their original order. + */ +class ThriftSetValue { +public: + virtual ~ThriftSetValue() {} + + /** + * @return const ThriftValuePtrList& containing the ThriftValues that comprise the set + */ + virtual const ThriftValuePtrList& elements() const PURE; + + /** + * @return FieldType of the underlying elements + */ + virtual FieldType elementType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Set; } +}; + +/** + * ThriftMapValue is a map of pairs of ThriftValues, maintained in their original order. + */ +class ThriftMapValue { +public: + virtual ~ThriftMapValue() {} + + /** + * @return const ThriftValuePtrPairList& containing the ThriftValue key-value paris that comprise + * the map. + */ + virtual const ThriftValuePtrPairList& elements() const PURE; + + /** + * @return FieldType of the underlying keys + */ + virtual FieldType keyType() const PURE; + + /** + * @return FieldType of the underlying values + */ + virtual FieldType valueType() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Map; } +}; + +/** + * ThriftStructValue is a sequence of ThriftFields. + */ +class ThriftStructValue { +public: + virtual ~ThriftStructValue() {} + + /** + * @return const ThriftFieldPtrList& containing the ThriftFields that comprise the struct. + */ + virtual const ThriftFieldPtrList& fields() const PURE; + + /** + * Used by ThriftValue::Traits to enforce type safety. + */ + static FieldType getFieldType() { return FieldType::Struct; } +}; + +/** + * ThriftObject is a ThrfitStructValue that can be read from a Buffer::Instance. + */ +class ThriftObject : public ThriftStructValue { +public: + virtual ~ThriftObject() {} + + /* + * Consumes bytes from the buffer until a single complete Thrift struct has been consumed. + * @param buffer starting with a Thrift struct + * @return true when a single complete struct has been consumed; false if more data is needed to + * complete decoding + * @throw EnvoyException if the struct is invalid + */ + virtual bool onData(Buffer::Instance& buffer) PURE; +}; + +typedef std::unique_ptr ThriftObjectPtr; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc new file mode 100644 index 0000000000000..fb9271331f3a2 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.cc @@ -0,0 +1,394 @@ +#include "extensions/filters/network/thrift_proxy/thrift_object_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace { + +std::unique_ptr makeValue(ThriftBase* parent, FieldType type) { + switch (type) { + case FieldType::Stop: + NOT_REACHED_GCOVR_EXCL_LINE; + + case FieldType::List: + return std::make_unique(parent); + + case FieldType::Set: + return std::make_unique(parent); + + case FieldType::Map: + return std::make_unique(parent); + + case FieldType::Struct: + return std::make_unique(parent); + + default: + return std::make_unique(parent, type); + } +} + +} // namespace + +ThriftBase::ThriftBase(ThriftBase* parent) : parent_(parent) {} + +FilterStatus ThriftBase::structBegin(absl::string_view name) { + ASSERT(delegate_ != nullptr); + return delegate_->structBegin(name); +} + +FilterStatus ThriftBase::structEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->structEnd(); +} + +FilterStatus ThriftBase::fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) { + ASSERT(delegate_ != nullptr); + return delegate_->fieldBegin(name, field_type, field_id); +} + +FilterStatus ThriftBase::fieldEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->fieldEnd(); +} + +FilterStatus ThriftBase::boolValue(bool value) { + ASSERT(delegate_ != nullptr); + return delegate_->boolValue(value); +} + +FilterStatus ThriftBase::byteValue(uint8_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->byteValue(value); +} + +FilterStatus ThriftBase::int16Value(int16_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int16Value(value); +} + +FilterStatus ThriftBase::int32Value(int32_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int32Value(value); +} + +FilterStatus ThriftBase::int64Value(int64_t value) { + ASSERT(delegate_ != nullptr); + return delegate_->int64Value(value); +} + +FilterStatus ThriftBase::doubleValue(double value) { + ASSERT(delegate_ != nullptr); + return delegate_->doubleValue(value); +} + +FilterStatus ThriftBase::stringValue(absl::string_view value) { + ASSERT(delegate_ != nullptr); + return delegate_->stringValue(value); +} + +FilterStatus ThriftBase::mapBegin(FieldType key_type, FieldType value_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->mapBegin(key_type, value_type, size); +} + +FilterStatus ThriftBase::mapEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->mapEnd(); +} + +FilterStatus ThriftBase::listBegin(FieldType elem_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->listBegin(elem_type, size); +} + +FilterStatus ThriftBase::listEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->listEnd(); +} + +FilterStatus ThriftBase::setBegin(FieldType elem_type, uint32_t size) { + ASSERT(delegate_ != nullptr); + return delegate_->setBegin(elem_type, size); +} + +FilterStatus ThriftBase::setEnd() { + ASSERT(delegate_ != nullptr); + return delegate_->setEnd(); +} + +void ThriftBase::delegateComplete() { + ASSERT(delegate_ != nullptr); + delegate_ = nullptr; +} + +ThriftFieldImpl::ThriftFieldImpl(ThriftStructValueImpl* parent, absl::string_view name, + FieldType field_type, int16_t field_id) + : ThriftBase(parent), name_(name), field_type_(field_type), field_id_(field_id) { + auto value = makeValue(this, field_type_); + delegate_ = value.get(); + value_ = std::move(value); +} + +FilterStatus ThriftFieldImpl::fieldEnd() { + if (delegate_) { + return delegate_->fieldEnd(); + } + + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftListValueImpl::listBegin(FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->listBegin(elem_type, size); + } + + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftListValueImpl::listEnd() { + if (delegate_) { + return delegate_->listEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftListValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + auto elem = makeValue(this, elem_type_); + delegate_ = elem.get(); + elements_.push_back(std::move(elem)); + remaining_--; +} + +FilterStatus ThriftSetValueImpl::setBegin(FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->setBegin(elem_type, size); + } + + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftSetValueImpl::setEnd() { + if (delegate_) { + return delegate_->setEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftSetValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + auto elem = makeValue(this, elem_type_); + delegate_ = elem.get(); + elements_.push_back(std::move(elem)); + remaining_--; +} + +FilterStatus ThriftMapValueImpl::mapBegin(FieldType key_type, FieldType elem_type, uint32_t size) { + if (delegate_) { + return delegate_->mapBegin(key_type, elem_type, size); + } + + key_type_ = key_type; + elem_type_ = elem_type; + remaining_ = size; + + delegateComplete(); + + return FilterStatus::Continue; +} + +FilterStatus ThriftMapValueImpl::mapEnd() { + if (delegate_) { + return delegate_->mapEnd(); + } + + ASSERT(remaining_ == 0); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +void ThriftMapValueImpl::delegateComplete() { + delegate_ = nullptr; + + if (remaining_ == 0) { + return; + } + + // Prepare for first element's key. + if (elements_.empty()) { + auto key = makeValue(this, key_type_); + delegate_ = key.get(); + elements_.emplace_back(std::move(key), nullptr); + return; + } + + // Prepare for any elements's value. + auto& elem = elements_.back(); + if (elem.second == nullptr) { + auto value = makeValue(this, elem_type_); + delegate_ = value.get(); + elem.second = std::move(value); + + remaining_--; + return; + } + + // Key-value pair completed, prepare for next key. + auto key = makeValue(this, key_type_); + delegate_ = key.get(); + elements_.emplace_back(std::move(key), nullptr); +} + +FilterStatus ThriftValueImpl::boolValue(bool value) { + ASSERT(value_type_ == FieldType::Bool); + bool_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::byteValue(uint8_t value) { + ASSERT(value_type_ == FieldType::Byte); + byte_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int16Value(int16_t value) { + ASSERT(value_type_ == FieldType::I16); + int16_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int32Value(int32_t value) { + ASSERT(value_type_ == FieldType::I32); + int32_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::int64Value(int64_t value) { + ASSERT(value_type_ == FieldType::I64); + int64_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::doubleValue(double value) { + ASSERT(value_type_ == FieldType::Double); + double_value_ = value; + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +FilterStatus ThriftValueImpl::stringValue(absl::string_view value) { + ASSERT(value_type_ == FieldType::String); + string_value_ = std::string(value); + parent_->delegateComplete(); + return FilterStatus::Continue; +} + +const void* ThriftValueImpl::getValue() const { + switch (value_type_) { + case FieldType::Bool: + return &bool_value_; + case FieldType::Byte: + return &byte_value_; + case FieldType::I16: + return &int16_value_; + case FieldType::I32: + return &int32_value_; + case FieldType::I64: + return &int64_value_; + case FieldType::Double: + return &double_value_; + case FieldType::String: + return &string_value_; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +FilterStatus ThriftStructValueImpl::structBegin(absl::string_view name) { + if (delegate_) { + return delegate_->structBegin(name); + } + + return FilterStatus::Continue; +} + +FilterStatus ThriftStructValueImpl::structEnd() { + if (delegate_) { + return delegate_->structEnd(); + } + + if (parent_) { + parent_->delegateComplete(); + } + + return FilterStatus::Continue; +} + +FilterStatus ThriftStructValueImpl::fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) { + if (delegate_) { + return delegate_->fieldBegin(name, field_type, field_id); + } + + if (field_type != FieldType::Stop) { + auto field = std::make_unique(this, name, field_type, field_id); + delegate_ = field.get(); + fields_.emplace_back(std::move(field)); + } + + return FilterStatus::Continue; +} + +ThriftObjectImpl::ThriftObjectImpl(Transport& transport, Protocol& protocol) + : ThriftStructValueImpl(nullptr), + decoder_(std::make_unique(transport, protocol, *this)) {} + +bool ThriftObjectImpl::onData(Buffer::Instance& buffer) { + bool underflow = false; + auto result = decoder_->onData(buffer, underflow); + ASSERT(result == FilterStatus::Continue); + + if (complete_) { + decoder_.reset(); + } + return complete_; +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h new file mode 100644 index 0000000000000..b9057dfab2bc7 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h @@ -0,0 +1,262 @@ +#pragma once + +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/thrift_object.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ThriftBase is a base class for decoding Thrift objects. It implements methods from + * DecoderEventHandler to automatically delegate to an underlying ThriftBase so that, for example, + * the fieldBegin call for a struct field nested within a list is automatically delegated down the + * object hierarchy to the correct ThriftBase subclass. + */ +class ThriftBase : public DecoderEventHandler { +public: + ThriftBase(ThriftBase* parent); + ~ThriftBase() {} + + // DecoderEventHandler + FilterStatus transportBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + FilterStatus transportEnd() override { return FilterStatus::Continue; } + FilterStatus messageBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + FilterStatus messageEnd() override { return FilterStatus::Continue; } + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType field_type, int16_t field_id) override; + FilterStatus fieldEnd() override; + FilterStatus boolValue(bool value) override; + FilterStatus byteValue(uint8_t value) override; + FilterStatus int16Value(int16_t value) override; + FilterStatus int32Value(int32_t value) override; + FilterStatus int64Value(int64_t value) override; + FilterStatus doubleValue(double value) override; + FilterStatus stringValue(absl::string_view value) override; + FilterStatus mapBegin(FieldType key_type, FieldType value_type, uint32_t size) override; + FilterStatus mapEnd() override; + FilterStatus listBegin(FieldType elem_type, uint32_t size) override; + FilterStatus listEnd() override; + FilterStatus setBegin(FieldType elem_type, uint32_t size) override; + FilterStatus setEnd() override; + + // Invoked when the current delegate is complete. Completion implies that the delegate is fully + // specified (all list values processed, all struct fields processed, etc). + virtual void delegateComplete(); + +protected: + ThriftBase* parent_; + ThriftBase* delegate_{nullptr}; +}; + +/** + * ThriftValueBase is a base class for all struct field values, list values, set values, map keys, + * and map values. + */ +class ThriftValueBase : public ThriftValue, public ThriftBase { +public: + ThriftValueBase(ThriftBase* parent, FieldType value_type) + : ThriftBase(parent), value_type_(value_type) {} + ~ThriftValueBase() {} + + // ThriftValue + FieldType type() const override { return value_type_; } + +protected: + const FieldType value_type_; +}; + +class ThriftStructValueImpl; + +/** + * ThriftField represents a field in a thrift Struct. It always delegates DecoderEventHandler + * methods to a subclass of ThriftValueBase. + */ +class ThriftFieldImpl : public ThriftField, public ThriftBase { +public: + ThriftFieldImpl(ThriftStructValueImpl* parent, absl::string_view name, FieldType field_type, + int16_t field_id); + + // DecoderEventHandler + FilterStatus fieldEnd() override; + + // ThriftField + FieldType fieldType() const override { return field_type_; } + int16_t fieldId() const override { return field_id_; } + const ThriftValue& getValue() const override { return *value_; } + +private: + std::string name_; + FieldType field_type_; + int16_t field_id_; + ThriftValuePtr value_; +}; + +/** + * ThriftStructValueImpl implements ThriftStruct. + */ +class ThriftStructValueImpl : public ThriftStructValue, public ThriftValueBase { +public: + ThriftStructValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Struct) {} + + // DecoderEventHandler + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType field_type, int16_t field_id) override; + + // ThriftStructValue + const ThriftFieldPtrList& fields() const override { return fields_; } + +private: + // ThriftValue + const void* getValue() const override { return this; }; + + ThriftFieldPtrList fields_; +}; + +/** + * ThriftListValueImpl represents Thrift lists. + */ +class ThriftListValueImpl : public ThriftListValue, public ThriftValueBase { +public: + ThriftListValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::List) {} + + // DecoderEventHandler + FilterStatus listBegin(FieldType elem_type, uint32_t size) override; + FilterStatus listEnd() override; + + // ThriftListValue + const ThriftValuePtrList& elements() const override { return elements_; } + FieldType elementType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrList elements_; +}; + +/** + * ThriftSetValueImpl represents Thrift sets. + */ +class ThriftSetValueImpl : public ThriftSetValue, public ThriftValueBase { +public: + ThriftSetValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Set) {} + + // DecoderEventHandler + FilterStatus setBegin(FieldType elem_type, uint32_t size) override; + FilterStatus setEnd() override; + + // ThriftSetValue + const ThriftValuePtrList& elements() const override { return elements_; } + FieldType elementType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrList elements_; // maintain original order +}; + +/** + * ThriftMapValueImpl represents Thrift maps. + */ +class ThriftMapValueImpl : public ThriftMapValue, public ThriftValueBase { +public: + ThriftMapValueImpl(ThriftBase* parent) : ThriftValueBase(parent, FieldType::Map) {} + + // DecoderEventHandler + FilterStatus mapBegin(FieldType key_type, FieldType elem_type, uint32_t size) override; + FilterStatus mapEnd() override; + + // ThriftMapValue + const ThriftValuePtrPairList& elements() const override { return elements_; } + FieldType keyType() const override { return key_type_; } + FieldType valueType() const override { return elem_type_; } + + void delegateComplete() override; + +protected: + // ThriftValue + const void* getValue() const override { return this; }; + + FieldType key_type_{FieldType::Stop}; + FieldType elem_type_{FieldType::Stop}; + uint32_t remaining_{0}; + ThriftValuePtrPairList elements_; // maintain original order +}; + +/** + * ThriftValueImpl represents primitive Thrift types, including strings. + */ +class ThriftValueImpl : public ThriftValueBase { +public: + ThriftValueImpl(ThriftBase* parent, FieldType value_type) : ThriftValueBase(parent, value_type) {} + + // DecoderEventHandler + FilterStatus boolValue(bool value) override; + FilterStatus byteValue(uint8_t value) override; + FilterStatus int16Value(int16_t value) override; + FilterStatus int32Value(int32_t value) override; + FilterStatus int64Value(int64_t value) override; + FilterStatus doubleValue(double value) override; + FilterStatus stringValue(absl::string_view value) override; + +protected: + // ThriftValue + const void* getValue() const override; + +private: + union { + bool bool_value_; + uint8_t byte_value_; + int16_t int16_value_; + int32_t int32_value_; + int64_t int64_value_; + double double_value_; + }; + std::string string_value_; +}; + +/** + * ThriftObjectImpl is a generic representation of a Thrift struct. + */ +class ThriftObjectImpl : public ThriftObject, + public ThriftStructValueImpl, + public DecoderCallbacks { +public: + ThriftObjectImpl(Transport& transport, Protocol& protocol); + + // DecoderCallbacks + DecoderEventHandler& newDecoderEventHandler() override { return *this; } + FilterStatus transportEnd() override { + complete_ = true; + return FilterStatus::Continue; + } + + // ThriftObject + bool onData(Buffer::Instance& buffer) override; + + // ThriftStruct + const ThriftFieldPtrList& fields() const override { return ThriftStructValueImpl::fields(); } + +private: + DecoderPtr decoder_; + bool complete_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index 1cc27974ea7b2..7a9d3c4b23ff3 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -245,6 +245,19 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "thrift_object_impl_test", + srcs = ["thrift_object_impl_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:thrift_object_lib", + "//test/test_common:printers_lib", + "//test/test_common:registry_lib", + ], +) + envoy_extension_cc_test( name = "integration_test", srcs = ["integration_test.cc"], diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 43f3352149723..338ce458123df 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -2,9 +2,12 @@ #include "common/buffer/buffer_impl.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/buffer_helper.h" #include "extensions/filters/network/thrift_proxy/config.h" #include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/header_transport_impl.h" #include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" @@ -17,8 +20,10 @@ #include "gtest/gtest.h" using testing::_; +using testing::AnyNumber; using testing::Invoke; using testing::NiceMock; +using testing::Ref; using testing::Return; using testing::ReturnRef; @@ -39,10 +44,23 @@ class TestConfigImpl : public ConfigImpl { void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override { callbacks.addDecoderFilter(decoder_filter_); } + TransportPtr createTransport() override { + if (transport_) { + return TransportPtr{transport_}; + } + return ConfigImpl::createTransport(); + } + ProtocolPtr createProtocol() override { + if (protocol_) { + return ProtocolPtr{protocol_}; + } + return ConfigImpl::createProtocol(); + } -private: ThriftFilters::DecoderFilterSharedPtr decoder_filter_; ThriftFilterStats& stats_; + MockTransport* transport_{}; + MockProtocol* protocol_{}; }; class ThriftConnectionManagerTest : public testing::Test { @@ -72,7 +90,14 @@ class ThriftConnectionManagerTest : public testing::Test { proto_config_.set_stat_prefix("test"); decoder_filter_.reset(new NiceMock()); + config_.reset(new TestConfigImpl(proto_config_, context_, decoder_filter_, stats_)); + if (custom_transport_) { + config_->transport_ = custom_transport_; + } + if (custom_protocol_) { + config_->protocol_ = custom_protocol_; + } filter_.reset(new ConnectionManager(*config_)); filter_->initializeReadFilterCallbacks(filter_callbacks_); @@ -267,6 +292,9 @@ class ThriftConnectionManagerTest : public testing::Test { Buffer::OwnedImpl write_buffer_; std::unique_ptr filter_; NiceMock filter_callbacks_; + + MockTransport* custom_transport_{}; + MockProtocol* custom_protocol_{}; }; TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCall) { @@ -602,7 +630,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponse) { writeComplexFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -634,7 +664,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndExceptionResponse) { writeFramedBinaryTApplicationException(write_buffer_, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -667,7 +699,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndErrorResponse) { writeFramedBinaryIDLException(write_buffer_, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -700,7 +734,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndInvalidResponse) { // Call is not valid in a response writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -739,7 +775,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { 0x08, 0xff, 0xff // illegal field id }); - callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_, write(_, false)); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); @@ -781,7 +819,9 @@ TEST_F(ThriftConnectionManagerTest, RequestAndTransportApplicationException) { 0x01, 0x02, 0x00, 0x00, // transforms: 1, 2; padding }); - callbacks->startUpstreamResponse(TransportType::Header, ProtocolType::Binary); + HeaderTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); @@ -817,15 +857,18 @@ TEST_F(ThriftConnectionManagerTest, PipelinedRequestAndResponse) { EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(2); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x01); - callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + callbacks.front()->startUpstreamResponse(transport, proto); EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); callbacks.pop_front(); EXPECT_EQ(1U, store_.counter("test.response").value()); EXPECT_EQ(1U, store_.counter("test.response_reply").value()); writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x02); - callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + callbacks.front()->startUpstreamResponse(transport, proto); EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); callbacks.pop_front(); EXPECT_EQ(2U, store_.counter("test.response").value()); @@ -857,6 +900,68 @@ TEST_F(ThriftConnectionManagerTest, ResetDownstreamConnection) { EXPECT_EQ(0U, store_.gauge("test.request_active").value()); } +TEST_F(ThriftConnectionManagerTest, DownstreamProtocolUpgrade) { + custom_transport_ = new NiceMock(); + custom_protocol_ = new NiceMock(); + initializeFilter(); + + EXPECT_CALL(*custom_transport_, decodeFrameStart(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*custom_protocol_, readMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMessageType(MessageType::Call); + metadata.setProtocolUpgradeMessage(true); + return true; + })); + EXPECT_CALL(*custom_protocol_, supportsUpgrade()).Times(AnyNumber()).WillRepeatedly(Return(true)); + + MockDecoderEventHandler* upgrade_decoder = new NiceMock(); + EXPECT_CALL(*custom_protocol_, upgradeRequestDecoder()) + .WillOnce(Invoke([&]() -> DecoderEventHandlerSharedPtr { + return DecoderEventHandlerSharedPtr{upgrade_decoder}; + })); + EXPECT_CALL(*upgrade_decoder, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, structBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readFieldBegin(_, _, _, _)) + .WillOnce(Invoke( + [&](Buffer::Instance&, std::string&, FieldType& field_type, int16_t& field_id) -> bool { + field_type = FieldType::Stop; + field_id = 0; + return true; + })); + EXPECT_CALL(*custom_protocol_, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, structEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_protocol_, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_transport_, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*upgrade_decoder, transportEnd()).WillOnce(Return(FilterStatus::Continue)); + + MockDirectResponse* direct_response = new NiceMock(); + + EXPECT_CALL(*custom_protocol_, upgradeResponse(Ref(*upgrade_decoder))) + .WillOnce(Invoke([&](const DecoderEventHandler&) -> DirectResponsePtr { + return DirectResponsePtr{direct_response}; + })); + + EXPECT_CALL(*direct_response, encode(_, Ref(*custom_protocol_), _)) + .WillOnce(Invoke([&](MessageMetadata&, Protocol&, Buffer::Instance& buffer) -> void { + buffer.add("response"); + })); + EXPECT_CALL(*custom_transport_, encodeFrame(_, _, _)) + .WillOnce(Invoke( + [&](Buffer::Instance& buffer, const MessageMetadata&, Buffer::Instance& message) -> void { + EXPECT_EQ("response", message.toString()); + buffer.add("transport-encoded response"); + })); + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ("transport-encoded response", buffer.toString()); + })); + + Buffer::OwnedImpl buffer; + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 6fd60d2256814..a8369e757b220 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -710,17 +710,17 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { } TEST(DecoderTest, OnData) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; @@ -732,7 +732,7 @@ TEST(DecoderTest, OnData) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -750,18 +750,18 @@ TEST(DecoderTest, OnData) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::Continue)); bool underflow = false; @@ -770,24 +770,24 @@ TEST(DecoderTest, OnData) { } TEST(DecoderTest, OnDataWithProtocolHint) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); metadata.setProtocol(ProtocolType::Binary); return true; })); - EXPECT_CALL(*proto, type()).WillOnce(Return(ProtocolType::Auto)); - EXPECT_CALL(*proto, setType(ProtocolType::Binary)); + EXPECT_CALL(proto, type()).WillOnce(Return(ProtocolType::Auto)); + EXPECT_CALL(proto, setType(ProtocolType::Binary)); EXPECT_CALL(handler, transportBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { EXPECT_TRUE(metadata->hasFrameSize()); @@ -799,7 +799,7 @@ TEST(DecoderTest, OnDataWithProtocolHint) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -817,18 +817,18 @@ TEST(DecoderTest, OnDataWithProtocolHint) { return FilterStatus::Continue; })); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::Continue)); bool underflow = false; @@ -837,23 +837,23 @@ TEST(DecoderTest, OnDataWithProtocolHint) { } TEST(DecoderTest, OnDataWithInconsistentProtocolHint) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); metadata.setProtocol(ProtocolType::Binary); return true; })); - EXPECT_CALL(*proto, type()).WillRepeatedly(Return(ProtocolType::Compact)); + EXPECT_CALL(proto, type()).WillRepeatedly(Return(ProtocolType::Compact)); bool underflow = false; EXPECT_THROW_WITH_MESSAGE(decoder.onData(buffer, underflow), EnvoyException, @@ -861,17 +861,17 @@ TEST(DecoderTest, OnDataWithInconsistentProtocolHint) { } TEST(DecoderTest, OnDataThrowsTransportAppException) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setAppException(AppExceptionType::InvalidTransform, "unknown xform"); return true; @@ -882,85 +882,85 @@ TEST(DecoderTest, OnDataThrowsTransportAppException) { } TEST(DecoderTest, OnDataResumes) { - NiceMock* transport = new NiceMock(); - NiceMock* proto = new NiceMock(); + NiceMock transport; + NiceMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; buffer.add("x"); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(false)); + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(false)); bool underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); // buffer.length() == 1 } TEST(DecoderTest, OnDataResumesTransportFrameStart) { - StrictMock* transport = new StrictMock(); - StrictMock* proto = new StrictMock(); + StrictMock transport; + StrictMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); - EXPECT_CALL(*transport, name()).Times(AnyNumber()); - EXPECT_CALL(*proto, name()).Times(AnyNumber()); + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; bool underflow = false; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)).WillOnce(Return(false)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); @@ -968,66 +968,65 @@ TEST(DecoderTest, OnDataResumesTransportFrameStart) { } TEST(DecoderTest, OnDataResumesTransportFrameEnd) { - StrictMock* transport = new StrictMock(); - StrictMock* proto = new StrictMock(); + StrictMock transport; + StrictMock proto; NiceMock callbacks; NiceMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); - EXPECT_CALL(*transport, name()).Times(AnyNumber()); - EXPECT_CALL(*proto, name()).Times(AnyNumber()); + EXPECT_CALL(transport, name()).Times(AnyNumber()); + EXPECT_CALL(proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; })); - EXPECT_CALL(*proto, readMessageBegin(_, _)) + EXPECT_CALL(proto, readMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); metadata.setSequenceId(100); return true; })); - EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(false)); + EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(false)); bool underflow = false; EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); - EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(_)).WillOnce(Return(true)); EXPECT_EQ(FilterStatus::Continue, decoder.onData(buffer, underflow)); EXPECT_TRUE(underflow); // buffer.length() == 0 } TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { + StrictMock transport; + EXPECT_CALL(transport, name()).WillRepeatedly(ReturnRef(transport.name_)); - StrictMock* transport = new StrictMock(); - EXPECT_CALL(*transport, name()).WillRepeatedly(ReturnRef(transport->name_)); - - StrictMock* proto = new StrictMock(); - EXPECT_CALL(*proto, name()).WillRepeatedly(ReturnRef(proto->name_)); + StrictMock proto; + EXPECT_CALL(proto, name()).WillRepeatedly(ReturnRef(proto.name_)); NiceMock callbacks; StrictMock handler; ON_CALL(callbacks, newDecoderEventHandler()).WillByDefault(ReturnRef(handler)); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Decoder decoder(transport, proto, callbacks); Buffer::OwnedImpl buffer; bool underflow = true; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + EXPECT_CALL(transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); return true; @@ -1042,7 +1041,7 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setMethodName("name"); metadata.setMessageType(MessageType::Call); @@ -1062,42 +1061,42 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(handler, structBegin(absl::string_view())) .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::I32), SetArgReferee<3>(1), Return(true))); EXPECT_CALL(handler, fieldBegin(absl::string_view(), FieldType::I32, 1)) .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readInt32(_, _)).WillOnce(Return(true)); + EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(true)); EXPECT_CALL(handler, int32Value(_)).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, fieldEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); - EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, structEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, messageEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); - EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(handler, transportEnd()).WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(FilterStatus::StopIteration, decoder.onData(buffer, underflow)); EXPECT_FALSE(underflow); diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index ad14d3c1149e1..ac0359dce7c17 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -27,6 +27,7 @@ MockProtocol::MockProtocol() { ON_CALL(*this, setType(_)).WillByDefault(Invoke([&](ProtocolType type) -> void { type_ = type; })); + ON_CALL(*this, supportsUpgrade()).WillByDefault(Return(false)); } MockProtocol::~MockProtocol() {} @@ -39,6 +40,9 @@ MockDecoderEventHandler::~MockDecoderEventHandler() {} MockDirectResponse::MockDirectResponse() {} MockDirectResponse::~MockDirectResponse() {} +MockThriftObject::MockThriftObject() {} +MockThriftObject::~MockThriftObject() {} + namespace ThriftFilters { MockDecoderFilter::MockDecoderFilter() { diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index b93b7717c8780..2067434b2db88 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -1,6 +1,7 @@ #pragma once #include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/conn_state.h" #include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/metadata.h" #include "extensions/filters/network/thrift_proxy/protocol.h" @@ -101,6 +102,12 @@ class MockProtocol : public Protocol { MOCK_METHOD2(writeDouble, void(Buffer::Instance& buffer, double value)); MOCK_METHOD2(writeString, void(Buffer::Instance& buffer, const std::string& value)); MOCK_METHOD2(writeBinary, void(Buffer::Instance& buffer, const std::string& value)); + MOCK_METHOD0(supportsUpgrade, bool()); + MOCK_METHOD0(upgradeRequestDecoder, DecoderEventHandlerSharedPtr()); + MOCK_METHOD1(upgradeResponse, DirectResponsePtr(const DecoderEventHandler&)); + MOCK_METHOD3(attemptUpgrade, + ThriftObjectPtr(Transport&, ThriftConnectionState&, Buffer::Instance&)); + MOCK_METHOD2(completeUpgrade, void(ThriftConnectionState&, ThriftObject&)); std::string name_{"mock"}; ProtocolType type_{ProtocolType::Auto}; @@ -154,6 +161,15 @@ class MockDirectResponse : public DirectResponse { MOCK_CONST_METHOD3(encode, void(MessageMetadata&, Protocol&, Buffer::Instance&)); }; +class MockThriftObject : public ThriftObject { +public: + MockThriftObject(); + ~MockThriftObject(); + + MOCK_CONST_METHOD0(fields, ThriftFieldPtrList&()); + MOCK_METHOD1(onData, bool(Buffer::Instance&)); +}; + namespace ThriftFilters { class MockDecoderFilter : public DecoderFilter { @@ -204,7 +220,7 @@ class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { MOCK_CONST_METHOD0(downstreamTransportType, TransportType()); MOCK_CONST_METHOD0(downstreamProtocolType, ProtocolType()); MOCK_METHOD1(sendLocalReply, void(const DirectResponse&)); - MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); + MOCK_METHOD2(startUpstreamResponse, void(Transport&, Protocol&)); MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); MOCK_METHOD0(resetDownstreamConnection, void()); diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc index 3b07f4d9186dd..31ce50e8327c7 100644 --- a/test/extensions/filters/network/thrift_proxy/router_test.cc +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -21,6 +21,7 @@ using testing::_; using testing::ContainsRegex; +using testing::InSequence; using testing::Invoke; using testing::NiceMock; using testing::Ref; @@ -71,8 +72,22 @@ class TestNamedProtocolConfigFactory : public NamedProtocolConfigFactory { class ThriftRouterTestBase { public: ThriftRouterTestBase() - : transport_factory_([&]() -> MockTransport* { return transport_; }), - protocol_factory_([&]() -> MockProtocol* { return protocol_; }), + : transport_factory_([&]() -> MockTransport* { + ASSERT(transport_ == nullptr); + transport_ = new NiceMock(); + if (mock_transport_cb_) { + mock_transport_cb_(transport_); + } + return transport_; + }), + protocol_factory_([&]() -> MockProtocol* { + ASSERT(protocol_ == nullptr); + protocol_ = new NiceMock(); + if (mock_protocol_cb_) { + mock_protocol_cb_(protocol_); + } + return protocol_; + }), transport_register_(transport_factory_), protocol_register_(protocol_factory_) {} void initializeRouter() { @@ -124,9 +139,6 @@ class ThriftRouterTestBase { upstream_callbacks_ = &cb; })); - protocol_ = new NiceMock(); - - ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { EXPECT_EQ(metadata_->methodName(), metadata.methodName()); @@ -165,15 +177,16 @@ class ThriftRouterTestBase { EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); - protocol_ = new NiceMock(); - ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); - EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) - .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { - EXPECT_EQ(metadata_->methodName(), metadata.methodName()); - EXPECT_EQ(metadata_->messageType(), metadata.messageType()); - EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); - })); + mock_protocol_cb_ = [&](MockProtocol* protocol) -> void { + ON_CALL(*protocol, type()).WillByDefault(Return(ProtocolType::Binary)); + EXPECT_CALL(*protocol, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + }; EXPECT_CALL(callbacks_, continueDecoding()).Times(0); EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, newConnection(_)) .WillOnce( @@ -240,8 +253,6 @@ class ThriftRouterTestBase { } void completeRequest() { - transport_ = new NiceMock(); - EXPECT_CALL(*protocol_, writeMessageEnd(_)); EXPECT_CALL(*transport_, encodeFrame(_, _, _)); EXPECT_CALL(upstream_connection_, write(_, false)); @@ -257,7 +268,7 @@ class ThriftRouterTestBase { void returnResponse() { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); upstream_callbacks_->onUpstreamData(buffer, false); @@ -277,6 +288,9 @@ class ThriftRouterTestBase { Registry::InjectFactory transport_register_; Registry::InjectFactory protocol_register_; + std::function mock_transport_cb_{}; + std::function mock_protocol_cb_{}; + NiceMock context_; NiceMock callbacks_; NiceMock* transport_{}; @@ -472,7 +486,7 @@ TEST_F(ThriftRouterTest, TruncatedResponse) { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, released(Ref(upstream_connection_))); EXPECT_CALL(callbacks_, resetDownstreamConnection()); @@ -531,7 +545,7 @@ TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) { Buffer::OwnedImpl buffer; - EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, startUpstreamResponse(_, _)); EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))) .WillOnce(Invoke([&](Buffer::Instance&) -> bool { router_->resetUpstreamConnection(); @@ -590,6 +604,100 @@ TEST_F(ThriftRouterTest, UnexpectedRouterDestroy) { destroyRouter(); } +TEST_F(ThriftRouterTest, ProtocolUpgrade) { + initializeRouter(); + startRequest(MessageType::Call); + + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { upstream_callbacks_ = &cb; })); + + Tcp::ConnectionPool::ConnectionStatePtr conn_state; + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state.get(); })); + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, setConnectionState_(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::ConnectionStatePtr& cs) -> void { conn_state.swap(cs); })); + + EXPECT_CALL(*protocol_, supportsUpgrade()).WillOnce(Return(true)); + + MockThriftObject* upgrade_response = new NiceMock(); + + EXPECT_CALL(*protocol_, attemptUpgrade(_, _, _)) + .WillOnce(Invoke( + [&](Transport&, ThriftConnectionState&, Buffer::Instance& buffer) -> ThriftObjectPtr { + buffer.add("upgrade request"); + return ThriftObjectPtr{upgrade_response}; + })); + EXPECT_CALL(upstream_connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ("upgrade request", buffer.toString()); + })); + + context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_); + EXPECT_NE(nullptr, upstream_callbacks_); + + Buffer::OwnedImpl buffer; + EXPECT_CALL(*upgrade_response, onData(Ref(buffer))).WillOnce(Return(false)); + upstream_callbacks_->onUpstreamData(buffer, false); + + EXPECT_CALL(*upgrade_response, onData(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(*protocol_, completeUpgrade(_, Ref(*upgrade_response))); + EXPECT_CALL(callbacks_, continueDecoding()); + EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + upstream_callbacks_->onUpstreamData(buffer, false); + + // Then the actual request... + sendTrivialStruct(FieldType::String); + completeRequest(); + returnResponse(); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, ProtocolUpgradeSkippedOnExistingConnection) { + initializeRouter(); + startRequest(MessageType::Call); + + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke( + [&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { upstream_callbacks_ = &cb; })); + + Tcp::ConnectionPool::ConnectionStatePtr conn_state = std::make_unique(); + EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, connectionState()) + .WillRepeatedly( + Invoke([&]() -> Tcp::ConnectionPool::ConnectionState* { return conn_state.get(); })); + + EXPECT_CALL(*protocol_, supportsUpgrade()).WillOnce(Return(true)); + + // Protocol determines that connection state shows upgrade already occurred + EXPECT_CALL(*protocol_, attemptUpgrade(_, _, _)) + .WillOnce(Invoke([&](Transport&, ThriftConnectionState&, + Buffer::Instance&) -> ThriftObjectPtr { return nullptr; })); + + EXPECT_CALL(*protocol_, writeMessageBegin(_, _)) + .WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void { + EXPECT_EQ(metadata_->methodName(), metadata.methodName()); + EXPECT_EQ(metadata_->messageType(), metadata.messageType()); + EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId()); + })); + EXPECT_CALL(callbacks_, continueDecoding()); + + context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_); + EXPECT_NE(nullptr, upstream_callbacks_); + + // Then the actual request... + sendTrivialStruct(FieldType::String); + completeRequest(); + returnResponse(); + destroyRouter(); +} + TEST_P(ThriftRouterFieldTypeTest, OneWay) { FieldType field_type = GetParam(); diff --git a/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc b/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc new file mode 100644 index 0000000000000..3e8d12403dcd8 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/thrift_object_impl_test.cc @@ -0,0 +1,494 @@ +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/thrift_object_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::AnyNumber; +using testing::Expectation; +using testing::ExpectationSet; +using testing::InSequence; +using testing::NiceMock; +using testing::Ref; +using testing::Return; +using testing::ReturnRef; +using testing::Test; +using testing::TestWithParam; +using testing::Values; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class ThriftObjectImplTestBase { +public: + virtual ~ThriftObjectImplTestBase() {} + + Expectation expectValue(FieldType field_type) { + switch (field_type) { + case FieldType::Bool: + return EXPECT_CALL(protocol_, readBool(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, bool& value) -> bool { + value = true; + return true; + })); + case FieldType::Byte: + return EXPECT_CALL(protocol_, readByte(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, uint8_t& value) -> bool { + value = 1; + return true; + })); + case FieldType::Double: + return EXPECT_CALL(protocol_, readDouble(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, double& value) -> bool { + value = 2.0; + return true; + })); + case FieldType::I16: + return EXPECT_CALL(protocol_, readInt16(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int16_t& value) -> bool { + value = 3; + return true; + })); + case FieldType::I32: + return EXPECT_CALL(protocol_, readInt32(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int32_t& value) -> bool { + value = 4; + return true; + })); + case FieldType::I64: + return EXPECT_CALL(protocol_, readInt64(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, int64_t& value) -> bool { + value = 5; + return true; + })); + case FieldType::String: + return EXPECT_CALL(protocol_, readString(Ref(buffer_), _)) + .WillOnce(Invoke([](Buffer::Instance&, std::string& value) -> bool { + value = "six"; + return true; + })); + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + Expectation expectFieldBegin(FieldType field_type, int16_t field_id) { + return EXPECT_CALL(protocol_, readFieldBegin(Ref(buffer_), _, _, _)) + .WillOnce( + Invoke([=](Buffer::Instance&, std::string&, FieldType& type, int16_t& id) -> bool { + type = field_type; + id = field_id; + return true; + })); + } + + Expectation expectFieldEnd() { + return EXPECT_CALL(protocol_, readFieldEnd(Ref(buffer_))).WillOnce(Return(true)); + } + + ExpectationSet expectField(FieldType field_type, int16_t field_id) { + ExpectationSet s; + s += expectFieldBegin(field_type, field_id); + s += expectValue(field_type); + s += expectFieldEnd(); + return s; + } + + Expectation expectStopField() { return expectFieldBegin(FieldType::Stop, 0); } + + void checkValue(FieldType field_type, const ThriftValue& value) { + EXPECT_EQ(field_type, value.type()); + + switch (field_type) { + case FieldType::Bool: + EXPECT_EQ(true, value.getValueTyped()); + break; + case FieldType::Byte: + EXPECT_EQ(1, value.getValueTyped()); + break; + case FieldType::Double: + EXPECT_EQ(2.0, value.getValueTyped()); + break; + case FieldType::I16: + EXPECT_EQ(3, value.getValueTyped()); + break; + case FieldType::I32: + EXPECT_EQ(4, value.getValueTyped()); + break; + case FieldType::I64: + EXPECT_EQ(5, value.getValueTyped()); + break; + case FieldType::String: + EXPECT_EQ("six", value.getValueTyped()); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + void checkFieldValue(const ThriftField& field) { + const ThriftValue& value = field.getValue(); + checkValue(field.fieldType(), value); + } + + NiceMock transport_; + NiceMock protocol_; + Buffer::OwnedImpl buffer_; +}; + +class ThriftObjectImplTest : public ThriftObjectImplTestBase, public Test {}; + +// Test parsing an empty struct (just a stop field). +TEST_F(ThriftObjectImplTest, ParseEmptyStruct) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_TRUE(thrift_obj.fields().empty()); +} + +class ThriftObjectImplValueTest : public ThriftObjectImplTestBase, + public TestWithParam {}; + +INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, ThriftObjectImplValueTest, + Values(FieldType::Bool, FieldType::Byte, FieldType::Double, FieldType::I16, + FieldType::I32, FieldType::I64, FieldType::String), + fieldTypeParamToString); + +// Test parsing a struct with a single field with a simple value. +TEST_P(ThriftObjectImplValueTest, ParseSingleValueStruct) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(field_type, 1); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + EXPECT_EQ(field_type, thrift_obj.fields().front()->fieldType()); + EXPECT_EQ(1, thrift_obj.fields().front()->fieldId()); + checkFieldValue(*thrift_obj.fields().front()); +} + +// Test parsing nested structs (struct -> struct -> simple field). +TEST_P(ThriftObjectImplValueTest, ParseNestedSingleValueStruct) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Struct, 1); + + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(field_type, 2); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(FieldType::Struct, field.fieldType()); + + const ThriftStructValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(1, nested.fields().size()); + EXPECT_EQ(field_type, nested.fields().front()->fieldType()); + EXPECT_EQ(2, nested.fields().front()->fieldId()); + checkFieldValue(*nested.fields().front()); +} + +// Test parsing a struct with a single list field (struct -> list). +TEST_P(ThriftObjectImplValueTest, ParseNestedListValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::List, 1); + + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& type, uint32_t& size) -> bool { + type = field_type; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(field_type); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::List, field.fieldType()); + + const ThriftListValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.elementType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value); + } +} + +// Test parsing a struct with a single set field (struct -> set). +TEST_P(ThriftObjectImplValueTest, ParseNestedSetValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Set, 1); + + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& type, uint32_t& size) -> bool { + type = field_type; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(field_type); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::Set, field.fieldType()); + + const ThriftSetValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.elementType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value); + } +} + +// Test parsing a struct with a single map field (struct -> map). +TEST_P(ThriftObjectImplValueTest, ParseNestedMapValue) { + FieldType field_type = GetParam(); + + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Map, 1); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = field_type; + value_type = FieldType::String; + size = 2; + return true; + })); + expectValue(field_type); + expectValue(FieldType::String); + expectValue(field_type); + expectValue(FieldType::String); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + const ThriftField& field = *thrift_obj.fields().front(); + EXPECT_EQ(1, field.fieldId()); + EXPECT_EQ(FieldType::Map, field.fieldType()); + + const ThriftMapValue& nested = field.getValue().getValueTyped(); + EXPECT_EQ(field_type, nested.keyType()); + EXPECT_EQ(FieldType::String, nested.valueType()); + EXPECT_EQ(2, nested.elements().size()); + for (auto& value : nested.elements()) { + checkValue(field_type, *value.first); + checkValue(FieldType::String, *value.second); + } +} + +// Test a struct with a map -> list -> set -> map -> list -> set -> struct. +TEST_F(ThriftObjectImplTest, DeeplyNestedStruct) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectFieldBegin(FieldType::Map, 1); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = FieldType::I32; + value_type = FieldType::List; + size = 1; + return true; + })); + expectValue(FieldType::I32); + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Set; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Map; + size = 1; + return true; + })); + + EXPECT_CALL(protocol_, readMapBegin(Ref(buffer_), _, _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& key_type, FieldType& value_type, + uint32_t& size) -> bool { + key_type = FieldType::I32; + value_type = FieldType::List; + size = 1; + return true; + })); + expectValue(FieldType::I32); + EXPECT_CALL(protocol_, readListBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Set; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readSetBegin(Ref(buffer_), _, _)) + .WillOnce(Invoke([&](Buffer::Instance&, FieldType& elem_type, uint32_t& size) -> bool { + elem_type = FieldType::Struct; + size = 1; + return true; + })); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(FieldType::I64, 100); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readSetEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readListEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMapEnd(Ref(buffer_))).WillOnce(Return(true)); + + expectFieldEnd(); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + + EXPECT_EQ(FieldType::Map, thrift_obj.fields().front()->fieldType()); + const ThriftMapValue& map_value = + thrift_obj.fields().front()->getValue().getValueTyped(); + EXPECT_EQ(1, map_value.elements().size()); + + const ThriftListValue& list_value = + map_value.elements().front().second->getValueTyped(); + EXPECT_EQ(1, list_value.elements().size()); + + const ThriftSetValue& set_value = list_value.elements().front()->getValueTyped(); + EXPECT_EQ(1, set_value.elements().size()); + + const ThriftMapValue& map_value2 = set_value.elements().front()->getValueTyped(); + EXPECT_EQ(1, map_value2.elements().size()); + + const ThriftListValue& list_value2 = + map_value2.elements().front().second->getValueTyped(); + EXPECT_EQ(1, list_value2.elements().size()); + + const ThriftSetValue& set_value2 = + list_value2.elements().front()->getValueTyped(); + EXPECT_EQ(1, set_value2.elements().size()); + + const ThriftStructValue& struct_value = + set_value2.elements().front()->getValueTyped(); + EXPECT_EQ(1, struct_value.fields().size()); + + EXPECT_EQ(5, struct_value.fields().front()->getValue().getValueTyped()); +} + +// Tests when caller requests wrong value type. +TEST_F(ThriftObjectImplTest, WrongValueType) { + ThriftObjectImpl thrift_obj(transport_, protocol_); + + InSequence s; + EXPECT_CALL(transport_, decodeFrameStart(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageBegin(Ref(buffer_), _)).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readStructBegin(Ref(buffer_), _)).WillOnce(Return(true)); + expectField(FieldType::String, 1); + expectStopField(); + EXPECT_CALL(protocol_, readStructEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(protocol_, readMessageEnd(Ref(buffer_))).WillOnce(Return(true)); + EXPECT_CALL(transport_, decodeFrameEnd(Ref(buffer_))).WillOnce(Return(true)); + + EXPECT_TRUE(thrift_obj.onData(buffer_)); + EXPECT_EQ(1, thrift_obj.fields().size()); + + const ThriftValue& value = thrift_obj.fields().front()->getValue(); + EXPECT_THROW_WITH_MESSAGE(value.getValueTyped(), EnvoyException, + fmt::format("expected field type {}, got {}", + static_cast(FieldType::I32), + static_cast(FieldType::String))); +} + +} // Namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy