diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD index 60540cd8079a4..da39334babd17 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/BUILD @@ -4,5 +4,8 @@ licenses(["notice"]) # Apache 2 api_proto_library_internal( name = "thrift_proxy", - srcs = ["thrift_proxy.proto"], + srcs = [ + "route.proto", + "thrift_proxy.proto", + ], ) diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto new file mode 100644 index 0000000000000..5c9af1c48755a --- /dev/null +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto @@ -0,0 +1,41 @@ +syntax = "proto3"; + +package envoy.extensions.filters.network.thrift_proxy.v2alpha1; +option go_package = "v2"; + +import "validate/validate.proto"; +import "gogoproto/gogo.proto"; + +// [#protodoc-title: Thrift route configuration] + +// [#comment:next free field: 3] +message RouteConfiguration { + // The name of the route configuration. Reserved for future use in asynchronous route discovery. + string name = 1; + + // The list of routes that will be matched, in order, against incoming requests. The first route + // that matches will be used. + repeated Route routes = 2 [(gogoproto.nullable) = false]; +} + +// [#comment:next free field: 3] +message Route { + // Route matching prarameters. + RouteMatch match = 1 [(validate.rules).message.required = true, (gogoproto.nullable) = false]; + + // Route request to some upstream cluster. + RouteAction route = 2 [(validate.rules).message.required = true, (gogoproto.nullable) = false]; +} + +// [#comment:next free field: 2] +message RouteMatch { + // If specified, the route must exactly match the request method name. As a special case, an + // empty string matches any request method name. + string method = 1; +} + +// [#comment:next free field: 2] +message RouteAction { + // Indicates the upstream cluster to which the request should be routed. + string cluster = 1 [(validate.rules).string.min_bytes = 1]; +} diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/BUILD b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/BUILD new file mode 100644 index 0000000000000..ce0ad0e254f03 --- /dev/null +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/BUILD @@ -0,0 +1,8 @@ +load("//bazel:api_build_system.bzl", "api_proto_library_internal") + +licenses(["notice"]) # Apache 2 + +api_proto_library_internal( + name = "router", + srcs = ["router.proto"], +) diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.proto new file mode 100644 index 0000000000000..2818c044e7b07 --- /dev/null +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.proto @@ -0,0 +1,9 @@ +syntax = "proto3"; + +package envoy.extensions.filters.network.thrift_proxy.v2alpha1.router; +option go_package = "router"; + +// [#protodoc-title: Thrift Router] +// Thrift Router configuration. +message Router { +} diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto index e2d6bd02cb261..8897292273202 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.proto @@ -3,11 +3,55 @@ syntax = "proto3"; package envoy.extensions.filters.network.thrift_proxy.v2alpha1; option go_package = "v2"; +import "envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.proto"; + import "validate/validate.proto"; +import "gogoproto/gogo.proto"; // [#protodoc-title: Extensions Thrift Proxy] // Thrift Proxy filter configuration. +// [#comment:next free field: 5] message ThriftProxy { + enum TransportType { + option (gogoproto.goproto_enum_prefix) = false; + + // For every new connection, the Thrift proxy will determine which transport to use. + AUTO_TRANSPORT = 0; + + // The Thrift proxy will assume the client is using the Thrift framed transport. + FRAMED = 1; + + // The Thrift proxy will assume the client is using the Thrift unframed transport. + UNFRAMED = 2; + } + + // Supplies the type of transport that the Thrift proxy should use. Defaults to `AUTO_TRANSPORT`. + TransportType transport = 2 [(validate.rules).enum.defined_only = true]; + + enum ProtocolType { + option (gogoproto.goproto_enum_prefix) = false; + + // For every new connection, the Thrift proxy will determine which protocol to use. + // N.B. The older, non-strict binary protocol is not included in automatic protocol + // detection. + AUTO_PROTOCOL = 0; + + // The Thrift proxy will assume the client is using the Thrift binary protocol. + BINARY = 1; + + // The Thrift proxy will assume the client is using the Thrift non-strict binary protocol. + LAX_BINARY = 2; + + // The Thrift proxy will assume the client is using the Thrift compact protocol. + COMPACT = 3; + } + + // Supplies the type of protocol that the Thrift proxy should use. Defaults to `AUTO_PROTOCOL`. + ProtocolType protocol = 3 [(validate.rules).enum.defined_only = true]; + // The human readable prefix to use when emitting statistics. string stat_prefix = 1 [(validate.rules).string.min_bytes = 1]; + + // The route table for the connection manager is static and is specified in this property. + RouteConfiguration route_config = 4; } diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index dad0243f89dc5..3adfeb6836c9e 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -85,6 +85,12 @@ EXTENSIONS = { "envoy.stat_sinks.metrics_service": "//source/extensions/stat_sinks/metrics_service:config", "envoy.stat_sinks.statsd": "//source/extensions/stat_sinks/statsd:config", + # + # Thrift filters + # + + "envoy.filters.thrift.router": "//source/extensions/filters/network/thrift_proxy/router:config", + # # Tracers # diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 48555347641b0..c69373d53e375 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -8,6 +8,17 @@ load( envoy_package() +envoy_cc_library( + name = "app_exception_lib", + srcs = ["app_exception_impl.cc"], + hdrs = ["app_exception_impl.h"], + deps = [ + ":protocol_interface", + "//include/envoy/buffer:buffer_interface", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + envoy_cc_library( name = "buffer_helper_lib", srcs = ["buffer_helper.cc"], @@ -24,41 +35,68 @@ envoy_cc_library( srcs = ["config.cc"], hdrs = ["config.h"], deps = [ - ":filter_lib", + ":conn_manager_lib", + ":decoder_lib", + ":protocol_lib", "//include/envoy/registry", - "//source/common/config:filter_json_lib", + "//source/common/config:utility_lib", "//source/extensions/filters/network:well_known_names", "//source/extensions/filters/network/common:factory_base_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_config_interface", + "//source/extensions/filters/network/thrift_proxy/filters:well_known_names", + "//source/extensions/filters/network/thrift_proxy/router:router_lib", "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v2alpha1:thrift_proxy_cc", ], ) +envoy_cc_library( + name = "conn_manager_lib", + srcs = ["conn_manager.cc"], + hdrs = ["conn_manager.h"], + deps = [ + ":app_exception_lib", + ":decoder_lib", + ":protocol_converter_lib", + ":protocol_lib", + ":stats_lib", + ":transport_lib", + "//include/envoy/event:deferred_deletable", + "//include/envoy/event:dispatcher_interface", + "//include/envoy/network:connection_interface", + "//include/envoy/network:filter_interface", + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:timespan", + "//source/common/buffer:buffer_lib", + "//source/common/common:assert_lib", + "//source/common/common:linked_object", + "//source/common/common:logger_lib", + "//source/common/network:filter_lib", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + ], +) + envoy_cc_library( name = "decoder_lib", srcs = ["decoder.cc"], hdrs = ["decoder.h"], deps = [ ":protocol_lib", + ":stats_lib", ":transport_lib", "//source/common/buffer:buffer_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) envoy_cc_library( - name = "filter_lib", - srcs = ["filter.cc"], - hdrs = ["filter.h"], + name = "protocol_converter_lib", + hdrs = [ + "protocol_converter.h", + ], deps = [ - ":decoder_lib", - "//include/envoy/network:connection_interface", - "//include/envoy/network:filter_interface", - "//include/envoy/stats:stats_interface", - "//include/envoy/stats:stats_macros", - "//include/envoy/stats:timespan", - "//source/common/buffer:buffer_lib", - "//source/common/common:assert_lib", - "//source/common/common:logger_lib", - "//source/common/network:filter_lib", + ":protocol_interface", + "//include/envoy/buffer:buffer_interface", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", ], ) @@ -70,6 +108,9 @@ envoy_cc_library( external_deps = ["abseil_optional"], deps = [ "//include/envoy/buffer:buffer_interface", + "//include/envoy/registry", + "//source/common/common:assert_lib", + "//source/common/config:utility_lib", "//source/common/singleton:const_singleton", ], ) @@ -94,12 +135,24 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "stats_lib", + hdrs = ["stats.h"], + deps = [ + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:stats_macros", + ], +) + envoy_cc_library( name = "transport_interface", hdrs = ["transport.h"], external_deps = ["abseil_optional"], deps = [ "//include/envoy/buffer:buffer_interface", + "//include/envoy/registry", + "//source/common/common:assert_lib", + "//source/common/config:utility_lib", "//source/common/singleton:const_singleton", ], ) @@ -109,6 +162,7 @@ envoy_cc_library( srcs = [ "framed_transport_impl.cc", "transport_impl.cc", + "unframed_transport_impl.cc", ], hdrs = [ "framed_transport_impl.h", diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc new file mode 100644 index 0000000000000..65455c12b3609 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.cc @@ -0,0 +1,34 @@ +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +static const std::string TApplicationException = "TApplicationException"; +static const std::string MessageField = "message"; +static const std::string TypeField = "type"; +static const std::string StopField = ""; + +void AppException::encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) { + proto.writeMessageBegin(buffer, method_name_, ThriftProxy::MessageType::Exception, seq_id_); + proto.writeStructBegin(buffer, TApplicationException); + + proto.writeFieldBegin(buffer, MessageField, ThriftProxy::FieldType::String, 1); + proto.writeString(buffer, error_message_); + proto.writeFieldEnd(buffer); + + proto.writeFieldBegin(buffer, TypeField, ThriftProxy::FieldType::I32, 2); + proto.writeInt32(buffer, static_cast(type_)); + proto.writeFieldEnd(buffer); + + proto.writeFieldBegin(buffer, StopField, ThriftProxy::FieldType::Stop, 0); + + proto.writeStructEnd(buffer); + proto.writeMessageEnd(buffer); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/app_exception_impl.h b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h new file mode 100644 index 0000000000000..4a0335704100a --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/app_exception_impl.h @@ -0,0 +1,44 @@ +#pragma once + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * Thrift Application Exception types. + * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md + */ +enum class AppExceptionType { + Unknown = 0, + UnknownMethod = 1, + InvalidMessageType = 2, + WrongMethodName = 3, + BadSequenceId = 4, + MissingResult = 5, + InternalError = 6, + ProtocolError = 7, + InvalidTransform = 8, + InvalidProtocol = 9, + UnsupportedClientType = 10, +}; + +struct AppException : public ThriftFilters::DirectResponse { + AppException(const absl::string_view method_name, int32_t seq_id, AppExceptionType type, + const std::string& error_message) + : method_name_(method_name), seq_id_(seq_id), type_(type), error_message_(error_message) {} + + void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) override; + + const std::string method_name_; + const int32_t seq_id_; + const AppExceptionType type_; + const std::string error_message_; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc index 3fc1038962e8a..ee5a734eda209 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.cc @@ -60,26 +60,22 @@ bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& msg_type = type; seq_id = BufferHelper::drainI32(buffer); - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } bool BinaryProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onMessageComplete(); return true; } bool BinaryProtocolImpl::readStructBegin(Buffer::Instance& buffer, std::string& name) { UNREFERENCED_PARAMETER(buffer); name.clear(); // binary protocol does not transmit struct names - onStructBegin(absl::string_view(name)); return true; } bool BinaryProtocolImpl::readStructEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onStructEnd(); return true; } @@ -110,7 +106,6 @@ bool BinaryProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& n name.clear(); // binary protocol does not transmit field names field_type = type; - onStructField(absl::string_view(name), field_type, field_id); return true; } @@ -402,7 +397,6 @@ bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::stri seq_id = BufferHelper::peekI32(buffer, 1); buffer.drain(5); - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } @@ -413,6 +407,27 @@ void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const st BufferHelper::writeI32(buffer, seq_id); } +class BinaryProtocolConfigFactory : public ProtocolFactoryBase { +public: + BinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().BINARY) {} +}; + +/** + * Static registration for the binary protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + +class LaxBinaryProtocolConfigFactory : public ProtocolFactoryBase { +public: + LaxBinaryProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().LAX_BINARY) {} +}; + +/** + * Static registration for the auto protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_lax_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h index fac781ea563b9..e292d6cd036b9 100644 --- a/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/binary_protocol_impl.h @@ -16,12 +16,13 @@ namespace ThriftProxy { * BinaryProtocolImpl implements the Thrift Binary protocol with strict message encoding. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-binary-protocol.md */ -class BinaryProtocolImpl : public ProtocolImplBase { +class BinaryProtocolImpl : public Protocol { public: - BinaryProtocolImpl(ProtocolCallbacks& callbacks) : ProtocolImplBase(callbacks) {} + BinaryProtocolImpl() {} // Protocol const std::string& name() const override { return ProtocolNames::get().BINARY; } + ProtocolType type() const override { return ProtocolType::Binary; } bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; @@ -81,7 +82,7 @@ class BinaryProtocolImpl : public ProtocolImplBase { */ class LaxBinaryProtocolImpl : public BinaryProtocolImpl { public: - LaxBinaryProtocolImpl(ProtocolCallbacks& callbacks) : BinaryProtocolImpl(callbacks) {} + LaxBinaryProtocolImpl() {} const std::string& name() const override { return ProtocolNames::get().LAX_BINARY; } diff --git a/source/extensions/filters/network/thrift_proxy/buffer_helper.h b/source/extensions/filters/network/thrift_proxy/buffer_helper.h index 8eb633a0140a6..c4945cd5da6b5 100644 --- a/source/extensions/filters/network/thrift_proxy/buffer_helper.h +++ b/source/extensions/filters/network/thrift_proxy/buffer_helper.h @@ -10,57 +10,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/** - * BufferWrapper provides a partial implementation of Buffer::Instance that is sufficient for - * BufferHelper to read Thrift protocol data without draining the buffer's contents. - */ -class BufferWrapper : public Buffer::Instance { -public: - BufferWrapper(Buffer::Instance& underlying) : underlying_(underlying) {} - - uint64_t position() { return position_; } - - // Buffer::Instance - void copyOut(size_t start, uint64_t size, void* data) const override { - ASSERT(position_ + start + size <= underlying_.length()); - underlying_.copyOut(start + position_, size, data); - } - void drain(uint64_t size) override { - ASSERT(position_ + size <= underlying_.length()); - position_ += size; - } - uint64_t length() const override { - ASSERT(underlying_.length() >= position_); - return underlying_.length() - position_; - } - void* linearize(uint32_t size) override { - ASSERT(position_ + size <= underlying_.length()); - uint8_t* p = static_cast(underlying_.linearize(position_ + size)); - return p + position_; - } - void add(const void*, uint64_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - void addBufferFragment(Buffer::BufferFragment&) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - void add(const std::string&) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - void add(const Buffer::Instance&) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - void commit(Buffer::RawSlice*, uint64_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - uint64_t getRawSlices(Buffer::RawSlice*, uint64_t) const override { - NOT_IMPLEMENTED_GCOVR_EXCL_LINE; - } - void move(Buffer::Instance&) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - void move(Buffer::Instance&, uint64_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - std::tuple read(int, uint64_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - uint64_t reserve(uint64_t, Buffer::RawSlice*, uint64_t) override { - NOT_IMPLEMENTED_GCOVR_EXCL_LINE; - } - ssize_t search(const void*, uint64_t, size_t) const override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - std::tuple write(int) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - std::string toString() const override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } - -private: - Buffer::Instance& underlying_; - uint64_t position_{0}; -}; - /** * BufferHelper provides buffer operations for reading bytes and numbers in the various encodings * used by Thrift protocols. diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc index 5ea3a3cafb483..417a80d8b6197 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.cc @@ -72,13 +72,11 @@ bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string msg_type = type; seq_id = id; - onMessageStart(absl::string_view(name), msg_type, seq_id); return true; } bool CompactProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { UNREFERENCED_PARAMETER(buffer); - onMessageComplete(); return true; } @@ -91,7 +89,6 @@ bool CompactProtocolImpl::readStructBegin(Buffer::Instance& buffer, std::string& last_field_id_stack_.push(last_field_id_); last_field_id_ = 0; - onStructBegin(absl::string_view(name)); return true; } @@ -105,7 +102,6 @@ bool CompactProtocolImpl::readStructEnd(Buffer::Instance& buffer) { last_field_id_ = last_field_id_stack_.top(); last_field_id_stack_.pop(); - onStructEnd(); return true; } @@ -124,7 +120,6 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& field_type = FieldType::Stop; buffer.drain(1); - onStructField(absl::string_view(name), field_type, field_id); return true; } @@ -166,7 +161,6 @@ bool CompactProtocolImpl::readFieldBegin(Buffer::Instance& buffer, std::string& buffer.drain(id_size + 1); - onStructField(absl::string_view(name), field_type, field_id); return true; } @@ -459,7 +453,7 @@ void CompactProtocolImpl::writeFieldBeginInternal( static_cast(compact_field_type)); } else { BufferHelper::writeI8(buffer, static_cast(compact_field_type)); - BufferHelper::writeI16(buffer, field_id); + BufferHelper::writeZigZagI32(buffer, static_cast(field_id)); } last_field_id_ = field_id; @@ -623,6 +617,17 @@ CompactProtocolImpl::CompactFieldType CompactProtocolImpl::convertFieldType(Fiel } } +class CompactProtocolConfigFactory : public ProtocolFactoryBase { +public: + CompactProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().COMPACT) {} +}; + +/** + * Static registration for the binary protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h index 72bad6dee4482..322d03a3a83da 100644 --- a/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/compact_protocol_impl.h @@ -19,12 +19,13 @@ namespace ThriftProxy { * CompactProtocolImpl implements the Thrift Compact protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-compact-protocol.md */ -class CompactProtocolImpl : public ProtocolImplBase { +class CompactProtocolImpl : public Protocol { public: - CompactProtocolImpl(ProtocolCallbacks& callbacks) : ProtocolImplBase(callbacks) {} + CompactProtocolImpl() {} // Protocol const std::string& name() const override { return ProtocolNames::get().COMPACT; } + ProtocolType type() const override { return ProtocolType::Compact; } bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index 49f169eee4815..ffab3d37f53c9 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -1,26 +1,81 @@ #include "extensions/filters/network/thrift_proxy/config.h" +#include #include #include "envoy/network/connection.h" #include "envoy/registry/registry.h" -#include "extensions/filters/network/thrift_proxy/filter.h" +#include "common/config/utility.h" + +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter_config.h" +#include "extensions/filters/network/thrift_proxy/filters/well_known_names.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/stats.h" +#include "extensions/filters/network/thrift_proxy/transport_impl.h" +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +namespace { + +typedef std::map< + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType, + TransportType> + TransportTypeMap; + +static const TransportTypeMap& transportTypeMap() { + CONSTRUCT_ON_FIRST_USE(TransportTypeMap, + { + {envoy::extensions::filters::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_AUTO_TRANSPORT, + TransportType::Auto}, + {envoy::extensions::filters::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_FRAMED, + TransportType::Framed}, + {envoy::extensions::filters::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_UNFRAMED, + TransportType::Unframed}, + }); +} + +typedef std::map< + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType, + ProtocolType> + ProtocolTypeMap; + +static const ProtocolTypeMap& protocolTypeMap() { + CONSTRUCT_ON_FIRST_USE(ProtocolTypeMap, { + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_AUTO_PROTOCOL, + ProtocolType::Auto}, + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_BINARY, + ProtocolType::Binary}, + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_LAX_BINARY, + ProtocolType::LaxBinary}, + {envoy::extensions::filters::network::thrift_proxy:: + v2alpha1::ThriftProxy_ProtocolType_COMPACT, + ProtocolType::Compact}, + }); +} + +} // namespace Network::FilterFactoryCb ThriftProxyFilterConfigFactory::createFilterFactoryFromProtoTyped( const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& proto_config, Server::Configuration::FactoryContext& context) { - ASSERT(!proto_config.stat_prefix().empty()); - - const std::string stat_prefix = fmt::format("thrift.{}.", proto_config.stat_prefix()); + std::shared_ptr filter_config(new ConfigImpl(proto_config, context)); - return [stat_prefix, &context](Network::FilterManager& filter_manager) -> void { - filter_manager.addFilter(std::make_shared(stat_prefix, context.scope())); + return [filter_config](Network::FilterManager& filter_manager) -> void { + filter_manager.addReadFilter(std::make_shared(*filter_config)); }; } @@ -31,6 +86,48 @@ static Registry::RegisterFactory registered_; +ConfigImpl::ConfigImpl( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& config, + Server::Configuration::FactoryContext& context) + : context_(context), stats_prefix_(fmt::format("thrift.{}.", config.stat_prefix())), + stats_(ThriftFilterStats::generateStats(stats_prefix_, context_.scope())), + transport_(config.transport()), proto_(config.protocol()), + route_matcher_(new Router::RouteMatcher(config.route_config())) { + + // Construct the only Thrift DecoderFilter: the Router + auto& factory = + Envoy::Config::Utility::getAndCheckFactory( + ThriftFilters::ThriftFilterNames::get().ROUTER); + ThriftFilters::FilterFactoryCb callback; + + auto empty_config = factory.createEmptyConfigProto(); + callback = factory.createFilterFactoryFromProto(*empty_config, stats_prefix_, context_); + filter_factories_.push_back(callback); +} + +void ConfigImpl::createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) { + for (const ThriftFilters::FilterFactoryCb& factory : filter_factories_) { + factory(callbacks); + } +} + +DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { + return std::make_unique(createTransport(), createProtocol(), callbacks); +} + +TransportPtr ConfigImpl::createTransport() { + TransportTypeMap::const_iterator i = transportTypeMap().find(transport_); + RELEASE_ASSERT(i != transportTypeMap().end(), "invalid transport type"); + + return NamedTransportConfigFactory::getFactory(i->second).createTransport(); +} + +ProtocolPtr ConfigImpl::createProtocol() { + ProtocolTypeMap::const_iterator i = protocolTypeMap().find(proto_); + RELEASE_ASSERT(i != protocolTypeMap().end(), "invalid protocol type"); + return NamedProtocolConfigFactory::getFactory(i->second).createProtocol(); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 40aa4b418774d..7dcc4e094353e 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -1,11 +1,16 @@ #pragma once +#include #include #include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" #include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.validate.h" +#include "envoy/stats/stats.h" #include "extensions/filters/network/common/factory_base.h" +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" #include "extensions/filters/network/well_known_names.h" namespace Envoy { @@ -28,6 +33,42 @@ class ThriftProxyFilterConfigFactory Server::Configuration::FactoryContext& context) override; }; +class ConfigImpl : public Config, + public Router::Config, + public ThriftFilters::FilterChainFactory, + Logger::Loggable { +public: + ConfigImpl(const envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy& config, + Server::Configuration::FactoryContext& context); + + // ThriftFilters::FilterChainFactory + void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override; + + // Router::Config + Router::RouteConstSharedPtr route(const std::string& method_name) const override { + return route_matcher_->route(method_name); + } + + // Config + ThriftFilterStats& stats() override { return stats_; } + ThriftFilters::FilterChainFactory& filterFactory() override { return *this; } + DecoderPtr createDecoder(DecoderCallbacks& callbacks) override; + Router::Config& routerConfig() override { return *this; } + +private: + TransportPtr createTransport(); + ProtocolPtr createProtocol(); + + Server::Configuration::FactoryContext& context_; + const std::string stats_prefix_; + ThriftFilterStats stats_; + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType transport_; + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType proto_; + std::unique_ptr route_matcher_; + + std::list filter_factories_; +}; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc new file mode 100644 index 0000000000000..c94bbeefcec36 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -0,0 +1,290 @@ +#include "extensions/filters/network/thrift_proxy/conn_manager.h" + +#include "envoy/common/exception.h" +#include "envoy/event/dispatcher.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" +#include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +ConnectionManager::ConnectionManager(Config& config) + : config_(config), stats_(config_.stats()), decoder_(config_.createDecoder(*this)) {} + +ConnectionManager::~ConnectionManager() {} + +Network::FilterStatus ConnectionManager::onData(Buffer::Instance& data, bool end_stream) { + UNREFERENCED_PARAMETER(end_stream); + + request_buffer_.move(data); + dispatch(); + + return Network::FilterStatus::StopIteration; +} + +void ConnectionManager::dispatch() { + if (stopped_) { + ENVOY_LOG(error, "thrift filter stopped"); + return; + } + + try { + bool underflow = false; + while (!underflow) { + ThriftFilters::FilterStatus status = decoder_->onData(request_buffer_, underflow); + if (status == ThriftFilters::FilterStatus::StopIteration) { + stopped_ = true; + break; + } + } + } catch (const EnvoyException& ex) { + ENVOY_LOG(error, "thrift error: {}", ex.what()); + stats_.request_decoding_error_.inc(); + + // Use the current rpc to send an error downstream, if possible. + rpcs_.front()->onError(ex.what()); + + resetAllRpcs(); + read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); + } +} + +void ConnectionManager::continueDecoding() { + stopped_ = false; + dispatch(); +} + +void ConnectionManager::doDeferredRpcDestroy(ConnectionManager::ActiveRpc& rpc) { + read_callbacks_->connection().dispatcher().deferredDelete(rpc.removeFromList(rpcs_)); +} + +void ConnectionManager::resetAllRpcs() { + while (!rpcs_.empty()) { + rpcs_.front()->onReset(); + } +} + +void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) { + read_callbacks_ = &callbacks; +} + +void ConnectionManager::onEvent(Network::ConnectionEvent event) { + if (!rpcs_.empty()) { + if (event == Network::ConnectionEvent::RemoteClose) { + stats_.cx_destroy_remote_with_active_rq_.inc(); + } else if (event == Network::ConnectionEvent::LocalClose) { + stats_.cx_destroy_local_with_active_rq_.inc(); + } + + resetAllRpcs(); + } +} + +ThriftFilters::DecoderFilter& ConnectionManager::newDecoderFilter() { + ENVOY_LOG(debug, "new decoder filter"); + + ActiveRpcPtr new_rpc(new ActiveRpc(*this)); + new_rpc->createFilterChain(); + new_rpc->moveIntoList(std::move(new_rpc), rpcs_); + + return **rpcs_.begin(); +} + +bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) { + upstream_buffer_.move(data); + + bool underflow = false; + decoder_->onData(upstream_buffer_, underflow); + ASSERT(complete_ || underflow); + return complete_; +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::messageBegin(absl::string_view name, + MessageType msg_type, + int32_t seq_id) { + reply_.emplace(std::string(name), msg_type, seq_id); + first_reply_field_ = (msg_type == MessageType::Reply); + return ProtocolConverter::messageBegin(name, msg_type, seq_id); +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::fieldBegin(absl::string_view name, + FieldType field_type, + int16_t field_id) { + if (first_reply_field_) { + // Reply messages contain a struct where field 0 is the call result and fields 1+ are + // exceptions, if defined. At most one field may be set. Therefore, the very first field we + // encounter in a reply is either field 0 (success) or not (IDL exception returned). + ASSERT(reply_.has_value()); + reply_.value().success_ = field_id == 0 && field_type != FieldType::Stop; + first_reply_field_ = false; + } + + return ProtocolConverter::fieldBegin(name, field_type, field_id); +} + +ThriftFilters::FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { + ConnectionManager& cm = parent_.parent_; + + Buffer::OwnedImpl buffer; + + // Use the factory to get the concrete transport from the decoder transport (as opposed to + // potentially pre-detection auto transport). + TransportPtr transport = + NamedTransportConfigFactory::getFactory(parent_.parent_.decoder_->transportType()) + .createTransport(); + transport->encodeFrame(buffer, parent_.response_buffer_); + complete_ = true; + + cm.read_callbacks_->connection().write(buffer, false); + + cm.stats_.response_.inc(); + + ASSERT(reply_.has_value()); + switch (reply_.value().msg_type_) { + case MessageType::Reply: + cm.stats_.response_reply_.inc(); + if (reply_.value().success_.value_or(false)) { + cm.stats_.response_success_.inc(); + } else { + cm.stats_.response_error_.inc(); + } + + break; + + case MessageType::Exception: + cm.stats_.response_exception_.inc(); + break; + + default: + cm.stats_.response_invalid_type_.inc(); + break; + } + + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus ConnectionManager::ActiveRpc::transportEnd() { + ASSERT(call_.has_value()); + + parent_.stats_.request_.inc(); + + switch (call_.value().msg_type_) { + case MessageType::Call: + parent_.stats_.request_call_.inc(); + break; + + case MessageType::Oneway: + parent_.stats_.request_oneway_.inc(); + + // No response forthcoming, we're done. + parent_.doDeferredRpcDestroy(*this); + break; + + default: + parent_.stats_.request_invalid_type_.inc(); + break; + } + + return decoder_filter_->transportEnd(); +} + +void ConnectionManager::ActiveRpc::createFilterChain() { + parent_.config_.filterFactory().createFilterChain(*this); +} + +void ConnectionManager::ActiveRpc::onReset() { + // TODO(zuercher): e.g., parent_.stats_.named_.downstream_rq_rx_reset_.inc(); + parent_.doDeferredRpcDestroy(*this); +} + +void ConnectionManager::ActiveRpc::onError(const std::string& what) { + if (call_.has_value()) { + const Message& msg = call_.value(); + sendLocalReply(std::make_unique(msg.method_name_, msg.seq_id_, + AppExceptionType::ProtocolError, what)); + return; + } + + // Transport or protocol error happened before (or during message begin) parsing. It's not + // possible to provide a valid response, so don't try. +} + +const Network::Connection* ConnectionManager::ActiveRpc::connection() const { + return &parent_.read_callbacks_->connection(); +} + +void ConnectionManager::ActiveRpc::continueDecoding() { parent_.continueDecoding(); } + +Router::RouteConstSharedPtr ConnectionManager::ActiveRpc::route() { + if (!cached_route_) { + if (call_.has_value()) { + Router::RouteConstSharedPtr route = + parent_.config_.routerConfig().route(call_.value().method_name_); + cached_route_ = std::move(route); + } else { + cached_route_ = nullptr; + } + } + + return cached_route_.value(); +} + +void ConnectionManager::ActiveRpc::sendLocalReply(ThriftFilters::DirectResponsePtr&& response) { + // Use the factory to get the concrete protocol from the decoder protocol (as opposed to + // potentially pre-detection auto protocol). + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(parent_.decoder_->protocolType()).createProtocol(); + Buffer::OwnedImpl buffer; + + response->encode(*proto, buffer); + + // Same logic as protocol above. + TransportPtr transport = + NamedTransportConfigFactory::getFactory(parent_.decoder_->transportType()).createTransport(); + transport->encodeFrame(response_buffer_, buffer); + + parent_.read_callbacks_->connection().write(response_buffer_, false); + parent_.doDeferredRpcDestroy(*this); +} + +void ConnectionManager::ActiveRpc::startUpstreamResponse(TransportType transport_type, + ProtocolType protocol_type) { + ASSERT(response_decoder_ == nullptr); + + response_decoder_ = std::make_unique(*this, transport_type, protocol_type); +} + +bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) { + ASSERT(response_decoder_ != nullptr); + + try { + bool complete = response_decoder_->onData(buffer); + if (complete) { + parent_.doDeferredRpcDestroy(*this); + } + return complete; + } catch (const EnvoyException& ex) { + ENVOY_LOG(error, "thrift response error: {}", ex.what()); + parent_.stats_.response_decoding_error_.inc(); + + onError(ex.what()); + decoder_filter_->resetUpstreamConnection(); + return true; + } +} + +void ConnectionManager::ActiveRpc::resetDownstreamConnection() { + parent_.read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush); + parent_.doDeferredRpcDestroy(*this); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h new file mode 100644 index 0000000000000..c366a40c0f2a5 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -0,0 +1,256 @@ +#pragma once + +#include "envoy/common/pure.h" +#include "envoy/event/deferred_deletable.h" +#include "envoy/network/connection.h" +#include "envoy/network/filter.h" +#include "envoy/stats/stats.h" +#include "envoy/stats/timespan.h" + +#include "common/buffer/buffer_impl.h" +#include "common/common/linked_object.h" +#include "common/common/logger.h" + +#include "extensions/filters/network/thrift_proxy/decoder.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/protocol_converter.h" +#include "extensions/filters/network/thrift_proxy/stats.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * Config is a configuration interface for ConnectionManager. + */ +class Config { +public: + virtual ~Config() {} + + virtual ThriftFilters::FilterChainFactory& filterFactory() PURE; + virtual ThriftFilterStats& stats() PURE; + virtual DecoderPtr createDecoder(DecoderCallbacks& callbacks) PURE; + virtual Router::Config& routerConfig() PURE; +}; + +/** + * ConnectionManager is a Network::Filter that will perform Thrift request handling on a connection. + */ +class ConnectionManager : public Network::ReadFilter, + public Network::ConnectionCallbacks, + public DecoderCallbacks, + Logger::Loggable { +public: + ConnectionManager(Config& config); + ~ConnectionManager(); + + // Network::ReadFilter + Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override; + Network::FilterStatus onNewConnection() override { return Network::FilterStatus::Continue; } + void initializeReadFilterCallbacks(Network::ReadFilterCallbacks&) override; + + // Network::ConnectionCallbacks + void onEvent(Network::ConnectionEvent) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + + // DecoderCallbacks + ThriftFilters::DecoderFilter& newDecoderFilter() override; + +private: + class Message { + public: + Message(const std::string& method_name, MessageType msg_type, int32_t seq_id) + : method_name_(method_name), msg_type_(msg_type), seq_id_(seq_id) {} + + const std::string method_name_; + const MessageType msg_type_; + const int32_t seq_id_; + absl::optional success_; + }; + + struct ActiveRpc; + + struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { + ResponseDecoder(ActiveRpc& parent, TransportType transport_type, ProtocolType protocol_type) + : parent_(parent), + decoder_(std::make_unique( + NamedTransportConfigFactory::getFactory(transport_type).createTransport(), + NamedProtocolConfigFactory::getFactory(protocol_type).createProtocol(), *this)), + complete_(false), first_reply_field_(false) { + // Use the factory to get the concrete protocol from the decoder protocol (as opposed to + // potentially pre-detection auto protocol). + initProtocolConverter( + NamedProtocolConfigFactory::getFactory(parent_.parent_.decoder_->protocolType()) + .createProtocol(), + parent_.response_buffer_); + } + + bool onData(Buffer::Instance& data); + + // ProtocolConverter + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override; + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override; + ThriftFilters::FilterStatus transportBegin(absl::optional size) override { + UNREFERENCED_PARAMETER(size); + return ThriftFilters::FilterStatus::Continue; + } + ThriftFilters::FilterStatus transportEnd() override; + + // DecoderCallbacks + ThriftFilters::DecoderFilter& newDecoderFilter() override { return *this; } + + ActiveRpc& parent_; + DecoderPtr decoder_; + Buffer::OwnedImpl upstream_buffer_; + absl::optional reply_; + bool complete_ : 1; + bool first_reply_field_ : 1; + }; + typedef std::unique_ptr ResponseDecoderPtr; + + // ActiveRpc tracks request/response pairs. + struct ActiveRpc : LinkedObject, + public Event::DeferredDeletable, + public ThriftFilters::DecoderFilter, + public ThriftFilters::DecoderFilterCallbacks, + public ThriftFilters::FilterChainFactoryCallbacks { + ActiveRpc(ConnectionManager& parent) + : parent_(parent), request_timer_(new Stats::Timespan(parent_.stats_.request_time_ms_)), + stream_id_(parent_.stream_id_++) { + parent_.stats_.request_active_.inc(); + } + ~ActiveRpc() { + request_timer_->complete(); + parent_.stats_.request_active_.dec(); + + if (decoder_filter_ != nullptr) { + decoder_filter_->onDestroy(); + } + } + + // ThriftFilters::DecoderFilter + void onDestroy() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + ThriftFilters::FilterStatus transportBegin(absl::optional size) override { + return decoder_filter_->transportBegin(size); + } + ThriftFilters::FilterStatus transportEnd() override; + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override { + call_.emplace(std::string(name), msg_type, seq_id); + return decoder_filter_->messageBegin(name, msg_type, seq_id); + } + ThriftFilters::FilterStatus messageEnd() override { return decoder_filter_->messageEnd(); } + ThriftFilters::FilterStatus structBegin(absl::string_view name) override { + return decoder_filter_->structBegin(name); + } + ThriftFilters::FilterStatus structEnd() override { return decoder_filter_->structEnd(); } + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override { + return decoder_filter_->fieldBegin(name, field_type, field_id); + } + ThriftFilters::FilterStatus fieldEnd() override { return decoder_filter_->fieldEnd(); } + ThriftFilters::FilterStatus boolValue(bool value) override { + return decoder_filter_->boolValue(value); + } + ThriftFilters::FilterStatus byteValue(uint8_t value) override { + return decoder_filter_->byteValue(value); + } + ThriftFilters::FilterStatus int16Value(int16_t value) override { + return decoder_filter_->int16Value(value); + } + ThriftFilters::FilterStatus int32Value(int32_t value) override { + return decoder_filter_->int32Value(value); + } + ThriftFilters::FilterStatus int64Value(int64_t value) override { + return decoder_filter_->int64Value(value); + } + ThriftFilters::FilterStatus doubleValue(double value) override { + return decoder_filter_->doubleValue(value); + } + ThriftFilters::FilterStatus stringValue(absl::string_view value) override { + return decoder_filter_->stringValue(value); + } + ThriftFilters::FilterStatus mapBegin(FieldType key_type, FieldType value_type, + uint32_t size) override { + return decoder_filter_->mapBegin(key_type, value_type, size); + } + ThriftFilters::FilterStatus mapEnd() override { return decoder_filter_->mapEnd(); } + ThriftFilters::FilterStatus listBegin(FieldType elem_type, uint32_t size) override { + return decoder_filter_->listBegin(elem_type, size); + } + ThriftFilters::FilterStatus listEnd() override { return decoder_filter_->listEnd(); } + ThriftFilters::FilterStatus setBegin(FieldType elem_type, uint32_t size) override { + return decoder_filter_->setBegin(elem_type, size); + } + ThriftFilters::FilterStatus setEnd() override { return decoder_filter_->setEnd(); } + + // ThriftFilters::DecoderFilterCallbacks + uint64_t streamId() const override { return stream_id_; } + const Network::Connection* connection() const override; + void continueDecoding() override; + Router::RouteConstSharedPtr route() override; + TransportType downstreamTransportType() const override { + return parent_.decoder_->transportType(); + } + ProtocolType downstreamProtocolType() const override { + return parent_.decoder_->protocolType(); + } + void sendLocalReply(ThriftFilters::DirectResponsePtr&& response) override; + void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) override; + bool upstreamData(Buffer::Instance& buffer) override; + void resetDownstreamConnection() override; + + // Thrift::FilterChainFactoryCallbacks + void addDecoderFilter(ThriftFilters::DecoderFilterSharedPtr filter) override { + // TODO(zuercher): support multiple filters + filter->setDecoderFilterCallbacks(*this); + decoder_filter_ = filter; + } + + void createFilterChain(); + void onReset(); + void onError(const std::string& what); + + ConnectionManager& parent_; + Stats::TimespanPtr request_timer_; + uint64_t stream_id_; + ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + ResponseDecoderPtr response_decoder_; + absl::optional cached_route_; + absl::optional call_; + Buffer::OwnedImpl response_buffer_; + }; + + typedef std::unique_ptr ActiveRpcPtr; + + void continueDecoding(); + void dispatch(); + void doDeferredRpcDestroy(ActiveRpc& rpc); + void resetAllRpcs(); + + Config& config_; + ThriftFilterStats& stats_; + + Network::ReadFilterCallbacks* read_callbacks_{}; + + DecoderPtr decoder_; + std::list rpcs_; + Buffer::OwnedImpl request_buffer_; + uint64_t stream_id_{1}; + bool stopped_{false}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index 84b9df721faa6..39aa8af6a2836 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -13,69 +13,72 @@ namespace NetworkFilters { namespace ThriftProxy { // MessageBegin -> StructBegin -ProtocolState DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::messageBegin(Buffer::Instance& buffer) { std::string message_name; MessageType msg_type; int32_t seq_id; if (!proto_.readMessageBegin(buffer, message_name, msg_type, seq_id)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.clear(); stack_.emplace_back(Frame(ProtocolState::MessageEnd)); - return ProtocolState::StructBegin; + return DecoderStatus(ProtocolState::StructBegin, + filter_.messageBegin(absl::string_view(message_name), msg_type, seq_id)); } // MessageEnd -> Done -ProtocolState DecoderStateMachine::messageEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::messageEnd(Buffer::Instance& buffer) { if (!proto_.readMessageEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return ProtocolState::Done; + return DecoderStatus(ProtocolState::Done, filter_.messageEnd()); } // StructBegin -> FieldBegin -ProtocolState DecoderStateMachine::structBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::structBegin(Buffer::Instance& buffer) { std::string name; if (!proto_.readStructBegin(buffer, name)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return ProtocolState::FieldBegin; + return DecoderStatus(ProtocolState::FieldBegin, filter_.structBegin(absl::string_view(name))); } // StructEnd -> stack's return state -ProtocolState DecoderStateMachine::structEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::structEnd(Buffer::Instance& buffer) { if (!proto_.readStructEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.structEnd()); } // FieldBegin -> FieldValue, or // FieldBegin -> StructEnd (stop field) -ProtocolState DecoderStateMachine::fieldBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldBegin(Buffer::Instance& buffer) { std::string name; FieldType field_type; int16_t field_id; if (!proto_.readFieldBegin(buffer, name, field_type, field_id)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } if (field_type == FieldType::Stop) { - return ProtocolState::StructEnd; + return DecoderStatus(ProtocolState::StructEnd, ThriftFilters::FilterStatus::Continue); } stack_.emplace_back(Frame(ProtocolState::FieldEnd, field_type)); - return ProtocolState::FieldValue; + return DecoderStatus(ProtocolState::FieldValue, + filter_.fieldBegin(absl::string_view(name), field_type, field_id)); } // FieldValue -> FieldEnd (via stack return state) -ProtocolState DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); @@ -83,36 +86,36 @@ ProtocolState DecoderStateMachine::fieldValue(Buffer::Instance& buffer) { } // FieldEnd -> FieldBegin -ProtocolState DecoderStateMachine::fieldEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::fieldEnd(Buffer::Instance& buffer) { if (!proto_.readFieldEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } popReturnState(); - return ProtocolState::FieldBegin; + return DecoderStatus(ProtocolState::FieldBegin, filter_.fieldEnd()); } // ListBegin -> ListValue -ProtocolState DecoderStateMachine::listBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listBegin(Buffer::Instance& buffer) { FieldType elem_type; uint32_t size; if (!proto_.readListBegin(buffer, elem_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::ListEnd, elem_type, size)); - return ProtocolState::ListValue; + return DecoderStatus(ProtocolState::ListValue, filter_.listBegin(elem_type, size)); } // ListValue -> ListValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // ListValue -> ListEnd -ProtocolState DecoderStateMachine::listValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } frame.remaining_--; @@ -120,34 +123,35 @@ ProtocolState DecoderStateMachine::listValue(Buffer::Instance& buffer) { } // ListEnd -> stack's return state -ProtocolState DecoderStateMachine::listEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::listEnd(Buffer::Instance& buffer) { if (!proto_.readListEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.listEnd()); } // MapBegin -> MapKey -ProtocolState DecoderStateMachine::mapBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapBegin(Buffer::Instance& buffer) { FieldType key_type, value_type; uint32_t size; if (!proto_.readMapBegin(buffer, key_type, value_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::MapEnd, key_type, value_type, size)); - return ProtocolState::MapKey; + return DecoderStatus(ProtocolState::MapKey, filter_.mapBegin(key_type, value_type, size)); } // MapKey -> MapValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on key type), or // MapKey -> MapEnd -ProtocolState DecoderStateMachine::mapKey(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapKey(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } return handleValue(buffer, frame.elem_type_, ProtocolState::MapValue); @@ -155,7 +159,7 @@ ProtocolState DecoderStateMachine::mapKey(Buffer::Instance& buffer) { // MapValue -> MapKey, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // MapValue -> MapKey -ProtocolState DecoderStateMachine::mapValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); ASSERT(frame.remaining_ != 0); @@ -165,34 +169,35 @@ ProtocolState DecoderStateMachine::mapValue(Buffer::Instance& buffer) { } // MapEnd -> stack's return state -ProtocolState DecoderStateMachine::mapEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::mapEnd(Buffer::Instance& buffer) { if (!proto_.readMapEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.mapEnd()); } // SetBegin -> SetValue -ProtocolState DecoderStateMachine::setBegin(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setBegin(Buffer::Instance& buffer) { FieldType elem_type; uint32_t size; if (!proto_.readSetBegin(buffer, elem_type, size)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } stack_.emplace_back(Frame(ProtocolState::SetEnd, elem_type, size)); - return ProtocolState::SetValue; + return DecoderStatus(ProtocolState::SetValue, filter_.setBegin(elem_type, size)); } // SetValue -> SetValue, ListBegin, MapBegin, SetBegin, StructBegin (depending on value type), or // SetValue -> SetEnd -ProtocolState DecoderStateMachine::setValue(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setValue(Buffer::Instance& buffer) { ASSERT(!stack_.empty()); Frame& frame = stack_.back(); if (frame.remaining_ == 0) { - return popReturnState(); + return DecoderStatus(popReturnState(), ThriftFilters::FilterStatus::Continue); } frame.remaining_--; @@ -200,85 +205,88 @@ ProtocolState DecoderStateMachine::setValue(Buffer::Instance& buffer) { } // SetEnd -> stack's return state -ProtocolState DecoderStateMachine::setEnd(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::setEnd(Buffer::Instance& buffer) { if (!proto_.readSetEnd(buffer)) { - return ProtocolState::WaitForData; + return DecoderStatus(ProtocolState::WaitForData); } - return popReturnState(); + ProtocolState next_state = popReturnState(); + return DecoderStatus(next_state, filter_.setEnd()); } -ProtocolState DecoderStateMachine::handleValue(Buffer::Instance& buffer, FieldType elem_type, - ProtocolState return_state) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::handleValue(Buffer::Instance& buffer, + FieldType elem_type, + ProtocolState return_state) { switch (elem_type) { - case FieldType::Bool: - bool value; - if (!proto_.readBool(buffer, value)) { - return ProtocolState::WaitForData; + case FieldType::Bool: { + bool value{}; + if (proto_.readBool(buffer, value)) { + return DecoderStatus(return_state, filter_.boolValue(value)); } break; + } case FieldType::Byte: { - uint8_t value; - if (!proto_.readByte(buffer, value)) { - return ProtocolState::WaitForData; + uint8_t value{}; + if (proto_.readByte(buffer, value)) { + return DecoderStatus(return_state, filter_.byteValue(value)); } break; } case FieldType::I16: { - int16_t value; - if (!proto_.readInt16(buffer, value)) { - return ProtocolState::WaitForData; + int16_t value{}; + if (proto_.readInt16(buffer, value)) { + return DecoderStatus(return_state, filter_.int16Value(value)); } break; } case FieldType::I32: { - int32_t value; - if (!proto_.readInt32(buffer, value)) { - return ProtocolState::WaitForData; + int32_t value{}; + if (proto_.readInt32(buffer, value)) { + return DecoderStatus(return_state, filter_.int32Value(value)); } break; } case FieldType::I64: { - int64_t value; - if (!proto_.readInt64(buffer, value)) { - return ProtocolState::WaitForData; + int64_t value{}; + if (proto_.readInt64(buffer, value)) { + return DecoderStatus(return_state, filter_.int64Value(value)); } break; } case FieldType::Double: { - double value; - if (!proto_.readDouble(buffer, value)) { - return ProtocolState::WaitForData; + double value{}; + if (proto_.readDouble(buffer, value)) { + return DecoderStatus(return_state, filter_.doubleValue(value)); } break; } case FieldType::String: { std::string value; - if (!proto_.readString(buffer, value)) { - return ProtocolState::WaitForData; + if (proto_.readString(buffer, value)) { + return DecoderStatus(return_state, filter_.stringValue(value)); } break; } case FieldType::Struct: stack_.emplace_back(Frame(return_state)); - return ProtocolState::StructBegin; + return DecoderStatus(ProtocolState::StructBegin, ThriftFilters::FilterStatus::Continue); case FieldType::Map: stack_.emplace_back(Frame(return_state)); - return ProtocolState::MapBegin; + return DecoderStatus(ProtocolState::MapBegin, ThriftFilters::FilterStatus::Continue); case FieldType::List: stack_.emplace_back(Frame(return_state)); - return ProtocolState::ListBegin; + return DecoderStatus(ProtocolState::ListBegin, ThriftFilters::FilterStatus::Continue); case FieldType::Set: stack_.emplace_back(Frame(return_state)); - return ProtocolState::SetBegin; + return DecoderStatus(ProtocolState::SetBegin, ThriftFilters::FilterStatus::Continue); default: throw EnvoyException(fmt::format("unknown field type {}", static_cast(elem_type))); } - return return_state; + return DecoderStatus(ProtocolState::WaitForData); } -ProtocolState DecoderStateMachine::handleState(Buffer::Instance& buffer) { +DecoderStateMachine::DecoderStatus DecoderStateMachine::handleState(Buffer::Instance& buffer) { switch (state_) { case ProtocolState::MessageBegin: return messageBegin(buffer); @@ -328,61 +336,97 @@ ProtocolState DecoderStateMachine::popReturnState() { ProtocolState DecoderStateMachine::run(Buffer::Instance& buffer) { while (state_ != ProtocolState::Done) { - ProtocolState s = handleState(buffer); - if (s == ProtocolState::WaitForData) { - return s; + DecoderStatus s = handleState(buffer); + if (s.next_state_ == ProtocolState::WaitForData) { + return ProtocolState::WaitForData; } - state_ = s; + state_ = s.next_state_; + + ASSERT(s.filter_status_.has_value()); + if (s.filter_status_.value() == ThriftFilters::FilterStatus::StopIteration) { + return ProtocolState::StopIteration; + } } return state_; } -Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol) - : transport_(std::move(transport)), protocol_(std::move(protocol)), state_machine_{}, - frame_started_(false) {} +Decoder::Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks) + : transport_(std::move(transport)), protocol_(std::move(protocol)), callbacks_(callbacks) {} + +void Decoder::complete() { + request_.reset(); + state_machine_ = nullptr; + frame_started_ = false; + frame_ended_ = false; +} -void Decoder::onData(Buffer::Instance& data) { +ThriftFilters::FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { ENVOY_LOG(debug, "thrift: {} bytes available", data.length()); + buffer_underflow = false; - while (true) { - if (!frame_started_) { - // Look for start of next frame. - if (!transport_->decodeFrameStart(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); - return; - } - ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); - - frame_started_ = true; - state_machine_ = std::make_unique(*protocol_); - } + if (frame_ended_) { + // Continuation after filter stopped iteration on transportComplete callback. + complete(); + buffer_underflow = (data.length() == 0); + return ThriftFilters::FilterStatus::Continue; + } - ASSERT(state_machine_ != nullptr); + if (!frame_started_) { + // Look for start of next frame. + absl::optional size{}; + if (!transport_->decodeFrameStart(data, size)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport start", transport_->name()); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } + ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); - ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), - ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); + request_ = std::make_unique(callbacks_.newDecoderFilter()); + frame_started_ = true; + state_machine_ = std::make_unique(*protocol_, request_->filter_); - ProtocolState rv = state_machine_->run(data); - if (rv == ProtocolState::WaitForData) { - ENVOY_LOG(debug, "thrift: wait for data"); - return; + if (request_->filter_.transportBegin(size) == ThriftFilters::FilterStatus::StopIteration) { + return ThriftFilters::FilterStatus::StopIteration; } + } - ASSERT(rv == ProtocolState::Done); + ASSERT(state_machine_ != nullptr); - // Message complete, get decode end of frame. - if (!transport_->decodeFrameEnd(data)) { - ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); - return; - } - ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + ENVOY_LOG(debug, "thrift: protocol {}, state {}, {} bytes available", protocol_->name(), + ProtocolStateNameValues::name(state_machine_->currentState()), data.length()); - // Reset for next frame. - state_machine_ = nullptr; - frame_started_ = false; + ProtocolState rv = state_machine_->run(data); + if (rv == ProtocolState::WaitForData) { + ENVOY_LOG(debug, "thrift: wait for data"); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } else if (rv == ProtocolState::StopIteration) { + ENVOY_LOG(debug, "thrift: wait for continuation"); + return ThriftFilters::FilterStatus::StopIteration; } + + ASSERT(rv == ProtocolState::Done); + + // Message complete, decode end of frame. + if (!transport_->decodeFrameEnd(data)) { + ENVOY_LOG(debug, "thrift: need more data for {} transport end", transport_->name()); + buffer_underflow = true; + return ThriftFilters::FilterStatus::Continue; + } + + frame_ended_ = true; + + ENVOY_LOG(debug, "thrift: {} transport ended", transport_->name()); + if (request_->filter_.transportEnd() == ThriftFilters::FilterStatus::StopIteration) { + return ThriftFilters::FilterStatus::StopIteration; + } + + // Reset for next frame. + complete(); + buffer_underflow = (data.length() == 0); + return ThriftFilters::FilterStatus::Continue; } } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index a05d70cee30c8..26068858beb6d 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -6,6 +6,7 @@ #include "common/common/assert.h" #include "common/common/logger.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/protocol.h" #include "extensions/filters/network/thrift_proxy/transport.h" @@ -15,6 +16,7 @@ namespace NetworkFilters { namespace ThriftProxy { #define ALL_PROTOCOL_STATES(FUNCTION) \ + FUNCTION(StopIteration) \ FUNCTION(WaitForData) \ FUNCTION(MessageBegin) \ FUNCTION(MessageEnd) \ @@ -61,7 +63,8 @@ class ProtocolStateNameValues { */ class DecoderStateMachine { public: - DecoderStateMachine(Protocol& proto) : proto_(proto), state_(ProtocolState::MessageBegin) {} + DecoderStateMachine(Protocol& proto, ThriftFilters::DecoderFilter& filter) + : proto_(proto), filter_(filter), state_(ProtocolState::MessageBegin) {} /** * Consumes as much data from the configured Buffer as possible and executes the decoding state @@ -114,72 +117,107 @@ class DecoderStateMachine { uint32_t remaining_; }; + struct DecoderStatus { + DecoderStatus(ProtocolState next_state) : next_state_(next_state), filter_status_{} {}; + DecoderStatus(ProtocolState next_state, ThriftFilters::FilterStatus filter_status) + : next_state_(next_state), filter_status_(filter_status){}; + + ProtocolState next_state_; + absl::optional filter_status_; + }; + // These functions map directly to the matching ProtocolState values. Each returns the next state // or ProtocolState::WaitForData if more data is required. - ProtocolState messageBegin(Buffer::Instance& buffer); - ProtocolState messageEnd(Buffer::Instance& buffer); - ProtocolState structBegin(Buffer::Instance& buffer); - ProtocolState structEnd(Buffer::Instance& buffer); - ProtocolState fieldBegin(Buffer::Instance& buffer); - ProtocolState fieldValue(Buffer::Instance& buffer); - ProtocolState fieldEnd(Buffer::Instance& buffer); - ProtocolState listBegin(Buffer::Instance& buffer); - ProtocolState listValue(Buffer::Instance& buffer); - ProtocolState listEnd(Buffer::Instance& buffer); - ProtocolState mapBegin(Buffer::Instance& buffer); - ProtocolState mapKey(Buffer::Instance& buffer); - ProtocolState mapValue(Buffer::Instance& buffer); - ProtocolState mapEnd(Buffer::Instance& buffer); - ProtocolState setBegin(Buffer::Instance& buffer); - ProtocolState setValue(Buffer::Instance& buffer); - ProtocolState setEnd(Buffer::Instance& buffer); + DecoderStatus messageBegin(Buffer::Instance& buffer); + DecoderStatus messageEnd(Buffer::Instance& buffer); + DecoderStatus structBegin(Buffer::Instance& buffer); + DecoderStatus structEnd(Buffer::Instance& buffer); + DecoderStatus fieldBegin(Buffer::Instance& buffer); + DecoderStatus fieldValue(Buffer::Instance& buffer); + DecoderStatus fieldEnd(Buffer::Instance& buffer); + DecoderStatus listBegin(Buffer::Instance& buffer); + DecoderStatus listValue(Buffer::Instance& buffer); + DecoderStatus listEnd(Buffer::Instance& buffer); + DecoderStatus mapBegin(Buffer::Instance& buffer); + DecoderStatus mapKey(Buffer::Instance& buffer); + DecoderStatus mapValue(Buffer::Instance& buffer); + DecoderStatus mapEnd(Buffer::Instance& buffer); + DecoderStatus setBegin(Buffer::Instance& buffer); + DecoderStatus setValue(Buffer::Instance& buffer); + DecoderStatus setEnd(Buffer::Instance& buffer); // handleValue represents the generic Value state from the state machine documentation. It // returns either ProtocolState::WaitForData if more data is required or the next state. For // structs, lists, maps, or sets the return_state is pushed onto the stack and the next state is // based on elem_type. For primitive value types, return_state is returned as the next state // (unless WaitForData is returned). - ProtocolState handleValue(Buffer::Instance& buffer, FieldType elem_type, + DecoderStatus handleValue(Buffer::Instance& buffer, FieldType elem_type, ProtocolState return_state); // handleState delegates to the appropriate method based on state_. - ProtocolState handleState(Buffer::Instance& buffer); + DecoderStatus handleState(Buffer::Instance& buffer); // Helper method to retrieve the current frame's return state and remove the frame from the // stack. ProtocolState popReturnState(); Protocol& proto_; + ThriftFilters::DecoderFilter& filter_; ProtocolState state_; std::vector stack_; }; typedef std::unique_ptr DecoderStateMachinePtr; +class DecoderCallbacks { +public: + virtual ~DecoderCallbacks() {} + + /** + * @return DecoderFilter& a new DecoderFilter for a message. + */ + virtual ThriftFilters::DecoderFilter& newDecoderFilter() PURE; +}; + /** * Decoder encapsulates a configured TransportPtr and ProtocolPtr. */ class Decoder : public Logger::Loggable { public: - Decoder(TransportPtr&& transport, ProtocolPtr&& protocol); + Decoder(TransportPtr&& transport, ProtocolPtr&& protocol, DecoderCallbacks& callbacks); + Decoder(TransportType transport_type, ProtocolType protocol_type, DecoderCallbacks& callbacks); /** - * Drains data from the given buffer while executing a DecoderStateMachine over the data. A new - * DecoderStateMachine is instantiated for each message. + * Drains data from the given buffer while executing a DecoderStateMachine over the data. * * @param data a Buffer containing Thrift protocol data + * @param buffer_underflow bool set to true if more data is required to continue decoding + * @return ThriftFilters::FilterStatus::StopIteration when waiting for filter continuation, + * Continue otherwise. * @throw EnvoyException on Thrift protocol errors */ - void onData(Buffer::Instance& data); + ThriftFilters::FilterStatus onData(Buffer::Instance& data, bool& buffer_underflow); - const Transport& transport() { return *transport_; } - const Protocol& protocol() { return *protocol_; } + TransportType transportType() { return transport_->type(); } + ProtocolType protocolType() { return protocol_->type(); } private: + struct ActiveRequest { + ActiveRequest(ThriftFilters::DecoderFilter& filter) : filter_(filter) {} + + ThriftFilters::DecoderFilter& filter_; + }; + typedef std::unique_ptr ActiveRequestPtr; + + void complete(); + TransportPtr transport_; ProtocolPtr protocol_; + DecoderCallbacks& callbacks_; + ActiveRequestPtr request_; DecoderStateMachinePtr state_machine_; - bool frame_started_; + bool frame_started_{false}; + bool frame_ended_{false}; }; typedef std::unique_ptr DecoderPtr; diff --git a/source/extensions/filters/network/thrift_proxy/filter.cc b/source/extensions/filters/network/thrift_proxy/filter.cc deleted file mode 100644 index 3244c153d961c..0000000000000 --- a/source/extensions/filters/network/thrift_proxy/filter.cc +++ /dev/null @@ -1,310 +0,0 @@ -#include "extensions/filters/network/thrift_proxy/filter.h" - -#include "envoy/common/exception.h" - -#include "common/common/assert.h" - -#include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/protocol_impl.h" -#include "extensions/filters/network/thrift_proxy/transport_impl.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -Filter::Filter(const std::string& stat_prefix, Stats::Scope& scope) - : req_callbacks_(*this), resp_callbacks_(*this), stats_(generateStats(stat_prefix, scope)) {} - -Filter::~Filter() {} - -void Filter::onEvent(Network::ConnectionEvent event) { - if (active_call_map_.empty() && req_ == nullptr && resp_ == nullptr) { - return; - } - - if (event == Network::ConnectionEvent::RemoteClose) { - stats_.cx_destroy_local_with_active_rq_.inc(); - } - - if (event == Network::ConnectionEvent::LocalClose) { - stats_.cx_destroy_remote_with_active_rq_.inc(); - } -} - -Network::FilterStatus Filter::onData(Buffer::Instance& data, bool) { - if (!sniffing_) { - if (req_buffer_.length() > 0) { - // Stopped sniffing during response (in onWrite). Make sure leftover req_buffer_ contents are - // at the start of data or the upstream will see a corrupted request. - req_buffer_.move(data); - data.move(req_buffer_); - ASSERT(req_buffer_.length() == 0); - } - - return Network::FilterStatus::Continue; - } - - if (req_decoder_ == nullptr) { - req_decoder_ = std::make_unique(std::make_unique(req_callbacks_), - std::make_unique(req_callbacks_)); - } - - ENVOY_LOG(trace, "thrift: read {} bytes", data.length()); - req_buffer_.move(data); - - try { - BufferWrapper wrapped(req_buffer_); - - req_decoder_->onData(wrapped); - - // Move consumed portion of request back to data for the upstream to consume. - uint64_t pos = wrapped.position(); - if (pos > 0) { - data.move(req_buffer_, pos); - } - } catch (const EnvoyException& ex) { - ENVOY_LOG(error, "thrift error: {}", ex.what()); - req_decoder_.reset(); - data.move(req_buffer_); - stats_.request_decoding_error_.inc(); - sniffing_ = false; - } - - return Network::FilterStatus::Continue; -} - -Network::FilterStatus Filter::onWrite(Buffer::Instance& data, bool) { - if (!sniffing_) { - if (resp_buffer_.length() > 0) { - // Stopped sniffing during request (in onData). Make sure resp_buffer_ contents are at the - // start of data or the downstream will see a corrupted response. - resp_buffer_.move(data); - data.move(resp_buffer_); - ASSERT(resp_buffer_.length() == 0); - } - - return Network::FilterStatus::Continue; - } - - if (resp_decoder_ == nullptr) { - resp_decoder_ = std::make_unique(std::make_unique(resp_callbacks_), - std::make_unique(resp_callbacks_)); - } - - ENVOY_LOG(trace, "thrift wrote {} bytes", data.length()); - resp_buffer_.move(data); - - try { - BufferWrapper wrapped(resp_buffer_); - - resp_decoder_->onData(wrapped); - - // Move consumed portion of response back to data for the downstream to consume. - uint64_t pos = wrapped.position(); - if (pos > 0) { - data.move(resp_buffer_, pos); - } - } catch (const EnvoyException& ex) { - ENVOY_LOG(error, "thrift error: {}", ex.what()); - resp_decoder_.reset(); - data.move(resp_buffer_); - - stats_.response_decoding_error_.inc(); - sniffing_ = false; - } - - return Network::FilterStatus::Continue; -} - -void Filter::chargeDownstreamRequestStart(MessageType msg_type, int32_t seq_id) { - if (req_ != nullptr) { - throw EnvoyException("unexpected request messageStart callback"); - } - - if (active_call_map_.size() >= 64) { - throw EnvoyException("too many pending calls (64), disabling sniffing"); - } - - req_ = std::make_unique(*this, msg_type, seq_id); - - stats_.request_.inc(); - switch (msg_type) { - case MessageType::Call: - stats_.request_call_.inc(); - break; - case MessageType::Oneway: - stats_.request_oneway_.inc(); - break; - default: - stats_.request_invalid_type_.inc(); - break; - } -} - -void Filter::chargeDownstreamRequestComplete() { - if (req_ == nullptr) { - throw EnvoyException("unexpected request messageComplete callback"); - } - - // One-way messages do not receive responses. - if (req_->msg_type_ == MessageType::Oneway) { - req_.reset(); - return; - } - - int32_t seq_id = req_->seq_id_; - active_call_map_.emplace(seq_id, std::move(req_)); -} - -void Filter::chargeUpstreamResponseStart(MessageType msg_type, int32_t seq_id) { - if (resp_ != nullptr) { - throw EnvoyException("unexpected response messageStart callback"); - } - - auto i = active_call_map_.find(seq_id); - if (i == active_call_map_.end()) { - throw EnvoyException(fmt::format("unknown reply seq_id {}", seq_id)); - } - - resp_ = std::move(i->second); - resp_->response_msg_type_ = msg_type; - active_call_map_.erase(i); -} - -void Filter::chargeUpstreamResponseField(FieldType field_type, int16_t field_id) { - if (resp_ == nullptr) { - throw EnvoyException("unexpected response messageField callback"); - } - - if (resp_->response_msg_type_ != MessageType::Reply) { - // If this is not a reply, we'll count an exception instead of an error, so leave - // resp_->success_ unset. - return; - } - - if (resp_->success_.has_value()) { - // If resp->success_ is already set, leave the existing value. - return; - } - - // Successful replies have a single field, with field_id 0 that contains the response value. - // IDL-level exceptions are encoded as a single field with field_id >= 1. - resp_->success_ = field_id == 0 && field_type != FieldType::Stop; -} - -void Filter::chargeUpstreamResponseComplete() { - if (resp_ == nullptr) { - throw EnvoyException("unexpected response messageComplete callback"); - } - - stats_.response_.inc(); - switch (resp_->response_msg_type_) { - case MessageType::Reply: - stats_.response_reply_.inc(); - break; - case MessageType::Exception: - stats_.response_exception_.inc(); - break; - default: - stats_.response_invalid_type_.inc(); - break; - } - - if (resp_->success_.has_value()) { - if (resp_->success_.value()) { - stats_.response_success_.inc(); - } else { - stats_.response_error_.inc(); - } - } - - resp_.reset(); -} - -void Filter::RequestCallbacks::transportFrameStart(absl::optional size) { - UNREFERENCED_PARAMETER(size); - ENVOY_LOG(debug, "thrift request: started {} frame", parent_.req_decoder_->transport().name()); -} - -void Filter::RequestCallbacks::transportFrameComplete() { - ENVOY_LOG(debug, "thrift request: ended {} frame", parent_.req_decoder_->transport().name()); -} - -void Filter::RequestCallbacks::messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) { - ENVOY_LOG(debug, "thrift request: started {} message {}: {}", - parent_.req_decoder_->protocol().name(), name, seq_id); - parent_.chargeDownstreamRequestStart(msg_type, seq_id); -} - -void Filter::RequestCallbacks::structBegin(const absl::string_view name) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift request: started {} struct", parent_.req_decoder_->protocol().name()); -} - -void Filter::RequestCallbacks::structField(const absl::string_view name, FieldType field_type, - int16_t field_id) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift request: started {} field {}, type {}", - parent_.req_decoder_->protocol().name(), field_id, static_cast(field_type)); -} - -void Filter::RequestCallbacks::structEnd() { - ENVOY_LOG(debug, "thrift request: ended {} struct", parent_.req_decoder_->protocol().name()); -} - -void Filter::RequestCallbacks::messageComplete() { - ENVOY_LOG(debug, "thrift request: ended {} message", parent_.req_decoder_->protocol().name()); - parent_.chargeDownstreamRequestComplete(); -} - -void Filter::ResponseCallbacks::transportFrameStart(absl::optional size) { - UNREFERENCED_PARAMETER(size); - ENVOY_LOG(debug, "thrift response: started {} frame", parent_.resp_decoder_->transport().name()); -} - -void Filter::ResponseCallbacks::transportFrameComplete() { - ENVOY_LOG(debug, "thrift response: ended {} frame", parent_.resp_decoder_->transport().name()); -} - -void Filter::ResponseCallbacks::messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) { - ENVOY_LOG(debug, "thrift response: started {} message {}: {}", - parent_.resp_decoder_->protocol().name(), name, seq_id); - parent_.chargeUpstreamResponseStart(msg_type, seq_id); -} - -void Filter::ResponseCallbacks::structBegin(const absl::string_view name) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift response: started {} struct", parent_.req_decoder_->protocol().name()); - depth_++; -} - -void Filter::ResponseCallbacks::structField(const absl::string_view name, FieldType field_type, - int16_t field_id) { - UNREFERENCED_PARAMETER(name); - ENVOY_LOG(debug, "thrift response: started {} field {}, type {}", - parent_.req_decoder_->protocol().name(), field_id, static_cast(field_type)); - - if (depth_ == 1) { - // Only care about the outermost struct, which corresponds to the success or failure of the - // request. - parent_.chargeUpstreamResponseField(field_type, field_id); - } -} - -void Filter::ResponseCallbacks::structEnd() { - ENVOY_LOG(debug, "thrift request: ended {} struct", parent_.req_decoder_->protocol().name()); - depth_--; -} - -void Filter::ResponseCallbacks::messageComplete() { - ENVOY_LOG(debug, "thrift response: ended {} message", parent_.resp_decoder_->protocol().name()); - parent_.chargeUpstreamResponseComplete(); -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filter.h b/source/extensions/filters/network/thrift_proxy/filter.h deleted file mode 100644 index a0f028ed0221d..0000000000000 --- a/source/extensions/filters/network/thrift_proxy/filter.h +++ /dev/null @@ -1,173 +0,0 @@ -#pragma once - -#include - -#include "envoy/network/connection.h" -#include "envoy/network/filter.h" -#include "envoy/stats/stats.h" -#include "envoy/stats/stats_macros.h" -#include "envoy/stats/timespan.h" - -#include "common/buffer/buffer_impl.h" -#include "common/common/logger.h" - -#include "extensions/filters/network/thrift_proxy/decoder.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -/** - * All thrift filter stats. @see stats_macros.h - */ -// clang-format off -#define ALL_THRIFT_FILTER_STATS(COUNTER, GAUGE, HISTOGRAM) \ - COUNTER(request) \ - COUNTER(request_call) \ - COUNTER(request_oneway) \ - COUNTER(request_invalid_type) \ - GAUGE(request_active) \ - COUNTER(request_decoding_error) \ - HISTOGRAM(request_time_ms) \ - COUNTER(response) \ - COUNTER(response_reply) \ - COUNTER(response_success) \ - COUNTER(response_error) \ - COUNTER(response_exception) \ - COUNTER(response_invalid_type) \ - COUNTER(response_decoding_error) \ - COUNTER(cx_destroy_local_with_active_rq) \ - COUNTER(cx_destroy_remote_with_active_rq) -// clang-format on - -/** - * Struct definition for all mongo proxy stats. @see stats_macros.h - */ -struct ThriftFilterStats { - ALL_THRIFT_FILTER_STATS(GENERATE_COUNTER_STRUCT, GENERATE_GAUGE_STRUCT, GENERATE_HISTOGRAM_STRUCT) -}; - -/** - * A sniffing filter for thrift traffic. - */ -class Filter : public Network::Filter, - public Network::ConnectionCallbacks, - Logger::Loggable { -public: - Filter(const std::string& stat_prefix, Stats::Scope& scope); - ~Filter(); - - // Network::ReadFilter - Network::FilterStatus onData(Buffer::Instance& data, bool end_stream) override; - Network::FilterStatus onNewConnection() override { return Network::FilterStatus::Continue; } - void initializeReadFilterCallbacks(Network::ReadFilterCallbacks&) override {} - - // Network::WriteFilter - Network::FilterStatus onWrite(Buffer::Instance& data, bool end_stream) override; - - // Network::ConnectionCallbacks - void onEvent(Network::ConnectionEvent) override; - void onAboveWriteBufferHighWatermark() override {} - void onBelowWriteBufferLowWatermark() override {} - -private: - // RequestCallbacks handles callbacks related to decoding downstream requests. - class RequestCallbacks : public virtual ProtocolCallbacks, public virtual TransportCallbacks { - public: - RequestCallbacks(Filter& parent) : parent_(parent) {} - - // TransportCallbacks - void transportFrameStart(absl::optional size) override; - void transportFrameComplete() override; - - // ProtocolCallbacks - void messageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) override; - void structBegin(const absl::string_view name) override; - void structField(const absl::string_view name, FieldType field_type, int16_t field_id) override; - void structEnd() override; - void messageComplete() override; - - private: - Filter& parent_; - }; - - // ResponseCallbacks handles callbacks related to decoding upstream responses. - class ResponseCallbacks : public virtual ProtocolCallbacks, public virtual TransportCallbacks { - public: - ResponseCallbacks(Filter& parent) : parent_(parent) {} - - // TransportCallbacks - void transportFrameStart(absl::optional size) override; - void transportFrameComplete() override; - - // ProtocolCallbacks - void messageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) override; - void structBegin(const absl::string_view name) override; - void structField(const absl::string_view name, FieldType field_type, int16_t field_id) override; - void structEnd() override; - void messageComplete() override; - - private: - Filter& parent_; - int depth_{0}; - }; - - // ActiveMessage tracks downstream requests for which no response has been received. - struct ActiveMessage { - ActiveMessage(Filter& parent, MessageType msg_type, int32_t seq_id) - : parent_(parent), request_timer_(new Stats::Timespan(parent_.stats_.request_time_ms_)), - msg_type_(msg_type), seq_id_(seq_id) { - parent_.stats_.request_active_.inc(); - } - ~ActiveMessage() { - request_timer_->complete(); - parent_.stats_.request_active_.dec(); - } - - Filter& parent_; - Stats::TimespanPtr request_timer_; - const MessageType msg_type_; - const int32_t seq_id_; - MessageType response_msg_type_{}; - absl::optional success_{}; - }; - typedef std::unique_ptr ActiveMessagePtr; - - ThriftFilterStats generateStats(const std::string& prefix, Stats::Scope& scope) { - return ThriftFilterStats{ALL_THRIFT_FILTER_STATS(POOL_COUNTER_PREFIX(scope, prefix), - POOL_GAUGE_PREFIX(scope, prefix), - POOL_HISTOGRAM_PREFIX(scope, prefix))}; - } - - void chargeDownstreamRequestStart(MessageType msg_type, int32_t seq_id); - void chargeDownstreamRequestComplete(); - void chargeUpstreamResponseStart(MessageType msg_type, int32_t seq_id); - void chargeUpstreamResponseField(FieldType field_type, int16_t field_id); - void chargeUpstreamResponseComplete(); - - // Downstream request decoder, callbacks, and buffer. - DecoderPtr req_decoder_{}; - RequestCallbacks req_callbacks_; - Buffer::OwnedImpl req_buffer_; - // Request currently being decoded. - ActiveMessagePtr req_; - - // Upstream response decoder, callbacks, and buffer. - DecoderPtr resp_decoder_{}; - ResponseCallbacks resp_callbacks_; - Buffer::OwnedImpl resp_buffer_; - // Response currently being decoded. - ActiveMessagePtr resp_; - - // List of active request messages. - std::unordered_map active_call_map_; - - bool sniffing_{true}; - ThriftFilterStats stats_; -}; - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/BUILD b/source/extensions/filters/network/thrift_proxy/filters/BUILD new file mode 100644 index 0000000000000..7f374b1fe5fd1 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/BUILD @@ -0,0 +1,50 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "filter_config_interface", + hdrs = ["filter_config.h"], + deps = [ + ":filter_interface", + "//include/envoy/server:filter_config_interface", + "//source/common/common:macros", + "//source/common/protobuf:cc_wkt_protos", + ], +) + +envoy_cc_library( + name = "factory_base_lib", + hdrs = ["factory_base.h"], + deps = [ + ":filter_config_interface", + "//source/common/protobuf:utility_lib", + ], +) + +envoy_cc_library( + name = "filter_interface", + hdrs = ["filter.h"], + external_deps = ["abseil_optional"], + deps = [ + "//include/envoy/buffer:buffer_interface", + "//include/envoy/network:connection_interface", + "//source/extensions/filters/network/thrift_proxy:protocol_interface", + "//source/extensions/filters/network/thrift_proxy:transport_interface", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + ], +) + +envoy_cc_library( + name = "well_known_names", + hdrs = ["well_known_names.h"], + deps = [ + "//source/common/singleton:const_singleton", + ], +) diff --git a/source/extensions/filters/network/thrift_proxy/filters/factory_base.h b/source/extensions/filters/network/thrift_proxy/filters/factory_base.h new file mode 100644 index 0000000000000..bf2bb292f043b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/factory_base.h @@ -0,0 +1,45 @@ +#pragma once + +#include "common/protobuf/utility.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter_config.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +template class FactoryBase : public NamedThriftFilterConfigFactory { +public: + FilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message& proto_config, + const std::string& stats_prefix, + Server::Configuration::FactoryContext& context) override { + return createFilterFactoryFromProtoTyped( + MessageUtil::downcastAndValidate(proto_config), stats_prefix, context); + } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return std::make_unique(); + } + + std::string name() override { return name_; } + +protected: + FactoryBase(const std::string& name) : name_(name) {} + +private: + virtual FilterFactoryCb + createFilterFactoryFromProtoTyped(const ConfigProto& proto_config, + const std::string& stats_prefix, + Server::Configuration::FactoryContext& context) PURE; + + const std::string name_; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h new file mode 100644 index 0000000000000..969ffcadfc46c --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -0,0 +1,306 @@ +#pragma once + +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/network/connection.h" + +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" +#include "extensions/filters/network/thrift_proxy/transport.h" + +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +class DirectResponse { +public: + virtual ~DirectResponse() {} + + /** + * Encodes the response via the given Protocol. + * @param proto the Protocol to be used for message encoding + * @param buffer the Buffer into which the message should be encoded + */ + virtual void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) PURE; +}; + +typedef std::unique_ptr DirectResponsePtr; + +/** + * Decoder filter callbacks add additional callbacks. + */ +class DecoderFilterCallbacks { +public: + virtual ~DecoderFilterCallbacks() {} + + /** + * @return uint64_t the ID of the originating stream for logging purposes. + */ + virtual uint64_t streamId() const PURE; + + /** + * @return const Network::Connection* the originating connection, or nullptr if there is none. + */ + virtual const Network::Connection* connection() const PURE; + + /** + * Continue iterating through the filter chain with buffered data. This routine can only be + * called if the filter has previously returned StopIteration from one of the DecoderFilter + * methods. The connection manager will callbacks to the next filter in the chain. Further note + * that if the request is not complete, the calling filter may receive further callbacks and must + * return an appropriate status code depending on what the filter needs to do. + */ + virtual void continueDecoding() PURE; + + /** + * @return RouteConstSharedPtr the route for the current request. + */ + virtual Router::RouteConstSharedPtr route() PURE; + + /** + * @return TransportType the originating transport. + */ + virtual TransportType downstreamTransportType() const PURE; + + /** + * @return ProtocolType the originating protocol. + */ + virtual ProtocolType downstreamProtocolType() const PURE; + + /** + * Create a locally generated response using the provided response object. + * @param response DirectResponsePtr the response to send to the downstream client + */ + virtual void sendLocalReply(DirectResponsePtr&& response) PURE; + + /** + * Indicates the start of an upstream response. May only be called once. + * @param transport_type TransportType the upstream is using + * @param protocol_type ProtocolType the upstream is using + */ + virtual void startUpstreamResponse(TransportType transport_type, ProtocolType protocol_type) PURE; + + /** + * Called with upstream response data. + * @param data supplies the upstream's data + * @return true if the upstream response is complete; false if more data is expected + */ + virtual bool upstreamData(Buffer::Instance& data) PURE; + + /** + * Reset the downstream connection. + */ + virtual void resetDownstreamConnection() PURE; +}; + +enum class FilterStatus { + // Continue filter chain iteration. + Continue, + + // Stop iterating over filters in the filter chain. Iteration must be explicitly restarted via + // continueDecoding(). + StopIteration +}; + +/** + * Decoder filter interface. + */ +class DecoderFilter { +public: + virtual ~DecoderFilter() {} + + /** + * This routine is called prior to a filter being destroyed. This may happen after normal stream + * finish (both downstream and upstream) or due to reset. Every filter is responsible for making + * sure that any async events are cleaned up in the context of this routine. This includes timers, + * network calls, etc. The reason there is an onDestroy() method vs. doing this type of cleanup + * in the destructor is due to the deferred deletion model that Envoy uses to avoid stack unwind + * complications. Filters must not invoke either encoder or decoder filter callbacks after having + * onDestroy() invoked. + */ + virtual void onDestroy() PURE; + + /** + * Called by the connection manager once to initialize the filter decoder callbacks that the + * filter should use. Callbacks will not be invoked by the filter after onDestroy() is called. + */ + virtual void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) PURE; + + /** + * Resets the upstream connection. + */ + virtual void resetUpstreamConnection() PURE; + + /** + * Indicates the start of a Thrift transport frame was detected. Unframed transports generate + * simulated start messages. + * @param size the size of the message, if available to the transport + */ + virtual FilterStatus transportBegin(absl::optional size) PURE; + + /** + * Indicates the end of a Thrift transport frame was detected. Unframed transport generate + * simulated complete messages. + */ + virtual FilterStatus transportEnd() PURE; + + /** + * Indicates that the start of a Thrift protocol message was detected. + * @param name the name of the message, if available + * @param msg_type the type of the message + * @param seq_id the message sequence id + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) PURE; + + /** + * Indicates that the end of a Thrift protocol message was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus messageEnd() PURE; + + /** + * Indicates that the start of a Thrift protocol struct was detected. + * @param name the name of the struct, if available + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus structBegin(absl::string_view name) PURE; + + /** + * Indicates that the end of a Thrift protocol struct was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus structEnd() PURE; + + /** + * Indicates that the start of Thrift protocol struct field was detected. + * @param name the name of the field, if available + * @param field_type the type of the field + * @param field_id the field id + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) PURE; + + /** + * Indicates that the end of a Thrift protocol struct field was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus fieldEnd() PURE; + + /** + * A struct field, map key, map value, list element or set element was detected. + * @param value type value of the field + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus boolValue(bool value) PURE; + virtual FilterStatus byteValue(uint8_t value) PURE; + virtual FilterStatus int16Value(int16_t value) PURE; + virtual FilterStatus int32Value(int32_t value) PURE; + virtual FilterStatus int64Value(int64_t value) PURE; + virtual FilterStatus doubleValue(double value) PURE; + virtual FilterStatus stringValue(absl::string_view value) PURE; + + /** + * Indicates the start of a Thrift protocol map was detected. + * @param key_type the map key type + * @param value_type the map value type + * @param size the number of key-value pairs + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus mapBegin(FieldType key_type, FieldType value_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol map was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus mapEnd() PURE; + + /** + * Indicates the start of a Thrift protocol list was detected. + * @param elem_type the list value type + * @param size the number of values in the list + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus listBegin(FieldType elem_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol list was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus listEnd() PURE; + + /** + * Indicates the start of a Thrift protocol set was detected. + * @param elem_type the set value type + * @param size the number of values in the set + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus setBegin(FieldType elem_type, uint32_t size) PURE; + + /** + * Indicates that the end of a Thrift protocol set was detected. + * @return FilterStatus to indicate if filter chain iteration should continue + */ + virtual FilterStatus setEnd() PURE; +}; + +typedef std::shared_ptr DecoderFilterSharedPtr; + +/** + * These callbacks are provided by the connection manager to the factory so that the factory can + * build the filter chain in an application specific way. + */ +class FilterChainFactoryCallbacks { +public: + virtual ~FilterChainFactoryCallbacks() {} + + /** + * Add a decoder filter that is used when reading connection data. + * @param filter supplies the filter to add. + */ + virtual void addDecoderFilter(DecoderFilterSharedPtr filter) PURE; +}; + +/** + * This function is used to wrap the creation of a Thrift filter chain for new connections as they + * come in. Filter factories create the function at configuration initialization time, and then + * they are used at runtime. + * @param callbacks supplies the callbacks for the stream to install filters to. Typically the + * function will install a single filter, but it's technically possibly to install more than one + * if desired. + */ +typedef std::function FilterFactoryCb; + +/** + * A FilterChainFactory is used by a connection manager to create a Thrift level filter chain when + * a new connection is created. Typically it would be implemented by a configuration engine that + * would install a set of filters that are able to process an application scenario on top of a + * stream of Thrift requests. + */ +class FilterChainFactory { +public: + virtual ~FilterChainFactory() {} + + /** + * Called when a new Thrift stream is created on the connection. + * @param callbacks supplies the "sink" that is used for actually creating the filter chain. @see + * FilterChainFactoryCallbacks. + */ + virtual void createFilterChain(FilterChainFactoryCallbacks& callbacks) PURE; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter_config.h b/source/extensions/filters/network/thrift_proxy/filters/filter_config.h new file mode 100644 index 0000000000000..86f4b7730517b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/filter_config.h @@ -0,0 +1,55 @@ +#pragma once + +#include "envoy/server/filter_config.h" + +#include "common/common/macros.h" +#include "common/protobuf/protobuf.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +/** + * Implemented by each Thrift filter and registered via Registry::registerFactory or the + * convenience class RegisterFactory. + */ +class NamedThriftFilterConfigFactory { +public: + virtual ~NamedThriftFilterConfigFactory() {} + + /** + * Create a particular thrift filter factory implementation. If the implementation is unable to + * produce a factory with the provided parameters, it should throw an EnvoyException in the case + * of general error. The returned callback should always be initialized. + * @param config supplies the configuration for the filter + * @param stat_prefix prefix for stat logging + * @param context supplies the filter's context. + * @return FilterFactoryCb the factory creation function. + */ + virtual FilterFactoryCb + createFilterFactoryFromProto(const Protobuf::Message& config, const std::string& stat_prefix, + Server::Configuration::FactoryContext& context) PURE; + + /** + * @return ProtobufTypes::MessagePtr create empty config proto message for v2. The filter + * config, which arrives in an opaque google.protobuf.Struct message, will be converted to + * JSON and then parsed into this empty proto. + */ + virtual ProtobufTypes::MessagePtr createEmptyConfigProto() PURE; + + /** + * @return std::string the identifying name for a particular implementation of a thrift filter + * produced by the factory. + */ + virtual std::string name() PURE; +}; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h b/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h new file mode 100644 index 0000000000000..41abac8d54753 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filters/well_known_names.h @@ -0,0 +1,25 @@ +#pragma once + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +/** + * Well-known http filter names. + * NOTE: New filters should use the well known name: envoy.filters.thrift.name. + */ +class ThriftFilterNameValues { +public: + // Router filter + const std::string ROUTER = "envoy.filters.thrift.router"; +}; + +typedef ConstSingleton ThriftFilterNames; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc index b66863e9bfbfb..a45861a349e4b 100644 --- a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.cc @@ -10,28 +10,26 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { +bool FramedTransportImpl::decodeFrameStart(Buffer::Instance& buffer, + absl::optional& size) { if (buffer.length() < 4) { return false; } - int32_t size = BufferHelper::peekI32(buffer); + int32_t thrift_size = BufferHelper::peekI32(buffer); - if (size <= 0 || size > MaxFrameSize) { - throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", size)); + if (thrift_size <= 0 || thrift_size > MaxFrameSize) { + throw EnvoyException(fmt::format("invalid thrift framed transport frame size {}", thrift_size)); } - onFrameStart(absl::optional(static_cast(size))); - buffer.drain(4); - return true; -} -bool FramedTransportImpl::decodeFrameEnd(Buffer::Instance&) { - onFrameComplete(); + size = static_cast(thrift_size); return true; } +bool FramedTransportImpl::decodeFrameEnd(Buffer::Instance&) { return true; } + void FramedTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) { uint64_t size = message.length(); if (size == 0 || size > MaxFrameSize) { @@ -44,6 +42,17 @@ void FramedTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance buffer.move(message); } +class FramedTransportConfigFactory : public TransportFactoryBase { +public: + FramedTransportConfigFactory() : TransportFactoryBase(TransportNames::get().FRAMED) {} +}; + +/** + * Static registration for the framed transport. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h index b51b5b3819cd2..4c7569487ea38 100644 --- a/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/framed_transport_impl.h @@ -17,13 +17,14 @@ namespace ThriftProxy { * FramedTransportImpl implements the Thrift Framed transport. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md */ -class FramedTransportImpl : public TransportImplBase { +class FramedTransportImpl : public Transport { public: - FramedTransportImpl(TransportCallbacks& callbacks) : TransportImplBase(callbacks) {} + FramedTransportImpl() {} // Transport const std::string& name() const override { return TransportNames::get().FRAMED; } - bool decodeFrameStart(Buffer::Instance& buffer) override; + TransportType type() const override { return TransportType::Framed; } + bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; bool decodeFrameEnd(Buffer::Instance& buffer) override; void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 472f2d1852d51..02f2808427e08 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -5,7 +5,10 @@ #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" +#include "envoy/registry/registry.h" +#include "common/common/assert.h" +#include "common/config/utility.h" #include "common/singleton/const_singleton.h" #include "absl/strings/string_view.h" @@ -15,6 +18,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +enum class ProtocolType { + Binary, + LaxBinary, + Compact, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST PROTOCOL TYPE + LastProtocolType = Auto, +}; + /** * Names of available Protocol implementations. */ @@ -29,11 +42,23 @@ class ProtocolNameValues { // Compact protocol const std::string COMPACT = "compact"; - // JSON protocol - const std::string JSON = "json"; - // Auto-detection protocol const std::string AUTO = "auto"; + + const std::string& fromType(ProtocolType type) const { + switch (type) { + case ProtocolType::Binary: + return BINARY; + case ProtocolType::LaxBinary: + return LAX_BINARY; + case ProtocolType::Compact: + return COMPACT; + case ProtocolType::Auto: + return AUTO; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } }; typedef ConstSingleton ProtocolNames; @@ -75,48 +100,6 @@ enum class FieldType { LastFieldType = List, }; -/** - * ProtocolCallbacks are Thrift protocol-level callbacks. - */ -class ProtocolCallbacks { -public: - virtual ~ProtocolCallbacks() {} - - /** - * Indicates that the start of a Thrift protocol message was detected. - * @param name the name of the message, if available - * @param msg_type the type of the message - * @param seq_id the message sequence id - */ - virtual void messageStart(const absl::string_view name, MessageType msg_type, - int32_t seq_id) PURE; - - /** - * Indicates that the start of a Thrift protocol struct was detected. - * @param name the name of the struct, if available - */ - virtual void structBegin(const absl::string_view name) PURE; - - /** - * Indicates that the start of Thrift protocol struct field was detected. - * @param name the name of the field, if available - * @param field_type the type of the field - * @param field_id the field id - */ - virtual void structField(const absl::string_view name, FieldType field_type, - int16_t field_id) PURE; - - /** - * Indicates that the end of a Thrift protocol struct was detected. - */ - virtual void structEnd() PURE; - - /** - * Indicates that the end of a Thrift protocol message was detected. - */ - virtual void messageComplete() PURE; -}; - /** * Protocol represents the operations necessary to implement the a generic Thrift protocol. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-protocol-spec.md @@ -125,8 +108,16 @@ class Protocol { public: virtual ~Protocol() {} + /** + * @return const std::string& the human-readable name of the protocol + */ virtual const std::string& name() const PURE; + /** + * @return ProtocolType the protocol type + */ + virtual ProtocolType type() const PURE; + /** * Reads the start of a Thrift protocol message from the buffer and updates the name, msg_type, * and seq_id parameters with values from the message header. If successful, the message header @@ -485,6 +476,52 @@ class Protocol { typedef std::unique_ptr ProtocolPtr; +/** + * Implemented by each Thrift protocol and registered via Registry::registerFactory or the + * convenience class RegisterFactory. + */ +class NamedProtocolConfigFactory { +public: + virtual ~NamedProtocolConfigFactory() {} + + /** + * Create a particular Thrift protocol + * @return ProtocolFactoryCb the protocol + */ + virtual ProtocolPtr createProtocol() PURE; + + /** + * @return std::string the identifying name for a particular implementation of thrift protocol + * produced by the factory. + */ + virtual std::string name() PURE; + + /** + * Convenience method to lookup a factory by type. + * @param ProtocolType the protocol type + * @return NamedProtocolConfigFactory& for the ProtocolType + */ + static NamedProtocolConfigFactory& getFactory(ProtocolType type) { + const std::string& name = ProtocolNames::get().fromType(type); + return Envoy::Config::Utility::getAndCheckFactory(name); + } +}; + +/** + * ProtocolFactoryBase provides a template for a trivial NamedProtocolConfigFactory. + */ +template class ProtocolFactoryBase : public NamedProtocolConfigFactory { + ProtocolPtr createProtocol() override { return std::move(std::make_unique()); } + + std::string name() override { return name_; } + +protected: + ProtocolFactoryBase(const std::string& name) : name_(name) {} + +private: + const std::string name_; +}; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h new file mode 100644 index 0000000000000..af7b6dfa2af3d --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -0,0 +1,144 @@ +#pragma once + +#include "envoy/buffer/buffer.h" + +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * ProtocolConverter is an abstract class that implements protocol-related methods on + * ThriftFilters::DecoderFilter in terms of converting the decoded messages into a different + * protocol. + */ +class ProtocolConverter : public ThriftFilters::DecoderFilter { +public: + ProtocolConverter() {} + ~ProtocolConverter() {} + + void initProtocolConverter(ProtocolPtr&& proto, Buffer::Instance& buffer) { + proto_ = std::move(proto); + buffer_ = &buffer; + } + + // ThiftFilters::DecoderFilter + void onDestroy() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks&) override { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } + void resetUpstreamConnection() override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override { + proto_->writeMessageBegin(*buffer_, std::string(name), msg_type, seq_id); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus messageEnd() override { + proto_->writeMessageEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus structBegin(absl::string_view name) override { + proto_->writeStructBegin(*buffer_, std::string(name)); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus structEnd() override { + proto_->writeFieldBegin(*buffer_, "", FieldType::Stop, 0); + proto_->writeStructEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus fieldBegin(absl::string_view name, FieldType field_type, + int16_t field_id) override { + proto_->writeFieldBegin(*buffer_, std::string(name), field_type, field_id); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus fieldEnd() override { + proto_->writeFieldEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus boolValue(bool value) override { + proto_->writeBool(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus byteValue(uint8_t value) override { + proto_->writeByte(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int16Value(int16_t value) override { + proto_->writeInt16(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int32Value(int32_t value) override { + proto_->writeInt32(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus int64Value(int64_t value) override { + proto_->writeInt64(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus doubleValue(double value) override { + proto_->writeDouble(*buffer_, value); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus stringValue(absl::string_view value) override { + proto_->writeString(*buffer_, std::string(value)); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus mapBegin(FieldType key_type, FieldType value_type, + uint32_t size) override { + proto_->writeMapBegin(*buffer_, key_type, value_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus mapEnd() override { + proto_->writeMapEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus listBegin(FieldType elem_type, uint32_t size) override { + proto_->writeListBegin(*buffer_, elem_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus listEnd() override { + proto_->writeListEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus setBegin(FieldType elem_type, uint32_t size) override { + proto_->writeSetBegin(*buffer_, elem_type, size); + return ThriftFilters::FilterStatus::Continue; + } + + ThriftFilters::FilterStatus setEnd() override { + proto_->writeSetEnd(*buffer_); + return ThriftFilters::FilterStatus::Continue; + } + +protected: + ProtocolType protocolType() const { return proto_->type(); } + +private: + ProtocolPtr proto_; + Buffer::Instance* buffer_{}; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc index f2cac638d2925..46636099f5430 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc @@ -26,9 +26,9 @@ bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& n uint16_t version = BufferHelper::peekU16(buffer); if (BinaryProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique(callbacks_)); + setProtocol(std::make_unique()); } else if (CompactProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique(callbacks_)); + setProtocol(std::make_unique()); } else { throw EnvoyException( fmt::format("unknown thrift auto protocol message start {:04x}", version)); @@ -45,6 +45,16 @@ bool AutoProtocolImpl::readMessageEnd(Buffer::Instance& buffer) { return protocol_->readMessageEnd(buffer); } +class AutoProtocolConfigFactory : public ProtocolFactoryBase { +public: + AutoProtocolConfigFactory() : ProtocolFactoryBase(ProtocolNames::get().AUTO) {} +}; + +/** + * Static registration for the auto protocol. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.h b/source/extensions/filters/network/thrift_proxy/protocol_impl.h index def5061fa8e9e..0eb374e9f9869 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.h @@ -13,39 +13,24 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/* - * ProtocolImplBase provides a base class for Protocol implementations. - */ -class ProtocolImplBase : public virtual Protocol { -public: - ProtocolImplBase(ProtocolCallbacks& callbacks) : callbacks_(callbacks) {} - -protected: - void onMessageStart(const absl::string_view name, MessageType msg_type, int32_t seq_id) const { - callbacks_.messageStart(name, msg_type, seq_id); - } - void onStructBegin(const absl::string_view name) const { callbacks_.structBegin(name); } - void onStructField(const absl::string_view name, FieldType field_type, int16_t field_id) const { - callbacks_.structField(name, field_type, field_id); - } - void onStructEnd() const { callbacks_.structEnd(); } - void onMessageComplete() const { callbacks_.messageComplete(); } - - ProtocolCallbacks& callbacks_; -}; - /** * AutoProtocolImpl attempts to distinguish between the Thrift binary (strict mode only) and * compact protocols and then delegates subsequent decoding operations to the appropriate Protocol * implementation. */ -class AutoProtocolImpl : public ProtocolImplBase { +class AutoProtocolImpl : public Protocol { public: - AutoProtocolImpl(ProtocolCallbacks& callbacks) - : ProtocolImplBase(callbacks), name_(ProtocolNames::get().AUTO) {} + AutoProtocolImpl() : name_(ProtocolNames::get().AUTO) {} // Protocol const std::string& name() const override { return name_; } + ProtocolType type() const override { + if (protocol_ != nullptr) { + return protocol_->type(); + } + return ProtocolType::Auto; + } + bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id) override; bool readMessageEnd(Buffer::Instance& buffer) override; diff --git a/source/extensions/filters/network/thrift_proxy/router/BUILD b/source/extensions/filters/network/thrift_proxy/router/BUILD new file mode 100644 index 0000000000000..2d32db0ee154c --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/BUILD @@ -0,0 +1,51 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "config", + srcs = ["config.cc"], + hdrs = ["config.h"], + deps = [ + ":router_lib", + "//include/envoy/registry", + "//source/extensions/filters/network/thrift_proxy/filters:factory_base_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_config_interface", + "//source/extensions/filters/network/thrift_proxy/filters:well_known_names", + "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v2alpha1/router:router_cc", + ], +) + +envoy_cc_library( + name = "router_interface", + hdrs = ["router.h"], + external_deps = ["abseil_optional"], + deps = [], +) + +envoy_cc_library( + name = "router_lib", + srcs = ["router_impl.cc"], + hdrs = ["router_impl.h"], + deps = [ + ":router_interface", + "//include/envoy/tcp:conn_pool_interface", + "//include/envoy/upstream:cluster_manager_interface", + "//include/envoy/upstream:load_balancer_interface", + "//include/envoy/upstream:thread_local_cluster_interface", + "//source/common/common:logger_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_converter_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_lib", + "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "@envoy_api//envoy/extensions/filters/network/thrift_proxy/v2alpha1:thrift_proxy_cc", + ], +) diff --git a/source/extensions/filters/network/thrift_proxy/router/config.cc b/source/extensions/filters/network/thrift_proxy/router/config.cc new file mode 100644 index 0000000000000..abf7cafbc2655 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/config.cc @@ -0,0 +1,34 @@ +#include "extensions/filters/network/thrift_proxy/router/config.h" + +#include "envoy/registry/registry.h" + +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +ThriftFilters::FilterFactoryCb RouterFilterConfig::createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::router::Router& proto_config, + const std::string& stat_prefix, Server::Configuration::FactoryContext& context) { + UNREFERENCED_PARAMETER(proto_config); + UNREFERENCED_PARAMETER(stat_prefix); + + return [&context](ThriftFilters::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addDecoderFilter(std::make_shared(context.clusterManager())); + }; +} + +/** + * Static registration for the router filter. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/config.h b/source/extensions/filters/network/thrift_proxy/router/config.h new file mode 100644 index 0000000000000..ae4629bfd0ada --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/config.h @@ -0,0 +1,32 @@ +#pragma once + +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.pb.h" +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/router/router.pb.validate.h" + +#include "extensions/filters/network/thrift_proxy/filters/factory_base.h" +#include "extensions/filters/network/thrift_proxy/filters/well_known_names.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +class RouterFilterConfig + : public ThriftFilters::FactoryBase< + envoy::extensions::filters::network::thrift_proxy::v2alpha1::router::Router> { +public: + RouterFilterConfig() : FactoryBase(ThriftFilters::ThriftFilterNames::get().ROUTER) {} + +private: + ThriftFilters::FilterFactoryCb createFilterFactoryFromProtoTyped( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::router::Router& + proto_config, + const std::string& stat_prefix, Server::Configuration::FactoryContext& context) override; +}; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router.h b/source/extensions/filters/network/thrift_proxy/router/router.h new file mode 100644 index 0000000000000..32d717de52351 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router.h @@ -0,0 +1,62 @@ +#pragma once + +#include +#include + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +/** + * RouteEntry is an individual resolved route entry. + */ +class RouteEntry { +public: + virtual ~RouteEntry() {} + + /** + * @return const std::string& the upstream cluster that owns the route. + */ + virtual const std::string& clusterName() const PURE; +}; + +/** + * Route holds the RouteEntry for a request. + */ +class Route { +public: + virtual ~Route() {} + + /** + * @return the route entry or nullptr if there is no matching route for the request. + */ + virtual const RouteEntry* routeEntry() const PURE; +}; + +typedef std::shared_ptr RouteConstSharedPtr; + +/** + * The router configuration. + */ +class Config { +public: + virtual ~Config() {} + + /** + * Based on the incoming Thrift request transport and/or protocol data, determine the target + * route for the request. + * @param method supplies the thrift method name + * @return the route or nullptr if there is no matching route for the request. + */ + virtual RouteConstSharedPtr route(const std::string& method) const PURE; +}; + +typedef std::shared_ptr ConfigConstSharedPtr; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.cc b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc new file mode 100644 index 0000000000000..057d529f13981 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.cc @@ -0,0 +1,305 @@ +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" +#include "envoy/upstream/cluster_manager.h" +#include "envoy/upstream/thread_local_cluster.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +RouteEntryImplBase::RouteEntryImplBase( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route) + : cluster_name_(route.route().cluster()) {} + +const std::string& RouteEntryImplBase::clusterName() const { return cluster_name_; } + +const RouteEntry* RouteEntryImplBase::routeEntry() const { return this; } + +RouteConstSharedPtr RouteEntryImplBase::clusterEntry() const { return shared_from_this(); } + +MethodNameRouteEntryImpl::MethodNameRouteEntryImpl( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route) + : RouteEntryImplBase(route), method_name_(route.match().method()) {} + +RouteConstSharedPtr MethodNameRouteEntryImpl::matches(const std::string& method_name) const { + if (method_name_.empty() || method_name_ == method_name) { + return clusterEntry(); + } + + return nullptr; +} + +RouteMatcher::RouteMatcher( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration& config) { + for (const auto& route : config.routes()) { + routes_.emplace_back(new MethodNameRouteEntryImpl(route)); + } +} + +RouteConstSharedPtr RouteMatcher::route(const std::string& method_name) const { + for (const auto& route : routes_) { + RouteConstSharedPtr route_entry = route->matches(method_name); + if (nullptr != route_entry) { + return route_entry; + } + } + + return nullptr; +} + +void Router::onDestroy() { + if (upstream_request_ != nullptr) { + upstream_request_->resetStream(); + } + cleanup(); +} + +void Router::setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) { + callbacks_ = &callbacks; + + // TODO(zuercher): handle buffer limits +} + +void Router::resetUpstreamConnection() { + if (upstream_request_ != nullptr) { + upstream_request_->resetStream(); + } +} + +ThriftFilters::FilterStatus Router::transportBegin(absl::optional size) { + UNREFERENCED_PARAMETER(size); + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus Router::transportEnd() { + if (upstream_request_->msg_type_ == MessageType::Oneway) { + // No response expected + upstream_request_->onResponseComplete(); + cleanup(); + } + return ThriftFilters::FilterStatus::Continue; +} + +ThriftFilters::FilterStatus Router::messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) { + // TODO(zuercher): route stats (e.g., no_route, no_cluster, upstream_rq_maintenance_mode, no + // healtthy upstream) + + route_ = callbacks_->route(); + if (!route_) { + ENVOY_STREAM_LOG(debug, "no cluster match for method '{}'", *callbacks_, name); + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ + new AppException(name, seq_id, AppExceptionType::UnknownMethod, + fmt::format("no route for method '{}'", name))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + route_entry_ = route_->routeEntry(); + + Upstream::ThreadLocalCluster* cluster = cluster_manager_.get(route_entry_->clusterName()); + if (!cluster) { + ENVOY_STREAM_LOG(debug, "unknown cluster '{}'", *callbacks_, route_entry_->clusterName()); + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{ + new AppException(name, seq_id, AppExceptionType::InternalError, + fmt::format("unknown cluster '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + cluster_ = cluster->info(); + ENVOY_STREAM_LOG(debug, "cluster '{}' match for method '{}'", *callbacks_, + route_entry_->clusterName(), name); + + if (cluster_->maintenanceMode()) { + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + name, seq_id, AppExceptionType::InternalError, + fmt::format("maintenance mode for cluster '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + Tcp::ConnectionPool::Instance* conn_pool = cluster_manager_.tcpConnPoolForCluster( + route_entry_->clusterName(), Upstream::ResourcePriority::Default, this); + if (!conn_pool) { + callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + name, seq_id, AppExceptionType::InternalError, + fmt::format("no healthy upstream for '{}'", route_entry_->clusterName()))}); + return ThriftFilters::FilterStatus::StopIteration; + } + + ENVOY_STREAM_LOG(debug, "router decoding request", *callbacks_); + + upstream_request_.reset(new UpstreamRequest(*this, *conn_pool, name, msg_type, seq_id)); + upstream_request_->start(); + return ThriftFilters::FilterStatus::StopIteration; +} + +ThriftFilters::FilterStatus Router::messageEnd() { + ProtocolConverter::messageEnd(); + + Buffer::OwnedImpl transport_buffer; + upstream_request_->transport_->encodeFrame(transport_buffer, upstream_request_buffer_); + upstream_request_->conn_data_->connection().write(transport_buffer, false); + upstream_request_->onRequestComplete(); + return ThriftFilters::FilterStatus::Continue; +} + +void Router::onUpstreamData(Buffer::Instance& data, bool end_stream) { + ASSERT(!upstream_request_->response_complete_); + + if (!upstream_request_->response_started_) { + callbacks_->startUpstreamResponse(upstream_request_->transport_->type(), protocolType()); + upstream_request_->response_started_ = true; + } + + if (callbacks_->upstreamData(data)) { + upstream_request_->onResponseComplete(); + cleanup(); + return; + } + + if (end_stream) { + // Response is incomplete, but no more data is coming. + upstream_request_->onResponseComplete(); + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); + cleanup(); + } +} + +void Router::onEvent(Network::ConnectionEvent event) { + if (!upstream_request_ || upstream_request_->response_complete_) { + // Client closed connection after completing response. + return; + } + + switch (event) { + case Network::ConnectionEvent::RemoteClose: + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure); + break; + case Network::ConnectionEvent::LocalClose: + upstream_request_->onResetStream( + Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure); + break; + default: + // Connected is consumed by the connection pool. + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +const Network::Connection* Router::downstreamConnection() const { + if (callbacks_ != nullptr) { + return callbacks_->connection(); + } + + return nullptr; +} + +void Router::convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id) { + ProtocolConverter::messageBegin(absl::string_view(name), msg_type, seq_id); +} + +void Router::cleanup() { upstream_request_.reset(); } + +Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, + absl::string_view method_name, MessageType msg_type, + int32_t seq_id) + : parent_(parent), conn_pool_(pool), method_name_(std::string(method_name)), + msg_type_(msg_type), seq_id_(seq_id), request_complete_(false), response_started_(false), + response_complete_(false) {} + +Router::UpstreamRequest::~UpstreamRequest() {} + +void Router::UpstreamRequest::start() { + Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(*this); + if (handle) { + conn_pool_handle_ = handle; + } +} + +void Router::UpstreamRequest::resetStream() { + if (conn_data_ != nullptr) { + conn_data_->connection().close(Network::ConnectionCloseType::NoFlush); + conn_data_ = nullptr; + } +} + +void Router::UpstreamRequest::onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) { + // Mimic an upstream reset. + onUpstreamHostSelected(host); + onResetStream(reason); +} + +void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionData& conn_data, + Upstream::HostDescriptionConstSharedPtr host) { + onUpstreamHostSelected(host); + conn_data_ = &conn_data; + conn_data_->addUpstreamCallbacks(parent_); + + conn_pool_handle_ = nullptr; + + // TODO(zuercher): let cluster specify a specific transport and protocol + transport_ = + NamedTransportConfigFactory::getFactory(parent_.callbacks_->downstreamTransportType()) + .createTransport(); + + parent_.initProtocolConverter( + NamedProtocolConfigFactory::getFactory(parent_.callbacks_->downstreamProtocolType()) + .createProtocol(), + parent_.upstream_request_buffer_); + + // TODO(zuercher): need to use an upstream-connection-specific sequence id + parent_.convertMessageBegin(method_name_, msg_type_, seq_id_); + + parent_.callbacks_->continueDecoding(); +} + +void Router::UpstreamRequest::onRequestComplete() { request_complete_ = true; } + +void Router::UpstreamRequest::onResponseComplete() { + response_complete_ = true; + if (conn_data_ != nullptr) { + conn_data_->release(); + } + conn_data_ = nullptr; +} + +void Router::UpstreamRequest::onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host) { + upstream_host_ = host; +} + +void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReason reason) { + switch (reason) { + case Tcp::ConnectionPool::PoolFailureReason::Overflow: + parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + method_name_, seq_id_, AppExceptionType::InternalError, + fmt::format("too many connections to '{}'", upstream_host_->address()->asString()))}); + break; + case Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure: + case Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure: + case Tcp::ConnectionPool::PoolFailureReason::Timeout: + // TODO(zuercher): distinguish between these cases where appropriate (particularly timeout) + if (!response_started_) { + parent_.callbacks_->sendLocalReply(ThriftFilters::DirectResponsePtr{new AppException( + method_name_, seq_id_, AppExceptionType::InternalError, + fmt::format("connection failure '{}'", upstream_host_->address()->asString()))}); + return; + } + + parent_.callbacks_->resetDownstreamConnection(); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/router/router_impl.h b/source/extensions/filters/network/thrift_proxy/router/router_impl.h new file mode 100644 index 0000000000000..aa202734b5595 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/router/router_impl.h @@ -0,0 +1,160 @@ +#pragma once + +#include +#include +#include + +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" +#include "envoy/tcp/conn_pool.h" +#include "envoy/upstream/load_balancer.h" + +#include "common/common/logger.h" + +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +class RouteEntryImplBase : public RouteEntry, + public Route, + public std::enable_shared_from_this { +public: + RouteEntryImplBase( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route); + + // Router::RouteEntry + const std::string& clusterName() const override; + + // Router::Route + const RouteEntry* routeEntry() const override; + + virtual RouteConstSharedPtr matches(const std::string& method_name) const PURE; + +protected: + RouteConstSharedPtr clusterEntry() const; + +private: + const std::string cluster_name_; +}; + +typedef std::shared_ptr RouteEntryImplBaseConstSharedPtr; + +class MethodNameRouteEntryImpl : public RouteEntryImplBase { +public: + MethodNameRouteEntryImpl( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::Route& route); + + const std::string& methodName() const { return method_name_; } + + // RoutEntryImplBase + RouteConstSharedPtr matches(const std::string& method_name) const override; + +private: + const std::string method_name_; +}; + +class RouteMatcher { +public: + RouteMatcher( + const envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration&); + + RouteConstSharedPtr route(const std::string& method_name) const; + +private: + std::vector routes_; +}; + +class Router : public Tcp::ConnectionPool::UpstreamCallbacks, + public Upstream::LoadBalancerContext, + public ProtocolConverter, + Logger::Loggable { +public: + Router(Upstream::ClusterManager& cluster_manager) : cluster_manager_(cluster_manager) {} + + ~Router() {} + + // ProtocolConverter + void onDestroy() override; + void setDecoderFilterCallbacks(ThriftFilters::DecoderFilterCallbacks& callbacks) override; + void resetUpstreamConnection() override; + ThriftFilters::FilterStatus transportBegin(absl::optional size) override; + ThriftFilters::FilterStatus transportEnd() override; + ThriftFilters::FilterStatus messageBegin(absl::string_view name, MessageType msg_type, + int32_t seq_id) override; + ThriftFilters::FilterStatus messageEnd() override; + + // Upstream::LoadBalancerContext + absl::optional computeHashKey() override { return {}; } + const Envoy::Router::MetadataMatchCriteria* metadataMatchCriteria() override { return nullptr; } + const Network::Connection* downstreamConnection() const override; + const Http::HeaderMap* downstreamHeaders() const override { return nullptr; } + + // Tcp::ConnectionPool::UpstreamCallbacks + void onUpstreamData(Buffer::Instance& data, bool end_stream) override; + void onEvent(Network::ConnectionEvent event) override; + void onAboveWriteBufferHighWatermark() override {} + void onBelowWriteBufferLowWatermark() override {} + +private: + struct UpstreamRequest : public Tcp::ConnectionPool::Callbacks { + UpstreamRequest(Router& parent, Tcp::ConnectionPool::Instance& pool, + absl::string_view method_name, MessageType msg_type, int32_t seq_id); + ~UpstreamRequest(); + + void start(); + void resetStream(); + + // Tcp::ConnectionPool::Callbacks + void onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason, + Upstream::HostDescriptionConstSharedPtr host) override; + void onPoolReady(Tcp::ConnectionPool::ConnectionData& conn, + Upstream::HostDescriptionConstSharedPtr host) override; + + void onRequestComplete(); + void onResponseComplete(); + void onUpstreamHostSelected(Upstream::HostDescriptionConstSharedPtr host); + void onResetStream(Tcp::ConnectionPool::PoolFailureReason reason); + + Router& parent_; + Tcp::ConnectionPool::Instance& conn_pool_; + const std::string method_name_; + const MessageType msg_type_; + const int32_t seq_id_; + + Tcp::ConnectionPool::Cancellable* conn_pool_handle_{}; + Tcp::ConnectionPool::ConnectionData* conn_data_{}; + Upstream::HostDescriptionConstSharedPtr upstream_host_; + TransportPtr transport_; + ProtocolType proto_type_{ProtocolType::Auto}; + + bool request_complete_ : 1; + bool response_started_ : 1; + bool response_complete_ : 1; + }; + + void convertMessageBegin(const std::string& name, MessageType msg_type, int32_t seq_id); + void cleanup(); + + Upstream::ClusterManager& cluster_manager_; + + ThriftFilters::DecoderFilterCallbacks* callbacks_{}; + RouteConstSharedPtr route_{}; + const RouteEntry* route_entry_{}; + Upstream::ClusterInfoConstSharedPtr cluster_; + + std::unique_ptr upstream_request_; + Buffer::OwnedImpl upstream_request_buffer_; +}; + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/stats.h b/source/extensions/filters/network/thrift_proxy/stats.h new file mode 100644 index 0000000000000..a630fc735d034 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/stats.h @@ -0,0 +1,52 @@ +#pragma once + +#include + +#include "envoy/stats/stats.h" +#include "envoy/stats/stats_macros.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * All thrift filter stats. @see stats_macros.h + */ +// clang-format off +#define ALL_THRIFT_FILTER_STATS(COUNTER, GAUGE, HISTOGRAM) \ + COUNTER(request) \ + COUNTER(request_call) \ + COUNTER(request_oneway) \ + COUNTER(request_invalid_type) \ + GAUGE(request_active) \ + COUNTER(request_decoding_error) \ + HISTOGRAM(request_time_ms) \ + COUNTER(response) \ + COUNTER(response_reply) \ + COUNTER(response_success) \ + COUNTER(response_error) \ + COUNTER(response_exception) \ + COUNTER(response_invalid_type) \ + COUNTER(response_decoding_error) \ + COUNTER(cx_destroy_local_with_active_rq) \ + COUNTER(cx_destroy_remote_with_active_rq) +// clang-format on + +/** + * Struct definition for all mongo proxy stats. @see stats_macros.h + */ +struct ThriftFilterStats { + ALL_THRIFT_FILTER_STATS(GENERATE_COUNTER_STRUCT, GENERATE_GAUGE_STRUCT, GENERATE_HISTOGRAM_STRUCT) + + static ThriftFilterStats generateStats(const std::string& prefix, Stats::Scope& scope) { + return ThriftFilterStats{ALL_THRIFT_FILTER_STATS(POOL_COUNTER_PREFIX(scope, prefix), + POOL_GAUGE_PREFIX(scope, prefix), + POOL_HISTOGRAM_PREFIX(scope, prefix))}; + } +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/transport.h b/source/extensions/filters/network/thrift_proxy/transport.h index 516c055ee72b9..1bda083e6bc3e 100644 --- a/source/extensions/filters/network/thrift_proxy/transport.h +++ b/source/extensions/filters/network/thrift_proxy/transport.h @@ -4,7 +4,10 @@ #include #include "envoy/buffer/buffer.h" +#include "envoy/registry/registry.h" +#include "common/common/assert.h" +#include "common/config/utility.h" #include "common/singleton/const_singleton.h" #include "absl/types/optional.h" @@ -14,6 +17,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +enum class TransportType { + Framed, + Unframed, + Auto, + + // ATTENTION: MAKE SURE THIS REMAINS EQUAL TO THE LAST TRANSPORT TYPE + LastTransportType = Auto, + +}; + /** * Names of available Transport implementations. */ @@ -27,29 +40,23 @@ class TransportNameValues { // Auto-detection transport const std::string AUTO = "auto"; + + const std::string& fromType(TransportType type) const { + switch (type) { + case TransportType::Framed: + return FRAMED; + case TransportType::Unframed: + return UNFRAMED; + case TransportType::Auto: + return AUTO; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } }; typedef ConstSingleton TransportNames; -/** - * TransportCallbacks are Thrift transport-level callbacks. - */ -class TransportCallbacks { -public: - virtual ~TransportCallbacks() {} - - /** - * Indicates the start of a Thrift transport frame was detected. - * @param size the size of the message, if available to the transport - */ - virtual void transportFrameStart(absl::optional size) PURE; - - /** - * Indicates the end of a Thrift transport frame was detected. - */ - virtual void transportFrameComplete() PURE; -}; - /** * Transport represents a Thrift transport. The Thrift transport is nominally a generic, * bi-directional byte stream. In Envoy we assume it always represents a network byte stream and @@ -66,20 +73,27 @@ class Transport { */ virtual const std::string& name() const PURE; + /** + * @return TransportType the transport type + */ + virtual TransportType type() const PURE; + /* - * decodeFrameStart decodes the start of a transport message, potentially invoking callbacks. - * If successful, the start of the frame is removed from the buffer. + * Decodes the start of a transport message. If successful, the start of the frame is removed + * from the buffer. * * @param buffer the currently buffered thrift data. + * @param size updated with the frame size on success. If frame size is not encoded, the size + * is cleared on success. * @return bool true if a complete frame header was successfully consumed, false if more data * is required. * @throws EnvoyException if the data is not valid for this transport. */ - virtual bool decodeFrameStart(Buffer::Instance& buffer) PURE; + virtual bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) PURE; /* - * decodeFrameEnd decodes the end of a transport message, potentially invoking callbacks. - * If successful, the end of the frame is removed from the buffer. + * Decodes the end of a transport message. If successful, the end of the frame is removed from + * the buffer. * * @param buffer the currently buffered thrift data. * @return bool true if a complete frame trailer was successfully consumed, false if more data @@ -89,8 +103,8 @@ class Transport { virtual bool decodeFrameEnd(Buffer::Instance& buffer) PURE; /** - * encodeFrame wraps the given message buffer with the transport's header and trailer (if any). - * After encoding, message will be empty. + * Wraps the given message buffer with the transport's header and trailer (if any). After + * encoding, message will be empty. * @param buffer is the output buffer * @param message a protocol-encoded message * @throws EnvoyException if the message is too large for the transport @@ -100,6 +114,52 @@ class Transport { typedef std::unique_ptr TransportPtr; +/** + * Implemented by each Thrift transport and registered via Registry::registerFactory or the + * convenience class RegisterFactory. + */ +class NamedTransportConfigFactory { +public: + virtual ~NamedTransportConfigFactory() {} + + /** + * Create a particular Thrift transport. + * @return TransportPtr the transport + */ + virtual TransportPtr createTransport() PURE; + + /** + * @return std::string the identifying name for a particular implementation of thrift transport + * produced by the factory. + */ + virtual std::string name() PURE; + + /** + * Convenience method to lookup a factory by type. + * @param TransportType the transport type + * @return NamedTransportConfigFactory& for the TransportType + */ + static NamedTransportConfigFactory& getFactory(TransportType type) { + const std::string& name = TransportNames::get().fromType(type); + return Envoy::Config::Utility::getAndCheckFactory(name); + } +}; + +/** + * TransportFactoryBase provides a template for a trivial NamedTransportConfigFactory. + */ +template class TransportFactoryBase : public NamedTransportConfigFactory { + TransportPtr createTransport() override { return std::move(std::make_unique()); } + + std::string name() override { return name_; } + +protected: + TransportFactoryBase(const std::string& name) : name_(name) {} + +private: + const std::string name_; +}; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.cc b/source/extensions/filters/network/thrift_proxy/transport_impl.cc index 4caa88efcb313..ff177884d94c5 100644 --- a/source/extensions/filters/network/thrift_proxy/transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.cc @@ -15,7 +15,7 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { +bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) { if (transport_ == nullptr) { // Not enough data to select a transport. if (buffer.length() < 8) { @@ -30,13 +30,13 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { // is configurable, but defaults to 256 MB (0x1000000). THeaderTransport will take up to ~1GB // (0x3FFFFFFF) when it falls back to framed mode. if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { - setTransport(std::make_unique(callbacks_)); + setTransport(std::make_unique()); } } else { // Check for sane unframed protocol. proto_start = static_cast((size >> 16) & 0xFFFF); if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { - setTransport(std::make_unique(callbacks_)); + setTransport(std::make_unique()); } } @@ -51,7 +51,7 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer) { } } - return transport_->decodeFrameStart(buffer); + return transport_->decodeFrameStart(buffer, size); } bool AutoTransportImpl::decodeFrameEnd(Buffer::Instance& buffer) { @@ -64,6 +64,16 @@ void AutoTransportImpl::encodeFrame(Buffer::Instance& buffer, Buffer::Instance& transport_->encodeFrame(buffer, message); } +class AutoTransportConfigFactory : public TransportFactoryBase { +public: + AutoTransportConfigFactory() : TransportFactoryBase(TransportNames::get().AUTO) {} +}; + +/** + * Static registration for the auto transport. @see RegisterFactory. + */ +static Registry::RegisterFactory register_; + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.h b/source/extensions/filters/network/thrift_proxy/transport_impl.h index 137a58a622c9f..08281c53c6d16 100644 --- a/source/extensions/filters/network/thrift_proxy/transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.h @@ -15,33 +15,25 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/* - * TransportImplBase provides a base class for Transport implementations. - */ -class TransportImplBase : public virtual Transport { -public: - TransportImplBase(TransportCallbacks& callbacks) : callbacks_(callbacks) {} - -protected: - void onFrameStart(absl::optional size) const { callbacks_.transportFrameStart(size); } - void onFrameComplete() const { callbacks_.transportFrameComplete(); } - - TransportCallbacks& callbacks_; -}; - /** * AutoTransportImpl implements Transport and attempts to distinguish between the Thrift framed and * unframed transports. Once the transport is detected, subsequent operations are delegated to the * appropriate implementation. */ -class AutoTransportImpl : public TransportImplBase { +class AutoTransportImpl : public Transport { public: - AutoTransportImpl(TransportCallbacks& callbacks) - : TransportImplBase(callbacks), name_(TransportNames::get().AUTO){}; + AutoTransportImpl() : name_(TransportNames::get().AUTO){}; // Transport const std::string& name() const override { return name_; } - bool decodeFrameStart(Buffer::Instance& buffer) override; + TransportType type() const override { + if (transport_ != nullptr) { + return transport_->type(); + } + + return TransportType::Auto; + } + bool decodeFrameStart(Buffer::Instance& buffer, absl::optional& size) override; bool decodeFrameEnd(Buffer::Instance& buffer) override; void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override; diff --git a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc new file mode 100644 index 0000000000000..d3a2744540c96 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.cc @@ -0,0 +1,22 @@ +#include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class UnframedTransportConfigFactory : public TransportFactoryBase { +public: + UnframedTransportConfigFactory() : TransportFactoryBase(TransportNames::get().UNFRAMED) {} +}; + +/** + * Static registration for the unframed transport. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h index af7ea884ddba0..29992dfc0812f 100644 --- a/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/unframed_transport_impl.h @@ -17,20 +17,18 @@ namespace ThriftProxy { * UnframedTransportImpl implements the Thrift Unframed transport. * See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md */ -class UnframedTransportImpl : public TransportImplBase { +class UnframedTransportImpl : public Transport { public: - UnframedTransportImpl(TransportCallbacks& callbacks) : TransportImplBase(callbacks) {} + UnframedTransportImpl() {} // Transport const std::string& name() const override { return TransportNames::get().UNFRAMED; } - bool decodeFrameStart(Buffer::Instance&) override { - onFrameStart(absl::optional()); - return true; - } - bool decodeFrameEnd(Buffer::Instance&) override { - onFrameComplete(); + TransportType type() const override { return TransportType::Unframed; } + bool decodeFrameStart(Buffer::Instance&, absl::optional& size) override { + size.reset(); return true; } + bool decodeFrameEnd(Buffer::Instance&) override { return true; } void encodeFrame(Buffer::Instance& buffer, Buffer::Instance& message) override { buffer.move(message); } diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index faa42cb93aff6..1eea2452232bb 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -19,7 +19,12 @@ envoy_extension_cc_mock( hdrs = ["mocks.h"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy:protocol_lib", "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + "//test/mocks/network:network_mocks", "//test/test_common:printers_lib", ], ) @@ -79,33 +84,40 @@ envoy_extension_cc_test( extension_name = "envoy.filters.network.thrift_proxy", deps = [ "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy/router:config", "//test/mocks/server:server_mocks", ], ) envoy_extension_cc_test( - name = "decoder_test", - srcs = ["decoder_test.cc"], + name = "conn_manager_test", + srcs = ["conn_manager_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ ":mocks", ":utility_lib", - "//source/extensions/filters/network/thrift_proxy:decoder_lib", + "//source/extensions/filters/network/thrift_proxy:config", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//source/extensions/filters/network/thrift_proxy/router:router_interface", + "//test/mocks/network:network_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/upstream:upstream_mocks", "//test/test_common:printers_lib", - "//test/test_common:utility_lib", ], ) envoy_extension_cc_test( - name = "filter_test", - srcs = ["filter_test.cc"], + name = "decoder_test", + srcs = ["decoder_test.cc"], extension_name = "envoy.filters.network.thrift_proxy", deps = [ + ":mocks", ":utility_lib", - "//source/common/stats:stats_lib", - "//source/extensions/filters/network/thrift_proxy:filter_lib", - "//test/mocks/network:network_mocks", + "//source/extensions/filters/network/thrift_proxy:decoder_lib", "//test/test_common:printers_lib", + "//test/test_common:utility_lib", ], ) @@ -117,7 +129,6 @@ envoy_extension_cc_test( ":mocks", ":utility_lib", "//source/extensions/filters/network/thrift_proxy:transport_lib", - "//test/mocks/buffer:buffer_mocks", "//test/test_common:printers_lib", "//test/test_common:utility_lib", ], @@ -136,6 +147,24 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "router_test", + srcs = ["router_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:app_exception_lib", + "//source/extensions/filters/network/thrift_proxy/router:config", + "//source/extensions/filters/network/thrift_proxy/router:router_lib", + "//test/mocks/network:network_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/upstream:upstream_mocks", + "//test/test_common:printers_lib", + "//test/test_common:registry_lib", + ], +) + envoy_extension_cc_test( name = "transport_impl_test", srcs = ["transport_impl_test.cc"], @@ -163,8 +192,8 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "filter_integration_test", - srcs = ["filter_integration_test.cc"], + name = "integration_test", + srcs = ["integration_test.cc"], data = [ "//test/extensions/filters/network/thrift_proxy/driver:generate_fixture", ], @@ -172,7 +201,8 @@ envoy_extension_cc_test( deps = [ "//source/extensions/filters/network/tcp_proxy:config", "//source/extensions/filters/network/thrift_proxy:config", - "//source/extensions/filters/network/thrift_proxy:filter_lib", + "//source/extensions/filters/network/thrift_proxy:conn_manager_lib", + "//source/extensions/filters/network/thrift_proxy/router:config", "//test/integration:integration_lib", "//test/test_common:environment_lib", "//test/test_common:network_utility_lib", diff --git a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc index 6ed1723fee094..3db9d588d5338 100644 --- a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc @@ -4,29 +4,24 @@ #include "extensions/filters/network/thrift_proxy/binary_protocol_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" #include "gtest/gtest.h" -using testing::StrictMock; - namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { TEST(BinaryProtocolTest, Name) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary"); } TEST(BinaryProtocolTest, ReadMessageBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -97,7 +92,6 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addInt32(buffer, 0); addInt32(buffer, 1234); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 1234)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -139,7 +133,6 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { addString(buffer, "the_name"); addInt32(buffer, 5678); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 5678)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -150,34 +143,29 @@ TEST(BinaryProtocolTest, ReadMessageBegin) { TEST(BinaryProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; - EXPECT_CALL(cb, messageComplete()); EXPECT_TRUE(proto.readMessageEnd(buffer)); } TEST(BinaryProtocolTest, ReadStructBegin) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; std::string name = "-"; - EXPECT_CALL(cb, structBegin(absl::string_view(""))); + EXPECT_TRUE(proto.readStructBegin(buffer, name)); EXPECT_EQ(name, ""); } TEST(BinaryProtocolTest, ReadStructEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); - EXPECT_CALL(cb, structEnd()); + BinaryProtocolImpl proto; + EXPECT_TRUE(proto.readStructEnd(buffer)); } TEST(BinaryProtocolTest, ReadFieldBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -201,7 +189,6 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { addInt8(buffer, FieldType::Stop); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Stop, 0)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Stop); @@ -234,7 +221,6 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { addInt8(buffer, FieldType::I32); addInt16(buffer, 99); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 99)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -263,14 +249,12 @@ TEST(BinaryProtocolTest, ReadFieldBegin) { TEST(BinaryProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } TEST(BinaryProtocolTest, ReadMapBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -328,14 +312,12 @@ TEST(BinaryProtocolTest, ReadMapBegin) { TEST(BinaryProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } TEST(BinaryProtocolTest, ReadListBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -385,14 +367,12 @@ TEST(BinaryProtocolTest, ReadListBegin) { TEST(BinaryProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } TEST(BinaryProtocolTest, ReadSetBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() Buffer::OwnedImpl buffer; @@ -410,14 +390,12 @@ TEST(BinaryProtocolTest, ReadSetBegin) { TEST(BinaryProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } TEST(BinaryProtocolTest, ReadIntegerTypes) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Bool { @@ -535,8 +513,7 @@ TEST(BinaryProtocolTest, ReadIntegerTypes) { } TEST(BinaryProtocolTest, ReadDouble) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data { @@ -564,8 +541,7 @@ TEST(BinaryProtocolTest, ReadDouble) { } TEST(BinaryProtocolTest, ReadString) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Insufficient data to read length { @@ -632,8 +608,7 @@ TEST(BinaryProtocolTest, ReadString) { TEST(BinaryProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; std::string value = "-"; @@ -646,8 +621,7 @@ TEST(BinaryProtocolTest, ReadBinary) { } TEST(BinaryProtocolTest, WriteMessageBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Named call { @@ -665,32 +639,28 @@ TEST(BinaryProtocolTest, WriteMessageBegin) { } TEST(BinaryProtocolTest, WriteMessageEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMessageEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteStructBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructBegin(buffer, "unused"); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteStructEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteFieldBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Stop field { @@ -708,16 +678,14 @@ TEST(BinaryProtocolTest, WriteFieldBegin) { } TEST(BinaryProtocolTest, WriteFieldEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteMapBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Non-empty map { @@ -743,16 +711,14 @@ TEST(BinaryProtocolTest, WriteMapBegin) { } TEST(BinaryProtocolTest, WriteMapEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMapEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteListBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Non-empty list { @@ -777,16 +743,14 @@ TEST(BinaryProtocolTest, WriteListBegin) { } TEST(BinaryProtocolTest, WriteListEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeListEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteSetBegin) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Only test the happy path, as this shares an implementation with writeListBegin // Non-empty list @@ -796,16 +760,14 @@ TEST(BinaryProtocolTest, WriteSetBegin) { } TEST(BinaryProtocolTest, WriteSetEnd) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeSetEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(BinaryProtocolTest, WriteBool) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // True { @@ -823,8 +785,7 @@ TEST(BinaryProtocolTest, WriteBool) { } TEST(BinaryProtocolTest, WriteByte) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -840,8 +801,7 @@ TEST(BinaryProtocolTest, WriteByte) { } TEST(BinaryProtocolTest, WriteInt16) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -857,8 +817,7 @@ TEST(BinaryProtocolTest, WriteInt16) { } TEST(BinaryProtocolTest, WriteInt32) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -874,8 +833,7 @@ TEST(BinaryProtocolTest, WriteInt32) { } TEST(BinaryProtocolTest, WriteInt64) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -891,16 +849,14 @@ TEST(BinaryProtocolTest, WriteInt64) { } TEST(BinaryProtocolTest, WriteDouble) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeDouble(buffer, 3.0); EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); } TEST(BinaryProtocolTest, WriteString) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -919,8 +875,7 @@ TEST(BinaryProtocolTest, WriteString) { } TEST(BinaryProtocolTest, WriteBinary) { - StrictMock cb; - BinaryProtocolImpl proto(cb); + BinaryProtocolImpl proto; // Happy path only, since this is just a synonym for writeString Buffer::OwnedImpl buffer; @@ -932,14 +887,12 @@ TEST(BinaryProtocolTest, WriteBinary) { } TEST(LaxBinaryProtocolTest, Name) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; EXPECT_EQ(proto.name(), "binary/non-strict"); } TEST(LaxBinaryProtocolTest, ReadMessageBegin) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; // Insufficient data { @@ -989,7 +942,6 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { addInt8(buffer, MessageType::Call); addInt32(buffer, 1234); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 1234)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -1027,7 +979,6 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { addInt8(buffer, MessageType::Call); addInt32(buffer, 5678); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 5678)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -1037,8 +988,7 @@ TEST(LaxBinaryProtocolTest, ReadMessageBegin) { } TEST(LaxBinaryProtocolTest, WriteMessageBegin) { - StrictMock cb; - LaxBinaryProtocolImpl proto(cb); + LaxBinaryProtocolImpl proto; // Named call { diff --git a/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc b/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc index 9536177a96552..26030cd7bd9a7 100644 --- a/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc +++ b/test/extensions/filters/network/thrift_proxy/buffer_helper_test.cc @@ -17,50 +17,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -TEST(BufferWrapperTest, ImplementedFunctions) { - Buffer::OwnedImpl buffer; - addString(buffer, "abcdefghij"); - - BufferWrapper wrapper(buffer); - { - char s[4] = {0}; - wrapper.copyOut(0, 3, s); - EXPECT_EQ("abc", std::string(s)); - EXPECT_EQ(10, wrapper.length()); - EXPECT_EQ(0, wrapper.position()); - } - - { - char s[6] = {0}; - wrapper.copyOut(5, 5, s); - EXPECT_EQ("fghij", std::string(s)); - EXPECT_EQ(10, wrapper.length()); - EXPECT_EQ(0, wrapper.position()); - } - - { - std::string s(static_cast(wrapper.linearize(5)), 5); - EXPECT_EQ("abcde", s); - EXPECT_EQ(0, wrapper.position()); - } - - wrapper.drain(2); - - { - char s[4] = {0}; - wrapper.copyOut(4, 3, s); - EXPECT_EQ("ghi", std::string(s)); - EXPECT_EQ(8, wrapper.length()); - EXPECT_EQ(2, wrapper.position()); - } - - { - std::string s(static_cast(wrapper.linearize(8)), 8); - EXPECT_EQ("cdefghij", s); - EXPECT_EQ(2, wrapper.position()); - } -} - TEST(BufferHelperTest, PeekI8) { { Buffer::OwnedImpl buffer; diff --git a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc index b784d85e85297..79187821def52 100644 --- a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc @@ -4,15 +4,12 @@ #include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" #include "gtest/gtest.h" -using testing::NiceMock; -using testing::StrictMock; using testing::TestWithParam; using testing::Values; @@ -22,14 +19,12 @@ namespace NetworkFilters { namespace ThriftProxy { TEST(CompactProtocolTest, Name) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_EQ(proto.name(), "compact"); } TEST(CompactProtocolTest, ReadMessageBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -170,7 +165,6 @@ TEST(CompactProtocolTest, ReadMessageBegin) { addInt8(buffer, 32); addInt8(buffer, 0); - EXPECT_CALL(cb, messageStart(absl::string_view(""), MessageType::Call, 32)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, ""); EXPECT_EQ(msg_type, MessageType::Call); @@ -228,7 +222,6 @@ TEST(CompactProtocolTest, ReadMessageBegin) { addInt8(buffer, 8); addString(buffer, "the_name"); - EXPECT_CALL(cb, messageStart(absl::string_view("the_name"), MessageType::Call, 0x0102)); EXPECT_TRUE(proto.readMessageBegin(buffer, name, msg_type, seq_id)); EXPECT_EQ(name, "the_name"); EXPECT_EQ(msg_type, MessageType::Call); @@ -239,22 +232,19 @@ TEST(CompactProtocolTest, ReadMessageBegin) { TEST(CompactProtocolTest, ReadMessageEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); - EXPECT_CALL(cb, messageComplete()); + CompactProtocolImpl proto; + EXPECT_TRUE(proto.readMessageEnd(buffer)); } TEST(CompactProtocolTest, ReadStruct) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; std::string name = "-"; - EXPECT_CALL(cb, structBegin(absl::string_view(""))); + EXPECT_TRUE(proto.readStructBegin(buffer, name)); EXPECT_EQ(name, ""); - EXPECT_CALL(cb, structEnd()); EXPECT_TRUE(proto.readStructEnd(buffer)); EXPECT_THROW_WITH_MESSAGE(proto.readStructEnd(buffer), EnvoyException, @@ -262,8 +252,7 @@ TEST(CompactProtocolTest, ReadStruct) { } TEST(CompactProtocolTest, ReadFieldBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -287,7 +276,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0xF0); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Stop, 0)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Stop); @@ -400,7 +388,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0x05); addInt8(buffer, 0x04); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 2)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -417,7 +404,6 @@ TEST(CompactProtocolTest, ReadFieldBegin) { addInt8(buffer, 0xF5); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::I32, 17)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::I32); @@ -428,14 +414,12 @@ TEST(CompactProtocolTest, ReadFieldBegin) { TEST(CompactProtocolTest, ReadFieldEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readFieldEnd(buffer)); } TEST(CompactProtocolTest, ReadMapBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -575,14 +559,12 @@ TEST(CompactProtocolTest, ReadMapBegin) { TEST(CompactProtocolTest, ReadMapEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readMapEnd(buffer)); } TEST(CompactProtocolTest, ReadListBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -690,14 +672,12 @@ TEST(CompactProtocolTest, ReadListBegin) { TEST(CompactProtocolTest, ReadListEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readListEnd(buffer)); } TEST(CompactProtocolTest, ReadSetBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Test only the happy path, since this method is just delegated to readListBegin() Buffer::OwnedImpl buffer; @@ -714,14 +694,12 @@ TEST(CompactProtocolTest, ReadSetBegin) { TEST(CompactProtocolTest, ReadSetEnd) { Buffer::OwnedImpl buffer; - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; EXPECT_TRUE(proto.readSetEnd(buffer)); } TEST(CompactProtocolTest, ReadBool) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Bool field values are encoded in the field type { @@ -734,7 +712,6 @@ TEST(CompactProtocolTest, ReadBool) { addInt8(buffer, 0x01); addInt8(buffer, 0x04); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Bool, 2)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Bool); @@ -751,7 +728,6 @@ TEST(CompactProtocolTest, ReadBool) { addInt8(buffer, 0x02); addInt8(buffer, 0x06); - EXPECT_CALL(cb, structField(absl::string_view(""), FieldType::Bool, 3)); EXPECT_TRUE(proto.readFieldBegin(buffer, name, field_type, field_id)); EXPECT_EQ(name, ""); EXPECT_EQ(field_type, FieldType::Bool); @@ -787,8 +763,7 @@ TEST(CompactProtocolTest, ReadBool) { } TEST(CompactProtocolTest, ReadIntegerTypes) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Byte { @@ -919,8 +894,7 @@ TEST(CompactProtocolTest, ReadIntegerTypes) { } TEST(CompactProtocolTest, ReadDouble) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -950,8 +924,7 @@ TEST(CompactProtocolTest, ReadDouble) { } TEST(CompactProtocolTest, ReadString) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Insufficient data { @@ -1028,8 +1001,7 @@ TEST(CompactProtocolTest, ReadString) { TEST(CompactProtocolTest, ReadBinary) { // Test only the happy path, since this method is just delegated to readString() - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; std::string value = "-"; @@ -1046,8 +1018,7 @@ class CompactProtocolFieldTypeTest : public TestWithParam {}; TEST_P(CompactProtocolFieldTypeTest, ConvertsToFieldType) { uint8_t compact_field_type = GetParam(); - NiceMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; std::string name = "-"; int8_t invalid_field_type = static_cast(FieldType::LastFieldType) + 1; FieldType field_type = static_cast(invalid_field_type); @@ -1080,8 +1051,7 @@ INSTANTIATE_TEST_CASE_P(CompactFieldTypes, CompactProtocolFieldTypeTest, Values(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12)); TEST(CompactProtocolTest, WriteMessageBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Named call { @@ -1099,16 +1069,14 @@ TEST(CompactProtocolTest, WriteMessageBegin) { } TEST(CompactProtocolTest, WriteMessageEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMessageEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteStruct) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeStructBegin(buffer, "unused"); @@ -1123,16 +1091,14 @@ TEST(CompactProtocolTest, WriteStruct) { TEST(CompactProtocolTest, WriteFieldBegin) { // Stop field { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Stop, 1); EXPECT_EQ(std::string("\0", 1), buffer.toString()); } { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Short form { @@ -1145,7 +1111,7 @@ TEST(CompactProtocolTest, WriteFieldBegin) { { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Struct, 17); - EXPECT_EQ(std::string("\xC\0\x11", 3), buffer.toString()); + EXPECT_EQ(std::string("\xC\x22", 2), buffer.toString()); } // Short form @@ -1164,14 +1130,13 @@ TEST(CompactProtocolTest, WriteFieldBegin) { } { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Long form { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::I32, 16); - EXPECT_EQ(std::string("\x5\0\x10", 3), buffer.toString()); + EXPECT_EQ(std::string("\x5\x20", 2), buffer.toString()); } // Short form @@ -1185,21 +1150,20 @@ TEST(CompactProtocolTest, WriteFieldBegin) { { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Byte, 33); - EXPECT_EQ(std::string("\x3\0\x21", 3), buffer.toString()); + EXPECT_EQ(std::string("\x3\x42", 2), buffer.toString()); } - // Long form + // Long form (3 bytes) { Buffer::OwnedImpl buffer; - proto.writeFieldBegin(buffer, "unused", FieldType::String, 49); - EXPECT_EQ(std::string("\x8\0\x31", 3), buffer.toString()); + proto.writeFieldBegin(buffer, "unused", FieldType::String, 64); + EXPECT_EQ(std::string("\x8\x80\x1", 3), buffer.toString()); } } // Unknown field type { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; int8_t invalid_field_type = static_cast(FieldType::LastFieldType) + 1; @@ -1212,8 +1176,7 @@ TEST(CompactProtocolTest, WriteFieldBegin) { } TEST(CompactProtocolTest, WriteFieldEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeFieldEnd(buffer); EXPECT_EQ(0, buffer.length()); @@ -1224,8 +1187,7 @@ TEST(CompactProtocolTest, WriteBoolField) { // Short form field { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 8); @@ -1245,15 +1207,14 @@ TEST(CompactProtocolTest, WriteBoolField) { // Long form field { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 16); EXPECT_EQ(0, buffer.length()); proto.writeBool(buffer, true); - EXPECT_EQ(std::string("\x1\0\x10", 3), buffer.toString()); + EXPECT_EQ(std::string("\x1\x20", 2), buffer.toString()); } { @@ -1261,14 +1222,13 @@ TEST(CompactProtocolTest, WriteBoolField) { proto.writeFieldBegin(buffer, "unused", FieldType::Bool, 32); EXPECT_EQ(0, buffer.length()); proto.writeBool(buffer, false); - EXPECT_EQ(std::string("\x2\0\x20", 3), buffer.toString()); + EXPECT_EQ(std::string("\x2\x40", 2), buffer.toString()); } } } TEST(CompactProtocolTest, WriteMapBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Empty map { @@ -1294,16 +1254,14 @@ TEST(CompactProtocolTest, WriteMapBegin) { } TEST(CompactProtocolTest, WriteMapEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeMapEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteListBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Empty list { @@ -1335,16 +1293,14 @@ TEST(CompactProtocolTest, WriteListBegin) { } TEST(CompactProtocolTest, WriteListEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeListEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteSetBegin) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Empty set only, as writeSetBegin delegates to writeListBegin. Buffer::OwnedImpl buffer; @@ -1353,16 +1309,14 @@ TEST(CompactProtocolTest, WriteSetBegin) { } TEST(CompactProtocolTest, WriteSetEnd) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeSetEnd(buffer); EXPECT_EQ(0, buffer.length()); } TEST(CompactProtocolTest, WriteBool) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // Non-field bools (see WriteBoolField test) { @@ -1379,8 +1333,7 @@ TEST(CompactProtocolTest, WriteBool) { } TEST(CompactProtocolTest, WriteByte) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -1396,8 +1349,7 @@ TEST(CompactProtocolTest, WriteByte) { } TEST(CompactProtocolTest, WriteInt16) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // zigzag(1) = 2 { @@ -1436,8 +1388,7 @@ TEST(CompactProtocolTest, WriteInt16) { } TEST(CompactProtocolTest, WriteInt32) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // zigzag(1) = 2 { @@ -1476,8 +1427,7 @@ TEST(CompactProtocolTest, WriteInt32) { } TEST(CompactProtocolTest, WriteInt64) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // zigzag(1) = 2 { @@ -1516,16 +1466,14 @@ TEST(CompactProtocolTest, WriteInt64) { } TEST(CompactProtocolTest, WriteDouble) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; Buffer::OwnedImpl buffer; proto.writeDouble(buffer, 3.0); EXPECT_EQ(std::string("\x40\x8\0\0\0\0\0\0", 8), buffer.toString()); } TEST(CompactProtocolTest, WriteString) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; { Buffer::OwnedImpl buffer; @@ -1551,8 +1499,7 @@ TEST(CompactProtocolTest, WriteString) { } TEST(CompactProtocolTest, WriteBinary) { - StrictMock cb; - CompactProtocolImpl proto(cb); + CompactProtocolImpl proto; // writeBinary is an alias for writeString Buffer::OwnedImpl buffer; diff --git a/test/extensions/filters/network/thrift_proxy/config_test.cc b/test/extensions/filters/network/thrift_proxy/config_test.cc index 3047282f9c771..56c481075a01f 100644 --- a/test/extensions/filters/network/thrift_proxy/config_test.cc +++ b/test/extensions/filters/network/thrift_proxy/config_test.cc @@ -31,7 +31,7 @@ TEST(ThriftFilterConfigTest, ValidProtoConfiguration) { ThriftProxyFilterConfigFactory factory; Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); Network::MockConnection connection; - EXPECT_CALL(connection, addFilter(_)); + EXPECT_CALL(connection, addReadFilter(_)); cb(connection); } @@ -45,7 +45,7 @@ TEST(ThriftFilterConfigTest, ThriftProxyWithEmptyProto) { Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); Network::MockConnection connection; - EXPECT_CALL(connection, addFilter(_)); + EXPECT_CALL(connection, addReadFilter(_)); cb(connection); } diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc new file mode 100644 index 0000000000000..48dce17b7d40e --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -0,0 +1,760 @@ +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/thrift_proxy.pb.h" + +#include "common/buffer/buffer_impl.h" +#include "common/stats/stats_impl.h" + +#include "extensions/filters/network/thrift_proxy/buffer_helper.h" +#include "extensions/filters/network/thrift_proxy/config.h" +#include "extensions/filters/network/thrift_proxy/conn_manager.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/server/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/printers.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::Invoke; +using testing::NiceMock; +using testing::Return; +using testing::ReturnRef; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +class TestConfigImpl : public ConfigImpl { +public: + TestConfigImpl( + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy proto_config, + Server::Configuration::MockFactoryContext& context, + ThriftFilters::DecoderFilterSharedPtr decoder_filter, ThriftFilterStats& stats) + : ConfigImpl(proto_config, context), decoder_filter_(decoder_filter), stats_(stats) {} + + // ConfigImpl + ThriftFilterStats& stats() override { return stats_; } + void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override { + callbacks.addDecoderFilter(decoder_filter_); + } + +private: + ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + ThriftFilterStats& stats_; +}; + +class ThriftConnectionManagerTest : public testing::Test { +public: + ThriftConnectionManagerTest() : stats_(ThriftFilterStats::generateStats("test.", store_)) {} + ~ThriftConnectionManagerTest() { + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + void initializeFilter() { initializeFilter(""); } + + void initializeFilter(const std::string& yaml) { + // Destroy any existing filter first. + filter_ = nullptr; + + for (auto counter : store_.counters()) { + counter->reset(); + } + + if (yaml.empty()) { + proto_config_.set_stat_prefix("test"); + } else { + MessageUtil::loadFromYaml(yaml, proto_config_); + MessageUtil::validate(proto_config_); + } + + proto_config_.set_stat_prefix("test"); + + decoder_filter_.reset(new NiceMock()); + config_.reset(new TestConfigImpl(proto_config_, context_, decoder_filter_, stats_)); + + filter_.reset(new ConnectionManager(*config_)); + filter_->initializeReadFilterCallbacks(filter_callbacks_); + filter_->onNewConnection(); + + // NOP currently. + filter_->onAboveWriteBufferHighWatermark(); + filter_->onBelowWriteBufferLowWatermark(); + } + + void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", msg_type, seq_id); + proto->writeStructBegin(msg, "response"); + proto->writeFieldBegin(msg, "success", FieldType::String, 0); + proto->writeString(msg, "field"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writeComplexFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, + int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", msg_type, seq_id); + proto->writeStructBegin(msg, "wrapper"); // call args struct or response struct + proto->writeFieldBegin(msg, "wrapper_field", FieldType::Struct, 0); // call arg/response success + + proto->writeStructBegin(msg, "payload"); + proto->writeFieldBegin(msg, "f1", FieldType::Bool, 1); + proto->writeBool(msg, true); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f2", FieldType::Byte, 2); + proto->writeByte(msg, 2); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f3", FieldType::Double, 3); + proto->writeDouble(msg, 3.0); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f4", FieldType::I16, 4); + proto->writeInt16(msg, 4); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f5", FieldType::I32, 5); + proto->writeInt32(msg, 5); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f6", FieldType::I64, 6); + proto->writeInt64(msg, 6); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f7", FieldType::String, 7); + proto->writeString(msg, "seven"); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f8", FieldType::Map, 8); + proto->writeMapBegin(msg, FieldType::I32, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeInt32(msg, 8); + proto->writeMapEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f9", FieldType::List, 9); + proto->writeListBegin(msg, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeListEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "f10", FieldType::Set, 10); + proto->writeSetBegin(msg, FieldType::I32, 1); + proto->writeInt32(msg, 8); + proto->writeSetEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); // payload stop field + proto->writeStructEnd(msg); + proto->writeFieldEnd(msg); + + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); // wrapper stop field + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writePartialFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, + int32_t seq_id, bool start) { + Buffer::OwnedImpl frame; + writeFramedBinaryMessage(frame, msg_type, seq_id); + + if (start) { + buffer.move(frame, 27); + } else { + frame.drain(27); + buffer.move(frame); + } + } + + void writeFramedBinaryTApplicationException(Buffer::Instance& buffer, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", MessageType::Exception, seq_id); + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::String, 1); + proto->writeString(msg, "error"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::I32, 2); + proto->writeInt32(msg, 1); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + void writeFramedBinaryIDLException(Buffer::Instance& buffer, int32_t seq_id) { + Buffer::OwnedImpl msg; + ProtocolPtr proto = + NamedProtocolConfigFactory::getFactory(ProtocolType::Binary).createProtocol(); + proto->writeMessageBegin(msg, "name", MessageType::Reply, seq_id); + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::Struct, 2); + + proto->writeStructBegin(msg, ""); + proto->writeFieldBegin(msg, "", FieldType::String, 1); + proto->writeString(msg, "err"); + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + + proto->writeFieldEnd(msg); + proto->writeFieldBegin(msg, "", FieldType::Stop, 0); + proto->writeStructEnd(msg); + proto->writeMessageEnd(msg); + + TransportPtr transport = + NamedTransportConfigFactory::getFactory(TransportType::Framed).createTransport(); + transport->encodeFrame(buffer, msg); + } + + NiceMock context_; + std::shared_ptr decoder_filter_; + Stats::IsolatedStoreImpl store_; + ThriftFilterStats stats_; + envoy::extensions::filters::network::thrift_proxy::v2alpha1::ThriftProxy proto_config_; + + std::unique_ptr config_; + + Buffer::OwnedImpl buffer_; + Buffer::OwnedImpl write_buffer_; + std::unique_ptr filter_; + NiceMock filter_callbacks_; +}; + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCall) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftOneWay) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesStopIterationAndResume) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + // Nothing further happens: we're stopped. + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_EQ(1, callbacks->streamId()); + EXPECT_EQ(TransportType::Framed, callbacks->downstreamTransportType()); + EXPECT_EQ(ProtocolType::Binary, callbacks->downstreamProtocolType()); + EXPECT_EQ(&filter_callbacks_.connection_, callbacks->connection()); + + // Resume processing. + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + callbacks->continueDecoding(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesFrameSplitAcrossBuffers) { + initializeFilter(); + + writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0, buffer_.length()); + + // Complete the buffer + writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0, buffer_.length()); + + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesInvalidMsgType) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Reply, 0x0F); // reply is invalid for a request + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(0U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); + EXPECT_EQ(1U, store_.counter("test.request_invalid_type").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolError) { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1f, // framed: 31 bytes + 0x80, 0x01, 0x00, 0x01, // binary, call + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x08, 0xff, 0xff // illegal field id + }); + + std::string err = "invalid binary protocol field id -1"; + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, 0x42, // framed: 66 bytes + 0x80, 0x01, 0x00, 0x03, // binary, exception + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x0b, 0x00, 0x01, // begin string field + }); + addInt32(write_buffer_, err.length()); + addString(write_buffer_, err); + addSeq(write_buffer_, { + 0x08, 0x00, 0x02, // begin i32 field + 0x00, 0x00, 0x00, 0x07, // protocol error + 0x00, // stop field + }); + + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ(bufferToString(write_buffer_), bufferToString(buffer)); + })); + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolErrorDuringMessageBegin) { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0xff, // binary, invalid type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x00, // stop field + }); + + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, OnEvent) { + // No active calls + { + initializeFilter(); + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + EXPECT_EQ(0U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + EXPECT_EQ(0U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + } + + // Remote close mid-request + { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0x01, // binary proto, call type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x0F, // seq id + }); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Local close mid-request + { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes + 0x80, 0x01, 0x00, 0x01, // binary proto, call type + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x0F, // seq id + }); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Remote close before response + { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::RemoteClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } + + // Local close before response + { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + filter_->onEvent(Network::ConnectionEvent::LocalClose); + + EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); + + buffer_.drain(buffer_.length()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + } +} + +TEST_F(ThriftConnectionManagerTest, Routing) { + const std::string yaml = R"EOF( +transport: FRAMED +protocol: BINARY +stat_prefix: test +route_config: + name: "routes" + routes: + - match: + method: name + route: + cluster: cluster +)EOF"; + + initializeFilter(yaml); + writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, messageBegin(_, _, _)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(0U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + Router::RouteConstSharedPtr route = callbacks->route(); + EXPECT_NE(nullptr, route); + EXPECT_NE(nullptr, route->routeEntry()); + EXPECT_EQ("cluster", route->routeEntry()->clusterName()); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + callbacks->continueDecoding(); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndResponse) { + initializeFilter(); + writeComplexFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeComplexFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(1U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndExceptionResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryTApplicationException(write_buffer_, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndErrorResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + writeFramedBinaryIDLException(write_buffer_, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(1U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndInvalidResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // Call is not valid in a response + writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(1U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // illegal field id + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, 0x1f, // framed: 31 bytes + 0x80, 0x01, 0x00, 0x02, // binary, reply + 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x08, 0xff, 0xff // illegal field id + }); + + callbacks->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_CALL(*decoder_filter_, resetUpstreamConnection()); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); +} + +TEST_F(ThriftConnectionManagerTest, PipelinedRequestAndResponse) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x01); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x02); + + std::list callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillRepeatedly(Invoke( + [&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks.push_back(&cb); })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(2U, store_.gauge("test.request_active").value()); + EXPECT_EQ(2U, store_.counter("test.request").value()); + EXPECT_EQ(2U, store_.counter("test.request_call").value()); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(2); + + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x01); + callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); + callbacks.pop_front(); + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x02); + callbacks.front()->startUpstreamResponse(TransportType::Framed, ProtocolType::Binary); + EXPECT_EQ(true, callbacks.front()->upstreamData(write_buffer_)); + callbacks.pop_front(); + EXPECT_EQ(2U, store_.counter("test.response").value()); + EXPECT_EQ(2U, store_.counter("test.response_reply").value()); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +TEST_F(ThriftConnectionManagerTest, ResetDownstreamConnection) { + initializeFilter(); + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(1U, store_.gauge("test.request_active").value()); + + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::NoFlush)); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + callbacks->resetDownstreamConnection(); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 834b04d662e20..9762fddfaf007 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -7,6 +7,7 @@ #include "test/test_common/printers.h" #include "test/test_common/utility.h" +#include "absl/strings/string_view.h" #include "gmock/gmock.h" #include "gtest/gtest.h" @@ -33,48 +34,96 @@ namespace NetworkFilters { namespace ThriftProxy { namespace { -Expectation expectValue(NiceMock& proto, FieldType field_type, bool result = true) { +ExpectationSet expectValue(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type, bool result = true) { + ExpectationSet s; switch (field_type) { case FieldType::Bool: - return EXPECT_CALL(proto, readBool(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readBool(_, _)).WillOnce(Return(result)); + if (result) { + s += + EXPECT_CALL(filter, boolValue(_)).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::Byte: - return EXPECT_CALL(proto, readByte(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readByte(_, _)).WillOnce(Return(result)); + if (result) { + s += + EXPECT_CALL(filter, byteValue(_)).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::Double: - return EXPECT_CALL(proto, readDouble(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readDouble(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, doubleValue(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I16: - return EXPECT_CALL(proto, readInt16(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt16(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int16Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I32: - return EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt32(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int32Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::I64: - return EXPECT_CALL(proto, readInt64(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readInt64(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, int64Value(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; case FieldType::String: - return EXPECT_CALL(proto, readString(_, _)).WillOnce(Return(result)); + s += EXPECT_CALL(proto, readString(_, _)).WillOnce(Return(result)); + if (result) { + s += EXPECT_CALL(filter, stringValue(_)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + } + break; default: NOT_REACHED_GCOVR_EXCL_LINE; } + return s; } -ExpectationSet expectContainerStart(NiceMock& proto, FieldType field_type, - FieldType inner_type) { +ExpectationSet expectContainerStart(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type, FieldType inner_type) { ExpectationSet s; switch (field_type) { case FieldType::Struct: s += EXPECT_CALL(proto, readStructBegin(_, _)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); s += EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(inner_type), SetArgReferee<3>(1), Return(true))); + s += EXPECT_CALL(filter, fieldBegin(absl::string_view(), inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::List: s += EXPECT_CALL(proto, readListBegin(_, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(1), Return(true))); + s += EXPECT_CALL(filter, listBegin(inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Map: s += EXPECT_CALL(proto, readMapBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(inner_type), SetArgReferee<3>(1), Return(true))); + s += EXPECT_CALL(filter, mapBegin(inner_type, inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Set: s += EXPECT_CALL(proto, readSetBegin(_, _, _)) .WillOnce(DoAll(SetArgReferee<1>(inner_type), SetArgReferee<2>(1), Return(true))); + s += EXPECT_CALL(filter, setBegin(inner_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; default: NOT_REACHED_GCOVR_EXCL_LINE; @@ -82,23 +131,29 @@ ExpectationSet expectContainerStart(NiceMock& proto, FieldType fie return s; } -ExpectationSet expectContainerEnd(NiceMock& proto, FieldType field_type) { +ExpectationSet expectContainerEnd(MockProtocol& proto, ThriftFilters::MockDecoderFilter& filter, + FieldType field_type) { ExpectationSet s; switch (field_type) { case FieldType::Struct: s += EXPECT_CALL(proto, readFieldEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); s += EXPECT_CALL(proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); s += EXPECT_CALL(proto, readStructEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::List: s += EXPECT_CALL(proto, readListEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, listEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Map: s += EXPECT_CALL(proto, readMapEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, mapEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; case FieldType::Set: s += EXPECT_CALL(proto, readSetEnd(_)).WillOnce(Return(true)); + s += EXPECT_CALL(filter, setEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); break; default: NOT_REACHED_GCOVR_EXCL_LINE; @@ -153,7 +208,8 @@ TEST_P(DecoderStateMachineNonValueTest, NoData) { ProtocolState state = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; - DecoderStateMachine dsm(proto); + StrictMock filter; + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(state); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); EXPECT_EQ(dsm.currentState(), state); @@ -164,17 +220,18 @@ TEST_P(DecoderStateMachineValueTest, NoFieldValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type, false); - expectValue(proto, field_type, true); + expectValue(proto, filter, field_type, false); + expectValue(proto, filter, field_type, true); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -188,18 +245,19 @@ TEST_P(DecoderStateMachineValueTest, FieldValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(std::string("")), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::FieldBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -209,13 +267,14 @@ TEST_P(DecoderStateMachineValueTest, FieldValue) { TEST(DecoderStateMachineTest, NoListValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -225,13 +284,14 @@ TEST(DecoderStateMachineTest, NoListValueData) { TEST(DecoderStateMachineTest, EmptyList) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -242,16 +302,17 @@ TEST_P(DecoderStateMachineValueTest, ListValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -262,18 +323,19 @@ TEST_P(DecoderStateMachineValueTest, MultipleListValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readListBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, field_type); + expectValue(proto, filter, field_type); } EXPECT_CALL(proto, readListEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::ListBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -283,6 +345,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleListValues) { TEST(DecoderStateMachineTest, NoMapKeyData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -290,7 +353,7 @@ TEST(DecoderStateMachineTest, NoMapKeyData) { SetArgReferee<3>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -300,6 +363,7 @@ TEST(DecoderStateMachineTest, NoMapKeyData) { TEST(DecoderStateMachineTest, NoMapValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -308,7 +372,7 @@ TEST(DecoderStateMachineTest, NoMapValueData) { EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(true)); EXPECT_CALL(proto, readString(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -318,6 +382,7 @@ TEST(DecoderStateMachineTest, NoMapValueData) { TEST(DecoderStateMachineTest, EmptyMap) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -325,7 +390,7 @@ TEST(DecoderStateMachineTest, EmptyMap) { SetArgReferee<3>(0), Return(true))); EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -336,18 +401,19 @@ TEST_P(DecoderStateMachineValueTest, MapKeyValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(FieldType::String), SetArgReferee<3>(1), Return(true))); - expectValue(proto, field_type); // key - expectValue(proto, FieldType::String); // value + expectValue(proto, filter, field_type); // key + expectValue(proto, filter, FieldType::String); // value EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -358,18 +424,19 @@ TEST_P(DecoderStateMachineValueTest, MapValueValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); - expectValue(proto, FieldType::I32); // key - expectValue(proto, field_type); // value + expectValue(proto, filter, FieldType::I32); // key + expectValue(proto, filter, field_type); // value EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -380,6 +447,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMapBegin(Ref(buffer), _, _, _)) @@ -387,13 +455,13 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { SetArgReferee<3>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, FieldType::I32); // key - expectValue(proto, field_type); // value + expectValue(proto, filter, FieldType::I32); // key + expectValue(proto, filter, field_type); // value } EXPECT_CALL(proto, readMapEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::MapBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -403,13 +471,14 @@ TEST_P(DecoderStateMachineValueTest, MultipleMapKeyValues) { TEST(DecoderStateMachineTest, NoSetValueData) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(1), Return(true))); EXPECT_CALL(proto, readInt32(Ref(buffer), _)).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -419,13 +488,14 @@ TEST(DecoderStateMachineTest, NoSetValueData) { TEST(DecoderStateMachineTest, EmptySet) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(FieldType::I32), SetArgReferee<2>(0), Return(true))); EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -436,16 +506,17 @@ TEST_P(DecoderStateMachineValueTest, SetValue) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(1), Return(true))); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -456,18 +527,19 @@ TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readSetBegin(Ref(buffer), _, _)) .WillOnce(DoAll(SetArgReferee<1>(field_type), SetArgReferee<2>(5), Return(true))); for (int i = 0; i < 5; i++) { - expectValue(proto, field_type); + expectValue(proto, filter, field_type); } EXPECT_CALL(proto, readSetEnd(Ref(buffer))).WillOnce(Return(false)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); dsm.setCurrentState(ProtocolState::SetBegin); EXPECT_EQ(dsm.run(buffer), ProtocolState::WaitForData); @@ -477,6 +549,7 @@ TEST_P(DecoderStateMachineValueTest, MultipleSetValues) { TEST(DecoderStateMachineTest, EmptyStruct) { Buffer::OwnedImpl buffer; NiceMock proto; + NiceMock filter; InSequence dummy; EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) @@ -488,7 +561,7 @@ TEST(DecoderStateMachineTest, EmptyStruct) { EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -498,24 +571,39 @@ TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { FieldType field_type = GetParam(); Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(1), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -524,6 +612,7 @@ TEST_P(DecoderStateMachineValueTest, SingleFieldStruct) { TEST(DecoderStateMachineTest, MultiFieldStruct) { Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; std::vector field_types = {FieldType::Bool, FieldType::Byte, FieldType::Double, @@ -533,24 +622,36 @@ TEST(DecoderStateMachineTest, MultiFieldStruct) { EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); int16_t field_id = 1; for (FieldType field_type : field_types) { EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) - .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(field_id++), Return(true))); + .WillOnce(DoAll(SetArgReferee<2>(field_type), SetArgReferee<3>(field_id), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), field_type, field_id)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + field_id++; - expectValue(proto, field_type); + expectValue(proto, filter, field_type); EXPECT_CALL(proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); } EXPECT_CALL(proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -562,35 +663,41 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { Buffer::OwnedImpl buffer; NiceMock proto; + StrictMock filter; InSequence dummy; // start of message and outermost struct EXPECT_CALL(proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); - expectContainerStart(proto, FieldType::Struct, outer_field_type); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + + expectContainerStart(proto, filter, FieldType::Struct, outer_field_type); - expectContainerStart(proto, outer_field_type, inner_type); + expectContainerStart(proto, filter, outer_field_type, inner_type); int outer_reps = outer_field_type == FieldType::Map ? 2 : 1; for (int i = 0; i < outer_reps; i++) { - expectContainerStart(proto, inner_type, value_type); + expectContainerStart(proto, filter, inner_type, value_type); int inner_reps = inner_type == FieldType::Map ? 2 : 1; for (int j = 0; j < inner_reps; j++) { - expectValue(proto, value_type); + expectValue(proto, filter, value_type); } - expectContainerEnd(proto, inner_type); + expectContainerEnd(proto, filter, inner_type); } - expectContainerEnd(proto, outer_field_type); + expectContainerEnd(proto, filter, outer_field_type); // end of message and outermost struct - expectContainerEnd(proto, FieldType::Struct); + expectContainerEnd(proto, filter, FieldType::Struct); + EXPECT_CALL(proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - DecoderStateMachine dsm(proto); + DecoderStateMachine dsm(proto, filter); EXPECT_EQ(dsm.run(buffer), ProtocolState::Done); EXPECT_EQ(dsm.currentState(), ProtocolState::Done); @@ -599,63 +706,135 @@ TEST_P(DecoderStateMachineNestingTest, NestedTypes) { TEST(DecoderTest, OnData) { NiceMock* transport = new NiceMock(); NiceMock* proto = new NiceMock(); + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(filter, transportBegin(absl::optional(100))) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(filter, transportEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); - decoder.onData(buffer); + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); } TEST(DecoderTest, OnDataResumes) { NiceMock* transport = new NiceMock(); NiceMock* proto = new NiceMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; + buffer.add("x"); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(false)); - decoder.onData(buffer); + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); + EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); // buffer.length() == 1 +} + +TEST(DecoderTest, OnDataResumesTransportFrameStart) { + StrictMock* transport = new StrictMock(); + StrictMock* proto = new StrictMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + EXPECT_CALL(*transport, name()).Times(AnyNumber()); + EXPECT_CALL(*proto, name()).Times(AnyNumber()); + + InSequence dummy; + + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + bool underflow = false; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(false))); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) + .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), + SetArgReferee<3>(100), Return(true))); EXPECT_CALL(*proto, readStructBegin(_, _)).WillOnce(Return(true)); EXPECT_CALL(*proto, readFieldBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(false)); - decoder.onData(buffer); + + underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 } TEST(DecoderTest, OnDataResumesTransportFrameEnd) { StrictMock* transport = new StrictMock(); StrictMock* proto = new StrictMock(); + NiceMock callbacks; + NiceMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); EXPECT_CALL(*transport, name()).Times(AnyNumber()); EXPECT_CALL(*proto, name()).Times(AnyNumber()); InSequence dummy; - Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}); + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); Buffer::OwnedImpl buffer; - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(true)); + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); EXPECT_CALL(*proto, readMessageBegin(_, _, _, _)) .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), SetArgReferee<3>(100), Return(true))); @@ -665,11 +844,90 @@ TEST(DecoderTest, OnDataResumesTransportFrameEnd) { EXPECT_CALL(*proto, readStructEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*proto, readMessageEnd(_)).WillOnce(Return(true)); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(false)); - decoder.onData(buffer); + + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); EXPECT_CALL(*transport, decodeFrameEnd(_)).WillOnce(Return(true)); - EXPECT_CALL(*transport, decodeFrameStart(_)).WillOnce(Return(false)); - decoder.onData(buffer); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); // buffer.length() == 0 +} + +TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { + + StrictMock* transport = new StrictMock(); + EXPECT_CALL(*transport, name()).WillRepeatedly(ReturnRef(transport->name_)); + + StrictMock* proto = new StrictMock(); + EXPECT_CALL(*proto, name()).WillRepeatedly(ReturnRef(proto->name_)); + + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + InSequence dummy; + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + bool underflow = true; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(DoAll(SetArgReferee<1>(absl::optional(100)), Return(true))); + EXPECT_CALL(filter, transportBegin(absl::optional(100))) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<1>("name"), SetArgReferee<2>(MessageType::Call), + SetArgReferee<3>(100), Return(true))); + EXPECT_CALL(filter, messageBegin(absl::string_view("name"), MessageType::Call, 100)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::I32), SetArgReferee<3>(1), Return(true))); + EXPECT_CALL(filter, fieldBegin(absl::string_view(), FieldType::I32, 1)) + .WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readInt32(_, _)).WillOnce(Return(true)); + EXPECT_CALL(filter, int32Value(_)).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, fieldEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, transportEnd()).WillOnce(Return(ThriftFilters::FilterStatus::StopIteration)); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, decoder.onData(buffer, underflow)); + EXPECT_FALSE(underflow); + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); } #define TEST_NAME(X) EXPECT_EQ(ProtocolStateNameValues::name(ProtocolState::X), #X); diff --git a/test/extensions/filters/network/thrift_proxy/filter_test.cc b/test/extensions/filters/network/thrift_proxy/filter_test.cc deleted file mode 100644 index d192e47372098..0000000000000 --- a/test/extensions/filters/network/thrift_proxy/filter_test.cc +++ /dev/null @@ -1,559 +0,0 @@ -#include "common/buffer/buffer_impl.h" -#include "common/stats/stats_impl.h" - -#include "extensions/filters/network/thrift_proxy/buffer_helper.h" -#include "extensions/filters/network/thrift_proxy/filter.h" - -#include "test/extensions/filters/network/thrift_proxy/utility.h" -#include "test/mocks/network/mocks.h" -#include "test/test_common/printers.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::NiceMock; - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -class ThriftFilterTest : public testing::Test { -public: - ThriftFilterTest() {} - - void initializeFilter() { - for (auto counter : store_.counters()) { - counter->reset(); - } - - filter_.reset(new Filter("test.", store_)); - filter_->initializeReadFilterCallbacks(read_filter_callbacks_); - filter_->onNewConnection(); - - // NOP currently. - filter_->onAboveWriteBufferHighWatermark(); - filter_->onBelowWriteBufferLowWatermark(); - } - - void writeFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, int32_t seq_id) { - uint8_t mt = static_cast(msg_type); - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, mt, // binary proto, type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0b, 0x00, 0x00, // begin string field - 0x00, 0x00, 0x00, 0x05, 'f', 'i', 'e', 'l', 'd', // string - 0x00, // stop field - }); - } - - void writePartialFramedBinaryMessage(Buffer::Instance& buffer, MessageType msg_type, - int32_t seq_id, bool start) { - if (start) { - uint8_t mt = static_cast(msg_type); - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x2d, // framed: 45 bytes - 0x80, 0x01, 0x00, mt, // binary proto, type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0c, 0x00, 0x00, // begin struct field - 0x0b, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x05 // string length only - }); - } else { - addSeq(buffer, { - 'f', 'i', 'e', 'l', 'd', // string data - 0x0b, 0x00, 0x02, // begin string field - 0x00, 0x00, 0x00, 0x05, 'x', 'x', 'x', 'x', 'x', // string - 0x00, // stop field - 0x00, // stop field - }); - } - } - - void writeFramedBinaryTApplicationException(Buffer::Instance& buffer, int32_t seq_id) { - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x24, // framed: 36 bytes - 0x80, 0x01, 0x00, 0x03, // binary, exception - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0B, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x05, 'e', 'r', 'r', 'o', 'r', // string - 0x08, 0x00, 0x02, // begin i32 field - 0x00, 0x00, 0x00, 0x01, // exception type 1 - 0x00, // stop field - }); - } - - void writeFramedBinaryIDLException(Buffer::Instance& buffer, int32_t seq_id) { - uint8_t s1 = (seq_id >> 24) & 0xFF; - uint8_t s2 = (seq_id >> 16) & 0xFF; - uint8_t s3 = (seq_id >> 8) & 0xFF; - uint8_t s4 = seq_id & 0xFF; - - addSeq(buffer, { - 0x00, 0x00, 0x00, 0x23, // framed: 35 bytes - 0x80, 0x01, 0x00, 0x02, // binary proto, reply - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - s1, s2, s3, s4, // sequence id - 0x0C, 0x00, 0x02, // begin exception struct - 0x0B, 0x00, 0x01, // begin string field - 0x00, 0x00, 0x00, 0x03, 'e', 'r', 'r', // string - 0x00, // exception struct stop - 0x00, // reply struct stop field - }); - } - - Buffer::OwnedImpl buffer_; - Buffer::OwnedImpl write_buffer_; - Stats::IsolatedStoreImpl store_; - std::unique_ptr filter_; - NiceMock read_filter_callbacks_; -}; - -TEST_F(ThriftFilterTest, OnDataHandlesThriftCall) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesThriftOneWay) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(0U, store_.counter("test.request_call").value()); - EXPECT_EQ(1U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(0U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesFrameSplitAcrossBuffers) { - initializeFilter(); - - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); - std::string expected_contents = bufferToString(buffer_); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - // Filter passes on the partial buffer, up to the last 4 bytes which it needs to resume the - // decoder on the next call. - std::string contents = bufferToString(buffer_); - EXPECT_EQ(len - 4, buffer_.length()); - EXPECT_EQ(expected_contents.substr(0, len - 4), contents); - - buffer_.drain(buffer_.length()); - - // Complete the buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); - expected_contents = expected_contents.substr(len - 4) + bufferToString(buffer_); - len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - // Filter buffered bytes from end of first buffer and passes them on now. - contents = bufferToString(buffer_); - EXPECT_EQ(len + 4, buffer_.length()); - EXPECT_EQ(expected_contents, contents); - - EXPECT_EQ(1U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesInvalidMsgType) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Reply, 0x0F); // reply is invalid for a request - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(0U, store_.counter("test.request_call").value()); - EXPECT_EQ(0U, store_.counter("test.request_oneway").value()); - EXPECT_EQ(1U, store_.counter("test.request_invalid_type").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - EXPECT_EQ(0U, store_.counter("test.response").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesProtocolError) { - initializeFilter(); - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); - EXPECT_EQ(len, buffer_.length()); - - // Sniffing is now disabled. - buffer_.drain(buffer_.length()); - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(0U, store_.counter("test.request").value()); -} - -TEST_F(ThriftFilterTest, OnDataHandlesProtocolErrorOnWrite) { - initializeFilter(); - - // Start the read buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, true); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - len -= buffer_.length(); - - // Disable sniffing - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); - - // Complete the read buffer - writePartialFramedBinaryMessage(buffer_, MessageType::Call, 0x10, false); - len += buffer_.length(); - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - len -= buffer_.length(); - EXPECT_EQ(0, len); -} - -TEST_F(ThriftFilterTest, OnDataStopsSniffingWithTooManyPendingCalls) { - initializeFilter(); - for (int i = 0; i < 64; i++) { - writeFramedBinaryMessage(buffer_, MessageType::Call, i); - } - - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(64U, store_.gauge("test.request_active").value()); - buffer_.drain(buffer_.length()); - - // Sniffing is now disabled. - writeFramedBinaryMessage(buffer_, MessageType::Oneway, 100); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(64U, store_.gauge("test.request_active").value()); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesThriftReply) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesOutOrOrderThriftReply) { - initializeFilter(); - - // set up two requests - writeFramedBinaryMessage(buffer_, MessageType::Call, 1); - writeFramedBinaryMessage(buffer_, MessageType::Call, 2); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(2U, store_.counter("test.request").value()); - EXPECT_EQ(2U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 2); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - write_buffer_.drain(write_buffer_.length()); - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 1); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(2U, store_.counter("test.response").value()); - EXPECT_EQ(2U, store_.counter("test.response_reply").value()); - EXPECT_EQ(2U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesFrameSplitAcrossBuffers) { - initializeFilter(); - - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F, true); - std::string expected_contents = bufferToString(write_buffer_); - uint64_t len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - // Filter passes on the partial buffer, up to the last 4 bytes which it needs to resume the - // decoder on the next call. - std::string contents = bufferToString(write_buffer_); - EXPECT_EQ(len - 4, write_buffer_.length()); - EXPECT_EQ(expected_contents.substr(0, len - 4), contents); - - write_buffer_.drain(write_buffer_.length()); - - // Complete the buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F, false); - expected_contents = expected_contents.substr(len - 4) + bufferToString(write_buffer_); - len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - // Filter buffered bytes from end of first buffer and passes them on now. - contents = bufferToString(write_buffer_); - EXPECT_EQ(len + 4, write_buffer_.length()); - EXPECT_EQ(expected_contents, contents); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(1U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesTApplicationException) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryTApplicationException(write_buffer_, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(0U, store_.counter("test.response_reply").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(1U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesIDLException) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); // set up request - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryIDLException(write_buffer_, 0x0F); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(1U, store_.counter("test.response_reply").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(1U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesInvalidMsgType) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request").value()); - EXPECT_EQ(1U, store_.gauge("test.request_active").value()); - - writeFramedBinaryMessage(write_buffer_, MessageType::Call, 0x0F); // call is invalid for response - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response").value()); - EXPECT_EQ(0U, store_.counter("test.response_success").value()); - EXPECT_EQ(0U, store_.counter("test.response_error").value()); - EXPECT_EQ(0U, store_.counter("test.response_exception").value()); - EXPECT_EQ(1U, store_.counter("test.response_invalid_type").value()); - EXPECT_EQ(0U, store_.gauge("test.request_active").value()); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesProtocolError) { - initializeFilter(); - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x01, // sequence id - 0x00, // struct stop field - }); - uint64_t len = buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); - EXPECT_EQ(len, buffer_.length()); - - // Sniffing is now disabled. - write_buffer_.drain(write_buffer_.length()); - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 1); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); -} - -TEST_F(ThriftFilterTest, OnWriteHandlesProtocolErrorOnData) { - initializeFilter(); - - // Set up a request for the partial write - writeFramedBinaryMessage(buffer_, MessageType::Call, 1); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - buffer_.drain(buffer_.length()); - - // Start the write buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 1, true); - uint64_t len = write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - len -= write_buffer_.length(); - - // Force an error on the next request. - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0xFF, // binary, illegal type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x02, // sequence id - 0x00, // struct stop field - }); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); - - // Complete the read buffer - writePartialFramedBinaryMessage(write_buffer_, MessageType::Reply, 1, false); - len += write_buffer_.length(); - - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - len -= write_buffer_.length(); - EXPECT_EQ(0, len); -} - -TEST_F(ThriftFilterTest, OnEvent) { - // No active calls - { - initializeFilter(); - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(0U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - EXPECT_EQ(0U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - } - - // Close mid-request - { - initializeFilter(); - addSeq(buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0x01, // binary proto, call type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x0F, // seq id - }); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - } - - // Close before response - { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - } - - // Close mid-response - { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - addSeq(write_buffer_, { - 0x00, 0x00, 0x00, 0x1d, // framed: 29 bytes - 0x80, 0x01, 0x00, 0x02, // binary proto, reply type - 0x00, 0x00, 0x00, 0x04, 'n', 'a', 'm', 'e', // message name - 0x00, 0x00, 0x00, 0x0F, // seq id - }); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - filter_->onEvent(Network::ConnectionEvent::RemoteClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_local_with_active_rq").value()); - - filter_->onEvent(Network::ConnectionEvent::LocalClose); - EXPECT_EQ(1U, store_.counter("test.cx_destroy_remote_with_active_rq").value()); - - buffer_.drain(buffer_.length()); - write_buffer_.drain(write_buffer_.length()); - } -} - -TEST_F(ThriftFilterTest, ResponseWithUnknownSequenceID) { - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::Continue); - - writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x10); - EXPECT_EQ(filter_->onWrite(write_buffer_, false), Network::FilterStatus::Continue); - - EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc index 76ed15709a2f3..2999de7edbf53 100644 --- a/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/framed_transport_impl_test.cc @@ -4,80 +4,80 @@ #include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" -#include "test/mocks/buffer/mocks.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" -#include "gmock/gmock.h" #include "gtest/gtest.h" -using testing::StrictMock; - namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { TEST(FramedTransportTest, Name) { - StrictMock cb; - FramedTransportImpl transport(cb); + FramedTransportImpl transport; EXPECT_EQ(transport.name(), "framed"); } +TEST(FramedTransportTest, Type) { + FramedTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Framed); +} + TEST(FramedTransportTest, NotEnoughData) { Buffer::OwnedImpl buffer; - StrictMock cb; - FramedTransportImpl transport(cb); + FramedTransportImpl transport; + absl::optional size = 1; - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(1), size); addRepeated(buffer, 3, 0); - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(1), size); } TEST(FramedTransportTest, InvalidFrameSize) { - StrictMock cb; - FramedTransportImpl transport(cb); + FramedTransportImpl transport; { Buffer::OwnedImpl buffer; addInt32(buffer, -1); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 1; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "invalid thrift framed transport frame size -1"); + EXPECT_EQ(absl::optional(1), size); } { Buffer::OwnedImpl buffer; addInt32(buffer, 0x7fffffff); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 1; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "invalid thrift framed transport frame size 2147483647"); + EXPECT_EQ(absl::optional(1), size); } } TEST(FramedTransportTest, DecodeFrameStart) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameStart(absl::optional(100U))); - - FramedTransportImpl transport(cb); + FramedTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 100); - EXPECT_EQ(buffer.length(), 4); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100U), size); EXPECT_EQ(buffer.length(), 0); } TEST(FramedTransportTest, DecodeFrameEnd) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameComplete()); - - FramedTransportImpl transport(cb); + FramedTransportImpl transport; Buffer::OwnedImpl buffer; @@ -85,9 +85,7 @@ TEST(FramedTransportTest, DecodeFrameEnd) { } TEST(FramedTransportTest, EncodeFrame) { - StrictMock cb; - - FramedTransportImpl transport(cb); + FramedTransportImpl transport; { Buffer::OwnedImpl message; diff --git a/test/extensions/filters/network/thrift_proxy/filter_integration_test.cc b/test/extensions/filters/network/thrift_proxy/integration_test.cc similarity index 88% rename from test/extensions/filters/network/thrift_proxy/filter_integration_test.cc rename to test/extensions/filters/network/thrift_proxy/integration_test.cc index 8317a11d8a123..fe2b5cafb054a 100644 --- a/test/extensions/filters/network/thrift_proxy/filter_integration_test.cc +++ b/test/extensions/filters/network/thrift_proxy/integration_test.cc @@ -29,11 +29,11 @@ enum class CallResult { Exception, }; -class ThriftFilterIntegrationTest +class ThriftConnManagerIntegrationTest : public BaseIntegrationTest, public TestWithParam> { public: - ThriftFilterIntegrationTest() + ThriftConnManagerIntegrationTest() : BaseIntegrationTest(Network::Address::IpVersion::v4, thrift_config) {} static void SetUpTestCase() { @@ -43,10 +43,17 @@ class ThriftFilterIntegrationTest - name: envoy.filters.network.thrift_proxy config: stat_prefix: thrift_stats - - name: envoy.tcp_proxy - config: - stat_prefix: tcp_stats - cluster: cluster_0 + route_config: + name: "routes" + routes: + - match: + method_name: "execute" + route: + cluster: "cluster_0" + - match: + method_name: "poke" + route: + cluster: "cluster_0" )EOF"; } @@ -162,12 +169,12 @@ paramToString(const TestParamInfo>& p } INSTANTIATE_TEST_CASE_P( - TransportAndProtocol, ThriftFilterIntegrationTest, + TransportAndProtocol, ThriftConnManagerIntegrationTest, Combine(Values(TransportNames::get().FRAMED, TransportNames::get().UNFRAMED), Values(ProtocolNames::get().BINARY, ProtocolNames::get().COMPACT), Values(false, true)), paramToString); -TEST_P(ThriftFilterIntegrationTest, Success) { +TEST_P(ThriftConnManagerIntegrationTest, Success) { initializeCall(CallResult::Success); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -176,13 +183,12 @@ TEST_P(ThriftFilterIntegrationTest, Success) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); fake_upstream_connection->write(response_bytes_.toString()); tcp_client->waitForData(response_bytes_.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); @@ -192,7 +198,7 @@ TEST_P(ThriftFilterIntegrationTest, Success) { EXPECT_EQ(1U, counter->value()); } -TEST_P(ThriftFilterIntegrationTest, IDLException) { +TEST_P(ThriftConnManagerIntegrationTest, IDLException) { initializeCall(CallResult::IDLException); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -201,13 +207,12 @@ TEST_P(ThriftFilterIntegrationTest, IDLException) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); fake_upstream_connection->write(response_bytes_.toString()); tcp_client->waitForData(response_bytes_.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); @@ -217,7 +222,7 @@ TEST_P(ThriftFilterIntegrationTest, IDLException) { EXPECT_EQ(1U, counter->value()); } -TEST_P(ThriftFilterIntegrationTest, Exception) { +TEST_P(ThriftConnManagerIntegrationTest, Exception) { initializeCall(CallResult::Exception); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -226,13 +231,12 @@ TEST_P(ThriftFilterIntegrationTest, Exception) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); fake_upstream_connection->write(response_bytes_.toString()); tcp_client->waitForData(response_bytes_.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); EXPECT_TRUE(TestUtility::buffersEqual(Buffer::OwnedImpl(tcp_client->data()), response_bytes_)); @@ -242,7 +246,7 @@ TEST_P(ThriftFilterIntegrationTest, Exception) { EXPECT_EQ(1U, counter->value()); } -TEST_P(ThriftFilterIntegrationTest, Oneway) { +TEST_P(ThriftConnManagerIntegrationTest, Oneway) { initializeOneway(); IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); @@ -251,10 +255,9 @@ TEST_P(ThriftFilterIntegrationTest, Oneway) { FakeRawConnectionPtr fake_upstream_connection = fake_upstreams_[0]->waitForRawConnection(); Buffer::OwnedImpl upstream_request( fake_upstream_connection->waitForData(request_bytes_.length())); - EXPECT_TRUE(TestUtility::buffersEqual(upstream_request, request_bytes_)); + EXPECT_EQ(request_bytes_.toString(), upstream_request.toString()); tcp_client->close(); - fake_upstream_connection->waitForDisconnect(); Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_oneway"); EXPECT_EQ(1U, counter->value()); diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index b44d9b95dab57..caa93654233e8 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -2,25 +2,77 @@ #include "gtest/gtest.h" +using testing::Return; using testing::ReturnRef; +using testing::_; namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -MockTransportCallbacks::MockTransportCallbacks() {} -MockTransportCallbacks::~MockTransportCallbacks() {} +MockConfig::MockConfig() {} +MockConfig::~MockConfig() {} -MockTransport::MockTransport() { ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); } +MockTransport::MockTransport() { + ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); + ON_CALL(*this, type()).WillByDefault(Return(type_)); +} MockTransport::~MockTransport() {} -MockProtocolCallbacks::MockProtocolCallbacks() {} -MockProtocolCallbacks::~MockProtocolCallbacks() {} - -MockProtocol::MockProtocol() { ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); } +MockProtocol::MockProtocol() { + ON_CALL(*this, name()).WillByDefault(ReturnRef(name_)); + ON_CALL(*this, type()).WillByDefault(Return(type_)); +} MockProtocol::~MockProtocol() {} +MockDecoderCallbacks::MockDecoderCallbacks() {} +MockDecoderCallbacks::~MockDecoderCallbacks() {} + +namespace ThriftFilters { + +MockDecoderFilter::MockDecoderFilter() { + ON_CALL(*this, transportBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, transportEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, boolValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, byteValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int16Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int32Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int64Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, doubleValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, stringValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setEnd()).WillByDefault(Return(FilterStatus::Continue)); +} +MockDecoderFilter::~MockDecoderFilter() {} + +MockDecoderFilterCallbacks::MockDecoderFilterCallbacks() { + ON_CALL(*this, streamId()).WillByDefault(Return(stream_id_)); + ON_CALL(*this, connection()).WillByDefault(Return(&connection_)); +} +MockDecoderFilterCallbacks::~MockDecoderFilterCallbacks() {} + +} // namespace ThriftFilters + +namespace Router { + +MockRouteEntry::MockRouteEntry() {} +MockRouteEntry::~MockRouteEntry() {} + +MockRoute::MockRoute() {} +MockRoute::~MockRoute() {} + +} // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index 1668c42593ef9..f932bc808d418 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -1,25 +1,33 @@ #pragma once +#include "extensions/filters/network/thrift_proxy/conn_manager.h" +#include "extensions/filters/network/thrift_proxy/filters/filter.h" #include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/router/router.h" #include "extensions/filters/network/thrift_proxy/transport.h" +#include "test/mocks/network/mocks.h" #include "test/test_common/printers.h" #include "gmock/gmock.h" +using testing::NiceMock; + namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -class MockTransportCallbacks : public TransportCallbacks { +class MockConfig : public Config { public: - MockTransportCallbacks(); - ~MockTransportCallbacks(); - - // ThriftProxy::TransportCallbacks - MOCK_METHOD1(transportFrameStart, void(absl::optional size)); - MOCK_METHOD0(transportFrameComplete, void()); + MockConfig(); + ~MockConfig(); + + // ThriftProxy::Config + MOCK_METHOD0(filterFactory, ThriftFilters::FilterChainFactory&()); + MOCK_METHOD0(stats, ThriftFilterStats&()); + MOCK_METHOD1(createDecoder, DecoderPtr(DecoderCallbacks&)); + MOCK_METHOD0(routerConfig, Router::Config&()); }; class MockTransport : public Transport { @@ -29,24 +37,13 @@ class MockTransport : public Transport { // ThriftProxy::Transport MOCK_CONST_METHOD0(name, const std::string&()); - MOCK_METHOD1(decodeFrameStart, bool(Buffer::Instance&)); + MOCK_CONST_METHOD0(type, TransportType()); + MOCK_METHOD2(decodeFrameStart, bool(Buffer::Instance&, absl::optional&)); MOCK_METHOD1(decodeFrameEnd, bool(Buffer::Instance&)); MOCK_METHOD2(encodeFrame, void(Buffer::Instance&, Buffer::Instance&)); std::string name_{"mock"}; -}; - -class MockProtocolCallbacks : public ProtocolCallbacks { -public: - MockProtocolCallbacks(); - ~MockProtocolCallbacks(); - - // ThriftProxy::ProtocolCallbacks - MOCK_METHOD3(messageStart, void(const absl::string_view, MessageType, int32_t)); - MOCK_METHOD1(structBegin, void(const absl::string_view)); - MOCK_METHOD3(structField, void(const absl::string_view, FieldType, int16_t)); - MOCK_METHOD0(structEnd, void()); - MOCK_METHOD0(messageComplete, void()); + TransportType type_{TransportType::Auto}; }; class MockProtocol : public Protocol { @@ -56,6 +53,7 @@ class MockProtocol : public Protocol { // ThriftProxy::Protocol MOCK_CONST_METHOD0(name, const std::string&()); + MOCK_CONST_METHOD0(type, ProtocolType()); MOCK_METHOD4(readMessageBegin, bool(Buffer::Instance& buffer, std::string& name, MessageType& msg_type, int32_t& seq_id)); MOCK_METHOD1(readMessageEnd, bool(Buffer::Instance& buffer)); @@ -105,8 +103,100 @@ class MockProtocol : public Protocol { MOCK_METHOD2(writeBinary, void(Buffer::Instance& buffer, const std::string& value)); std::string name_{"mock"}; + ProtocolType type_{ProtocolType::Auto}; +}; + +class MockDecoderCallbacks : public DecoderCallbacks { +public: + MockDecoderCallbacks(); + ~MockDecoderCallbacks(); + + // ThriftProxy::DecoderCallbacks + MOCK_METHOD0(newDecoderFilter, ThriftFilters::DecoderFilter&()); +}; + +namespace ThriftFilters { + +class MockDecoderFilter : public DecoderFilter { +public: + MockDecoderFilter(); + ~MockDecoderFilter(); + + // ThriftProxy::ThriftFilters::DecoderFilter + MOCK_METHOD0(onDestroy, void()); + MOCK_METHOD1(setDecoderFilterCallbacks, void(DecoderFilterCallbacks& callbacks)); + MOCK_METHOD0(resetUpstreamConnection, void()); + MOCK_METHOD1(transportBegin, FilterStatus(absl::optional size)); + MOCK_METHOD0(transportEnd, FilterStatus()); + MOCK_METHOD3(messageBegin, + FilterStatus(const absl::string_view name, MessageType msg_type, int32_t seq_id)); + MOCK_METHOD0(messageEnd, FilterStatus()); + MOCK_METHOD1(structBegin, FilterStatus(const absl::string_view name)); + MOCK_METHOD0(structEnd, FilterStatus()); + MOCK_METHOD3(fieldBegin, + FilterStatus(const absl::string_view name, FieldType msg_type, int16_t field_id)); + MOCK_METHOD0(fieldEnd, FilterStatus()); + MOCK_METHOD1(boolValue, FilterStatus(bool value)); + MOCK_METHOD1(byteValue, FilterStatus(uint8_t value)); + MOCK_METHOD1(int16Value, FilterStatus(int16_t value)); + MOCK_METHOD1(int32Value, FilterStatus(int32_t value)); + MOCK_METHOD1(int64Value, FilterStatus(int64_t value)); + MOCK_METHOD1(doubleValue, FilterStatus(double value)); + MOCK_METHOD1(stringValue, FilterStatus(absl::string_view value)); + MOCK_METHOD3(mapBegin, FilterStatus(FieldType key_type, FieldType value_type, uint32_t size)); + MOCK_METHOD0(mapEnd, FilterStatus()); + MOCK_METHOD2(listBegin, FilterStatus(FieldType elem_type, uint32_t size)); + MOCK_METHOD0(listEnd, FilterStatus()); + MOCK_METHOD2(setBegin, FilterStatus(FieldType elem_type, uint32_t size)); + MOCK_METHOD0(setEnd, FilterStatus()); +}; + +class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { +public: + MockDecoderFilterCallbacks(); + ~MockDecoderFilterCallbacks(); + + // ThriftProxy::ThriftFilters::DecoderFilterCallbacks + MOCK_CONST_METHOD0(streamId, uint64_t()); + MOCK_CONST_METHOD0(connection, const Network::Connection*()); + MOCK_METHOD0(continueDecoding, void()); + MOCK_METHOD0(route, Router::RouteConstSharedPtr()); + MOCK_CONST_METHOD0(downstreamTransportType, TransportType()); + MOCK_CONST_METHOD0(downstreamProtocolType, ProtocolType()); + void sendLocalReply(DirectResponsePtr&& response) override { sendLocalReply_(response); } + MOCK_METHOD2(startUpstreamResponse, void(TransportType, ProtocolType)); + MOCK_METHOD1(upstreamData, bool(Buffer::Instance&)); + MOCK_METHOD0(resetDownstreamConnection, void()); + + MOCK_METHOD1(sendLocalReply_, void(DirectResponsePtr&)); + + uint64_t stream_id_{1}; + NiceMock connection_; +}; + +} // namespace ThriftFilters + +namespace Router { + +class MockRouteEntry : public RouteEntry { +public: + MockRouteEntry(); + ~MockRouteEntry(); + + // ThriftProxy::Router::RouteEntry + MOCK_CONST_METHOD0(clusterName, const std::string&()); +}; + +class MockRoute : public Route { +public: + MockRoute(); + ~MockRoute(); + + // ThriftProxy::Router::Route + MOCK_CONST_METHOD0(routeEntry, const RouteEntry*()); }; +} // namespace Router } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc index 58420a0492622..7a8fef74a1490 100644 --- a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc @@ -24,10 +24,16 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +TEST(ProtocolNames, FromType) { + for (int i = 0; i <= static_cast(ProtocolType::LastProtocolType); i++) { + ProtocolType type = static_cast(i); + EXPECT_NE("", ProtocolNames::get().fromType(type)); + } +} + TEST(AutoProtocolTest, NotEnoughData) { Buffer::OwnedImpl buffer; - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -41,8 +47,7 @@ TEST(AutoProtocolTest, NotEnoughData) { TEST(AutoProtocolTest, UnknownProtocol) { Buffer::OwnedImpl buffer; - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -59,8 +64,7 @@ TEST(AutoProtocolTest, UnknownProtocol) { TEST(AutoProtocolTest, ReadMessageBegin) { // Binary Protocol { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = -1; @@ -79,12 +83,12 @@ TEST(AutoProtocolTest, ReadMessageBegin) { EXPECT_EQ(seq_id, 1); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "binary(auto)"); + EXPECT_EQ(proto.type(), ProtocolType::Binary); } // Compact protocol { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; std::string name = "-"; MessageType msg_type = MessageType::Oneway; int32_t seq_id = 1; @@ -101,13 +105,13 @@ TEST(AutoProtocolTest, ReadMessageBegin) { EXPECT_EQ(seq_id, 0x0102); EXPECT_EQ(buffer.length(), 0); EXPECT_EQ(proto.name(), "compact(auto)"); + EXPECT_EQ(proto.type(), ProtocolType::Compact); } } TEST(AutoProtocolTest, ReadDelegation) { NiceMock* proto = new NiceMock(); - NiceMock dummy_cb; - AutoProtocolImpl auto_proto(dummy_cb); + AutoProtocolImpl auto_proto; auto_proto.setProtocol(ProtocolPtr{proto}); // readMessageBegin @@ -232,8 +236,7 @@ TEST(AutoProtocolTest, ReadDelegation) { TEST(AutoProtocolTest, WriteDelegation) { NiceMock* proto = new NiceMock(); - NiceMock dummy_cb; - AutoProtocolImpl auto_proto(dummy_cb); + AutoProtocolImpl auto_proto; auto_proto.setProtocol(ProtocolPtr{proto}); // writeMessageBegin @@ -319,11 +322,15 @@ TEST(AutoProtocolTest, WriteDelegation) { } TEST(AutoProtocolTest, Name) { - NiceMock cb; - AutoProtocolImpl proto(cb); + AutoProtocolImpl proto; EXPECT_EQ(proto.name(), "auto"); } +TEST(AutoProtocolTest, Type) { + AutoProtocolImpl proto; + EXPECT_EQ(proto.type(), ProtocolType::Auto); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/router_test.cc b/test/extensions/filters/network/thrift_proxy/router_test.cc new file mode 100644 index 0000000000000..9b623f37df2d4 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/router_test.cc @@ -0,0 +1,631 @@ +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.pb.h" +#include "envoy/extensions/filters/network/thrift_proxy/v2alpha1/route.pb.validate.h" +#include "envoy/tcp/conn_pool.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/router/config.h" +#include "extensions/filters/network/thrift_proxy/router/router_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/network/mocks.h" +#include "test/mocks/server/mocks.h" +#include "test/mocks/upstream/mocks.h" +#include "test/test_common/printers.h" +#include "test/test_common/registry.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ContainsRegex; +using testing::Invoke; +using testing::NiceMock; +using testing::Ref; +using testing::Return; +using testing::ReturnRef; +using testing::Test; +using testing::TestWithParam; +using testing::Values; +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace Router { + +namespace { + +envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration +parseRouteConfigurationFromV2Yaml(const std::string& yaml) { + envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration route_config; + MessageUtil::loadFromYaml(yaml, route_config); + MessageUtil::validate(route_config); + return route_config; +} + +class TestNamedTransportConfigFactory : public NamedTransportConfigFactory { +public: + TestNamedTransportConfigFactory(std::function f) : f_(f) {} + + TransportPtr createTransport() override { return TransportPtr{f_()}; } + std::string name() override { return TransportNames::get().FRAMED; } + + std::function f_; +}; + +class TestNamedProtocolConfigFactory : public NamedProtocolConfigFactory { +public: + TestNamedProtocolConfigFactory(std::function f) : f_(f) {} + + ProtocolPtr createProtocol() override { return ProtocolPtr{f_()}; } + std::string name() override { return ProtocolNames::get().BINARY; } + + std::function f_; +}; + +} // namespace + +class ThriftRouterTestBase { +public: + ThriftRouterTestBase() + : transport_factory_([&]() -> MockTransport* { return transport_; }), + protocol_factory_([&]() -> MockProtocol* { return protocol_; }), + transport_register_(transport_factory_), protocol_register_(protocol_factory_) {} + + void initializeRouter() { + route_ = new NiceMock(); + route_ptr_.reset(route_); + + host_ = new NiceMock(); + host_ptr_.reset(host_); + + router_.reset(new Router(context_.clusterManager())); + + EXPECT_EQ(nullptr, router_->downstreamConnection()); + + router_->setDecoderFilterCallbacks(callbacks_); + } + + void startRequest(MessageType msg_type) { + msg_type_ = msg_type; + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportBegin({})); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + + EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, newConnection(_)) + .WillOnce( + Invoke([&](Tcp::ConnectionPool::Callbacks& cb) -> Tcp::ConnectionPool::Cancellable* { + conn_pool_callbacks_ = &cb; + return &handle_; + })); + + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, msg_type_, seq_id_)); + EXPECT_NE(nullptr, conn_pool_callbacks_); + + NiceMock connection; + EXPECT_CALL(callbacks_, connection()).WillRepeatedly(Return(&connection)); + EXPECT_EQ(&connection, router_->downstreamConnection()); + + // Not yet implemented: + EXPECT_EQ(absl::optional(), router_->computeHashKey()); + EXPECT_EQ(nullptr, router_->metadataMatchCriteria()); + EXPECT_EQ(nullptr, router_->downstreamHeaders()); + } + + void connectUpstream() { + EXPECT_CALL(conn_data_, addUpstreamCallbacks(_)) + .WillOnce(Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void { + upstream_callbacks_ = &cb; + })); + + EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed)); + transport_ = new NiceMock(); + ON_CALL(*transport_, type()).WillByDefault(Return(TransportType::Framed)); + + EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary)); + protocol_ = new NiceMock(); + ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary)); + EXPECT_CALL(*protocol_, writeMessageBegin(_, method_name_, msg_type_, seq_id_)); + + EXPECT_CALL(callbacks_, continueDecoding()); + conn_pool_callbacks_->onPoolReady(conn_data_, host_ptr_); + EXPECT_NE(nullptr, upstream_callbacks_); + } + + void sendTrivialStruct(FieldType field_type) { + EXPECT_CALL(*protocol_, writeStructBegin(_, "")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structBegin({})); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", field_type, 1)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldBegin({}, field_type, 1)); + + sendTrivialValue(field_type); + + EXPECT_CALL(*protocol_, writeFieldEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldEnd()); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", FieldType::Stop, 0)); + EXPECT_CALL(*protocol_, writeStructEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structEnd()); + } + + void sendTrivialValue(FieldType field_type) { + switch (field_type) { + case FieldType::Bool: + EXPECT_CALL(*protocol_, writeBool(_, true)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->boolValue(true)); + break; + case FieldType::Byte: + EXPECT_CALL(*protocol_, writeByte(_, 2)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->byteValue(2)); + break; + case FieldType::I16: + EXPECT_CALL(*protocol_, writeInt16(_, 3)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int16Value(3)); + break; + case FieldType::I32: + EXPECT_CALL(*protocol_, writeInt32(_, 4)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(4)); + break; + case FieldType::I64: + EXPECT_CALL(*protocol_, writeInt64(_, 5)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int64Value(5)); + break; + case FieldType::Double: + EXPECT_CALL(*protocol_, writeDouble(_, 6.0)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->doubleValue(6.0)); + break; + case FieldType::String: + EXPECT_CALL(*protocol_, writeString(_, "seven")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->stringValue("seven")); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + } + + void completeRequest() { + EXPECT_CALL(*protocol_, writeMessageEnd(_)); + EXPECT_CALL(*transport_, encodeFrame(_, _)); + EXPECT_CALL(conn_data_.connection_, write(_, false)); + + if (msg_type_ == MessageType::Oneway) { + EXPECT_CALL(conn_data_, release()); + } + + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->messageEnd()); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportEnd()); + } + + void returnResponse() { + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); + upstream_callbacks_->onUpstreamData(buffer, false); + + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(conn_data_, release()); + upstream_callbacks_->onUpstreamData(buffer, false); + } + + void destroyRouter() { + router_->onDestroy(); + router_.reset(); + } + + TestNamedTransportConfigFactory transport_factory_; + TestNamedProtocolConfigFactory protocol_factory_; + Registry::InjectFactory transport_register_; + Registry::InjectFactory protocol_register_; + + NiceMock context_; + NiceMock callbacks_; + NiceMock* transport_{}; + NiceMock* protocol_{}; + NiceMock* route_{}; + NiceMock route_entry_; + NiceMock* host_{}; + + RouteConstSharedPtr route_ptr_; + Upstream::HostDescriptionConstSharedPtr host_ptr_; + + std::unique_ptr router_; + + std::string cluster_name_{"cluster"}; + + std::string method_name_{"method"}; + MessageType msg_type_{MessageType::Call}; + int32_t seq_id_{1}; + + NiceMock handle_; + NiceMock conn_data_; + Tcp::ConnectionPool::Callbacks* conn_pool_callbacks_{}; + Tcp::ConnectionPool::UpstreamCallbacks* upstream_callbacks_{}; +}; + +class ThriftRouterTest : public ThriftRouterTestBase, public Test { +public: + ThriftRouterTest() {} +}; + +class ThriftRouterFieldTypeTest : public ThriftRouterTestBase, public TestWithParam { +public: + ThriftRouterFieldTypeTest() {} +}; + +INSTANTIATE_TEST_CASE_P(PrimitiveFieldTypes, ThriftRouterFieldTypeTest, + Values(FieldType::Bool, FieldType::Byte, FieldType::I16, FieldType::I32, + FieldType::I64, FieldType::Double, FieldType::String), + fieldTypeParamToString); + +class ThriftRouterContainerTest : public ThriftRouterTestBase, public TestWithParam { +public: + ThriftRouterContainerTest() {} +}; + +INSTANTIATE_TEST_CASE_P(ContainerFieldTypes, ThriftRouterContainerTest, + Values(FieldType::Map, FieldType::List, FieldType::Set), + fieldTypeParamToString); + +TEST_F(ThriftRouterTest, PoolRemoteConnectionFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + conn_pool_callbacks_->onPoolFailure( + Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure, host_ptr_); +} + +TEST_F(ThriftRouterTest, PoolLocalConnectionFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + conn_pool_callbacks_->onPoolFailure( + Tcp::ConnectionPool::PoolFailureReason::LocalConnectionFailure, host_ptr_); +} + +TEST_F(ThriftRouterTest, PoolTimeout) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*connection failure.*")); + })); + conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::Timeout, host_ptr_); +} + +TEST_F(ThriftRouterTest, PoolOverflowFailure) { + initializeRouter(); + + startRequest(MessageType::Call); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*too many connections.*")); + })); + conn_pool_callbacks_->onPoolFailure(Tcp::ConnectionPool::PoolFailureReason::Overflow, host_ptr_); +} + +TEST_F(ThriftRouterTest, NoRoute) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(nullptr)); + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + if (app_ex != nullptr) { + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::UnknownMethod, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no route.*")); + } + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, NoCluster) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(context_.cluster_manager_, get(cluster_name_)).WillOnce(Return(nullptr)); + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*unknown cluster.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, ClusterMaintenanceMode) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(*context_.cluster_manager_.thread_local_cluster_.cluster_.info_, maintenanceMode()) + .WillOnce(Return(true)); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*maintenance mode.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, NoHealthyHosts) { + initializeRouter(); + + EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_)); + EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_)); + EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_)); + EXPECT_CALL(context_.cluster_manager_, tcpConnPoolForCluster(cluster_name_, _, _)) + .WillOnce(Return(nullptr)); + + EXPECT_CALL(callbacks_, sendLocalReply_(_)) + .WillOnce(Invoke([&](ThriftFilters::DirectResponsePtr& response) -> void { + auto* app_ex = dynamic_cast(response.get()); + EXPECT_NE(nullptr, app_ex); + EXPECT_EQ(method_name_, app_ex->method_name_); + EXPECT_EQ(seq_id_, app_ex->seq_id_); + EXPECT_EQ(AppExceptionType::InternalError, app_ex->type_); + EXPECT_THAT(app_ex->error_message_, ContainsRegex(".*no healthy upstream.*")); + })); + EXPECT_EQ(ThriftFilters::FilterStatus::StopIteration, + router_->messageBegin(method_name_, MessageType::Call, seq_id_)); +} + +TEST_F(ThriftRouterTest, TruncatedResponse) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + completeRequest(); + + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))).WillOnce(Return(false)); + EXPECT_CALL(conn_data_, release()); + EXPECT_CALL(callbacks_, resetDownstreamConnection()); + + upstream_callbacks_->onUpstreamData(buffer, true); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(FieldType::String); + completeRequest(); + + Buffer::OwnedImpl buffer; + + EXPECT_CALL(callbacks_, startUpstreamResponse(TransportType::Framed, ProtocolType::Binary)); + EXPECT_CALL(callbacks_, upstreamData(Ref(buffer))) + .WillOnce(Invoke([&](Buffer::Instance&) -> bool { + router_->resetUpstreamConnection(); + return true; + })); + EXPECT_CALL(conn_data_.connection_, close(Network::ConnectionCloseType::NoFlush)); + + upstream_callbacks_->onUpstreamData(buffer, true); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UnexpectedRouterDestroyBeforeUpstreamConnect) { + initializeRouter(); + startRequest(MessageType::Call); + destroyRouter(); +} + +TEST_F(ThriftRouterTest, UnexpectedRouterDestroy) { + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + EXPECT_CALL(conn_data_.connection_, close(Network::ConnectionCloseType::NoFlush)); + destroyRouter(); +} + +TEST_P(ThriftRouterFieldTypeTest, OneWay) { + FieldType field_type = GetParam(); + + initializeRouter(); + startRequest(MessageType::Oneway); + connectUpstream(); + sendTrivialStruct(field_type); + completeRequest(); + destroyRouter(); +} + +TEST_P(ThriftRouterFieldTypeTest, Call) { + FieldType field_type = GetParam(); + + initializeRouter(); + startRequest(MessageType::Call); + connectUpstream(); + sendTrivialStruct(field_type); + completeRequest(); + returnResponse(); + destroyRouter(); +} + +TEST_P(ThriftRouterContainerTest, DecoderFilterCallbacks) { + FieldType field_type = GetParam(); + + initializeRouter(); + + startRequest(MessageType::Oneway); + connectUpstream(); + + EXPECT_CALL(*protocol_, writeStructBegin(_, "")); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structBegin({})); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, "", field_type, 1)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldBegin({}, field_type, 1)); + + switch (field_type) { + case FieldType::Map: + EXPECT_CALL(*protocol_, writeMapBegin(_, FieldType::I32, FieldType::I32, 2)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, + router_->mapBegin(FieldType::I32, FieldType::I32, 2)); + for (int i = 0; i < 2; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + EXPECT_CALL(*protocol_, writeInt32(_, i + 100)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i + 100)); + } + EXPECT_CALL(*protocol_, writeMapEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->mapEnd()); + break; + case FieldType::List: + EXPECT_CALL(*protocol_, writeListBegin(_, FieldType::I32, 3)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->listBegin(FieldType::I32, 3)); + for (int i = 0; i < 3; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + } + EXPECT_CALL(*protocol_, writeListEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->listEnd()); + break; + case FieldType::Set: + EXPECT_CALL(*protocol_, writeSetBegin(_, FieldType::I32, 4)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->setBegin(FieldType::I32, 4)); + for (int i = 0; i < 4; i++) { + EXPECT_CALL(*protocol_, writeInt32(_, i)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->int32Value(i)); + } + EXPECT_CALL(*protocol_, writeSetEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->setEnd()); + break; + default: + NOT_REACHED_GCOVR_EXCL_LINE; + } + + EXPECT_CALL(*protocol_, writeFieldEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->fieldEnd()); + + EXPECT_CALL(*protocol_, writeFieldBegin(_, _, FieldType::Stop, 0)); + EXPECT_CALL(*protocol_, writeStructEnd(_)); + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structEnd()); + + completeRequest(); + destroyRouter(); +} + +TEST(RouteMatcherTest, Route) { + const std::string yaml = R"EOF( +name: config +routes: + - match: + method: "method1" + route: + cluster: "cluster1" + - match: + method: "method2" + route: + cluster: "cluster2" +)EOF"; + + envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration config = + parseRouteConfigurationFromV2Yaml(yaml); + + RouteMatcher matcher(config); + EXPECT_EQ(nullptr, matcher.route("unknown")); + EXPECT_EQ(nullptr, matcher.route("METHOD1")); + + RouteConstSharedPtr route = matcher.route("method1"); + EXPECT_NE(nullptr, route); + EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); + + RouteConstSharedPtr route2 = matcher.route("method2"); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); +} + +TEST(RouteMatcherTest, RouteMatchAny) { + const std::string yaml = R"EOF( +name: config +routes: + - match: + method: "method1" + route: + cluster: "cluster1" + - match: {} + route: + cluster: "cluster2" +)EOF"; + + envoy::extensions::filters::network::thrift_proxy::v2alpha1::RouteConfiguration config = + parseRouteConfigurationFromV2Yaml(yaml); + + RouteMatcher matcher(config); + RouteConstSharedPtr route = matcher.route("method1"); + EXPECT_NE(nullptr, route); + EXPECT_EQ("cluster1", route->routeEntry()->clusterName()); + + RouteConstSharedPtr route2 = matcher.route("anything"); + EXPECT_NE(nullptr, route2); + EXPECT_EQ("cluster2", route2->routeEntry()->clusterName()); +} + +} // namespace Router +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc index ea6d4e1e1656f..64ab3cce319be 100644 --- a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc @@ -14,28 +14,35 @@ using testing::NiceMock; using testing::Ref; -using testing::StrictMock; namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +TEST(TransportNames, FromType) { + for (int i = 0; i <= static_cast(TransportType::LastTransportType); i++) { + TransportType type = static_cast(i); + EXPECT_NE("", TransportNames::get().fromType(type)); + } +} + TEST(AutoTransportTest, NotEnoughData) { Buffer::OwnedImpl buffer; - StrictMock cb; - AutoTransportImpl transport(cb); + AutoTransportImpl transport; + absl::optional size = 100; - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100), size); addRepeated(buffer, 7, 0); - EXPECT_FALSE(transport.decodeFrameStart(buffer)); + EXPECT_FALSE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(100), size); } TEST(AutoTransportTest, UnknownTransport) { - StrictMock cb; - AutoTransportImpl transport(cb); + AutoTransportImpl transport; // Looks like unframed, but fails protocol check. { @@ -43,8 +50,10 @@ TEST(AutoTransportTest, UnknownTransport) { addInt32(buffer, 0); addInt32(buffer, 0); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 100; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "unknown thrift auto transport frame start 00 00 00 00 00 00 00 00"); + EXPECT_EQ(absl::optional(100), size); } // Looks like framed, but fails protocol check. @@ -53,91 +62,95 @@ TEST(AutoTransportTest, UnknownTransport) { addInt32(buffer, 0xFF); addInt32(buffer, 0); - EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer), EnvoyException, + absl::optional size = 100; + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, size), EnvoyException, "unknown thrift auto transport frame start 00 00 00 ff 00 00 00 00"); + EXPECT_EQ(absl::optional(100), size); } } TEST(AutoTransportTest, DecodeFrameStart) { - StrictMock cb; - // Framed transport + binary protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xFF); addInt16(buffer, 0x8001); addInt16(buffer, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional(255U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(255), size); EXPECT_EQ(transport.name(), "framed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Framed); EXPECT_EQ(buffer.length(), 4); } // Framed transport + compact protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xFFF); addInt16(buffer, 0x8201); addInt16(buffer, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional(4095U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_EQ(absl::optional(4095), size); EXPECT_EQ(transport.name(), "framed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Framed); EXPECT_EQ(buffer.length(), 4); } // Unframed transport + binary protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt16(buffer, 0x8001); addRepeated(buffer, 6, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional())); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); EXPECT_EQ(transport.name(), "unframed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Unframed); EXPECT_EQ(buffer.length(), 8); } // Unframed transport + compact protocol { - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt16(buffer, 0x8201); addRepeated(buffer, 6, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional())); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); EXPECT_EQ(transport.name(), "unframed(auto)"); + EXPECT_EQ(transport.type(), TransportType::Unframed); EXPECT_EQ(buffer.length(), 8); } } TEST(AutoTransportTest, DecodeFrameEnd) { - StrictMock cb; - - AutoTransportImpl transport(cb); + AutoTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xFF); addInt16(buffer, 0x8001); addInt16(buffer, 0); - EXPECT_CALL(cb, transportFrameStart(absl::optional(255U))); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + absl::optional size; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); EXPECT_EQ(buffer.length(), 4); - EXPECT_CALL(cb, transportFrameComplete()); EXPECT_TRUE(transport.decodeFrameEnd(buffer)); } TEST(AutoTransportTest, EncodeFrame) { - StrictMock cb; MockTransport* mock_transport = new NiceMock(); - AutoTransportImpl transport(cb); + AutoTransportImpl transport; transport.setTransport(TransportPtr{mock_transport}); Buffer::OwnedImpl buffer; @@ -148,11 +161,15 @@ TEST(AutoTransportTest, EncodeFrame) { } TEST(AutoTransportTest, Name) { - StrictMock cb; - AutoTransportImpl transport(cb); + AutoTransportImpl transport; EXPECT_EQ(transport.name(), "auto"); } +TEST(AutoTransportTest, Type) { + AutoTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Auto); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc index f26a75219ab61..f83119ffaf383 100644 --- a/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/unframed_transport_impl_test.cc @@ -2,55 +2,49 @@ #include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" -#include "test/extensions/filters/network/thrift_proxy/mocks.h" #include "test/extensions/filters/network/thrift_proxy/utility.h" #include "test/test_common/printers.h" #include "test/test_common/utility.h" -#include "gmock/gmock.h" #include "gtest/gtest.h" -using testing::StrictMock; - namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { TEST(UnframedTransportTest, Name) { - StrictMock cb; - UnframedTransportImpl transport(cb); + UnframedTransportImpl transport; EXPECT_EQ(transport.name(), "unframed"); } -TEST(UnframedTransportTest, DecodeFrameStart) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameStart(absl::optional())); +TEST(UnframedTransportTest, Type) { + UnframedTransportImpl transport; + EXPECT_EQ(transport.type(), TransportType::Unframed); +} - UnframedTransportImpl transport(cb); +TEST(UnframedTransportTest, DecodeFrameStart) { + UnframedTransportImpl transport; Buffer::OwnedImpl buffer; addInt32(buffer, 0xDEADBEEF); - EXPECT_EQ(buffer.length(), 4); - EXPECT_TRUE(transport.decodeFrameStart(buffer)); + + absl::optional size = 1; + EXPECT_TRUE(transport.decodeFrameStart(buffer, size)); + EXPECT_FALSE(size.has_value()); EXPECT_EQ(buffer.length(), 4); } TEST(UnframedTransportTest, DecodeFrameEnd) { - StrictMock cb; - EXPECT_CALL(cb, transportFrameComplete()); - - UnframedTransportImpl transport(cb); + UnframedTransportImpl transport; Buffer::OwnedImpl buffer; EXPECT_TRUE(transport.decodeFrameEnd(buffer)); } TEST(UnframedTransportTest, EncodeFrame) { - StrictMock cb; - - UnframedTransportImpl transport(cb); + UnframedTransportImpl transport; Buffer::OwnedImpl message; message.add("fake message"); diff --git a/test/mocks/tcp/mocks.cc b/test/mocks/tcp/mocks.cc index 1713d04e61285..757f7556f32b1 100644 --- a/test/mocks/tcp/mocks.cc +++ b/test/mocks/tcp/mocks.cc @@ -1,7 +1,8 @@ #include "mocks.h" #include "gmock/gmock.h" -#include "gtest/gtest.h" + +using testing::ReturnRef; using testing::Invoke; using testing::ReturnRef;